nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,862 @@
1
+ # modified
2
+ # Copyright © 2023 Apple Inc.
3
+
4
+ import base64
5
+ import gzip
6
+ import json
7
+ import math
8
+ import sys
9
+ import warnings
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+ from typing import List, Optional, Tuple, Union
13
+
14
+ import mlx.core as mx
15
+ import mlx.nn as nn
16
+ import numpy as np
17
+ import tqdm
18
+ from mlx.utils import tree_unflatten
19
+
20
+ from .audio import (
21
+ FRAMES_PER_SECOND,
22
+ HOP_LENGTH,
23
+ N_FRAMES,
24
+ N_SAMPLES,
25
+ SAMPLE_RATE,
26
+ log_mel_spectrogram,
27
+ pad_or_trim,
28
+ )
29
+ from .decoding import DecodingOptions, DecodingResult
30
+ from .decoding import decode as decode_function
31
+ from .decoding import detect_language as detect_language_function
32
+ from .timing import add_word_timestamps
33
+ from .tokenizer import LANGUAGES, get_tokenizer
34
+
35
+
36
+ def _format_timestamp(seconds: float):
37
+ assert seconds >= 0, "non-negative timestamp expected"
38
+ milliseconds = round(seconds * 1000.0)
39
+
40
+ hours = milliseconds // 3_600_000
41
+ milliseconds -= hours * 3_600_000
42
+
43
+ minutes = milliseconds // 60_000
44
+ milliseconds -= minutes * 60_000
45
+
46
+ seconds = milliseconds // 1_000
47
+ milliseconds -= seconds * 1_000
48
+
49
+ hours_marker = f"{hours:02d}:" if hours > 0 else ""
50
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
51
+
52
+
53
+ def _get_end(segments: List[dict]) -> Optional[float]:
54
+ return next(
55
+ (w["end"] for s in reversed(segments) for w in reversed(s["words"])),
56
+ segments[-1]["end"] if segments else None,
57
+ )
58
+
59
+
60
+ @dataclass
61
+ class STTOutput:
62
+ text: str
63
+ segments: List[dict] = None
64
+ language: str = None
65
+
66
+
67
+ @dataclass
68
+ class ModelDimensions:
69
+ n_mels: int
70
+ n_audio_ctx: int
71
+ n_audio_state: int
72
+ n_audio_head: int
73
+ n_audio_layer: int
74
+ n_vocab: int
75
+ n_text_ctx: int
76
+ n_text_state: int
77
+ n_text_head: int
78
+ n_text_layer: int
79
+
80
+
81
+ def sinusoids(length, channels, max_timescale=10000):
82
+ """Returns sinusoids for positional embedding"""
83
+ assert channels % 2 == 0
84
+ log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
85
+ inv_timescales = mx.exp(-log_timescale_increment * mx.arange(channels // 2))
86
+ scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :]
87
+ return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
88
+
89
+
90
+ class MultiHeadAttention(nn.Module):
91
+ def __init__(self, n_state: int, n_head: int):
92
+ super().__init__()
93
+ self.n_head = n_head
94
+ self.query = nn.Linear(n_state, n_state)
95
+ self.key = nn.Linear(n_state, n_state, bias=False)
96
+ self.value = nn.Linear(n_state, n_state)
97
+ self.out = nn.Linear(n_state, n_state)
98
+
99
+ def __call__(
100
+ self,
101
+ x,
102
+ xa=None,
103
+ mask=None,
104
+ kv_cache=None,
105
+ ):
106
+ q = self.query(x)
107
+
108
+ if xa is None:
109
+ k = self.key(x)
110
+ v = self.value(x)
111
+ if kv_cache is not None:
112
+ k = mx.concatenate([kv_cache[0], k], axis=1)
113
+ v = mx.concatenate([kv_cache[1], v], axis=1)
114
+ elif kv_cache is None:
115
+ k = self.key(xa)
116
+ v = self.value(xa)
117
+ else:
118
+ k, v = kv_cache
119
+
120
+ wv, qk = self.qkv_attention(q, k, v, mask)
121
+ return self.out(wv), (k, v), qk
122
+
123
+ def qkv_attention(self, q, k, v, mask=None):
124
+ n_batch, n_ctx, n_state = q.shape
125
+ scale = (n_state // self.n_head) ** -0.25
126
+ q = q.reshape(*q.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3) * scale
127
+ k = k.reshape(*k.shape[:2], self.n_head, -1).transpose(0, 2, 3, 1) * scale
128
+ v = v.reshape(*v.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3)
129
+
130
+ qk = q @ k
131
+ if mask is not None:
132
+ qk = qk + mask[:n_ctx, :n_ctx]
133
+
134
+ w = mx.softmax(qk, axis=-1, precise=True)
135
+ out = (w @ v).transpose(0, 2, 1, 3)
136
+ out = out.reshape(n_batch, n_ctx, n_state)
137
+ return out, qk
138
+
139
+
140
+ class ResidualAttentionBlock(nn.Module):
141
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
142
+ super().__init__()
143
+
144
+ self.attn = MultiHeadAttention(n_state, n_head)
145
+ self.attn_ln = nn.LayerNorm(n_state)
146
+
147
+ self.cross_attn = (
148
+ MultiHeadAttention(n_state, n_head) if cross_attention else None
149
+ )
150
+ self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
151
+
152
+ n_mlp = n_state * 4
153
+ self.mlp1 = nn.Linear(n_state, n_mlp)
154
+ self.mlp2 = nn.Linear(n_mlp, n_state)
155
+ self.mlp_ln = nn.LayerNorm(n_state)
156
+
157
+ def __call__(self, x, xa=None, mask=None, kv_cache=None):
158
+ kv, cross_kv = kv_cache if kv_cache else (None, None)
159
+ y, kv, _ = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv)
160
+ x += y
161
+ cross_qk = None
162
+ if self.cross_attn:
163
+ y, cross_kv, cross_qk = self.cross_attn(
164
+ self.cross_attn_ln(x), xa, kv_cache=cross_kv
165
+ )
166
+ x += y
167
+ x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))))
168
+ return x, (kv, cross_kv), cross_qk
169
+
170
+
171
+ class AudioEncoder(nn.Module):
172
+ def __init__(
173
+ self,
174
+ n_mels: int,
175
+ n_ctx: int,
176
+ n_state: int,
177
+ n_head: int,
178
+ n_layer: int,
179
+ dtype: mx.Dtype = mx.float16,
180
+ ):
181
+ super().__init__()
182
+ self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
183
+ self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
184
+ self._positional_embedding = sinusoids(n_ctx, n_state).astype(dtype)
185
+
186
+ self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
187
+ self.ln_post = nn.LayerNorm(n_state)
188
+
189
+ def __call__(self, x):
190
+ x = nn.gelu(self.conv1(x))
191
+ x = nn.gelu(self.conv2(x))
192
+ assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
193
+ x = x + self._positional_embedding
194
+
195
+ for block in self.blocks:
196
+ x, _, _ = block(x)
197
+
198
+ x = self.ln_post(x)
199
+ return x
200
+
201
+
202
+ class TextDecoder(nn.Module):
203
+ def __init__(
204
+ self,
205
+ n_vocab: int,
206
+ n_ctx: int,
207
+ n_state: int,
208
+ n_head: int,
209
+ n_layer: int,
210
+ dtype: mx.Dtype = mx.float16,
211
+ ):
212
+ super().__init__()
213
+
214
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
215
+ self.positional_embedding = mx.zeros((n_ctx, n_state))
216
+
217
+ self.blocks = [
218
+ ResidualAttentionBlock(n_state, n_head, cross_attention=True)
219
+ for _ in range(n_layer)
220
+ ]
221
+ self.ln = nn.LayerNorm(n_state)
222
+ self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(
223
+ dtype
224
+ )
225
+
226
+ def __call__(self, x, xa, kv_cache=None):
227
+ """
228
+ x : mx.array, shape = (batch_size, <= n_ctx)
229
+ the text tokens
230
+ xa : mx.array, shape = (batch_size, n_audio_ctx, n_audio_state)
231
+ the encoded audio features to be attended on
232
+ """
233
+ offset = kv_cache[0][0][0].shape[1] if kv_cache else 0
234
+ x = (
235
+ self.token_embedding(x)
236
+ + self.positional_embedding[offset : offset + x.shape[-1]]
237
+ )
238
+
239
+ if kv_cache is None:
240
+ kv_cache = [None] * len(self.blocks)
241
+ cross_qk = [None] * len(self.blocks)
242
+ for e, block in enumerate(self.blocks):
243
+ x, kv_cache[e], cross_qk[e] = block(
244
+ x, xa, mask=self._mask, kv_cache=kv_cache[e]
245
+ )
246
+
247
+ x = self.ln(x)
248
+ return self.token_embedding.as_linear(x), kv_cache, cross_qk
249
+
250
+
251
+ class Model(nn.Module):
252
+ def __init__(self, dims: ModelDimensions, dtype: mx.Dtype = mx.float16):
253
+ super().__init__()
254
+ self.dims = dims
255
+ self.dtype = dtype
256
+ self.encoder = AudioEncoder(
257
+ self.dims.n_mels,
258
+ self.dims.n_audio_ctx,
259
+ self.dims.n_audio_state,
260
+ self.dims.n_audio_head,
261
+ self.dims.n_audio_layer,
262
+ dtype,
263
+ )
264
+ self.decoder = TextDecoder(
265
+ self.dims.n_vocab,
266
+ self.dims.n_text_ctx,
267
+ self.dims.n_text_state,
268
+ self.dims.n_text_head,
269
+ self.dims.n_text_layer,
270
+ dtype,
271
+ )
272
+ # use the last half among the decoder layers for time alignment by default;
273
+ # to use a specific set of heads, see `set_alignment_heads()` below.
274
+ all_heads = np.zeros(
275
+ (self.dims.n_text_layer, self.dims.n_text_head), dtype=bool
276
+ )
277
+ all_heads[self.dims.n_text_layer // 2 :] = True
278
+ self.alignment_heads = mx.array(np.asarray(all_heads.nonzero()).T)
279
+
280
+ def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
281
+ if isinstance(dump, np.ndarray):
282
+ self.alignment_heads = mx.array(dump)
283
+ elif isinstance(dump, bytes):
284
+ array = np.frombuffer(
285
+ gzip.decompress(base64.b85decode(dump)), dtype=bool
286
+ ).copy()
287
+ mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head)
288
+ self.alignment_heads = mx.array(np.asarray(mask.nonzero()).T)
289
+ else:
290
+ raise ValueError(
291
+ f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing"
292
+ " alignment_head information"
293
+ )
294
+
295
+ def embed_audio(self, mel):
296
+ return self.encoder(mel)
297
+
298
+ def logits(self, tokens, audio_features):
299
+ return self.decoder(tokens, audio_features)[0]
300
+
301
+ def forward_with_cross_qk(self, mel, tokens):
302
+ logits, _, cross_qk = self.decoder(tokens, self.encoder(mel))
303
+ return logits, cross_qk
304
+
305
+ def __call__(self, mel, tokens):
306
+ return self.decoder(tokens, self.encoder(mel))[0]
307
+
308
+ @property
309
+ def is_multilingual(self):
310
+ return self.dims.n_vocab >= 51865
311
+
312
+ @property
313
+ def num_languages(self):
314
+ return self.dims.n_vocab - 51765 - int(self.is_multilingual)
315
+
316
+ detect_language = detect_language_function
317
+ decode = decode_function
318
+
319
+ @classmethod
320
+ def from_pretrained(
321
+ cls,
322
+ model_path: str,
323
+ dtype: mx.Dtype = mx.float16,
324
+ ) -> "Whisper":
325
+ model_path = Path(model_path)
326
+ if not model_path.exists():
327
+ raise FileNotFoundError(f"Model directory not found: {model_path}")
328
+
329
+ config_path = model_path / "config.json"
330
+ if not config_path.exists():
331
+ raise FileNotFoundError(f"config.json not found in {model_path}")
332
+
333
+ with open(str(config_path), "r") as f:
334
+ config = json.loads(f.read())
335
+ config.pop("model_type", None)
336
+ quantization = config.pop("quantization", None)
337
+
338
+ model_args = ModelDimensions(**config)
339
+
340
+ wf = model_path / "weights.safetensors"
341
+ if not wf.exists():
342
+ wf = model_path / "weights.npz"
343
+
344
+ if not wf.exists():
345
+ raise FileNotFoundError(f"Neither weights.safetensors nor weights.npz found in {model_path}")
346
+
347
+ weights = mx.load(str(wf))
348
+
349
+ model = Model(model_args, dtype)
350
+
351
+ if quantization is not None:
352
+ class_predicate = (
353
+ lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
354
+ and f"{p}.scales" in weights
355
+ )
356
+ nn.quantize(model, **quantization, class_predicate=class_predicate)
357
+
358
+ weights = tree_unflatten(list(weights.items()))
359
+ model.update(weights)
360
+ mx.eval(model.parameters())
361
+ return model
362
+
363
+ def generate(
364
+ self,
365
+ audio: Union[str, np.ndarray, mx.array],
366
+ *,
367
+ verbose: Optional[bool] = None,
368
+ temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
369
+ compression_ratio_threshold: Optional[float] = 2.4,
370
+ logprob_threshold: Optional[float] = -1.0,
371
+ no_speech_threshold: Optional[float] = 0.6,
372
+ condition_on_previous_text: bool = True,
373
+ initial_prompt: Optional[str] = None,
374
+ word_timestamps: bool = False,
375
+ prepend_punctuations: str = "\"'“¿([{-",
376
+ append_punctuations: str = "\"'.。,,!!??::”)]}、",
377
+ clip_timestamps: Union[str, List[float]] = "0",
378
+ hallucination_silence_threshold: Optional[float] = None,
379
+ **decode_options,
380
+ ):
381
+ """
382
+ Transcribe an audio file using Whisper
383
+
384
+ Parameters
385
+ ----------
386
+ audio: Union[str, np.ndarray, mx.array]
387
+ The path to the audio file to open, or the audio waveform
388
+
389
+ verbose: bool
390
+ Whether to display the text being decoded to the console. If True, displays all the details,
391
+ If False, displays minimal details. If None, does not display anything
392
+
393
+ temperature: Union[float, Tuple[float, ...]]
394
+ Temperature for sampling. It can be a tuple of temperatures, which will be successively used
395
+ upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
396
+
397
+ compression_ratio_threshold: float
398
+ If the gzip compression ratio is above this value, treat as failed
399
+
400
+ logprob_threshold: float
401
+ If the average log probability over sampled tokens is below this value, treat as failed
402
+
403
+ no_speech_threshold: float
404
+ If the no_speech probability is higher than this value AND the average log probability
405
+ over sampled tokens is below `logprob_threshold`, consider the segment as silent
406
+
407
+ condition_on_previous_text: bool
408
+ if True, the previous output of the model is provided as a prompt for the next window;
409
+ disabling may make the text inconsistent across windows, but the model becomes less prone to
410
+ getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
411
+
412
+ word_timestamps: bool
413
+ Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
414
+ and include the timestamps for each word in each segment.
415
+
416
+ prepend_punctuations: str
417
+ If word_timestamps is True, merge these punctuation symbols with the next word
418
+
419
+ append_punctuations: str
420
+ If word_timestamps is True, merge these punctuation symbols with the previous word
421
+
422
+ initial_prompt: Optional[str]
423
+ Optional text to provide as a prompt for the first window. This can be used to provide, or
424
+ "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
425
+ to make it more likely to predict those word correctly.
426
+
427
+ decode_options: dict
428
+ Keyword arguments to construct `DecodingOptions` instances
429
+
430
+ clip_timestamps: Union[str, List[float]]
431
+ Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
432
+ The last end timestamp defaults to the end of the file.
433
+
434
+ hallucination_silence_threshold: Optional[float]
435
+ When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
436
+ when a possible hallucination is detected
437
+
438
+ Returns
439
+ -------
440
+ A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
441
+ the spoken language ("language"), which is detected when `decode_options["language"]` is None.
442
+ """
443
+
444
+ # Pad 30-seconds of silence to the input audio, for slicing
445
+ mel = log_mel_spectrogram(audio, n_mels=self.dims.n_mels, padding=N_SAMPLES)
446
+ content_frames = mel.shape[-2] - N_FRAMES
447
+ content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
448
+
449
+ if verbose:
450
+ system_encoding = sys.getdefaultencoding()
451
+ if system_encoding != "utf-8":
452
+ make_safe = lambda x: x.encode(
453
+ system_encoding, errors="replace"
454
+ ).decode(system_encoding)
455
+ else:
456
+ make_safe = lambda x: x
457
+
458
+ if decode_options.get("language", None) is None:
459
+ if not self.is_multilingual:
460
+ decode_options["language"] = "en"
461
+ else:
462
+ if verbose:
463
+ print(
464
+ "Detecting language using up to the first 30 seconds. "
465
+ "Use the `language` decoding option to specify the language"
466
+ )
467
+ mel_segment = pad_or_trim(mel, N_FRAMES, axis=-2).astype(self.dtype)
468
+ _, probs = self.detect_language(mel_segment)
469
+ decode_options["language"] = max(probs, key=probs.get)
470
+ if verbose is not None:
471
+ print(
472
+ f"Detected language: {LANGUAGES[decode_options['language']].title()}"
473
+ )
474
+
475
+ language: str = decode_options["language"]
476
+ task: str = decode_options.get("task", "transcribe")
477
+ tokenizer = get_tokenizer(
478
+ self.is_multilingual,
479
+ num_languages=self.num_languages,
480
+ language=language,
481
+ task=task,
482
+ )
483
+
484
+ if isinstance(clip_timestamps, str):
485
+ clip_timestamps = [
486
+ float(ts)
487
+ for ts in (clip_timestamps.split(",") if clip_timestamps else [])
488
+ ]
489
+ seek_points: List[int] = [
490
+ round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps
491
+ ]
492
+ if len(seek_points) == 0:
493
+ seek_points.append(0)
494
+ if len(seek_points) % 2 == 1:
495
+ seek_points.append(content_frames)
496
+ else:
497
+ seek_points[-1] = min(content_frames, seek_points[-1])
498
+ seek_clips: List[Tuple[int, int]] = list(
499
+ zip(seek_points[::2], seek_points[1::2])
500
+ )
501
+
502
+ punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
503
+
504
+ if word_timestamps and task == "translate":
505
+ warnings.warn("Word-level timestamps on translations may not be reliable.")
506
+
507
+ def decode_with_fallback(segment: mx.array) -> DecodingResult:
508
+ temperatures = (
509
+ [temperature] if isinstance(temperature, (int, float)) else temperature
510
+ )
511
+ decode_result = None
512
+
513
+ for t in temperatures:
514
+ kwargs = {**decode_options}
515
+ if t > 0:
516
+ # disable beam_size and patience when t > 0
517
+ kwargs.pop("beam_size", None)
518
+ kwargs.pop("patience", None)
519
+ else:
520
+ # disable best_of when t == 0
521
+ kwargs.pop("best_of", None)
522
+
523
+ options = DecodingOptions(**kwargs, temperature=t)
524
+ decode_result = self.decode(segment, options)
525
+
526
+ needs_fallback = False
527
+ if (
528
+ compression_ratio_threshold is not None
529
+ and decode_result.compression_ratio > compression_ratio_threshold
530
+ ):
531
+ needs_fallback = True # too repetitive
532
+ if (
533
+ logprob_threshold is not None
534
+ and decode_result.avg_logprob < logprob_threshold
535
+ ):
536
+ needs_fallback = True # average log probability is too low
537
+ if (
538
+ no_speech_threshold is not None
539
+ and decode_result.no_speech_prob > no_speech_threshold
540
+ ):
541
+ needs_fallback = False # silence
542
+ if not needs_fallback:
543
+ break
544
+
545
+ return decode_result
546
+
547
+ clip_idx = 0
548
+ seek = seek_clips[clip_idx][0]
549
+ input_stride = (
550
+ N_FRAMES // self.dims.n_audio_ctx
551
+ ) # mel frames per output token: 2
552
+ time_precision = (
553
+ input_stride * HOP_LENGTH / SAMPLE_RATE
554
+ ) # time per output token: 0.02 (seconds)
555
+ all_tokens = []
556
+ all_segments = []
557
+ prompt_reset_since = 0
558
+
559
+ if initial_prompt is not None:
560
+ initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
561
+ all_tokens.extend(initial_prompt_tokens)
562
+ else:
563
+ initial_prompt_tokens = []
564
+
565
+ def new_segment(
566
+ *, start: float, end: float, tokens: mx.array, result: DecodingResult
567
+ ):
568
+ tokens = tokens.tolist()
569
+ text_tokens = [token for token in tokens if token < tokenizer.eot]
570
+ return {
571
+ "seek": seek,
572
+ "start": start,
573
+ "end": end,
574
+ "text": tokenizer.decode(text_tokens),
575
+ "tokens": tokens,
576
+ "temperature": result.temperature,
577
+ "avg_logprob": result.avg_logprob,
578
+ "compression_ratio": result.compression_ratio,
579
+ "no_speech_prob": result.no_speech_prob,
580
+ }
581
+
582
+ # show the progress bar when verbose is False (if True, transcribed text will be printed)
583
+ with tqdm.tqdm(
584
+ total=content_frames, unit="frames", disable=verbose is not False
585
+ ) as pbar:
586
+ last_speech_timestamp = 0.0
587
+ for seek_clip_start, seek_clip_end in seek_clips:
588
+ while seek < seek_clip_end:
589
+ time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
590
+ window_end_time = float(
591
+ (seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE
592
+ )
593
+ segment_size = min(
594
+ N_FRAMES, content_frames - seek, seek_clip_end - seek
595
+ )
596
+ mel_segment = mel[seek : seek + segment_size]
597
+ segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
598
+ mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(
599
+ self.dtype
600
+ )
601
+
602
+ decode_options["prompt"] = all_tokens[prompt_reset_since:]
603
+ result: DecodingResult = decode_with_fallback(mel_segment)
604
+
605
+ tokens = np.array(result.tokens)
606
+
607
+ if no_speech_threshold is not None:
608
+ # no voice activity check
609
+ should_skip = result.no_speech_prob > no_speech_threshold
610
+ if (
611
+ logprob_threshold is not None
612
+ and result.avg_logprob > logprob_threshold
613
+ ):
614
+ # don't skip if the logprob is high enough, despite the no_speech_prob
615
+ should_skip = False
616
+
617
+ if should_skip:
618
+ seek += segment_size # fast-forward to the next segment boundary
619
+ continue
620
+
621
+ previous_seek = seek
622
+ current_segments = []
623
+
624
+ # anomalous words are very long/short/improbable
625
+ def word_anomaly_score(word: dict) -> float:
626
+ probability = word.get("probability", 0.0)
627
+ duration = word["end"] - word["start"]
628
+ score = 0.0
629
+ if probability < 0.15:
630
+ score += 1.0
631
+ if duration < 0.133:
632
+ score += (0.133 - duration) * 15
633
+ if duration > 2.0:
634
+ score += duration - 2.0
635
+ return score
636
+
637
+ def is_segment_anomaly(segment: Optional[dict]) -> bool:
638
+ if segment is None or not segment["words"]:
639
+ return False
640
+ words = [
641
+ w for w in segment["words"] if w["word"] not in punctuation
642
+ ]
643
+ words = words[:8]
644
+ score = sum(word_anomaly_score(w) for w in words)
645
+ return score >= 3 or score + 0.01 >= len(words)
646
+
647
+ def next_words_segment(segments: List[dict]) -> Optional[dict]:
648
+ return next((s for s in segments if s["words"]), None)
649
+
650
+ timestamp_tokens = tokens >= tokenizer.timestamp_begin
651
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [
652
+ False,
653
+ True,
654
+ ]
655
+
656
+ consecutive = np.where(
657
+ np.logical_and(timestamp_tokens[:-1], timestamp_tokens[1:])
658
+ )[0]
659
+ consecutive += 1
660
+ if len(consecutive) > 0:
661
+ # if the output contains two consecutive timestamp tokens
662
+ slices = consecutive.tolist()
663
+ if single_timestamp_ending:
664
+ slices.append(len(tokens))
665
+
666
+ last_slice = 0
667
+ for current_slice in slices:
668
+ sliced_tokens = tokens[last_slice:current_slice]
669
+ start_timestamp_pos = (
670
+ sliced_tokens[0].item() - tokenizer.timestamp_begin
671
+ )
672
+ end_timestamp_pos = (
673
+ sliced_tokens[-1].item() - tokenizer.timestamp_begin
674
+ )
675
+ current_segments.append(
676
+ new_segment(
677
+ start=time_offset
678
+ + start_timestamp_pos * time_precision,
679
+ end=time_offset
680
+ + end_timestamp_pos * time_precision,
681
+ tokens=sliced_tokens,
682
+ result=result,
683
+ )
684
+ )
685
+ last_slice = current_slice
686
+
687
+ if single_timestamp_ending:
688
+ # single timestamp at the end means no speech after the last timestamp.
689
+ seek += segment_size
690
+ else:
691
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
692
+ last_timestamp_pos = (
693
+ tokens[last_slice - 1].item()
694
+ - tokenizer.timestamp_begin
695
+ )
696
+ seek += last_timestamp_pos * input_stride
697
+ else:
698
+ duration = segment_duration
699
+ timestamps = tokens[timestamp_tokens.nonzero()[0]]
700
+ if (
701
+ len(timestamps) > 0
702
+ and timestamps[-1].item() != tokenizer.timestamp_begin
703
+ ):
704
+ # no consecutive timestamps but it has a timestamp; use the last one.
705
+ last_timestamp_pos = (
706
+ timestamps[-1].item() - tokenizer.timestamp_begin
707
+ )
708
+ duration = last_timestamp_pos * time_precision
709
+
710
+ current_segments.append(
711
+ new_segment(
712
+ start=time_offset,
713
+ end=time_offset + duration,
714
+ tokens=tokens,
715
+ result=result,
716
+ )
717
+ )
718
+ seek += segment_size
719
+
720
+ if word_timestamps:
721
+ add_word_timestamps(
722
+ segments=current_segments,
723
+ model=self,
724
+ tokenizer=tokenizer,
725
+ mel=mel_segment,
726
+ num_frames=segment_size,
727
+ prepend_punctuations=prepend_punctuations,
728
+ append_punctuations=append_punctuations,
729
+ last_speech_timestamp=last_speech_timestamp,
730
+ )
731
+
732
+ if not single_timestamp_ending:
733
+ last_word_end = _get_end(current_segments)
734
+ if (
735
+ last_word_end is not None
736
+ and last_word_end > time_offset
737
+ ):
738
+ seek = round(last_word_end * FRAMES_PER_SECOND)
739
+
740
+ # skip silence before possible hallucinations
741
+ if hallucination_silence_threshold is not None:
742
+ threshold = hallucination_silence_threshold
743
+ if not single_timestamp_ending:
744
+ last_word_end = _get_end(current_segments)
745
+ if (
746
+ last_word_end is not None
747
+ and last_word_end > time_offset
748
+ ):
749
+ remaining_duration = window_end_time - last_word_end
750
+ if remaining_duration > threshold:
751
+ seek = round(last_word_end * FRAMES_PER_SECOND)
752
+ else:
753
+ seek = previous_seek + segment_size
754
+
755
+ # if first segment might be a hallucination, skip leading silence
756
+ first_segment = next_words_segment(current_segments)
757
+ if first_segment is not None and is_segment_anomaly(
758
+ first_segment
759
+ ):
760
+ gap = first_segment["start"] - time_offset
761
+ if gap > threshold:
762
+ seek = previous_seek + round(
763
+ gap * FRAMES_PER_SECOND
764
+ )
765
+ continue
766
+
767
+ # skip silence before any possible hallucination that is surrounded
768
+ # by silence or more hallucinations
769
+ hal_last_end = last_speech_timestamp
770
+ for si in range(len(current_segments)):
771
+ segment = current_segments[si]
772
+ if not segment["words"]:
773
+ continue
774
+ if is_segment_anomaly(segment):
775
+ next_segment = next_words_segment(
776
+ current_segments[si + 1 :]
777
+ )
778
+ if next_segment is not None:
779
+ hal_next_start = next_segment["words"][0][
780
+ "start"
781
+ ]
782
+ else:
783
+ hal_next_start = time_offset + segment_duration
784
+ silence_before = (
785
+ segment["start"] - hal_last_end > threshold
786
+ or segment["start"] < threshold
787
+ or segment["start"] - time_offset < 2.0
788
+ )
789
+ silence_after = (
790
+ hal_next_start - segment["end"] > threshold
791
+ or is_segment_anomaly(next_segment)
792
+ or window_end_time - segment["end"] < 2.0
793
+ )
794
+ if silence_before and silence_after:
795
+ seek = round(
796
+ max(time_offset + 1, segment["start"])
797
+ * FRAMES_PER_SECOND
798
+ )
799
+ if (
800
+ content_duration - segment["end"]
801
+ < threshold
802
+ ):
803
+ seek = content_frames
804
+ current_segments[si:] = []
805
+ break
806
+ hal_last_end = segment["end"]
807
+
808
+ last_word_end = _get_end(current_segments)
809
+ if last_word_end is not None:
810
+ last_speech_timestamp = last_word_end
811
+
812
+ if verbose:
813
+ for segment in current_segments:
814
+ start, end, text = (
815
+ segment["start"],
816
+ segment["end"],
817
+ segment["text"],
818
+ )
819
+ line = f"[{_format_timestamp(start)} --> {_format_timestamp(end)}] {text}"
820
+ print(make_safe(line))
821
+
822
+ # if a segment is instantaneous or does not contain text, clear it
823
+ for i, segment in enumerate(current_segments):
824
+ if (
825
+ segment["start"] == segment["end"]
826
+ or segment["text"].strip() == ""
827
+ ):
828
+ segment["text"] = ""
829
+ segment["tokens"] = []
830
+ segment["words"] = []
831
+
832
+ all_segments.extend(
833
+ [
834
+ {"id": i, **segment}
835
+ for i, segment in enumerate(
836
+ current_segments, start=len(all_segments)
837
+ )
838
+ ]
839
+ )
840
+ all_tokens.extend(
841
+ [
842
+ token
843
+ for segment in current_segments
844
+ for token in segment["tokens"]
845
+ ]
846
+ )
847
+
848
+ if not condition_on_previous_text or result.temperature > 0.5:
849
+ # do not feed the prompt tokens if a high temperature was used
850
+ prompt_reset_since = len(all_tokens)
851
+
852
+ # update progress bar
853
+ pbar.update(min(content_frames, seek) - previous_seek)
854
+
855
+ # Clear cache after each segment to avoid memory leaks
856
+ mx.clear_cache()
857
+
858
+ return STTOutput(
859
+ text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
860
+ segments=all_segments,
861
+ language=language,
862
+ )