nexaai 1.0.4rc10__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (519) hide show
  1. nexaai/__init__.py +71 -0
  2. nexaai/_version.py +4 -0
  3. nexaai/asr.py +60 -0
  4. nexaai/asr_impl/__init__.py +0 -0
  5. nexaai/asr_impl/mlx_asr_impl.py +91 -0
  6. nexaai/asr_impl/pybind_asr_impl.py +43 -0
  7. nexaai/base.py +39 -0
  8. nexaai/binds/__init__.py +3 -0
  9. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  10. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/libnexa_bridge.dylib +0 -0
  12. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  13. nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
  14. nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
  15. nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
  16. nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
  17. nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
  18. nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
  19. nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
  20. nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
  21. nexaai/binds/nexa_mlx/py-lib/ml.py +842 -0
  22. nexaai/binds/nexa_mlx/py-lib/mlx_audio/__init__.py +0 -0
  23. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/__init__.py +1 -0
  24. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  25. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  26. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  27. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  28. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  29. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  30. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  31. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  32. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  33. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  34. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  35. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  36. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  37. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  38. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  39. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  40. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  41. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  42. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  43. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  44. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  45. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  46. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  47. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  48. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  49. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  50. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  51. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  52. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  53. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  54. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  55. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  56. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  57. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  58. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  59. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  60. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  61. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  62. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  63. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  64. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  65. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  66. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  67. nexaai/binds/nexa_mlx/py-lib/mlx_audio/server.py +525 -0
  68. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/__init__.py +0 -0
  69. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  70. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  71. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/__init__.py +0 -0
  72. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/generate.py +174 -0
  73. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  74. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  75. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  76. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  77. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  78. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  79. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  80. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  81. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  82. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  83. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  84. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  85. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  86. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  87. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  88. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  89. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  90. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  91. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  92. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  93. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/utils.py +195 -0
  94. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/__init__.py +1 -0
  95. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/audio_player.py +120 -0
  96. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/convert.py +71 -0
  97. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/generate.py +449 -0
  98. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  99. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  100. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  101. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  102. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  103. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/base.py +84 -0
  104. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  105. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  106. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  107. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  108. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  109. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  110. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  111. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  112. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  113. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  114. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  115. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  116. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  117. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  118. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  119. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  120. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  121. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  122. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  123. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  124. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  125. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  126. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  127. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  128. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  129. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  130. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  131. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  132. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  133. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  134. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  135. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  136. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  137. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  138. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  139. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  140. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  141. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  142. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  143. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  144. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  145. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  146. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  147. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  148. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  149. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  150. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  151. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  152. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  153. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  154. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  155. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  156. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  157. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  158. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  159. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  160. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  161. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  162. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  163. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  164. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  165. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  166. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  167. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  168. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  169. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/utils.py +337 -0
  170. nexaai/binds/nexa_mlx/py-lib/mlx_audio/utils.py +237 -0
  171. nexaai/binds/nexa_mlx/py-lib/mlx_audio/version.py +1 -0
  172. nexaai/binds/nexa_mlx/py-lib/profiling.py +239 -0
  173. nexaai/common.py +61 -0
  174. nexaai/cv.py +87 -0
  175. nexaai/cv_impl/__init__.py +0 -0
  176. nexaai/cv_impl/mlx_cv_impl.py +88 -0
  177. nexaai/cv_impl/pybind_cv_impl.py +31 -0
  178. nexaai/embedder.py +68 -0
  179. nexaai/embedder_impl/__init__.py +0 -0
  180. nexaai/embedder_impl/mlx_embedder_impl.py +114 -0
  181. nexaai/embedder_impl/pybind_embedder_impl.py +91 -0
  182. nexaai/image_gen.py +136 -0
  183. nexaai/image_gen_impl/__init__.py +0 -0
  184. nexaai/image_gen_impl/mlx_image_gen_impl.py +291 -0
  185. nexaai/image_gen_impl/pybind_image_gen_impl.py +84 -0
  186. nexaai/llm.py +89 -0
  187. nexaai/llm_impl/__init__.py +0 -0
  188. nexaai/llm_impl/mlx_llm_impl.py +249 -0
  189. nexaai/llm_impl/pybind_llm_impl.py +207 -0
  190. nexaai/mlx_backend/asr/__init__.py +12 -0
  191. nexaai/mlx_backend/asr/interface.py +122 -0
  192. nexaai/mlx_backend/common/__init__.py +0 -0
  193. nexaai/mlx_backend/common/utils.py +25 -0
  194. nexaai/mlx_backend/cv/__init__.py +0 -0
  195. nexaai/mlx_backend/cv/generate.py +195 -0
  196. nexaai/mlx_backend/cv/interface.py +151 -0
  197. nexaai/mlx_backend/cv/main.py +81 -0
  198. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  199. nexaai/mlx_backend/embedding/__init__.py +0 -0
  200. nexaai/mlx_backend/embedding/generate.py +130 -0
  201. nexaai/mlx_backend/embedding/interface.py +312 -0
  202. nexaai/mlx_backend/embedding/main.py +82 -0
  203. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  204. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  205. nexaai/mlx_backend/llm/__init__.py +0 -0
  206. nexaai/mlx_backend/llm/generate.py +149 -0
  207. nexaai/mlx_backend/llm/interface.py +764 -0
  208. nexaai/mlx_backend/llm/main.py +68 -0
  209. nexaai/mlx_backend/ml.py +842 -0
  210. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  211. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  212. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  213. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  214. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  215. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  216. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  217. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  218. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  219. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  220. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  221. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  222. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  223. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  224. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  225. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  226. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  227. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  228. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  229. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  230. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  231. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  232. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  233. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  234. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  235. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  236. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  237. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  238. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  239. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  240. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  241. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  242. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  243. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  244. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  245. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  246. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  247. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  249. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  250. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  251. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  252. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  253. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  254. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  255. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  256. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  257. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  258. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  259. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  260. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  261. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  262. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  264. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  265. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  266. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  267. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  268. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  269. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  270. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  271. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  272. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  273. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  274. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  275. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  276. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  277. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  278. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  279. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  280. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  281. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  282. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  283. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  284. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  285. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  286. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  287. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  288. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  289. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  290. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  291. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  292. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  293. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  294. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  295. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  296. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  297. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  298. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  299. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  300. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  301. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  302. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  303. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  304. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  305. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  306. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  307. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  308. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  309. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  310. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  311. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  312. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  313. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  314. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  315. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  316. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  317. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  318. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  319. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  320. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  321. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  322. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  353. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  354. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  355. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  356. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  357. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  358. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  359. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  360. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  361. nexaai/mlx_backend/profiling.py +239 -0
  362. nexaai/mlx_backend/rerank/__init__.py +0 -0
  363. nexaai/mlx_backend/rerank/generate.py +174 -0
  364. nexaai/mlx_backend/rerank/interface.py +287 -0
  365. nexaai/mlx_backend/rerank/main.py +127 -0
  366. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  367. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  368. nexaai/mlx_backend/sd/__init__.py +1 -0
  369. nexaai/mlx_backend/sd/interface.py +362 -0
  370. nexaai/mlx_backend/sd/main.py +286 -0
  371. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  372. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  373. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  374. nexaai/mlx_backend/sd/modeling/model_io.py +330 -0
  375. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  376. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  377. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  378. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  379. nexaai/mlx_backend/tts/__init__.py +12 -0
  380. nexaai/mlx_backend/tts/interface.py +276 -0
  381. nexaai/mlx_backend/vlm/__init__.py +3 -0
  382. nexaai/mlx_backend/vlm/generate.py +572 -0
  383. nexaai/mlx_backend/vlm/interface.py +406 -0
  384. nexaai/mlx_backend/vlm/main.py +157 -0
  385. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  386. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  387. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  388. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  389. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  390. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  391. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  392. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  393. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  394. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  395. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  396. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  397. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  398. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  399. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  400. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  401. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  402. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  403. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  404. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  405. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  406. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  407. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  408. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  409. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  410. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  411. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  412. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  413. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  414. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  415. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  416. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  417. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  418. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  419. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  420. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  421. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  422. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  423. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  424. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  425. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  426. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  427. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  428. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  429. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  430. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  431. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  432. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  433. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  434. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  435. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  436. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  437. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  438. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  439. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  440. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  441. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  442. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  443. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  444. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  445. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  446. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  447. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  448. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  449. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  450. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  451. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  452. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  453. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  454. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  455. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  456. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  457. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  458. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  459. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  460. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  461. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  462. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  463. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  464. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  465. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  466. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  467. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  468. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  470. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  471. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  472. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  473. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  474. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  475. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  476. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  477. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  478. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  479. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  480. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  481. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  482. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  483. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  484. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  485. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  486. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  487. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  488. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  489. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  490. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  491. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  492. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  493. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  494. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  495. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  496. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  497. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  498. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  499. nexaai/rerank.py +51 -0
  500. nexaai/rerank_impl/__init__.py +0 -0
  501. nexaai/rerank_impl/mlx_rerank_impl.py +91 -0
  502. nexaai/rerank_impl/pybind_rerank_impl.py +42 -0
  503. nexaai/runtime.py +64 -0
  504. nexaai/tts.py +70 -0
  505. nexaai/tts_impl/__init__.py +0 -0
  506. nexaai/tts_impl/mlx_tts_impl.py +93 -0
  507. nexaai/tts_impl/pybind_tts_impl.py +42 -0
  508. nexaai/utils/avatar_fetcher.py +104 -0
  509. nexaai/utils/decode.py +18 -0
  510. nexaai/utils/model_manager.py +1195 -0
  511. nexaai/utils/progress_tracker.py +372 -0
  512. nexaai/vlm.py +120 -0
  513. nexaai/vlm_impl/__init__.py +0 -0
  514. nexaai/vlm_impl/mlx_vlm_impl.py +205 -0
  515. nexaai/vlm_impl/pybind_vlm_impl.py +228 -0
  516. nexaai-1.0.4rc10.dist-info/METADATA +26 -0
  517. nexaai-1.0.4rc10.dist-info/RECORD +519 -0
  518. nexaai-1.0.4rc10.dist-info/WHEEL +5 -0
  519. nexaai-1.0.4rc10.dist-info/top_level.txt +1 -0
