nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,777 @@
1
+ import functools
2
+ import json
3
+ import math
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from types import SimpleNamespace
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import mlx.core as mx
10
+ import mlx.nn as nn
11
+ import numpy as np
12
+ from huggingface_hub import snapshot_download
13
+
14
+
15
+ def filter_dataclass_fields(data_dict, dataclass_type):
16
+ """Filter a dictionary to only include keys that are fields in the dataclass."""
17
+ valid_fields = {f.name for f in dataclass_type.__dataclass_fields__.values()}
18
+ return {k: v for k, v in data_dict.items() if k in valid_fields}
19
+
20
+
21
+ @dataclass
22
+ class EncodecConfig:
23
+ model_type: str = "encodec"
24
+ audio_channels: int = 1
25
+ num_filters: int = 32
26
+ kernel_size: int = 7
27
+ num_residual_layers: int = 1
28
+ dilation_growth_rate: int = 2
29
+ codebook_size: int = 1024
30
+ codebook_dim: int = 128
31
+ hidden_size: int = 128
32
+ num_lstm_layers: int = 2
33
+ residual_kernel_size: int = 3
34
+ use_causal_conv: bool = True
35
+ normalize: bool = False
36
+ pad_mode: str = "reflect"
37
+ norm_type: str = "weight_norm"
38
+ last_kernel_size: int = 7
39
+ trim_right_ratio: float = 1.0
40
+ compress: int = 2
41
+ upsampling_ratios: List[int] = None
42
+ target_bandwidths: List[float] = None
43
+ sampling_rate: int = 24000
44
+ chunk_length_s: Optional[float] = None
45
+ overlap: Optional[float] = None
46
+ architectures: List[str] = None
47
+
48
+
49
+ def preprocess_audio(
50
+ raw_audio: Union[mx.array, List[mx.array]],
51
+ sampling_rate: int = 24000,
52
+ chunk_length: Optional[int] = None,
53
+ chunk_stride: Optional[int] = None,
54
+ ):
55
+ r"""
56
+ Prepare inputs for the EnCodec model.
57
+
58
+ Args:
59
+ raw_audio (mx.array or List[mx.array]): The sequence or batch of
60
+ sequences to be processed.
61
+ sampling_rate (int): The sampling rate at which the audio waveform
62
+ should be digitalized.
63
+ chunk_length (int, optional): The model's chunk length.
64
+ chunk_stride (int, optional): The model's chunk stride.
65
+ """
66
+ if not isinstance(raw_audio, list):
67
+ raw_audio = [raw_audio]
68
+
69
+ raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
70
+
71
+ max_length = max(array.shape[0] for array in raw_audio)
72
+ if chunk_length is not None:
73
+ max_length += chunk_length - (max_length % chunk_stride)
74
+
75
+ inputs = []
76
+ masks = []
77
+ for x in raw_audio:
78
+ length = x.shape[0]
79
+ mask = mx.ones((length,), dtype=mx.bool_)
80
+ difference = max_length - length
81
+ if difference > 0:
82
+ mask = mx.pad(mask, (0, difference))
83
+ x = mx.pad(x, ((0, difference), (0, 0)))
84
+ inputs.append(x)
85
+ masks.append(mask)
86
+ return mx.stack(inputs), mx.stack(masks)
87
+
88
+
89
+ _lstm_kernel = mx.fast.metal_kernel(
90
+ name="lstm",
91
+ input_names=["x", "h_in", "cell", "hidden_size", "time_step", "num_time_steps"],
92
+ output_names=["hidden_state", "cell_state"],
93
+ header="""
94
+ template <typename T>
95
+ T sigmoid(T x) {
96
+ auto y = 1 / (1 + metal::exp(-metal::abs(x)));
97
+ return (x < 0) ? 1 - y : y;
98
+ }
99
+ """,
100
+ source="""
101
+ uint b = thread_position_in_grid.x;
102
+ uint d = hidden_size * 4;
103
+
104
+ uint elem = b * d + thread_position_in_grid.y;
105
+ uint index = elem;
106
+ uint x_index = b * num_time_steps * d + time_step * d + index;
107
+
108
+ auto i = sigmoid(h_in[index] + x[x_index]);
109
+ index += hidden_size;
110
+ x_index += hidden_size;
111
+ auto f = sigmoid(h_in[index] + x[x_index]);
112
+ index += hidden_size;
113
+ x_index += hidden_size;
114
+ auto g = metal::precise::tanh(h_in[index] + x[x_index]);
115
+ index += hidden_size;
116
+ x_index += hidden_size;
117
+ auto o = sigmoid(h_in[index] + x[x_index]);
118
+
119
+ cell_state[elem] = f * cell[elem] + i * g;
120
+ hidden_state[elem] = o * metal::precise::tanh(cell_state[elem]);
121
+ """,
122
+ )
123
+
124
+
125
+ def lstm_custom(x, h_in, cell, time_step):
126
+ assert x.ndim == 3, "Input to LSTM must have 3 dimensions."
127
+ out_shape = cell.shape
128
+ return _lstm_kernel(
129
+ inputs=[x, h_in, cell, out_shape[-1], time_step, x.shape[-2]],
130
+ output_shapes=[out_shape, out_shape],
131
+ output_dtypes=[h_in.dtype, h_in.dtype],
132
+ grid=(x.shape[0], h_in.size // 4, 1),
133
+ threadgroup=(256, 1, 1),
134
+ )
135
+
136
+
137
+ class LSTM(nn.Module):
138
+ def __init__(
139
+ self,
140
+ input_size: int,
141
+ hidden_size: int,
142
+ bias: bool = True,
143
+ ):
144
+ super().__init__()
145
+
146
+ self.hidden_size = hidden_size
147
+ self.Wx = mx.zeros((4 * hidden_size, input_size))
148
+ self.Wh = mx.zeros((4 * hidden_size, hidden_size))
149
+ self.bias = mx.zeros((4 * hidden_size,)) if bias else None
150
+
151
+ def __call__(self, x, hidden=None, cell=None):
152
+ if self.bias is not None:
153
+ x = mx.addmm(self.bias, x, self.Wx.T)
154
+ else:
155
+ x = x @ self.Wx.T
156
+
157
+ all_hidden = []
158
+
159
+ B = x.shape[0]
160
+ cell = cell or mx.zeros((B, self.hidden_size), x.dtype)
161
+ for t in range(x.shape[-2]):
162
+ if hidden is None:
163
+ hidden = mx.zeros((B, self.hidden_size * 4), x.dtype)
164
+ else:
165
+ hidden = hidden @ self.Wh.T
166
+ hidden, cell = lstm_custom(x, hidden, cell, t)
167
+ all_hidden.append(hidden)
168
+
169
+ return mx.stack(all_hidden, axis=-2)
170
+
171
+
172
+ class EncodecConv1d(nn.Module):
173
+ """Conv1d with asymmetric or causal padding and normalization."""
174
+
175
+ def __init__(
176
+ self,
177
+ config,
178
+ in_channels: int,
179
+ out_channels: int,
180
+ kernel_size: int,
181
+ stride: int = 1,
182
+ dilation: int = 1,
183
+ ):
184
+ super().__init__()
185
+ self.causal = config.use_causal_conv
186
+ self.pad_mode = config.pad_mode
187
+ self.norm_type = config.norm_type
188
+
189
+ self.conv = nn.Conv1d(
190
+ in_channels, out_channels, kernel_size, stride, dilation=dilation
191
+ )
192
+ if self.norm_type == "time_group_norm":
193
+ self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
194
+
195
+ self.stride = stride
196
+
197
+ # Effective kernel size with dilations.
198
+ self.kernel_size = (kernel_size - 1) * dilation + 1
199
+
200
+ self.padding_total = kernel_size - stride
201
+
202
+ def _get_extra_padding_for_conv1d(
203
+ self,
204
+ hidden_states: mx.array,
205
+ ) -> mx.array:
206
+ length = hidden_states.shape[1]
207
+ n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
208
+ n_frames = int(math.ceil(n_frames)) - 1
209
+ ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
210
+ return ideal_length - length
211
+
212
+ def _pad1d(
213
+ self,
214
+ hidden_states: mx.array,
215
+ paddings: Tuple[int, int],
216
+ mode: str = "zero",
217
+ value: float = 0.0,
218
+ ):
219
+ if mode != "reflect":
220
+ return mx.pad(
221
+ hidden_states, paddings, mode="constant", constant_values=value
222
+ )
223
+
224
+ length = hidden_states.shape[1]
225
+ prefix = hidden_states[:, 1 : paddings[0] + 1][:, ::-1]
226
+ suffix = hidden_states[:, max(length - (paddings[1] + 1), 0) : -1][:, ::-1]
227
+ return mx.concatenate([prefix, hidden_states, suffix], axis=1)
228
+
229
+ def __call__(self, hidden_states):
230
+ extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
231
+
232
+ if self.causal:
233
+ # Left padding for causal
234
+ hidden_states = self._pad1d(
235
+ hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode
236
+ )
237
+ else:
238
+ # Asymmetric padding required for odd strides
239
+ padding_right = self.padding_total // 2
240
+ padding_left = self.padding_total - padding_right
241
+ hidden_states = self._pad1d(
242
+ hidden_states,
243
+ (padding_left, padding_right + extra_padding),
244
+ mode=self.pad_mode,
245
+ )
246
+
247
+ hidden_states = self.conv(hidden_states)
248
+
249
+ if self.norm_type == "time_group_norm":
250
+ hidden_states = self.norm(hidden_states)
251
+
252
+ return hidden_states
253
+
254
+
255
+ class EncodecConvTranspose1d(nn.Module):
256
+ """ConvTranspose1d with asymmetric or causal padding and normalization."""
257
+
258
+ def __init__(
259
+ self,
260
+ config,
261
+ in_channels: int,
262
+ out_channels: int,
263
+ kernel_size: int,
264
+ stride: int = 1,
265
+ ):
266
+ super().__init__()
267
+ self.causal = config.use_causal_conv
268
+ self.trim_right_ratio = config.trim_right_ratio
269
+ self.norm_type = config.norm_type
270
+ self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
271
+ if config.norm_type == "time_group_norm":
272
+ self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
273
+ self.padding_total = kernel_size - stride
274
+
275
+ def __call__(self, hidden_states):
276
+ hidden_states = self.conv(hidden_states)
277
+
278
+ if self.norm_type == "time_group_norm":
279
+ hidden_states = self.norm(hidden_states)
280
+
281
+ if self.causal:
282
+ padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
283
+ else:
284
+ padding_right = self.padding_total // 2
285
+
286
+ padding_left = self.padding_total - padding_right
287
+
288
+ end = hidden_states.shape[1] - padding_right
289
+ hidden_states = hidden_states[:, padding_left:end, :]
290
+ return hidden_states
291
+
292
+
293
+ class EncodecLSTM(nn.Module):
294
+ def __init__(self, config, dimension):
295
+ super().__init__()
296
+ self.lstm = [LSTM(dimension, dimension) for _ in range(config.num_lstm_layers)]
297
+
298
+ def __call__(self, hidden_states):
299
+ h = hidden_states
300
+ for lstm in self.lstm:
301
+ h = lstm(h)
302
+ return h + hidden_states
303
+
304
+
305
+ class EncodecResnetBlock(nn.Module):
306
+ """
307
+ Residual block from SEANet model as used by EnCodec.
308
+ """
309
+
310
+ def __init__(self, config, dim: int, dilations: List[int]):
311
+ super().__init__()
312
+ kernel_sizes = (config.residual_kernel_size, 1)
313
+ if len(kernel_sizes) != len(dilations):
314
+ raise ValueError("Number of kernel sizes should match number of dilations")
315
+
316
+ hidden = dim // config.compress
317
+ block = []
318
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
319
+ in_chs = dim if i == 0 else hidden
320
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
321
+ block += [nn.ELU()]
322
+ block += [
323
+ EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)
324
+ ]
325
+ self.block = block
326
+
327
+ if getattr(config, "use_conv_shortcut", True):
328
+ self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
329
+ else:
330
+ self.shortcut = nn.Identity()
331
+
332
+ def __call__(self, hidden_states):
333
+ residual = hidden_states
334
+ for layer in self.block:
335
+ hidden_states = layer(hidden_states)
336
+
337
+ return self.shortcut(residual) + hidden_states
338
+
339
+
340
+ class EncodecEncoder(nn.Module):
341
+ """SEANet encoder as used by EnCodec."""
342
+
343
+ def __init__(self, config):
344
+ super().__init__()
345
+ model = [
346
+ EncodecConv1d(
347
+ config, config.audio_channels, config.num_filters, config.kernel_size
348
+ )
349
+ ]
350
+ scaling = 1
351
+
352
+ for ratio in reversed(config.upsampling_ratios):
353
+ current_scale = scaling * config.num_filters
354
+ for j in range(config.num_residual_layers):
355
+ model += [
356
+ EncodecResnetBlock(
357
+ config, current_scale, [config.dilation_growth_rate**j, 1]
358
+ )
359
+ ]
360
+ model += [nn.ELU()]
361
+ model += [
362
+ EncodecConv1d(
363
+ config,
364
+ current_scale,
365
+ current_scale * 2,
366
+ kernel_size=ratio * 2,
367
+ stride=ratio,
368
+ )
369
+ ]
370
+ scaling *= 2
371
+
372
+ model += [EncodecLSTM(config, scaling * config.num_filters)]
373
+ model += [nn.ELU()]
374
+ model += [
375
+ EncodecConv1d(
376
+ config,
377
+ scaling * config.num_filters,
378
+ config.hidden_size,
379
+ config.last_kernel_size,
380
+ )
381
+ ]
382
+
383
+ self.layers = model
384
+
385
+ def __call__(self, hidden_states):
386
+ for layer in self.layers:
387
+ hidden_states = layer(hidden_states)
388
+ return hidden_states
389
+
390
+
391
+ class EncodecDecoder(nn.Module):
392
+ """SEANet decoder as used by EnCodec."""
393
+
394
+ def __init__(self, config):
395
+ super().__init__()
396
+ scaling = int(2 ** len(config.upsampling_ratios))
397
+ model = [
398
+ EncodecConv1d(
399
+ config,
400
+ config.hidden_size,
401
+ scaling * config.num_filters,
402
+ config.kernel_size,
403
+ )
404
+ ]
405
+
406
+ model += [EncodecLSTM(config, scaling * config.num_filters)]
407
+
408
+ for ratio in config.upsampling_ratios:
409
+ current_scale = scaling * config.num_filters
410
+ model += [nn.ELU()]
411
+ model += [
412
+ EncodecConvTranspose1d(
413
+ config,
414
+ current_scale,
415
+ current_scale // 2,
416
+ kernel_size=ratio * 2,
417
+ stride=ratio,
418
+ )
419
+ ]
420
+ for j in range(config.num_residual_layers):
421
+ model += [
422
+ EncodecResnetBlock(
423
+ config, current_scale // 2, (config.dilation_growth_rate**j, 1)
424
+ )
425
+ ]
426
+ scaling //= 2
427
+
428
+ model += [nn.ELU()]
429
+ model += [
430
+ EncodecConv1d(
431
+ config,
432
+ config.num_filters,
433
+ config.audio_channels,
434
+ config.last_kernel_size,
435
+ )
436
+ ]
437
+ self.layers = model
438
+
439
+ def __call__(self, hidden_states):
440
+ for layer in self.layers:
441
+ hidden_states = layer(hidden_states)
442
+ return hidden_states
443
+
444
+
445
+ class EncodecEuclideanCodebook(nn.Module):
446
+ """Codebook with Euclidean distance."""
447
+
448
+ def __init__(self, config):
449
+ super().__init__()
450
+ self.embed = mx.zeros((config.codebook_size, config.codebook_dim))
451
+
452
+ def quantize(self, hidden_states):
453
+ embed = self.embed.T
454
+ scaled_states = hidden_states.square().sum(axis=1, keepdims=True)
455
+ dist = -(
456
+ scaled_states
457
+ - 2 * hidden_states @ embed
458
+ + embed.square().sum(axis=0, keepdims=True)
459
+ )
460
+ embed_ind = dist.argmax(axis=-1)
461
+ return embed_ind
462
+
463
+ def encode(self, hidden_states):
464
+ shape = hidden_states.shape
465
+ hidden_states = hidden_states.reshape((-1, shape[-1]))
466
+ embed_ind = self.quantize(hidden_states)
467
+ embed_ind = embed_ind.reshape(*shape[:-1])
468
+ return embed_ind
469
+
470
+ def decode(self, embed_ind):
471
+ return self.embed[embed_ind]
472
+
473
+
474
+ class EncodecVectorQuantization(nn.Module):
475
+ """
476
+ Vector quantization implementation. Currently supports only euclidean distance.
477
+ """
478
+
479
+ def __init__(self, config):
480
+ super().__init__()
481
+ self.codebook = EncodecEuclideanCodebook(config)
482
+
483
+ def encode(self, hidden_states):
484
+ return self.codebook.encode(hidden_states)
485
+
486
+ def decode(self, embed_ind):
487
+ return self.codebook.decode(embed_ind)
488
+
489
+
490
+ class EncodecResidualVectorQuantizer(nn.Module):
491
+ """Residual Vector Quantizer."""
492
+
493
+ def __init__(self, config):
494
+ super().__init__()
495
+ self.codebook_size = config.codebook_size
496
+
497
+ hop_length = np.prod(config.upsampling_ratios)
498
+ self.frame_rate = math.ceil(config.sampling_rate / hop_length)
499
+ self.num_quantizers = int(
500
+ 1000 * config.target_bandwidths[-1] // (self.frame_rate * 10)
501
+ )
502
+ self.layers = [
503
+ EncodecVectorQuantization(config) for _ in range(self.num_quantizers)
504
+ ]
505
+
506
+ def get_num_quantizers_for_bandwidth(
507
+ self, bandwidth: Optional[float] = None
508
+ ) -> int:
509
+ """Return num_quantizers based on specified target bandwidth."""
510
+ bw_per_q = math.log2(self.codebook_size) * self.frame_rate
511
+ num_quantizers = self.num_quantizers
512
+ if bandwidth is not None and bandwidth > 0.0:
513
+ num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
514
+ return num_quantizers
515
+
516
+ def encode(
517
+ self, embeddings: mx.array, bandwidth: Optional[float] = None
518
+ ) -> mx.array:
519
+ """
520
+ Encode a given input array with the specified frame rate at the given
521
+ bandwidth. The RVQ encode method sets the appropriate number of
522
+ quantizers to use and returns indices for each quantizer.
523
+ """
524
+ num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
525
+ residual = embeddings
526
+ all_indices = []
527
+ for layer in self.layers[:num_quantizers]:
528
+ indices = layer.encode(residual)
529
+ quantized = layer.decode(indices)
530
+ residual = residual - quantized
531
+ all_indices.append(indices)
532
+ out_indices = mx.stack(all_indices, axis=1)
533
+ return out_indices
534
+
535
+ def decode(self, codes: mx.array) -> mx.array:
536
+ """Decode the given codes to the quantized representation."""
537
+ quantized_out = None
538
+ for i, indices in enumerate(codes.split(codes.shape[1], axis=1)):
539
+ layer = self.layers[i]
540
+ quantized = layer.decode(indices.squeeze(1))
541
+ if quantized_out is None:
542
+ quantized_out = quantized
543
+ else:
544
+ quantized_out = quantized + quantized_out
545
+ return quantized_out
546
+
547
+
548
+ class Encodec(nn.Module):
549
+ def __init__(self, config):
550
+ super().__init__()
551
+ self.config = config
552
+ self.encoder = EncodecEncoder(self.config)
553
+ self.decoder = EncodecDecoder(self.config)
554
+ self.quantizer = EncodecResidualVectorQuantizer(self.config)
555
+
556
+ def _encode_frame(
557
+ self, input_values: mx.array, bandwidth: float, padding_mask: mx.array
558
+ ) -> Tuple[mx.array, Optional[mx.array]]:
559
+ """
560
+ Encodes the given input using the underlying VQVAE.
561
+ """
562
+ length = input_values.shape[1]
563
+ duration = length / self.config.sampling_rate
564
+
565
+ if (
566
+ self.config.chunk_length_s is not None
567
+ and duration > 1e-5 + self.config.chunk_length_s
568
+ ):
569
+ raise RuntimeError(
570
+ f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}"
571
+ )
572
+
573
+ scale = None
574
+ if self.config.normalize:
575
+ # if the padding is non zero
576
+ input_values = input_values * padding_mask[..., None]
577
+ mono = mx.sum(input_values, axis=2, keepdims=True) / input_values.shape[2]
578
+ scale = mono.square().mean(axis=1, keepdims=True).sqrt() + 1e-8
579
+ input_values = input_values / scale
580
+
581
+ embeddings = self.encoder(input_values)
582
+ codes = self.quantizer.encode(embeddings, bandwidth)
583
+ return codes, scale
584
+
585
+ def encode(
586
+ self,
587
+ input_values: mx.array,
588
+ padding_mask: mx.array = None,
589
+ bandwidth: Optional[float] = None,
590
+ ) -> Tuple[mx.array, Optional[mx.array]]:
591
+ """
592
+ Encodes the input audio waveform into discrete codes.
593
+
594
+ Args:
595
+ input_values (mx.array): The input audio waveform with shape
596
+ ``(batch_size, channels, sequence_length)``.
597
+ padding_mask (mx.array): Padding mask used to pad the ``input_values``.
598
+ bandwidth (float, optional): The target bandwidth. Must be one of
599
+ ``config.target_bandwidths``. If ``None``, uses the smallest
600
+ possible bandwidth. bandwidth is represented as a thousandth of
601
+ what it is, e.g. 6kbps bandwidth is represented as bandwidth == 6.0
602
+
603
+ Returns:
604
+ A list of frames containing the discrete encoded codes for the
605
+ input audio waveform, along with rescaling factors for each chunk
606
+ when ``config.normalize==True``. Each frame is a tuple ``(codebook,
607
+ scale)``, with ``codebook`` of shape ``(batch_size, num_codebooks,
608
+ frames)``.
609
+ """
610
+
611
+ if bandwidth is None:
612
+ bandwidth = self.config.target_bandwidths[0]
613
+ if bandwidth not in self.config.target_bandwidths:
614
+ raise ValueError(
615
+ f"This model doesn't support the bandwidth {bandwidth}. Select one of {self.config.target_bandwidths}."
616
+ )
617
+
618
+ _, input_length, channels = input_values.shape
619
+
620
+ if channels < 1 or channels > 2:
621
+ raise ValueError(
622
+ f"Number of audio channels must be 1 or 2, but got {channels}"
623
+ )
624
+
625
+ chunk_length = self.chunk_length
626
+ if chunk_length is None:
627
+ chunk_length = input_length
628
+ stride = input_length
629
+ else:
630
+ stride = self.chunk_stride
631
+
632
+ if padding_mask is None:
633
+ padding_mask = mx.ones(input_values.shape[:2], dtype=mx.bool_)
634
+ encoded_frames = []
635
+ scales = []
636
+
637
+ step = chunk_length - stride
638
+ if (input_length % stride) != step:
639
+ raise ValueError(
640
+ "The input length is not properly padded for batched chunked encoding. Make sure to pad the input correctly."
641
+ )
642
+
643
+ for offset in range(0, input_length - step, stride):
644
+ mask = padding_mask[:, offset : offset + chunk_length].astype(mx.bool_)
645
+ frame = input_values[:, offset : offset + chunk_length]
646
+ encoded_frame, scale = self._encode_frame(frame, bandwidth, mask)
647
+ encoded_frames.append(encoded_frame)
648
+ scales.append(scale)
649
+
650
+ encoded_frames = mx.stack(encoded_frames)
651
+
652
+ return (encoded_frames, scales)
653
+
654
+ @staticmethod
655
+ def _linear_overlap_add(frames: List[mx.array], stride: int):
656
+ if len(frames) == 0:
657
+ raise ValueError("`frames` cannot be an empty list.")
658
+
659
+ dtype = frames[0].dtype
660
+ N, frame_length, C = frames[0].shape
661
+ total_size = stride * (len(frames) - 1) + frames[-1].shape[1]
662
+
663
+ time_vec = mx.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
664
+ weight = 0.5 - (time_vec - 0.5).abs()
665
+
666
+ weight = weight[:, None]
667
+ sum_weight = mx.zeros((total_size, 1), dtype=dtype)
668
+ out = mx.zeros((N, total_size, C), dtype=dtype)
669
+ offset = 0
670
+
671
+ for frame in frames:
672
+ frame_length = frame.shape[1]
673
+ out[:, offset : offset + frame_length] += weight[:frame_length] * frame
674
+ sum_weight[offset : offset + frame_length] += weight[:frame_length]
675
+ offset += stride
676
+
677
+ return out / sum_weight
678
+
679
+ def _decode_frame(
680
+ self, codes: mx.array, scale: Optional[mx.array] = None
681
+ ) -> mx.array:
682
+ embeddings = self.quantizer.decode(codes)
683
+ outputs = self.decoder(embeddings)
684
+ if scale is not None:
685
+ outputs = outputs * scale
686
+ return outputs
687
+
688
+ @property
689
+ def channels(self):
690
+ return self.config.audio_channels
691
+
692
+ @property
693
+ def sampling_rate(self):
694
+ return self.config.sampling_rate
695
+
696
+ @property
697
+ def chunk_length(self):
698
+ if self.config.chunk_length_s is None:
699
+ return None
700
+ else:
701
+ return int(self.config.chunk_length_s * self.config.sampling_rate)
702
+
703
+ @property
704
+ def chunk_stride(self):
705
+ if self.config.chunk_length_s is None or self.config.overlap is None:
706
+ return None
707
+ else:
708
+ return max(1, int((1.0 - self.config.overlap) * self.chunk_length))
709
+
710
+ @classmethod
711
+ def from_pretrained(cls, path_or_repo: str):
712
+ """
713
+ Load the model and audo preprocessor.
714
+ """
715
+ path = Path(path_or_repo)
716
+ if not path.exists():
717
+ path = Path(
718
+ snapshot_download(
719
+ repo_id=path_or_repo,
720
+ allow_patterns=["*.json", "*.safetensors", "*.model"],
721
+ )
722
+ )
723
+
724
+ with open(path / "config.json", "r") as f:
725
+ config = json.load(f)
726
+
727
+ filtered_config = filter_dataclass_fields(config, EncodecConfig)
728
+ config = EncodecConfig(**filtered_config)
729
+ model = cls(config)
730
+ model.load_weights(str(path / "model.safetensors"))
731
+ processor = functools.partial(
732
+ preprocess_audio,
733
+ sampling_rate=config.sampling_rate,
734
+ chunk_length=model.chunk_length,
735
+ chunk_stride=model.chunk_stride,
736
+ )
737
+ mx.eval(model)
738
+ return model, processor
739
+
740
+ def decode(
741
+ self,
742
+ audio_codes: mx.array,
743
+ audio_scales: Union[mx.array, List[mx.array]],
744
+ padding_mask: Optional[mx.array] = None,
745
+ ) -> Tuple[mx.array, mx.array]:
746
+ """
747
+ Decodes the given frames into an output audio waveform.
748
+
749
+ Note that the output might be a bit bigger than the input. In that
750
+ case, any extra steps at the end should be trimmed.
751
+
752
+ Args:
753
+ audio_codes (mx.array): Discret code embeddings of shape
754
+ ``(batch_size, nb_chunks, chunk_length)``.
755
+ audio_scales (mx.array): Scaling factor for each input.
756
+ padding_mask (mx.array): Padding mask.
757
+ """
758
+ chunk_length = self.chunk_length
759
+ if chunk_length is None:
760
+ if audio_codes.shape[1] != 1:
761
+ raise ValueError(f"Expected one frame, got {len(audio_codes)}")
762
+ audio_values = self._decode_frame(audio_codes[:, 0], audio_scales[0])
763
+ else:
764
+ decoded_frames = []
765
+
766
+ for frame, scale in zip(audio_codes, audio_scales):
767
+ frames = self._decode_frame(frame, scale)
768
+ decoded_frames.append(frames)
769
+
770
+ audio_values = self._linear_overlap_add(
771
+ decoded_frames, self.chunk_stride or 1
772
+ )
773
+
774
+ # truncate based on padding mask
775
+ if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
776
+ audio_values = audio_values[:, : padding_mask.shape[1]]
777
+ return audio_values