nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,757 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ from collections import UserDict
5
+ from dataclasses import dataclass
6
+ from enum import Enum
7
+ from typing import Any, List, Optional, Union
8
+
9
+ import mlx.core as mx
10
+ import numpy as np
11
+
12
+ from mlx_audio.tts.utils import get_model_path
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class TensorType(Enum):
18
+ MX = "mx"
19
+ NP = "np"
20
+
21
+
22
+ class BatchFeature(UserDict):
23
+ def __init__(
24
+ self,
25
+ data=None,
26
+ input_values: Any = None,
27
+ attention_mask: Any = None,
28
+ tensor_type: Union[str, TensorType] = TensorType.MX,
29
+ **kwargs,
30
+ ):
31
+ super().__init__()
32
+ if data:
33
+ self.data.update(data)
34
+
35
+ _input_values_key = "input_values"
36
+ _attention_mask_key = "attention_mask"
37
+
38
+ if input_values is not None:
39
+ # Ensure input_values is a list of items
40
+ if not (
41
+ isinstance(input_values, list)
42
+ and (
43
+ not input_values
44
+ or isinstance(input_values[0], (np.ndarray, mx.array, list, tuple))
45
+ )
46
+ ):
47
+ self.data[_input_values_key] = [input_values]
48
+ else:
49
+ self.data[_input_values_key] = input_values
50
+
51
+ if attention_mask is not None:
52
+ # Ensure attention_mask is a list of items
53
+ if not (
54
+ isinstance(attention_mask, list)
55
+ and (
56
+ not attention_mask
57
+ or isinstance(
58
+ attention_mask[0],
59
+ (np.ndarray, mx.array, list, tuple, type(None)),
60
+ )
61
+ )
62
+ ):
63
+ self.data[_attention_mask_key] = [attention_mask]
64
+ else:
65
+ self.data[_attention_mask_key] = attention_mask
66
+
67
+ if isinstance(tensor_type, str):
68
+ self.tensor_type = TensorType(tensor_type)
69
+ else:
70
+ self.tensor_type = tensor_type
71
+
72
+ # Update with any other kwargs passed
73
+ self.data.update(kwargs)
74
+
75
+
76
+ class PaddingStrategy(Enum):
77
+ LONGEST = "longest"
78
+ MAX_LENGTH = "max_length"
79
+ DO_NOT_PAD = "do_not_pad"
80
+
81
+
82
+ def load_json(path: os.PathLike) -> dict[str, Any]:
83
+ try:
84
+ with open(path, "r") as f:
85
+ return json.load(f)
86
+ except Exception as e:
87
+ raise ValueError(f"Error loading JSON file {path}: {e}")
88
+
89
+
90
+ class Wav2Vec2FeatureExtractor:
91
+ r"""
92
+ Constructs a Wav2Vec2 feature extractor.
93
+
94
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
95
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
96
+
97
+ Args:
98
+ feature_size (`int`, *optional*, defaults to 1):
99
+ The feature dimension of the extracted features.
100
+ sampling_rate (`int`, *optional*, defaults to 16000):
101
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
102
+ padding_value (`float`, *optional*, defaults to 0.0):
103
+ The value that is used to fill the padding values.
104
+ do_normalize (`bool`, *optional*, defaults to `True`):
105
+ Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
106
+ improve the performance for some models, *e.g.*,
107
+ [wav2vec2-lv60](https://huggingface.co/models?search=lv60).
108
+ return_attention_mask (`bool`, *optional*, defaults to `False`):
109
+ Whether or not [`~Wav2Vec2FeatureExtractor.__call__`] should return `attention_mask`.
110
+
111
+ <Tip>
112
+
113
+ Wav2Vec2 models that have set `config.feat_extract_norm == "group"`, such as
114
+ [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), have **not** been trained using
115
+ `attention_mask`. For such models, `input_values` should simply be padded with 0 and no `attention_mask`
116
+ should be passed.
117
+
118
+ For Wav2Vec2 models that have set `config.feat_extract_norm == "layer"`, such as
119
+ [wav2vec2-lv60](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self), `attention_mask` should be
120
+ passed for batched inference.
121
+
122
+ </Tip>"""
123
+
124
+ model_input_names = ["input_values", "attention_mask"]
125
+
126
+ def __init__(
127
+ self,
128
+ feature_size=1,
129
+ sampling_rate=16000,
130
+ padding_value=0.0,
131
+ return_attention_mask=False,
132
+ do_normalize=True,
133
+ **kwargs,
134
+ ):
135
+ self.feature_size = feature_size
136
+ self.sampling_rate = sampling_rate
137
+ self.padding_value = padding_value
138
+ self.padding_side = kwargs.get("padding_side", "right")
139
+ self.return_attention_mask = return_attention_mask
140
+ self.do_normalize = do_normalize
141
+
142
+ @staticmethod
143
+ def zero_mean_unit_var_norm(
144
+ input_values: List[np.ndarray],
145
+ attention_mask: List[np.ndarray],
146
+ padding_value: float = 0.0,
147
+ ) -> List[np.ndarray]:
148
+ """
149
+ Every array in the list is normalized to have zero mean and unit variance
150
+ """
151
+ if attention_mask is not None:
152
+ attention_mask = np.array(attention_mask, np.int32)
153
+ normed_input_values = []
154
+
155
+ for vector, length in zip(input_values, attention_mask.sum(-1)):
156
+ normed_slice = (vector - vector[:length].mean()) / np.sqrt(
157
+ vector[:length].var() + 1e-7
158
+ )
159
+ if length < normed_slice.shape[0]:
160
+ normed_slice[length:] = padding_value
161
+
162
+ normed_input_values.append(normed_slice)
163
+ else:
164
+ normed_input_values = [
165
+ (x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values
166
+ ]
167
+
168
+ return normed_input_values
169
+
170
+ def _truncate(
171
+ self,
172
+ processed_features: Union[dict[str, np.ndarray], BatchFeature],
173
+ max_length: Optional[int] = None,
174
+ pad_to_multiple_of: Optional[int] = None,
175
+ truncation: Optional[bool] = None,
176
+ ):
177
+ """
178
+ Truncate inputs to predefined length or max length in the batch
179
+
180
+ Args:
181
+ processed_features(`Union[Dict[str, np.ndarray], BatchFeature]`):
182
+ Dictionary of input values (`np.ndarray[float]`) / input vectors (`List[np.ndarray[float]]`) or batch
183
+ of inputs values (`List[np.ndarray[int]]`) / input vectors (`List[np.ndarray[int]]`)
184
+ max_length (`int`, *optional*):
185
+ maximum length of the returned list and optionally padding length (see below)
186
+ pad_to_multiple_of (`int`, *optional*) :
187
+ Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to
188
+ enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs
189
+ which benefit from having sequence lengths be a multiple of 128.
190
+ truncation (`bool`, *optional*):
191
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
192
+ """
193
+ if not truncation:
194
+ return processed_features
195
+ elif truncation and max_length is None:
196
+ raise ValueError(
197
+ "When setting ``truncation=True``, make sure that ``max_length`` is defined."
198
+ )
199
+
200
+ required_input = processed_features[self.model_input_names[0]]
201
+
202
+ # find `max_length` that fits `pad_to_multiple_of`
203
+ if (
204
+ max_length is not None
205
+ and pad_to_multiple_of is not None
206
+ and (max_length % pad_to_multiple_of != 0)
207
+ ):
208
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
209
+
210
+ needs_to_be_truncated = len(required_input) > max_length
211
+
212
+ if needs_to_be_truncated:
213
+ processed_features[self.model_input_names[0]] = processed_features[
214
+ self.model_input_names[0]
215
+ ][:max_length]
216
+ if "attention_mask" in processed_features:
217
+ processed_features["attention_mask"] = processed_features[
218
+ "attention_mask"
219
+ ][:max_length]
220
+
221
+ return processed_features
222
+
223
+ def _get_padding_strategies(self, padding=False, max_length=None):
224
+ """
225
+ Find the correct padding strategy
226
+ """
227
+
228
+ # Get padding strategy
229
+ if padding is not False:
230
+ if padding is True:
231
+ padding_strategy = (
232
+ PaddingStrategy.LONGEST
233
+ ) # Default to pad to the longest sequence in the batch
234
+ elif not isinstance(padding, PaddingStrategy):
235
+ padding_strategy = PaddingStrategy(padding)
236
+ elif isinstance(padding, PaddingStrategy):
237
+ padding_strategy = padding
238
+ else:
239
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
240
+
241
+ # Set max length if needed
242
+ if max_length is None:
243
+ if padding_strategy == PaddingStrategy.MAX_LENGTH:
244
+ raise ValueError(
245
+ f"When setting ``padding={PaddingStrategy.MAX_LENGTH}``, make sure that max_length is defined"
246
+ )
247
+
248
+ # Test if we have a padding value
249
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
250
+ self.padding_value is None
251
+ ):
252
+ raise ValueError(
253
+ "Asking to pad but the feature_extractor does not have a padding value. Please select a value to use"
254
+ " as `padding_value`. For example: `feature_extractor.padding_value = 0.0`."
255
+ )
256
+
257
+ return padding_strategy
258
+
259
+ def pad(
260
+ self,
261
+ processed_features: Union[
262
+ BatchFeature,
263
+ list[BatchFeature],
264
+ dict[str, BatchFeature],
265
+ dict[str, list[BatchFeature]],
266
+ list[dict[str, BatchFeature]],
267
+ ],
268
+ padding: Union[bool, str, PaddingStrategy] = True,
269
+ max_length: Optional[int] = None,
270
+ truncation: bool = False,
271
+ pad_to_multiple_of: Optional[int] = None,
272
+ return_attention_mask: Optional[bool] = None,
273
+ return_tensors: Optional[Union[str, Any]] = None,
274
+ ) -> BatchFeature:
275
+ """
276
+ Pad input values / input vectors or a batch of input values / input vectors up to predefined length or to the
277
+ max sequence length in the batch.
278
+
279
+ Padding side (left/right) padding values are defined at the feature extractor level (with `self.padding_side`,
280
+ `self.padding_value`)
281
+
282
+ <Tip>
283
+
284
+ If the `processed_features` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
285
+ result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of
286
+ PyTorch tensors, you will lose the specific device of your tensors however.
287
+
288
+ </Tip>
289
+
290
+ Args:
291
+ processed_features ([`BatchFeature`], list of [`BatchFeature`], `Dict[str, List[float]]`, `Dict[str, List[List[float]]` or `List[Dict[str, List[float]]]`):
292
+ Processed inputs. Can represent one input ([`BatchFeature`] or `Dict[str, List[float]]`) or a batch of
293
+ input values / vectors (list of [`BatchFeature`], *Dict[str, List[List[float]]]* or *List[Dict[str,
294
+ List[float]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader
295
+ collate function.
296
+
297
+ Instead of `List[float]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
298
+ see the note above for the return type.
299
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
300
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
301
+ index) among:
302
+
303
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
304
+ sequence if provided).
305
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
306
+ acceptable input length for the model if that argument is not provided.
307
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
308
+ lengths).
309
+ max_length (`int`, *optional*):
310
+ Maximum length of the returned list and optionally padding length (see above).
311
+ truncation (`bool`):
312
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
313
+ pad_to_multiple_of (`int`, *optional*):
314
+ If set will pad the sequence to a multiple of the provided value.
315
+
316
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
317
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
318
+ return_attention_mask (`bool`, *optional*):
319
+ Whether to return the attention mask. If left to the default, will return the attention mask according
320
+ to the specific feature_extractor's default.
321
+
322
+ [What are attention masks?](../glossary#attention-mask)
323
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
324
+ If set, will return tensors instead of list of python integers. Acceptable values are:
325
+
326
+ - `'mx'`: Return MXNet `mx.ndarray` objects.
327
+ - `'np'`: Return Numpy `np.ndarray` objects.
328
+ """
329
+ # If we have a list of dicts, let's convert it in a dict of lists
330
+ # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
331
+ if isinstance(processed_features, (list, tuple)) and isinstance(
332
+ processed_features[0], (dict, BatchFeature)
333
+ ):
334
+ processed_features = {
335
+ key: [example[key] for example in processed_features]
336
+ for key in processed_features[0].keys()
337
+ }
338
+
339
+ # The model's main input name, usually `input_values`, has be passed for padding
340
+ if self.model_input_names[0] not in processed_features:
341
+ raise ValueError(
342
+ "You should supply an instance of `transformers.BatchFeature` or list of `transformers.BatchFeature`"
343
+ f" to this method that includes {self.model_input_names[0]}, but you provided"
344
+ f" {list(processed_features.keys())}"
345
+ )
346
+
347
+ required_input = processed_features[self.model_input_names[0]]
348
+ return_attention_mask = (
349
+ return_attention_mask
350
+ if return_attention_mask is not None
351
+ else self.return_attention_mask
352
+ )
353
+
354
+ if len(required_input) == 0:
355
+ if return_attention_mask:
356
+ processed_features["attention_mask"] = []
357
+ return processed_features
358
+
359
+ # If we have PyTorch/TF tensors or lists as inputs, we cast them as Numpy arrays
360
+ # and rebuild them afterwards if no return_tensors is specified
361
+ # Note that we lose the specific device the tensor may be on for PyTorch
362
+
363
+ first_element = required_input[0]
364
+ if isinstance(first_element, (list, tuple)):
365
+ # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
366
+ index = 0
367
+ while len(required_input[index]) == 0:
368
+ index += 1
369
+ if index < len(required_input):
370
+ first_element = required_input[index][0]
371
+
372
+ if return_tensors is None:
373
+ if isinstance(first_element, mx.array):
374
+ return_tensors = "mx"
375
+ elif isinstance(first_element, (int, float, list, tuple, np.ndarray)):
376
+ return_tensors = "np"
377
+ else:
378
+ raise ValueError(
379
+ f"type of {first_element} unknown: {type(first_element)}. "
380
+ "Should be one of a python, numpy, pytorch or tensorflow object."
381
+ )
382
+
383
+ for key, value in processed_features.items():
384
+ if isinstance(value[0], (int, float)):
385
+ processed_features[key] = np.array(value)
386
+ else:
387
+ processed_features[key] = [np.array(v) for v in value]
388
+
389
+ # Convert padding_strategy in PaddingStrategy
390
+ padding_strategy = self._get_padding_strategies(
391
+ padding=padding, max_length=max_length
392
+ )
393
+
394
+ required_input = processed_features[self.model_input_names[0]]
395
+
396
+ batch_size = len(required_input)
397
+ if not all(len(v) == batch_size for v in processed_features.values()):
398
+ raise ValueError(
399
+ "Some items in the output dictionary have a different batch size than others."
400
+ )
401
+
402
+ truncated_inputs = []
403
+ for i in range(batch_size):
404
+ inputs = {k: v[i] for k, v in processed_features.items()}
405
+ # truncation
406
+ inputs_slice = self._truncate(
407
+ inputs,
408
+ max_length=max_length,
409
+ pad_to_multiple_of=pad_to_multiple_of,
410
+ truncation=truncation,
411
+ )
412
+ truncated_inputs.append(inputs_slice)
413
+
414
+ if padding_strategy == PaddingStrategy.LONGEST:
415
+ # make sure that `max_length` cannot be longer than the longest truncated length
416
+ max_length = max(
417
+ len(input_slice[self.model_input_names[0]])
418
+ for input_slice in truncated_inputs
419
+ )
420
+ padding_strategy = PaddingStrategy.MAX_LENGTH
421
+
422
+ batch_outputs = {}
423
+ for i in range(batch_size):
424
+ # padding
425
+ outputs = self._pad(
426
+ truncated_inputs[i],
427
+ max_length=max_length,
428
+ padding_strategy=padding_strategy,
429
+ pad_to_multiple_of=pad_to_multiple_of,
430
+ return_attention_mask=return_attention_mask,
431
+ )
432
+
433
+ for key, value in outputs.items():
434
+ if key not in batch_outputs:
435
+ batch_outputs[key] = []
436
+ if value.dtype is np.dtype(np.float64):
437
+ value = value.astype(np.float32)
438
+ batch_outputs[key].append(value)
439
+
440
+ return BatchFeature(batch_outputs, tensor_type=return_tensors)
441
+
442
+ def _pad(
443
+ self,
444
+ processed_features: Union[dict[str, np.ndarray], BatchFeature],
445
+ max_length: Optional[int] = None,
446
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
447
+ pad_to_multiple_of: Optional[int] = None,
448
+ return_attention_mask: Optional[bool] = None,
449
+ ) -> dict:
450
+ """
451
+ Pad inputs (on left/right and up to predefined length or max length in the batch)
452
+
453
+ Args:
454
+ processed_features (`Union[Dict[str, np.ndarray], BatchFeature]`):
455
+ Dictionary of input values (`np.ndarray[float]`) / input vectors (`List[np.ndarray[float]]`) or batch
456
+ of inputs values (`List[np.ndarray[int]]`) / input vectors (`List[np.ndarray[int]]`)
457
+ max_length (`int`, *optional*):
458
+ Maximum length of the returned list and optionally padding length (see below)
459
+ padding_strategy (`PaddingStrategy`, *optional*, default to `PaddingStrategy.DO_NOT_PAD`):
460
+ PaddingStrategy to use for padding.
461
+
462
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
463
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
464
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
465
+ The feature_extractor padding sides are defined in self.padding_side:
466
+
467
+ - 'left': pads on the left of the sequences
468
+ - 'right': pads on the right of the sequences
469
+ pad_to_multiple_of (`int`, *optional*):
470
+ Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to
471
+ enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs
472
+ which benefit from having sequence lengths be a multiple of 128.
473
+ return_attention_mask (`bool`, *optional*):
474
+ Set to False to avoid returning attention mask (default: set to model specifics)
475
+ """
476
+ required_input = processed_features[self.model_input_names[0]]
477
+
478
+ if padding_strategy == PaddingStrategy.LONGEST:
479
+ max_length = len(required_input)
480
+
481
+ if (
482
+ max_length is not None
483
+ and pad_to_multiple_of is not None
484
+ and (max_length % pad_to_multiple_of != 0)
485
+ ):
486
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
487
+
488
+ needs_to_be_padded = (
489
+ padding_strategy != PaddingStrategy.DO_NOT_PAD
490
+ and len(required_input) < max_length
491
+ )
492
+
493
+ if return_attention_mask and "attention_mask" not in processed_features:
494
+ processed_features["attention_mask"] = np.ones(
495
+ len(required_input), dtype=np.int32
496
+ )
497
+
498
+ if needs_to_be_padded:
499
+ difference = max_length - len(required_input)
500
+ if self.padding_side == "right":
501
+ if return_attention_mask:
502
+ processed_features["attention_mask"] = np.pad(
503
+ processed_features["attention_mask"], (0, difference)
504
+ )
505
+ padding_shape = (
506
+ ((0, difference), (0, 0))
507
+ if self.feature_size > 1
508
+ else (0, difference)
509
+ )
510
+ processed_features[self.model_input_names[0]] = np.pad(
511
+ required_input,
512
+ padding_shape,
513
+ "constant",
514
+ constant_values=self.padding_value,
515
+ )
516
+ elif self.padding_side == "left":
517
+ if return_attention_mask:
518
+ processed_features["attention_mask"] = np.pad(
519
+ processed_features["attention_mask"], (difference, 0)
520
+ )
521
+ padding_shape = (
522
+ ((difference, 0), (0, 0))
523
+ if self.feature_size > 1
524
+ else (difference, 0)
525
+ )
526
+ processed_features[self.model_input_names[0]] = np.pad(
527
+ required_input,
528
+ padding_shape,
529
+ "constant",
530
+ constant_values=self.padding_value,
531
+ )
532
+ else:
533
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
534
+
535
+ return processed_features
536
+
537
+ def __call__(
538
+ self,
539
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
540
+ padding: Union[bool, str, PaddingStrategy] = False,
541
+ max_length: Optional[int] = None,
542
+ truncation: bool = False,
543
+ pad_to_multiple_of: Optional[int] = None,
544
+ return_attention_mask: Optional[bool] = None,
545
+ return_tensors: Optional[Union[str, Any]] = None,
546
+ sampling_rate: Optional[int] = None,
547
+ **kwargs,
548
+ ) -> BatchFeature:
549
+ """
550
+ Main method to featurize and prepare for the model one or several sequence(s).
551
+
552
+ Args:
553
+ raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
554
+ The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
555
+ values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
556
+ stereo, i.e. single float per timestep.
557
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
558
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
559
+ index) among:
560
+
561
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
562
+ sequence if provided).
563
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
564
+ acceptable input length for the model if that argument is not provided.
565
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
566
+ lengths).
567
+ max_length (`int`, *optional*):
568
+ Maximum length of the returned list and optionally padding length (see above).
569
+ truncation (`bool`):
570
+ Activates truncation to cut input sequences longer than *max_length* to *max_length*.
571
+ pad_to_multiple_of (`int`, *optional*):
572
+ If set will pad the sequence to a multiple of the provided value.
573
+
574
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
575
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
576
+ return_attention_mask (`bool`, *optional*):
577
+ Whether to return the attention mask. If left to the default, will return the attention mask according
578
+ to the specific feature_extractor's default.
579
+
580
+ [What are attention masks?](../glossary#attention-mask)
581
+
582
+ <Tip>
583
+
584
+ Wav2Vec2 models that have set `config.feat_extract_norm == "group"`, such as
585
+ [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), have **not** been trained using
586
+ `attention_mask`. For such models, `input_values` should simply be padded with 0 and no
587
+ `attention_mask` should be passed.
588
+
589
+ For Wav2Vec2 models that have set `config.feat_extract_norm == "layer"`, such as
590
+ [wav2vec2-lv60](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self), `attention_mask` should
591
+ be passed for batched inference.
592
+
593
+ </Tip>
594
+
595
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
596
+ If set, will return tensors instead of list of python integers. Acceptable values are:
597
+
598
+ - `'mx'`: Return MXNet `mx.ndarray` objects.
599
+ - `'np'`: Return Numpy `np.ndarray` objects.
600
+ sampling_rate (`int`, *optional*):
601
+ The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
602
+ `sampling_rate` at the forward call to prevent silent errors.
603
+ padding_value (`float`, *optional*, defaults to 0.0):
604
+ """
605
+
606
+ if sampling_rate is not None:
607
+ if sampling_rate != self.sampling_rate:
608
+ raise ValueError(
609
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
610
+ f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
611
+ f" {self.sampling_rate} and not {sampling_rate}."
612
+ )
613
+ else:
614
+ logger.warning(
615
+ f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
616
+ "Failing to do so can result in silent errors that might be hard to debug."
617
+ )
618
+
619
+ is_batched_numpy = (
620
+ isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
621
+ )
622
+ if is_batched_numpy and len(raw_speech.shape) > 2:
623
+ raise ValueError(
624
+ f"Only mono-channel audio is supported for input to {self}"
625
+ )
626
+ is_batched = is_batched_numpy or (
627
+ isinstance(raw_speech, (list, tuple))
628
+ and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
629
+ )
630
+
631
+ # always return batch
632
+ if not is_batched:
633
+ raw_speech = [raw_speech]
634
+
635
+ # convert into correct format for padding
636
+ encoded_inputs = BatchFeature({"input_values": raw_speech})
637
+
638
+ padded_inputs = self.pad(
639
+ encoded_inputs,
640
+ padding=padding,
641
+ max_length=max_length,
642
+ truncation=truncation,
643
+ pad_to_multiple_of=pad_to_multiple_of,
644
+ return_attention_mask=return_attention_mask,
645
+ )
646
+
647
+ # convert input values to correct format
648
+ input_values = padded_inputs["input_values"]
649
+ if not isinstance(input_values[0], np.ndarray):
650
+ padded_inputs["input_values"] = [
651
+ np.asarray(array, dtype=np.float32) for array in input_values
652
+ ]
653
+ elif (
654
+ not isinstance(input_values, np.ndarray)
655
+ and isinstance(input_values[0], np.ndarray)
656
+ and input_values[0].dtype is np.dtype(np.float64)
657
+ ):
658
+ padded_inputs["input_values"] = [
659
+ array.astype(np.float32) for array in input_values
660
+ ]
661
+ elif isinstance(input_values, np.ndarray) and input_values.dtype is np.dtype(
662
+ np.float64
663
+ ):
664
+ padded_inputs["input_values"] = input_values.astype(np.float32)
665
+
666
+ # convert attention_mask to correct format
667
+ attention_mask = padded_inputs.get("attention_mask")
668
+ if attention_mask is not None:
669
+ padded_inputs["attention_mask"] = [
670
+ np.asarray(array, dtype=np.int32) for array in attention_mask
671
+ ]
672
+
673
+ # zero-mean and unit-variance normalization
674
+ if self.do_normalize:
675
+ attention_mask = (
676
+ attention_mask
677
+ if self._get_padding_strategies(padding, max_length=max_length)
678
+ is not PaddingStrategy.DO_NOT_PAD
679
+ else None
680
+ )
681
+ padded_inputs["input_values"] = self.zero_mean_unit_var_norm(
682
+ padded_inputs["input_values"],
683
+ attention_mask=attention_mask,
684
+ padding_value=self.padding_value,
685
+ )
686
+
687
+ if return_tensors is not None:
688
+ for k, v in padded_inputs.items():
689
+ if return_tensors == "mx":
690
+ # Convert to numpy array first if it's not already one
691
+ if isinstance(v, list):
692
+ v = np.array(v)
693
+ padded_inputs[k] = mx.array(v)
694
+ elif return_tensors == "np":
695
+ padded_inputs[k] = np.array(v)
696
+ else:
697
+ raise ValueError(f"Invalid return_tensors: {return_tensors}")
698
+ return padded_inputs
699
+
700
+ @classmethod
701
+ def from_pretrained(
702
+ cls,
703
+ pretrained_model_name_or_path: Union[str, os.PathLike],
704
+ file_name: str = "preprocessor_config.json",
705
+ revision: str = "main",
706
+ **kwargs,
707
+ ):
708
+ if isinstance(pretrained_model_name_or_path, str):
709
+ path = get_model_path(pretrained_model_name_or_path)
710
+ else:
711
+ path = pretrained_model_name_or_path
712
+
713
+ if not (path / file_name).exists():
714
+ raise FileNotFoundError(f"File {file_name} not found in {path}")
715
+
716
+ feature_extractor_dict = load_json(path / file_name)
717
+
718
+ return cls.from_dict(feature_extractor_dict, **kwargs)
719
+
720
+ @classmethod
721
+ def from_dict(
722
+ cls, feature_extractor_dict: dict[str, Any], **kwargs
723
+ ) -> "Wav2Vec2FeatureExtractor":
724
+ """
725
+ Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of
726
+ parameters.
727
+
728
+ Args:
729
+ feature_extractor_dict (`Dict[str, Any]`):
730
+ Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be
731
+ retrieved from a pretrained checkpoint by leveraging the
732
+ [`~feature_extraction_utils.FeatureExtractionMixin.to_dict`] method.
733
+ kwargs (`Dict[str, Any]`):
734
+ Additional parameters from which to initialize the feature extractor object.
735
+
736
+ Returns:
737
+ [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature extractor object instantiated from those
738
+ parameters.
739
+ """
740
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
741
+
742
+ # Update feature_extractor with kwargs if needed
743
+ to_remove = []
744
+ for key, value in kwargs.items():
745
+ if key in feature_extractor_dict:
746
+ feature_extractor_dict[key] = value
747
+ to_remove.append(key)
748
+ for key in to_remove:
749
+ kwargs.pop(key, None)
750
+
751
+ feature_extractor = cls(**feature_extractor_dict)
752
+
753
+ logger.info(f"Feature extractor {feature_extractor}")
754
+ if return_unused_kwargs:
755
+ return feature_extractor, kwargs
756
+ else:
757
+ return feature_extractor