@@ -0,0 +1,777 @@
1
+ import functools
2
+ import json
3
+ import math
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from types import SimpleNamespace
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import mlx.core as mx
10
+ import mlx.nn as nn
11
+ import numpy as np
12
+ from huggingface_hub import snapshot_download
13
+
14
+
15
+ def filter_dataclass_fields(data_dict, dataclass_type):
16
+ """Filter a dictionary to only include keys that are fields in the dataclass."""
17
+ valid_fields = {f.name for f in dataclass_type.__dataclass_fields__.values()}
18
+ return {k: v for k, v in data_dict.items() if k in valid_fields}
19
+
20
+
21
+ @dataclass
22
+ class EncodecConfig:
23
+ model_type: str = "encodec"
24
+ audio_channels: int = 1
25
+ num_filters: int = 32
26
+ kernel_size: int = 7
27
+ num_residual_layers: int = 1
28
+ dilation_growth_rate: int = 2
29
+ codebook_size: int = 1024
30
+ codebook_dim: int = 128
31
+ hidden_size: int = 128
32
+ num_lstm_layers: int = 2
33
+ residual_kernel_size: int = 3
34
+ use_causal_conv: bool = True
35
+ normalize: bool = False
36
+ pad_mode: str = "reflect"
37
+ norm_type: str = "weight_norm"
38
+ last_kernel_size: int = 7
39
+ trim_right_ratio: float = 1.0
40
+ compress: int = 2
41
+ upsampling_ratios: List[int] = None
42
+ target_bandwidths: List[float] = None
43
+ sampling_rate: int = 24000
44
+ chunk_length_s: Optional[float] = None
45
+ overlap: Optional[float] = None
46
+ architectures: List[str] = None
47
+
48
+
49
+ def preprocess_audio(
50
+ raw_audio: Union[mx.array, List[mx.array]],
51
+ sampling_rate: int = 24000,
52
+ chunk_length: Optional[int] = None,
53
+ chunk_stride: Optional[int] = None,
54
+ ):
55
+ r"""
56
+ Prepare inputs for the EnCodec model.
57
+
58
+ Args:
59
+ raw_audio (mx.array or List[mx.array]): The sequence or batch of
60
+ sequences to be processed.
61
+ sampling_rate (int): The sampling rate at which the audio waveform
62
+ should be digitalized.
63
+ chunk_length (int, optional): The model's chunk length.
64
+ chunk_stride (int, optional): The model's chunk stride.
65
+ """
66
+ if not isinstance(raw_audio, list):
67
+ raw_audio = [raw_audio]
68
+
69
+ raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
70
+
71
+ max_length = max(array.shape[0] for array in raw_audio)
72
+ if chunk_length is not None:
73
+ max_length += chunk_length - (max_length % chunk_stride)
74
+
75
+ inputs = []
76
+ masks = []
77
+ for x in raw_audio:
78
+ length = x.shape[0]
79
+ mask = mx.ones((length,), dtype=mx.bool_)
80
+ difference = max_length - length
81
+ if difference > 0:
82
+ mask = mx.pad(mask, (0, difference))
83
+ x = mx.pad(x, ((0, difference), (0, 0)))
84
+ inputs.append(x)
85
+ masks.append(mask)
86
+ return mx.stack(inputs), mx.stack(masks)
87
+
88
+
89
+ _lstm_kernel = mx.fast.metal_kernel(
90
+ name="lstm",
91
+ input_names=["x", "h_in", "cell", "hidden_size", "time_step", "num_time_steps"],
92
+ output_names=["hidden_state", "cell_state"],
93
+ header="""
94
+ template <typename T>
95
+ T sigmoid(T x) {
96
+ auto y = 1 / (1 + metal::exp(-metal::abs(x)));
97
+ return (x < 0) ? 1 - y : y;
98
+ }
99
+ """,
100
+ source="""
101
+ uint b = thread_position_in_grid.x;
102
+ uint d = hidden_size * 4;
103
+
104
+ uint elem = b * d + thread_position_in_grid.y;
105
+ uint index = elem;
106
+ uint x_index = b * num_time_steps * d + time_step * d + index;
107
+
108
+ auto i = sigmoid(h_in[index] + x[x_index]);
109
+ index += hidden_size;
110
+ x_index += hidden_size;
111
+ auto f = sigmoid(h_in[index] + x[x_index]);
112
+ index += hidden_size;
113
+ x_index += hidden_size;
114
+ auto g = metal::precise::tanh(h_in[index] + x[x_index]);
115
+ index += hidden_size;
116
+ x_index += hidden_size;
117
+ auto o = sigmoid(h_in[index] + x[x_index]);
118
+
119
+ cell_state[elem] = f * cell[elem] + i * g;
120
+ hidden_state[elem] = o * metal::precise::tanh(cell_state[elem]);
121
+ """,
122
+ )
123
+
124
+
125
+ def lstm_custom(x, h_in, cell, time_step):
126
+ assert x.ndim == 3, "Input to LSTM must have 3 dimensions."
127
+ out_shape = cell.shape
128
+ return _lstm_kernel(
129
+ inputs=[x, h_in, cell, out_shape[-1], time_step, x.shape[-2]],
130
+ output_shapes=[out_shape, out_shape],
131
+ output_dtypes=[h_in.dtype, h_in.dtype],
132
+ grid=(x.shape[0], h_in.size // 4, 1),
133
+ threadgroup=(256, 1, 1),
134
+ )
135
+
136
+
137
+ class LSTM(nn.Module):
138
+ def __init__(
139
+ self,
140
+ input_size: int,
141
+ hidden_size: int,
142
+ bias: bool = True,
143
+ ):
144
+ super().__init__()
145
+
146
+ self.hidden_size = hidden_size
147
+ self.Wx = mx.zeros((4 * hidden_size, input_size))
148
+ self.Wh = mx.zeros((4 * hidden_size, hidden_size))
149
+ self.bias = mx.zeros((4 * hidden_size,)) if bias else None
150
+
151
+ def __call__(self, x, hidden=None, cell=None):
152
+ if self.bias is not None:
153
+ x = mx.addmm(self.bias, x, self.Wx.T)
154
+ else:
155
+ x = x @ self.Wx.T
156
+
157
+ all_hidden = []
158
+
159
+ B = x.shape[0]
160
+ cell = cell or mx.zeros((B, self.hidden_size), x.dtype)
161
+ for t in range(x.shape[-2]):
162
+ if hidden is None:
163
+ hidden = mx.zeros((B, self.hidden_size * 4), x.dtype)
164
+ else:
165
+ hidden = hidden @ self.Wh.T
166
+ hidden, cell = lstm_custom(x, hidden, cell, t)
167
+ all_hidden.append(hidden)
168
+
169
+ return mx.stack(all_hidden, axis=-2)
170
+
171
+
172
+ class EncodecConv1d(nn.Module):
173
+ """Conv1d with asymmetric or causal padding and normalization."""
174
+
175
+ def __init__(
176
+ self,
177
+ config,
178
+ in_channels: int,
179
+ out_channels: int,
180
+ kernel_size: int,
181
+ stride: int = 1,
182
+ dilation: int = 1,
183
+ ):
184
+ super().__init__()
185
+ self.causal = config.use_causal_conv
186
+ self.pad_mode = config.pad_mode
187
+ self.norm_type = config.norm_type
188
+
189
+ self.conv = nn.Conv1d(
190
+ in_channels, out_channels, kernel_size, stride, dilation=dilation
191
+ )
192
+ if self.norm_type == "time_group_norm":
193
+ self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
194
+
195
+ self.stride = stride
196
+
197
+ # Effective kernel size with dilations.
198
+ self.kernel_size = (kernel_size - 1) * dilation + 1
199
+
200
+ self.padding_total = kernel_size - stride
201
+
202
+ def _get_extra_padding_for_conv1d(
203
+ self,
204
+ hidden_states: mx.array,
205
+ ) -> mx.array:
206
+ length = hidden_states.shape[1]
207
+ n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
208
+ n_frames = int(math.ceil(n_frames)) - 1
209
+ ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
210
+ return ideal_length - length
211
+
212
+ def _pad1d(
213
+ self,
214
+ hidden_states: mx.array,
215
+ paddings: Tuple[int, int],
216
+ mode: str = "zero",
217
+ value: float = 0.0,
218
+ ):
219
+ if mode != "reflect":
220
+ return mx.pad(
221
+ hidden_states, paddings, mode="constant", constant_values=value
222
+ )
223
+
224
+ length = hidden_states.shape[1]
225
+ prefix = hidden_states[:, 1 : paddings[0] + 1][:, ::-1]
226
+ suffix = hidden_states[:, max(length - (paddings[1] + 1), 0) : -1][:, ::-1]
227
+ return mx.concatenate([prefix, hidden_states, suffix], axis=1)
228
+
229
+ def __call__(self, hidden_states):
230
+ extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
231
+
232
+ if self.causal:
233
+ # Left padding for causal
234
+ hidden_states = self._pad1d(
235
+ hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode
236
+ )
237
+ else:
238
+ # Asymmetric padding required for odd strides
239
+ padding_right = self.padding_total // 2
240
+ padding_left = self.padding_total - padding_right
241
+ hidden_states = self._pad1d(
242
+ hidden_states,
243
+ (padding_left, padding_right + extra_padding),
244
+ mode=self.pad_mode,
245
+ )
246
+
247
+ hidden_states = self.conv(hidden_states)
248
+
249
+ if self.norm_type == "time_group_norm":
250
+ hidden_states = self.norm(hidden_states)
251
+
252
+ return hidden_states
253
+
254
+
255
+ class EncodecConvTranspose1d(nn.Module):
256
+ """ConvTranspose1d with asymmetric or causal padding and normalization."""
257
+
258
+ def __init__(
259
+ self,
260
+ config,
261
+ in_channels: int,
262
+ out_channels: int,
263
+ kernel_size: int,
264
+ stride: int = 1,
265
+ ):
266
+ super().__init__()
267
+ self.causal = config.use_causal_conv
268
+ self.trim_right_ratio = config.trim_right_ratio
269
+ self.norm_type = config.norm_type
270
+ self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
271
+ if config.norm_type == "time_group_norm":
272
+ self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
273
+ self.padding_total = kernel_size - stride
274
+
275
+ def __call__(self, hidden_states):
276
+ hidden_states = self.conv(hidden_states)
277
+
278
+ if self.norm_type == "time_group_norm":
279
+ hidden_states = self.norm(hidden_states)
280
+
281
+ if self.causal:
282
+ padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
283
+ else:
284
+ padding_right = self.padding_total // 2
285
+
286
+ padding_left = self.padding_total - padding_right
287
+
288
+ end = hidden_states.shape[1] - padding_right
289
+ hidden_states = hidden_states[:, padding_left:end, :]
290
+ return hidden_states
291
+
292
+
293
+ class EncodecLSTM(nn.Module):
294
+ def __init__(self, config, dimension):
295
+ super().__init__()
296
+ self.lstm = [LSTM(dimension, dimension) for _ in range(config.num_lstm_layers)]
297
+
298
+ def __call__(self, hidden_states):
299
+ h = hidden_states
300
+ for lstm in self.lstm:
301
+ h = lstm(h)
302
+ return h + hidden_states
303
+
304
+
305
+ class EncodecResnetBlock(nn.Module):
306
+ """
307
+ Residual block from SEANet model as used by EnCodec.
308
+ """
309
+
310
+ def __init__(self, config, dim: int, dilations: List[int]):
311
+ super().__init__()
312
+ kernel_sizes = (config.residual_kernel_size, 1)
313
+ if len(kernel_sizes) != len(dilations):
314
+ raise ValueError("Number of kernel sizes should match number of dilations")
315
+
316
+ hidden = dim // config.compress
317
+ block = []
318
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
319
+ in_chs = dim if i == 0 else hidden
320
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
321
+ block += [nn.ELU()]
322
+ block += [
323
+ EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)
324
+ ]
325
+ self.block = block
326
+
327
+ if getattr(config, "use_conv_shortcut", True):
328
+ self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
329
+ else:
330
+ self.shortcut = nn.Identity()
331
+
332
+ def __call__(self, hidden_states):
333
+ residual = hidden_states
334
+ for layer in self.block:
335
+ hidden_states = layer(hidden_states)
336
+
337
+ return self.shortcut(residual) + hidden_states
338
+
339
+
340
+ class EncodecEncoder(nn.Module):
341
+ """SEANet encoder as used by EnCodec."""
342
+
343
+ def __init__(self, config):
344
+ super().__init__()
345
+ model = [
346
+ EncodecConv1d(
347
+ config, config.audio_channels, config.num_filters, config.kernel_size
348
+ )
349
+ ]
350
+ scaling = 1
351
+
352
+ for ratio in reversed(config.upsampling_ratios):
353
+ current_scale = scaling * config.num_filters
354
+ for j in range(config.num_residual_layers):
355
+ model += [
356
+ EncodecResnetBlock(
357
+ config, current_scale, [config.dilation_growth_rate**j, 1]
358
+ )
359
+ ]
360
+ model += [nn.ELU()]
361
+ model += [
362
+ EncodecConv1d(
363
+ config,
364
+ current_scale,
365
+ current_scale * 2,
366
+ kernel_size=ratio * 2,
367
+ stride=ratio,
368
+ )
369
+ ]
370
+ scaling *= 2
371
+
372
+ model += [EncodecLSTM(config, scaling * config.num_filters)]
373
+ model += [nn.ELU()]
374
+ model += [
375
+ EncodecConv1d(
376
+ config,
377
+ scaling * config.num_filters,
378
+ config.hidden_size,
379
+ config.last_kernel_size,
380
+ )
381
+ ]
382
+
383
+ self.layers = model
384
+
385
+ def __call__(self, hidden_states):
386
+ for layer in self.layers:
387
+ hidden_states = layer(hidden_states)
388
+ return hidden_states
389
+
390
+
391
+ class EncodecDecoder(nn.Module):
392
+ """SEANet decoder as used by EnCodec."""
393
+
394
+ def __init__(self, config):
395
+ super().__init__()
396
+ scaling = int(2 ** len(config.upsampling_ratios))
397
+ model = [
398
+ EncodecConv1d(
399
+ config,
400
+ config.hidden_size,
401
+ scaling * config.num_filters,
402
+ config.kernel_size,
403
+ )
404
+ ]
405
+
406
+ model += [EncodecLSTM(config, scaling * config.num_filters)]
407
+
408
+ for ratio in config.upsampling_ratios:
409
+ current_scale = scaling * config.num_filters
410
+ model += [nn.ELU()]
411
+ model += [
412
+ EncodecConvTranspose1d(
413
+ config,
414
+ current_scale,
415
+ current_scale // 2,
416
+ kernel_size=ratio * 2,
417
+ stride=ratio,
418
+ )
419
+ ]
420
+ for j in range(config.num_residual_layers):
421
+ model += [
422
+ EncodecResnetBlock(
423
+ config, current_scale // 2, (config.dilation_growth_rate**j, 1)
424
+ )
425
+ ]
426
+ scaling //= 2
427
+
428
+ model += [nn.ELU()]
429
+ model += [
430
+ EncodecConv1d(
431
+ config,
432
+ config.num_filters,
433
+ config.audio_channels,
434
+ config.last_kernel_size,
435
+ )
436
+ ]
437
+ self.layers = model
438
+
439
+ def __call__(self, hidden_states):
440
+ for layer in self.layers:
441
+ hidden_states = layer(hidden_states)
442
+ return hidden_states
443
+
444
+
445
+ class EncodecEuclideanCodebook(nn.Module):
446
+ """Codebook with Euclidean distance."""
447
+
448
+ def __init__(self, config):
449
+ super().__init__()
450
+ self.embed = mx.zeros((config.codebook_size, config.codebook_dim))
451
+
452
+ def quantize(self, hidden_states):
453
+ embed = self.embed.T
454
+ scaled_states = hidden_states.square().sum(axis=1, keepdims=True)
455
+ dist = -(
456
+ scaled_states
457
+ - 2 * hidden_states @ embed
458
+ + embed.square().sum(axis=0, keepdims=True)
459
+ )
460
+ embed_ind = dist.argmax(axis=-1)
461
+ return embed_ind
462
+
463
+ def encode(self, hidden_states):
464
+ shape = hidden_states.shape
465
+ hidden_states = hidden_states.reshape((-1, shape[-1]))
466
+ embed_ind = self.quantize(hidden_states)
467
+ embed_ind = embed_ind.reshape(*shape[:-1])
468
+ return embed_ind
469
+
470
+ def decode(self, embed_ind):
471
+ return self.embed[embed_ind]
472
+
473
+
474
+ class EncodecVectorQuantization(nn.Module):
475
+ """
476
+ Vector quantization implementation. Currently supports only euclidean distance.
477
+ """
478
+
479
+ def __init__(self, config):
480
+ super().__init__()
481
+ self.codebook = EncodecEuclideanCodebook(config)
482
+
483
+ def encode(self, hidden_states):
484
+ return self.codebook.encode(hidden_states)
485
+
486
+ def decode(self, embed_ind):
487
+ return self.codebook.decode(embed_ind)
488
+
489
+
490
+ class EncodecResidualVectorQuantizer(nn.Module):
491
+ """Residual Vector Quantizer."""
492
+
493
+ def __init__(self, config):
494
+ super().__init__()
495
+ self.codebook_size = config.codebook_size
496
+
497
+ hop_length = np.prod(config.upsampling_ratios)
498
+ self.frame_rate = math.ceil(config.sampling_rate / hop_length)
499
+ self.num_quantizers = int(
500
+ 1000 * config.target_bandwidths[-1] // (self.frame_rate * 10)
501
+ )
502
+ self.layers = [
503
+ EncodecVectorQuantization(config) for _ in range(self.num_quantizers)
504
+ ]
505
+
506
+ def get_num_quantizers_for_bandwidth(
507
+ self, bandwidth: Optional[float] = None
508
+ ) -> int:
509
+ """Return num_quantizers based on specified target bandwidth."""
510
+ bw_per_q = math.log2(self.codebook_size) * self.frame_rate
511
+ num_quantizers = self.num_quantizers
512
+ if bandwidth is not None and bandwidth > 0.0:
513
+ num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
514
+ return num_quantizers
515
+
516
+ def encode(
517
+ self, embeddings: mx.array, bandwidth: Optional[float] = None
518
+ ) -> mx.array:
519
+ """
520
+ Encode a given input array with the specified frame rate at the given
521
+ bandwidth. The RVQ encode method sets the appropriate number of
522
+ quantizers to use and returns indices for each quantizer.
523
+ """
524
+ num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
525
+ residual = embeddings
526
+ all_indices = []
527
+ for layer in self.layers[:num_quantizers]:
528
+ indices = layer.encode(residual)
529
+ quantized = layer.decode(indices)
530
+ residual = residual - quantized
531
+ all_indices.append(indices)
532
+ out_indices = mx.stack(all_indices, axis=1)
533
+ return out_indices
534
+
535
+ def decode(self, codes: mx.array) -> mx.array:
536
+ """Decode the given codes to the quantized representation."""
537
+ quantized_out = None
538
+ for i, indices in enumerate(codes.split(codes.shape[1], axis=1)):
539
+ layer = self.layers[i]
540
+ quantized = layer.decode(indices.squeeze(1))
541
+ if quantized_out is None:
542
+ quantized_out = quantized
543
+ else:
544
+ quantized_out = quantized + quantized_out
545
+ return quantized_out
546
+
547
+
548
+ class Encodec(nn.Module):
549
+ def __init__(self, config):
550
+ super().__init__()
551
+ self.config = config
552
+ self.encoder = EncodecEncoder(self.config)
553
+ self.decoder = EncodecDecoder(self.config)
554
+ self.quantizer = EncodecResidualVectorQuantizer(self.config)
555
+
556
+ def _encode_frame(
557
+ self, input_values: mx.array, bandwidth: float, padding_mask: mx.array
558
+ ) -> Tuple[mx.array, Optional[mx.array]]:
559
+ """
560
+ Encodes the given input using the underlying VQVAE.
561
+ """
562
+ length = input_values.shape[1]
563
+ duration = length / self.config.sampling_rate
564
+
565
+ if (
566
+ self.config.chunk_length_s is not None
567
+ and duration > 1e-5 + self.config.chunk_length_s
568
+ ):
569
+ raise RuntimeError(
570
+ f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}"
571
+ )
572
+
573
+ scale = None
574
+ if self.config.normalize:
575
+ # if the padding is non zero
576
+ input_values = input_values * padding_mask[..., None]
577
+ mono = mx.sum(input_values, axis=2, keepdims=True) / input_values.shape[2]
578
+ scale = mono.square().mean(axis=1, keepdims=True).sqrt() + 1e-8
579
+ input_values = input_values / scale
580
+
581
+ embeddings = self.encoder(input_values)
582
+ codes = self.quantizer.encode(embeddings, bandwidth)
583
+ return codes, scale
584
+
585
+ def encode(
586
+ self,
587
+ input_values: mx.array,
588
+ padding_mask: mx.array = None,
589
+ bandwidth: Optional[float] = None,
590
+ ) -> Tuple[mx.array, Optional[mx.array]]:
591
+ """
592
+ Encodes the input audio waveform into discrete codes.
593
+
594
+ Args:
595
+ input_values (mx.array): The input audio waveform with shape
596
+ ``(batch_size, channels, sequence_length)``.
597
+ padding_mask (mx.array): Padding mask used to pad the ``input_values``.
598
+ bandwidth (float, optional): The target bandwidth. Must be one of
599
+ ``config.target_bandwidths``. If ``None``, uses the smallest
600
+ possible bandwidth. bandwidth is represented as a thousandth of
601
+ what it is, e.g. 6kbps bandwidth is represented as bandwidth == 6.0
602
+
603
+ Returns:
604
+ A list of frames containing the discrete encoded codes for the
605
+ input audio waveform, along with rescaling factors for each chunk
606
+ when ``config.normalize==True``. Each frame is a tuple ``(codebook,
607
+ scale)``, with ``codebook`` of shape ``(batch_size, num_codebooks,
608
+ frames)``.
609
+ """
610
+
611
+ if bandwidth is None:
612
+ bandwidth = self.config.target_bandwidths[0]
613
+ if bandwidth not in self.config.target_bandwidths:
614
+ raise ValueError(
615
+ f"This model doesn't support the bandwidth {bandwidth}. Select one of {self.config.target_bandwidths}."
616
+ )
617
+
618
+ _, input_length, channels = input_values.shape
619
+
620
+ if channels < 1 or channels > 2:
621
+ raise ValueError(
622
+ f"Number of audio channels must be 1 or 2, but got {channels}"
623
+ )
624
+
625
+ chunk_length = self.chunk_length
626
+ if chunk_length is None:
627
+ chunk_length = input_length
628
+ stride = input_length
629
+ else:
630
+ stride = self.chunk_stride
631
+
632
+ if padding_mask is None:
633
+ padding_mask = mx.ones(input_values.shape[:2], dtype=mx.bool_)
634
+ encoded_frames = []
635
+ scales = []
636
+
637
+ step = chunk_length - stride
638
+ if (input_length % stride) != step:
639
+ raise ValueError(
640
+ "The input length is not properly padded for batched chunked encoding. Make sure to pad the input correctly."
641
+ )
642
+
643
+ for offset in range(0, input_length - step, stride):
644
+ mask = padding_mask[:, offset : offset + chunk_length].astype(mx.bool_)
645
+ frame = input_values[:, offset : offset + chunk_length]
646
+ encoded_frame, scale = self._encode_frame(frame, bandwidth, mask)
647
+ encoded_frames.append(encoded_frame)
648
+ scales.append(scale)
649
+
650
+ encoded_frames = mx.stack(encoded_frames)
651
+
652
+ return (encoded_frames, scales)
653
+
654
+ @staticmethod
655
+ def _linear_overlap_add(frames: List[mx.array], stride: int):
656
+ if len(frames) == 0:
657
+ raise ValueError("`frames` cannot be an empty list.")
658
+
659
+ dtype = frames[0].dtype
660
+ N, frame_length, C = frames[0].shape
661
+ total_size = stride * (len(frames) - 1) + frames[-1].shape[1]
662
+
663
+ time_vec = mx.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
664
+ weight = 0.5 - (time_vec - 0.5).abs()
665
+
666
+ weight = weight[:, None]
667
+ sum_weight = mx.zeros((total_size, 1), dtype=dtype)
668
+ out = mx.zeros((N, total_size, C), dtype=dtype)
669
+ offset = 0
670
+
671
+ for frame in frames:
672
+ frame_length = frame.shape[1]
673
+ out[:, offset : offset + frame_length] += weight[:frame_length] * frame
674
+ sum_weight[offset : offset + frame_length] += weight[:frame_length]
675
+ offset += stride
676
+
677
+ return out / sum_weight
678
+
679
+ def _decode_frame(
680
+ self, codes: mx.array, scale: Optional[mx.array] = None
681
+ ) -> mx.array:
682
+ embeddings = self.quantizer.decode(codes)
683
+ outputs = self.decoder(embeddings)
684
+ if scale is not None:
685
+ outputs = outputs * scale
686
+ return outputs
687
+
688
+ @property
689
+ def channels(self):
690
+ return self.config.audio_channels
691
+
692
+ @property
693
+ def sampling_rate(self):
694
+ return self.config.sampling_rate
695
+
696
+ @property
697
+ def chunk_length(self):
698
+ if self.config.chunk_length_s is None:
699
+ return None
700
+ else:
701
+ return int(self.config.chunk_length_s * self.config.sampling_rate)
702
+
703
+ @property
704
+ def chunk_stride(self):
705
+ if self.config.chunk_length_s is None or self.config.overlap is None:
706
+ return None
707
+ else:
708
+ return max(1, int((1.0 - self.config.overlap) * self.chunk_length))
709
+
710
+ @classmethod
711
+ def from_pretrained(cls, path_or_repo: str):
712
+ """
713
+ Load the model and audo preprocessor.
714
+ """
715
+ path = Path(path_or_repo)
716
+ if not path.exists():
717
+ path = Path(
718
+ snapshot_download(
719
+ repo_id=path_or_repo,
720
+ allow_patterns=["*.json", "*.safetensors", "*.model"],
721
+ )
722
+ )
723
+
724
+ with open(path / "config.json", "r") as f:
725
+ config = json.load(f)
726
+
727
+ filtered_config = filter_dataclass_fields(config, EncodecConfig)
728
+ config = EncodecConfig(**filtered_config)
729
+ model = cls(config)
730
+ model.load_weights(str(path / "model.safetensors"))
731
+ processor = functools.partial(
732
+ preprocess_audio,
733
+ sampling_rate=config.sampling_rate,
734
+ chunk_length=model.chunk_length,
735
+ chunk_stride=model.chunk_stride,
736
+ )
737
+ mx.eval(model)
738
+ return model, processor
739
+
740
+ def decode(
741
+ self,
742
+ audio_codes: mx.array,
743
+ audio_scales: Union[mx.array, List[mx.array]],
744
+ padding_mask: Optional[mx.array] = None,
745
+ ) -> Tuple[mx.array, mx.array]:
746
+ """
747
+ Decodes the given frames into an output audio waveform.
748
+
749
+ Note that the output might be a bit bigger than the input. In that
750
+ case, any extra steps at the end should be trimmed.
751
+
752
+ Args:
753
+ audio_codes (mx.array): Discret code embeddings of shape
754
+ ``(batch_size, nb_chunks, chunk_length)``.
755
+ audio_scales (mx.array): Scaling factor for each input.
756
+ padding_mask (mx.array): Padding mask.
757
+ """
758
+ chunk_length = self.chunk_length
759
+ if chunk_length is None:
760
+ if audio_codes.shape[1] != 1:
761
+ raise ValueError(f"Expected one frame, got {len(audio_codes)}")
762
+ audio_values = self._decode_frame(audio_codes[:, 0], audio_scales[0])
763
+ else:
764
+ decoded_frames = []
765
+
766
+ for frame, scale in zip(audio_codes, audio_scales):
767
+ frames = self._decode_frame(frame, scale)
768
+ decoded_frames.append(frames)
769
+
770
+ audio_values = self._linear_overlap_add(
771
+ decoded_frames, self.chunk_stride or 1
772
+ )
773
+
774
+ # truncate based on padding mask
775
+ if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
776
+ audio_values = audio_values[:, : padding_mask.shape[1]]
777
+ return audio_values