nexaai 1.0.16rc13__cp310-cp310-macosx_15_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (557) hide show
  1. nexaai/__init__.py +83 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +64 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +92 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +44 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +4 -0
  10. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/libnexa_bridge.dylib +0 -0
  13. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  14. nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
  15. nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
  16. nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
  17. nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
  18. nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
  19. nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
  20. nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
  21. nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
  22. nexaai/binds/nexa_mlx/py-lib/ml.py +888 -0
  23. nexaai/binds/nexa_mlx/py-lib/mlx_audio/__init__.py +0 -0
  24. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/__init__.py +1 -0
  25. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  26. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  27. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  28. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  29. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  30. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  31. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  32. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  33. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  34. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  35. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  36. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  37. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  38. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  39. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  40. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  41. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  42. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  43. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  44. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  45. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  46. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  47. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  48. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  49. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  50. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  51. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  52. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  53. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  54. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  55. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  56. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  57. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  58. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  59. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  60. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  61. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  62. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  63. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  64. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  65. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  66. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  67. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  68. nexaai/binds/nexa_mlx/py-lib/mlx_audio/server.py +525 -0
  69. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/__init__.py +0 -0
  70. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  71. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  72. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/__init__.py +0 -0
  73. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/generate.py +174 -0
  74. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  75. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  76. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  77. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  78. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  79. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  80. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  81. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  82. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  83. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  84. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  85. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  86. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  87. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  88. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  89. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  90. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  91. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  92. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  93. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  94. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/utils.py +195 -0
  95. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/__init__.py +1 -0
  96. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/audio_player.py +120 -0
  97. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/convert.py +71 -0
  98. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/generate.py +449 -0
  99. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  100. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  101. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  102. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  103. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  104. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/base.py +84 -0
  105. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  106. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  107. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  108. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  109. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  110. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  111. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  112. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  113. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  114. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  115. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  116. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  117. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  118. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  119. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  120. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  121. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  122. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  123. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  124. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  125. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  126. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  127. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  128. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  129. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  130. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  131. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  132. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  133. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  134. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  135. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  136. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  137. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  138. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  139. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  140. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  141. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  142. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  143. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  144. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  145. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  146. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  147. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  148. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  149. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  150. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  151. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  152. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  153. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  154. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  155. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  156. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  157. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  158. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  159. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  160. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  161. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  162. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  163. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  164. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  165. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  166. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  167. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  168. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  169. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  170. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/utils.py +337 -0
  171. nexaai/binds/nexa_mlx/py-lib/mlx_audio/utils.py +237 -0
  172. nexaai/binds/nexa_mlx/py-lib/mlx_audio/version.py +1 -0
  173. nexaai/binds/nexa_mlx/py-lib/profiling.py +239 -0
  174. nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
  175. nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
  176. nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
  177. nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
  178. nexaai/binds/nexa_nexaml/libnexa-mm-process.dylib +0 -0
  179. nexaai/binds/nexa_nexaml/libnexa-sampling.dylib +0 -0
  180. nexaai/binds/nexa_nexaml/libnexa_plugin.dylib +0 -0
  181. nexaai/binds/nexa_nexaml/libnexaproc.dylib +0 -0
  182. nexaai/binds/nexa_nexaml/libqwen3-vl.dylib +0 -0
  183. nexaai/binds/nexa_nexaml/libqwen3vl-vision.dylib +0 -0
  184. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  185. nexaai/common.py +104 -0
  186. nexaai/cv.py +92 -0
  187. nexaai/cv_impl/__init__.py +0 -0
  188. nexaai/cv_impl/mlx_cv_impl.py +89 -0
  189. nexaai/cv_impl/pybind_cv_impl.py +32 -0
  190. nexaai/embedder.py +72 -0
  191. nexaai/embedder_impl/__init__.py +0 -0
  192. nexaai/embedder_impl/mlx_embedder_impl.py +116 -0
  193. nexaai/embedder_impl/pybind_embedder_impl.py +95 -0
  194. nexaai/image_gen.py +140 -0
  195. nexaai/image_gen_impl/__init__.py +0 -0
  196. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  197. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  198. nexaai/llm.py +96 -0
  199. nexaai/llm_impl/__init__.py +0 -0
  200. nexaai/llm_impl/mlx_llm_impl.py +269 -0
  201. nexaai/llm_impl/pybind_llm_impl.py +218 -0
  202. nexaai/log.py +92 -0
  203. nexaai/mlx_backend/asr/__init__.py +12 -0
  204. nexaai/mlx_backend/asr/interface.py +122 -0
  205. nexaai/mlx_backend/common/__init__.py +0 -0
  206. nexaai/mlx_backend/common/utils.py +25 -0
  207. nexaai/mlx_backend/cv/__init__.py +0 -0
  208. nexaai/mlx_backend/cv/generate.py +195 -0
  209. nexaai/mlx_backend/cv/interface.py +151 -0
  210. nexaai/mlx_backend/cv/main.py +81 -0
  211. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  212. nexaai/mlx_backend/embedding/__init__.py +0 -0
  213. nexaai/mlx_backend/embedding/generate.py +333 -0
  214. nexaai/mlx_backend/embedding/interface.py +617 -0
  215. nexaai/mlx_backend/embedding/main.py +173 -0
  216. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  217. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  218. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  219. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  220. nexaai/mlx_backend/image_gen/interface.py +82 -0
  221. nexaai/mlx_backend/image_gen/main.py +281 -0
  222. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  223. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  224. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  225. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  226. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  227. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  228. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  229. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  230. nexaai/mlx_backend/llm/__init__.py +0 -0
  231. nexaai/mlx_backend/llm/generate.py +149 -0
  232. nexaai/mlx_backend/llm/interface.py +764 -0
  233. nexaai/mlx_backend/llm/main.py +68 -0
  234. nexaai/mlx_backend/ml.py +888 -0
  235. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  236. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  237. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  238. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  239. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  240. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  241. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  242. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  243. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  244. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  245. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  246. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  247. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  248. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  272. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  273. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  274. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  275. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  276. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  277. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  278. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  279. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  280. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  281. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  282. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  283. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  284. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  286. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  287. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  288. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  289. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  290. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  291. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  292. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  293. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  294. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  295. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  296. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  297. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  305. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  306. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  307. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  308. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  309. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  310. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  311. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  312. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  313. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  314. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  315. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  316. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  317. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  318. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  319. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  320. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  321. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  322. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  378. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  379. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  380. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  381. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  382. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  383. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  384. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  385. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  386. nexaai/mlx_backend/profiling.py +239 -0
  387. nexaai/mlx_backend/rerank/__init__.py +0 -0
  388. nexaai/mlx_backend/rerank/generate.py +174 -0
  389. nexaai/mlx_backend/rerank/interface.py +287 -0
  390. nexaai/mlx_backend/rerank/main.py +127 -0
  391. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  392. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  393. nexaai/mlx_backend/sd/__init__.py +1 -0
  394. nexaai/mlx_backend/sd/interface.py +362 -0
  395. nexaai/mlx_backend/sd/main.py +286 -0
  396. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  397. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  398. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  399. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  400. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  401. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  402. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  403. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  404. nexaai/mlx_backend/tts/__init__.py +12 -0
  405. nexaai/mlx_backend/tts/interface.py +276 -0
  406. nexaai/mlx_backend/vlm/__init__.py +3 -0
  407. nexaai/mlx_backend/vlm/generate.py +572 -0
  408. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +261 -0
  409. nexaai/mlx_backend/vlm/interface.py +415 -0
  410. nexaai/mlx_backend/vlm/main.py +316 -0
  411. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  412. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  413. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  414. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  415. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  416. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  417. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  418. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  419. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  420. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  421. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  422. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  423. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  424. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  425. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  426. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  427. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  429. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  430. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  431. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  432. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  433. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  434. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  435. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  436. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  437. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  438. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  439. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  440. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  441. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  442. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  443. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  444. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  445. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  446. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  447. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  448. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  449. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  450. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  451. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  452. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  453. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  454. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  455. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  456. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  457. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  458. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  459. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  460. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  461. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  462. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  463. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  464. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  465. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  466. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  467. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  468. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  469. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  470. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  471. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  475. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  476. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  477. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  478. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  479. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  480. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  481. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  482. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  483. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  484. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  485. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  486. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  487. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  488. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  489. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  490. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  492. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  493. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  494. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  496. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  497. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  498. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  499. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  500. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  501. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  502. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  503. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  504. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  505. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  506. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  507. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  508. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  509. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  510. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  511. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  512. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  513. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  514. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  515. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
  522. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  523. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  524. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  525. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  526. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  527. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  528. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  529. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  530. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  531. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  532. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  533. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  534. nexaai/rerank.py +55 -0
  535. nexaai/rerank_impl/__init__.py +0 -0
  536. nexaai/rerank_impl/mlx_rerank_impl.py +92 -0
  537. nexaai/rerank_impl/pybind_rerank_impl.py +43 -0
  538. nexaai/runtime.py +68 -0
  539. nexaai/tts.py +74 -0
  540. nexaai/tts_impl/__init__.py +0 -0
  541. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  542. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  543. nexaai/utils/avatar_fetcher.py +104 -0
  544. nexaai/utils/decode.py +18 -0
  545. nexaai/utils/manifest_utils.py +324 -0
  546. nexaai/utils/model_manager.py +1353 -0
  547. nexaai/utils/model_types.py +47 -0
  548. nexaai/utils/progress_tracker.py +385 -0
  549. nexaai/utils/quantization_utils.py +245 -0
  550. nexaai/vlm.py +128 -0
  551. nexaai/vlm_impl/__init__.py +0 -0
  552. nexaai/vlm_impl/mlx_vlm_impl.py +258 -0
  553. nexaai/vlm_impl/pybind_vlm_impl.py +230 -0
  554. nexaai-1.0.16rc13.dist-info/METADATA +32 -0
  555. nexaai-1.0.16rc13.dist-info/RECORD +557 -0
  556. nexaai-1.0.16rc13.dist-info/WHEEL +5 -0
  557. nexaai-1.0.16rc13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,979 @@
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import mlx.core as mx
5
+ import mlx.nn as nn
6
+ import numpy as np
7
+
8
+ from mlx_audio.utils import istft, stft
9
+
10
+ from ..base import check_array_shape
11
+ from ..interpolate import interpolate
12
+
13
+
14
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
15
+ return int((kernel_size * dilation - dilation) / 2)
16
+
17
+
18
+ def compute_norm(
19
+ x: mx.array,
20
+ p: int,
21
+ dim: Optional[Union[int, List[int]]] = None,
22
+ keepdim: bool = False,
23
+ ) -> mx.array:
24
+ """
25
+ Compute the p-norm of a tensor along specified dimensions.
26
+
27
+ Args:
28
+ x: Input array
29
+ p: Order of the norm (1 or 2)
30
+ dim: Dimension(s) along which to compute the norm
31
+ keepdim: Whether to keep the reduced dimensions
32
+
33
+ Returns:
34
+ MLX array containing the computed norm
35
+ """
36
+ if p not in [1, 2]:
37
+ raise ValueError("Only p-norms with p of 1 or 2 are supported")
38
+
39
+ # Handle dimension input
40
+ if dim is None:
41
+ dim = tuple(range(x.ndim))
42
+ elif isinstance(dim, int):
43
+ dim = (dim,)
44
+
45
+ if p == 1:
46
+ # L1 norm
47
+ return mx.sum(mx.abs(x), axis=dim, keepdims=keepdim)
48
+ else:
49
+ # L2 norm
50
+ return mx.sqrt(mx.sum(x * x, axis=dim, keepdims=keepdim))
51
+
52
+
53
+ def weight_norm(
54
+ weight_v: mx.array, weight_g: mx.array, dim: Optional[int] = None
55
+ ) -> mx.array:
56
+ """
57
+ Applies weight normalization to the input tensor.
58
+
59
+ Weight normalization reparameterizes weight vectors in a neural network
60
+ as a magnitude scalar times a direction vector: w = g * v/||v||
61
+
62
+ Args:
63
+ weight_v: Weight direction tensor (v)
64
+ weight_g: Weight magnitude tensor (g)
65
+ dim: Dimension along which to normalize. If None, normalize over all dims
66
+ except dim=-1
67
+
68
+ Returns:
69
+ Normalized weight tensor
70
+ """
71
+ rank = len(weight_v.shape)
72
+
73
+ if dim is not None:
74
+ # Adjust negative dim
75
+ if dim < -1:
76
+ dim += rank
77
+
78
+ # Create list of axes to normalize over
79
+ axes = list(range(rank))
80
+ if dim != -1:
81
+ axes.remove(dim)
82
+ else:
83
+ # Default behavior: normalize over all dimensions
84
+ axes = list(range(rank))
85
+
86
+ # Compute L2 norm of v along specified axes
87
+ norm_v = compute_norm(weight_v, p=2, dim=axes, keepdim=True)
88
+
89
+ # Normalize and scale by g: w = g * (v / ||v||)
90
+ normalized_weight = weight_v / (
91
+ norm_v + 1e-7
92
+ ) # Add epsilon for numerical stability
93
+ return normalized_weight * weight_g
94
+
95
+
96
+ class ConvWeighted(nn.Module):
97
+ """Conv1d with weight normalization"""
98
+
99
+ def __init__(
100
+ self,
101
+ in_channels: int,
102
+ out_channels: int,
103
+ kernel_size: int,
104
+ stride: int = 1,
105
+ padding: int = 1,
106
+ dilation: int = 1,
107
+ groups: int = 1,
108
+ bias: bool = True,
109
+ encode: bool = False,
110
+ ):
111
+ super().__init__()
112
+
113
+ self.stride = stride
114
+ self.padding = padding
115
+ self.dilation = dilation
116
+ self.groups = groups
117
+
118
+ # Initialize weight magnitude (g) and direction (v) vectors
119
+ self.weight_g = mx.ones(
120
+ (out_channels, 1, 1)
121
+ ) # Scalar magnitude per output channel
122
+ self.weight_v = mx.ones(
123
+ (out_channels, kernel_size, in_channels)
124
+ ) # Direction vectors
125
+
126
+ self.bias = mx.zeros(in_channels if encode else out_channels) if bias else None
127
+
128
+ def __call__(self, x, conv):
129
+
130
+ weight = weight_norm(self.weight_v, self.weight_g, dim=0)
131
+
132
+ if self.bias is not None:
133
+ bias = self.bias.reshape(1, 1, -1)
134
+ else:
135
+ bias = None
136
+
137
+ def apply_conv(x, weight_to_use):
138
+ if self.bias is not None:
139
+ return (
140
+ conv(
141
+ x,
142
+ weight_to_use,
143
+ stride=self.stride,
144
+ padding=self.padding,
145
+ dilation=self.dilation,
146
+ groups=self.groups,
147
+ )
148
+ + bias
149
+ )
150
+ return conv(
151
+ x,
152
+ weight_to_use,
153
+ stride=self.stride,
154
+ padding=self.padding,
155
+ dilation=self.dilation,
156
+ groups=self.groups,
157
+ )
158
+
159
+ try:
160
+ # Check if channels last match or if groups > 1 for ConvTransposed1d
161
+ if x.shape[-1] == weight.shape[-1] or self.groups > 1:
162
+ # Input is channels first, use weight as-is
163
+ return apply_conv(x, weight)
164
+ else:
165
+ # Input is channels last, need to transpose weight
166
+ return apply_conv(x, weight.T)
167
+ except Exception as e:
168
+ print(f"Error: {e}")
169
+ print(f"x.shape: {x.shape}, weight.shape: {weight.shape}")
170
+ raise e
171
+
172
+
173
+ class _InstanceNorm(nn.Module):
174
+ def __init__(
175
+ self,
176
+ num_features: int,
177
+ eps: float = 1e-5,
178
+ momentum: float = 0.1,
179
+ affine: bool = False,
180
+ track_running_stats: bool = False,
181
+ ) -> None:
182
+ super().__init__()
183
+ self.num_features = num_features
184
+ self.eps = eps
185
+ self.momentum = momentum
186
+ self.affine = affine
187
+ self.track_running_stats = track_running_stats
188
+
189
+ # Initialize parameters
190
+ if self.affine:
191
+ self.weight = mx.ones((num_features,))
192
+ self.bias = mx.zeros((num_features,))
193
+ else:
194
+ self.weight = None
195
+ self.bias = None
196
+
197
+ if self.track_running_stats:
198
+ self.running_mean = mx.zeros((num_features,))
199
+ self.running_var = mx.ones((num_features,))
200
+ else:
201
+ self.running_mean = None
202
+ self.running_var = None
203
+
204
+ def _check_input_dim(self, input):
205
+ raise NotImplementedError
206
+
207
+ def _get_no_batch_dim(self):
208
+ raise NotImplementedError
209
+
210
+ def _handle_no_batch_input(self, input):
211
+ # Add batch dimension, apply norm, then remove batch dimension
212
+ expanded = mx.expand_dims(input, axis=0)
213
+ result = self._apply_instance_norm(expanded)
214
+ return mx.squeeze(result, axis=0)
215
+
216
+ def _apply_instance_norm(self, input):
217
+ # MLX doesn't have a direct instance_norm function like PyTorch
218
+ # So we need to implement it manually
219
+
220
+ # Get dimensions
221
+ dims = list(range(input.ndim))
222
+ feature_dim = dims[-self._get_no_batch_dim()]
223
+
224
+ # Compute statistics along all dims except batch and feature dims
225
+ reduce_dims = [d for d in dims if d != 0 and d != feature_dim]
226
+
227
+ if self.training or not self.track_running_stats:
228
+ # Compute mean and variance for normalization
229
+ mean = mx.mean(input, axis=reduce_dims, keepdims=True)
230
+ var = mx.var(input, axis=reduce_dims, keepdims=True)
231
+
232
+ # Update running stats if tracking
233
+ if self.track_running_stats and self.training:
234
+ # Compute overall mean and variance (across batch too)
235
+ overall_mean = mx.mean(mean, axis=0)
236
+ overall_var = mx.mean(var, axis=0)
237
+
238
+ # Update running statistics
239
+ self.running_mean = (
240
+ 1 - self.momentum
241
+ ) * self.running_mean + self.momentum * overall_mean
242
+ self.running_var = (
243
+ 1 - self.momentum
244
+ ) * self.running_var + self.momentum * overall_var
245
+ else:
246
+ # Use running statistics
247
+ mean_shape = [1] * input.ndim
248
+ mean_shape[feature_dim] = self.num_features
249
+ var_shape = mean_shape.copy()
250
+
251
+ mean = mx.reshape(self.running_mean, mean_shape)
252
+ var = mx.reshape(self.running_var, var_shape)
253
+
254
+ # Normalize
255
+ x_norm = (input - mean) / mx.sqrt(var + self.eps)
256
+
257
+ # Apply affine transform if needed
258
+ if self.affine:
259
+ weight_shape = [1] * input.ndim
260
+ weight_shape[feature_dim] = self.num_features
261
+ bias_shape = weight_shape.copy()
262
+
263
+ weight = mx.reshape(self.weight, weight_shape)
264
+ bias = mx.reshape(self.bias, bias_shape)
265
+
266
+ return x_norm * weight + bias
267
+ else:
268
+ return x_norm
269
+
270
+ def __call__(self, input):
271
+ self._check_input_dim(input)
272
+
273
+ feature_dim = input.ndim - self._get_no_batch_dim()
274
+ if input.shape[feature_dim] != self.num_features:
275
+ if self.affine:
276
+ raise ValueError(
277
+ f"expected input's size at dim={feature_dim} to match num_features"
278
+ f" ({self.num_features}), but got: {input.shape[feature_dim]}."
279
+ )
280
+ else:
281
+ print(
282
+ f"input's size at dim={feature_dim} does not match num_features. "
283
+ "You can silence this warning by not passing in num_features, "
284
+ "which is not used because affine=False"
285
+ )
286
+
287
+ if input.ndim == self._get_no_batch_dim():
288
+ return self._handle_no_batch_input(input)
289
+
290
+ return self._apply_instance_norm(input)
291
+
292
+
293
+ class InstanceNorm1d(_InstanceNorm):
294
+ """Applies Instance Normalization over a 2D (unbatched) or 3D (batched) input.
295
+
296
+ This implementation follows the algorithm described in the paper
297
+ "Instance Normalization: The Missing Ingredient for Fast Stylization".
298
+
299
+ Args:
300
+ num_features: Number of features or channels (C) of the input
301
+ eps: A value added to the denominator for numerical stability. Default: 1e-5
302
+ momentum: The value used for the running_mean and running_var computation. Default: 0.1
303
+ affine: When True, this module has learnable affine parameters. Default: False
304
+ track_running_stats: When True, this module tracks running statistics. Default: False
305
+
306
+ Shape:
307
+ - Input: (N, C, L) or (C, L)
308
+ - Output: Same shape as input
309
+
310
+ Examples:
311
+ >>> # Without Learnable Parameters
312
+ >>> m = nn.InstanceNorm1d(100)
313
+ >>> # With Learnable Parameters
314
+ >>> m = nn.InstanceNorm1d(100, affine=True)
315
+ >>> input = mx.random.normal((20, 100, 40))
316
+ >>> output = m(input)
317
+ """
318
+
319
+ def _get_no_batch_dim(self):
320
+ return 2
321
+
322
+ def _check_input_dim(self, input):
323
+ if input.ndim not in (2, 3):
324
+ raise ValueError(f"expected 2D or 3D input (got {input.ndim}D input)")
325
+
326
+
327
+ class AdaIN1d(nn.Module):
328
+ def __init__(self, style_dim: int, num_features: int):
329
+ super().__init__()
330
+ self.norm = InstanceNorm1d(num_features, affine=False)
331
+ self.fc = nn.Linear(style_dim, num_features * 2)
332
+
333
+ def __call__(self, x: mx.array, s: mx.array) -> mx.array:
334
+ h = self.fc(s)
335
+ h = mx.expand_dims(h, axis=2) # Equivalent to view(..., 1)
336
+ gamma, beta = mx.split(h, 2, axis=1)
337
+ x = (1 + gamma) * self.norm(x) + beta
338
+ return x
339
+
340
+
341
+ class AdaINResBlock1(nn.Module):
342
+ def __init__(
343
+ self,
344
+ channels: int,
345
+ kernel_size: int = 3,
346
+ dilation: Tuple[int, int, int] = (1, 3, 5),
347
+ style_dim: int = 64,
348
+ ):
349
+ super(AdaINResBlock1, self).__init__()
350
+ self.convs1 = [
351
+ ConvWeighted(
352
+ channels,
353
+ channels,
354
+ kernel_size,
355
+ stride=1,
356
+ padding=get_padding(kernel_size, dilation[i]),
357
+ dilation=dilation[i],
358
+ )
359
+ for i in range(3)
360
+ ]
361
+ self.convs2 = [
362
+ ConvWeighted(
363
+ channels,
364
+ channels,
365
+ kernel_size,
366
+ stride=1,
367
+ padding=get_padding(kernel_size, 1),
368
+ dilation=1,
369
+ )
370
+ for _ in range(3)
371
+ ]
372
+ self.adain1 = [AdaIN1d(style_dim, channels) for _ in range(3)]
373
+ self.adain2 = [AdaIN1d(style_dim, channels) for _ in range(3)]
374
+ self.alpha1 = [mx.ones((1, channels, 1)) for _ in range(len(self.convs1))]
375
+ self.alpha2 = [mx.ones((1, channels, 1)) for _ in range(len(self.convs2))]
376
+
377
+ def __call__(self, x: mx.array, s: mx.array) -> mx.array:
378
+ for c1, c2, n1, n2, a1, a2 in zip(
379
+ self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2
380
+ ):
381
+ xt = n1(x, s)
382
+ xt = xt + (1 / a1) * (mx.sin(a1 * xt) ** 2) # Snake1D
383
+
384
+ xt = xt.swapaxes(2, 1)
385
+ xt = c1(xt, mx.conv1d)
386
+ xt = xt.swapaxes(2, 1)
387
+
388
+ xt = n2(xt, s)
389
+ xt = xt + (1 / a2) * (mx.sin(a2 * xt) ** 2) # Snake1D
390
+
391
+ xt = xt.swapaxes(2, 1)
392
+ xt = c2(xt, mx.conv1d)
393
+ xt = xt.swapaxes(2, 1)
394
+
395
+ x = xt + x
396
+ return x
397
+
398
+
399
+ def mlx_angle(z, deg=False):
400
+ z = mx.array(z)
401
+
402
+ if z.dtype == mx.complex64:
403
+ zimag = mx.imag(z)
404
+ zreal = mx.real(z)
405
+ else:
406
+ zimag = mx.zeros_like(z)
407
+ zreal = z
408
+
409
+ a = mx.arctan2(zimag, zreal)
410
+
411
+ if deg:
412
+ a = a * (180.0 / math.pi)
413
+
414
+ return a
415
+
416
+
417
+ def mlx_unwrap(p, discont=None, axis=-1, period=2 * math.pi):
418
+ if discont is None:
419
+ discont = period / 2
420
+
421
+ discont = max(discont, period / 2)
422
+
423
+ slice_indices = [slice(None)] * p.ndim
424
+
425
+ slice_indices[axis] = slice(1, None)
426
+ after_slice = tuple(slice_indices)
427
+
428
+ slice_indices[axis] = slice(None, -1)
429
+ before_slice = tuple(slice_indices)
430
+
431
+ dd = p[after_slice] - p[before_slice]
432
+
433
+ interval_high = period / 2
434
+ interval_low = -interval_high
435
+
436
+ ddmod = dd - period * mx.floor((dd - interval_low) / period)
437
+ ddmod = mx.where(
438
+ (mx.abs(dd - interval_high) < 1e-10) & (dd > 0), interval_high, ddmod
439
+ )
440
+
441
+ ph_correct = ddmod - dd
442
+ ph_correct = mx.where(mx.abs(dd) < discont, 0, ph_correct)
443
+
444
+ padding_shape = list(ph_correct.shape)
445
+ padding_shape[axis] = 1
446
+ zero_padding = mx.zeros(padding_shape)
447
+ padded_corrections = mx.concatenate([zero_padding, ph_correct], axis=axis)
448
+ cumulative_corrections = mx.cumsum(padded_corrections, axis=axis)
449
+
450
+ return p + cumulative_corrections
451
+
452
+
453
+ class MLXSTFT:
454
+ def __init__(
455
+ self, filter_length=800, hop_length=200, win_length=800, window="hann"
456
+ ):
457
+ self.filter_length = filter_length
458
+ self.hop_length = hop_length
459
+ self.win_length = win_length
460
+
461
+ self.window = window
462
+
463
+ def transform(self, input_data):
464
+ # Ensure 2D
465
+ if input_data.ndim == 1:
466
+ input_data = input_data[None, :]
467
+
468
+ magnitudes = []
469
+ phases = []
470
+
471
+ for batch_idx in range(input_data.shape[0]):
472
+ # Compute STFT
473
+ x_stft = stft(
474
+ input_data[batch_idx],
475
+ n_fft=self.filter_length,
476
+ hop_length=self.hop_length,
477
+ win_length=self.win_length,
478
+ window=self.window,
479
+ center=True,
480
+ pad_mode="reflect",
481
+ ).transpose(1, 0)
482
+
483
+ # Get magnitude
484
+ magnitude = mx.abs(x_stft)
485
+
486
+ # Get phase
487
+ phase = mlx_angle(x_stft)
488
+
489
+ magnitudes.append(magnitude)
490
+ phases.append(phase)
491
+
492
+ magnitudes = mx.stack(magnitudes, axis=0)
493
+ phases = mx.stack(phases, axis=0)
494
+
495
+ return magnitudes, phases
496
+
497
+ def inverse(self, magnitude, phase):
498
+ reconstructed = []
499
+
500
+ for batch_idx in range(magnitude.shape[0]):
501
+ # Unwrap phases for reconstruction
502
+ phase_cont = mlx_unwrap(phase[batch_idx], axis=1)
503
+
504
+ # Combine magnitude and phase
505
+ real_part = magnitude[batch_idx] * mx.cos(phase_cont)
506
+ imag_part = magnitude[batch_idx] * mx.sin(phase_cont)
507
+ x_stft = real_part + 1j * imag_part
508
+
509
+ # Inverse STFT
510
+ audio = istft(
511
+ x_stft,
512
+ hop_length=self.hop_length,
513
+ win_length=self.win_length,
514
+ window=self.window,
515
+ center=True,
516
+ length=None,
517
+ )
518
+
519
+ reconstructed.append(audio)
520
+
521
+ reconstructed = mx.stack(reconstructed, axis=0)[:, None, :]
522
+
523
+ return reconstructed
524
+
525
+ def __call__(self, input_data: mx.array) -> mx.array:
526
+ self.magnitude, self.phase = self.transform(input_data)
527
+ reconstruction = self.inverse(self.magnitude, self.phase)
528
+ return mx.expand_dims(reconstruction, axis=-2)
529
+
530
+
531
+ class SineGen:
532
+ def __init__(
533
+ self,
534
+ samp_rate: int,
535
+ upsample_scale: int,
536
+ harmonic_num: int = 0,
537
+ sine_amp: float = 0.1,
538
+ noise_std: float = 0.003,
539
+ voiced_threshold: float = 0,
540
+ flag_for_pulse: bool = False,
541
+ ):
542
+ super().__init__()
543
+ self.sine_amp = sine_amp
544
+ self.noise_std = noise_std
545
+ self.harmonic_num = harmonic_num
546
+ self.dim = self.harmonic_num + 1
547
+ self.sampling_rate = samp_rate
548
+ self.voiced_threshold = voiced_threshold
549
+ self.flag_for_pulse = flag_for_pulse
550
+ self.upsample_scale = upsample_scale
551
+
552
+ def _f02uv(self, f0: mx.array) -> mx.array:
553
+ return mx.array(f0 > self.voiced_threshold, dtype=mx.float32)
554
+
555
+ def _f02sine(self, f0_values: mx.array) -> mx.array:
556
+ """f0_values: (batchsize, length, dim)
557
+ where dim indicates fundamental tone and overtones
558
+ """
559
+ # convert to F0 in rad. The interger part n can be ignored
560
+ # because 2 * np.pi * n doesn't affect phase
561
+ rad_values = (f0_values / self.sampling_rate) % 1
562
+ # initial phase noise (no noise for fundamental component)
563
+ rand_ini = mx.random.normal((f0_values.shape[0], f0_values.shape[2]))
564
+ rand_ini[:, 0] = 0
565
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
566
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
567
+ if not self.flag_for_pulse:
568
+ rad_values = interpolate(
569
+ rad_values.transpose(0, 2, 1),
570
+ scale_factor=1 / self.upsample_scale,
571
+ mode="linear",
572
+ ).transpose(0, 2, 1)
573
+ phase = mx.cumsum(rad_values, axis=1) * 2 * mx.pi
574
+ phase = interpolate(
575
+ phase.transpose(0, 2, 1) * self.upsample_scale,
576
+ scale_factor=self.upsample_scale,
577
+ mode="linear",
578
+ ).transpose(0, 2, 1)
579
+ sines = mx.sin(phase)
580
+ else:
581
+ # If necessary, make sure that the first time step of every
582
+ # voiced segments is sin(pi) or cos(0)
583
+ # This is used for pulse-train generation
584
+ # identify the last time step in unvoiced segments
585
+ uv = self._f02uv(f0_values)
586
+ uv_1 = mx.roll(uv, shifts=-1, axis=1)
587
+ uv_1[:, -1, :] = 1
588
+ u_loc = (uv < 1) * (uv_1 > 0)
589
+ # get the instantanouse phase
590
+ tmp_cumsum = mx.cumsum(rad_values, axis=1)
591
+ # different batch needs to be processed differently
592
+ for idx in range(f0_values.shape[0]):
593
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
594
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
595
+ # stores the accumulation of i.phase within
596
+ # each voiced segments
597
+ tmp_cumsum[idx, :, :] = 0
598
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
599
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
600
+ # within the previous voiced segment.
601
+ i_phase = mx.cumsum(rad_values - tmp_cumsum, axis=1)
602
+ # get the sines
603
+ sines = mx.cos(i_phase * 2 * mx.pi)
604
+ return sines
605
+
606
+ def __call__(self, f0: mx.array) -> Tuple[mx.array, mx.array, mx.array]:
607
+ f0_buf = mx.zeros((f0.shape[0], f0.shape[1], self.dim))
608
+
609
+ # Fundamental component
610
+ fn = f0 * mx.arange(1, self.harmonic_num + 2)[None, None, :]
611
+
612
+ # Generate sine waveforms
613
+ sine_waves = self._f02sine(fn) * self.sine_amp
614
+
615
+ # Generate UV signal
616
+ uv = self._f02uv(f0)
617
+
618
+ # Generate noise
619
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
620
+ noise = noise_amp * mx.random.normal(sine_waves.shape)
621
+
622
+ sine_waves = sine_waves * uv + noise
623
+ return sine_waves, uv, noise
624
+
625
+
626
+ class SourceModuleHnNSF(nn.Module):
627
+ """SourceModule for hn-nsf
628
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
629
+ add_noise_std=0.003, voiced_threshod=0)
630
+ sampling_rate: sampling_rate in Hz
631
+ harmonic_num: number of harmonic above F0 (default: 0)
632
+ sine_amp: amplitude of sine source signal (default: 0.1)
633
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
634
+ note that amplitude of noise in unvoiced is decided
635
+ by sine_amp
636
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
637
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
638
+ F0_sampled (batchsize, length, 1)
639
+ Sine_source (batchsize, length, 1)
640
+ noise_source (batchsize, length 1)
641
+ uv (batchsize, length, 1)
642
+ """
643
+
644
+ def __init__(
645
+ self,
646
+ sampling_rate,
647
+ upsample_scale,
648
+ harmonic_num=0,
649
+ sine_amp=0.1,
650
+ add_noise_std=0.003,
651
+ voiced_threshod=0,
652
+ ):
653
+ super(SourceModuleHnNSF, self).__init__()
654
+ self.sine_amp = sine_amp
655
+ self.noise_std = add_noise_std
656
+ # to produce sine waveforms
657
+ self.l_sin_gen = SineGen(
658
+ sampling_rate,
659
+ upsample_scale,
660
+ harmonic_num,
661
+ sine_amp,
662
+ add_noise_std,
663
+ voiced_threshod,
664
+ )
665
+ # to merge source harmonics into a single excitation
666
+ self.l_linear = nn.Linear(harmonic_num + 1, 1)
667
+
668
+ def __call__(self, x):
669
+ """
670
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
671
+ F0_sampled (batchsize, length, 1)
672
+ Sine_source (batchsize, length, 1)
673
+ noise_source (batchsize, length 1)
674
+ """
675
+ # source for harmonic branch
676
+ sine_wavs, uv, _ = self.l_sin_gen(x)
677
+ sine_merge = mx.tanh(self.l_linear(sine_wavs))
678
+ # source for noise branch, in the same shape as uv
679
+ noise = mx.random.normal(uv.shape) * self.sine_amp / 3
680
+ return sine_merge, noise, uv
681
+
682
+
683
+ class ReflectionPad1d(nn.Module):
684
+ def __init__(self, padding):
685
+ super().__init__()
686
+ self.padding = padding
687
+
688
+ def __call__(self, x):
689
+ return mx.pad(x, ((0, 0), (0, 0), (self.padding[0], self.padding[1])))
690
+
691
+
692
+ def leaky_relu(x, negative_slope=0.01):
693
+ return mx.where(x > 0, x, x * negative_slope)
694
+
695
+
696
+ class Generator(nn.Module):
697
+ def __init__(
698
+ self,
699
+ style_dim,
700
+ resblock_kernel_sizes,
701
+ upsample_rates,
702
+ upsample_initial_channel,
703
+ resblock_dilation_sizes,
704
+ upsample_kernel_sizes,
705
+ gen_istft_n_fft,
706
+ gen_istft_hop_size,
707
+ ):
708
+ super(Generator, self).__init__()
709
+ self.num_kernels = len(resblock_kernel_sizes)
710
+ self.num_upsamples = len(upsample_rates)
711
+ upsample_rates = mx.array(upsample_rates)
712
+ self.m_source = SourceModuleHnNSF(
713
+ sampling_rate=24000,
714
+ upsample_scale=mx.prod(upsample_rates) * gen_istft_hop_size,
715
+ harmonic_num=8,
716
+ voiced_threshod=10,
717
+ )
718
+ self.f0_upsamp = nn.Upsample(
719
+ scale_factor=mx.prod(upsample_rates) * gen_istft_hop_size
720
+ )
721
+ self.noise_convs = []
722
+ self.noise_res = []
723
+ self.ups = []
724
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
725
+ self.ups.append(
726
+ ConvWeighted(
727
+ upsample_initial_channel // (2 ** (i + 1)),
728
+ upsample_initial_channel // (2**i),
729
+ int(k),
730
+ int(u),
731
+ padding=int((k - u) // 2),
732
+ encode=True,
733
+ )
734
+ )
735
+ self.resblocks = []
736
+ for i in range(len(self.ups)):
737
+ ch = upsample_initial_channel // (2 ** (i + 1))
738
+ for j, (k, d) in enumerate(
739
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
740
+ ):
741
+ self.resblocks.append(AdaINResBlock1(ch, k, d, style_dim))
742
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
743
+ if i + 1 < len(upsample_rates):
744
+ stride_f0 = int(mx.prod(upsample_rates[i + 1 :]))
745
+ self.noise_convs.append(
746
+ nn.Conv1d(
747
+ gen_istft_n_fft + 2,
748
+ c_cur,
749
+ kernel_size=stride_f0 * 2,
750
+ stride=stride_f0,
751
+ padding=(stride_f0 + 1) // 2,
752
+ )
753
+ )
754
+ self.noise_res.append(AdaINResBlock1(c_cur, 7, [1, 3, 5], style_dim))
755
+ else:
756
+ self.noise_convs.append(
757
+ nn.Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1)
758
+ )
759
+ self.noise_res.append(AdaINResBlock1(c_cur, 11, [1, 3, 5], style_dim))
760
+ self.post_n_fft = gen_istft_n_fft
761
+ self.conv_post = ConvWeighted(ch, self.post_n_fft + 2, 7, 1, padding=3)
762
+ self.reflection_pad = ReflectionPad1d((1, 0))
763
+ self.stft = MLXSTFT(
764
+ filter_length=gen_istft_n_fft,
765
+ hop_length=gen_istft_hop_size,
766
+ win_length=gen_istft_n_fft,
767
+ )
768
+
769
+ def __call__(self, x, s, f0):
770
+ f0 = self.f0_upsamp(f0[:, None].transpose(0, 2, 1)) # bs,n,t
771
+ har_source, noi_source, uv = self.m_source(f0)
772
+ har_source = mx.squeeze(har_source.transpose(0, 2, 1), axis=1)
773
+ har_spec, har_phase = self.stft.transform(har_source)
774
+ har = mx.concatenate([har_spec, har_phase], axis=1)
775
+ har = har.swapaxes(2, 1)
776
+ for i in range(self.num_upsamples):
777
+ x = leaky_relu(x, negative_slope=0.1)
778
+ x_source = self.noise_convs[i](har)
779
+ x_source = x_source.swapaxes(2, 1)
780
+ x_source = self.noise_res[i](x_source, s)
781
+
782
+ x = x.swapaxes(2, 1)
783
+ x = self.ups[i](x, mx.conv_transpose1d)
784
+ x = x.swapaxes(2, 1)
785
+
786
+ if i == self.num_upsamples - 1:
787
+ x = self.reflection_pad(x)
788
+ x = x + x_source
789
+
790
+ xs = None
791
+ for j in range(self.num_kernels):
792
+ if xs is None:
793
+ xs = self.resblocks[i * self.num_kernels + j](x, s)
794
+ else:
795
+ xs += self.resblocks[i * self.num_kernels + j](x, s)
796
+ x = xs / self.num_kernels
797
+
798
+ x = leaky_relu(x, negative_slope=0.01)
799
+
800
+ x = x.swapaxes(2, 1)
801
+ x = self.conv_post(x, mx.conv1d)
802
+ x = x.swapaxes(2, 1)
803
+
804
+ spec = mx.exp(x[:, : self.post_n_fft // 2 + 1, :])
805
+ phase = mx.sin(x[:, self.post_n_fft // 2 + 1 :, :])
806
+ result = self.stft.inverse(spec, phase)
807
+ return result
808
+
809
+
810
+ class UpSample1d(nn.Module):
811
+ def __init__(self, layer_type):
812
+ super().__init__()
813
+ self.layer_type = layer_type
814
+ self.interpolate = nn.Upsample(
815
+ scale_factor=2, mode="nearest", align_corners=True
816
+ )
817
+
818
+ def __call__(self, x):
819
+ if self.layer_type == "none":
820
+ return x
821
+ else:
822
+ return self.interpolate(x)
823
+
824
+
825
+ class AdainResBlk1d(nn.Module):
826
+ def __init__(
827
+ self,
828
+ dim_in,
829
+ dim_out,
830
+ style_dim=64,
831
+ actv=nn.LeakyReLU(0.2),
832
+ upsample="none",
833
+ dropout_p=0.0,
834
+ bias=False,
835
+ conv_type=None,
836
+ ):
837
+ super().__init__()
838
+ self.actv = actv
839
+ self.dim_in = dim_in
840
+ self.conv_type = conv_type
841
+ self.upsample_type = upsample
842
+ self.upsample = UpSample1d(upsample)
843
+ self.learned_sc = dim_in != dim_out
844
+ self._build_weights(dim_in, dim_out, style_dim)
845
+ self.dropout = nn.Dropout(dropout_p)
846
+ if upsample == "none":
847
+ self.pool = nn.Identity()
848
+ else:
849
+ self.pool = ConvWeighted(
850
+ 1, dim_in, kernel_size=3, stride=2, padding=1, groups=dim_in
851
+ )
852
+
853
+ def _build_weights(self, dim_in, dim_out, style_dim):
854
+ self.conv1 = ConvWeighted(dim_in, dim_out, kernel_size=3, stride=1, padding=1)
855
+ self.conv2 = ConvWeighted(dim_out, dim_out, kernel_size=3, stride=1, padding=1)
856
+ self.norm1 = AdaIN1d(style_dim, dim_in)
857
+ self.norm2 = AdaIN1d(style_dim, dim_out)
858
+ if self.learned_sc:
859
+ self.conv1x1 = ConvWeighted(
860
+ dim_in, dim_out, kernel_size=1, stride=1, padding=0, bias=False
861
+ )
862
+
863
+ def _shortcut(self, x):
864
+ x = x.swapaxes(2, 1)
865
+ x = self.upsample(x)
866
+ x = x.swapaxes(2, 1)
867
+
868
+ if self.learned_sc:
869
+ x = x.swapaxes(2, 1)
870
+ x = self.conv1x1(x, mx.conv1d)
871
+ x = x.swapaxes(2, 1)
872
+ return x
873
+
874
+ def _residual(self, x, s):
875
+ x = self.norm1(x, s)
876
+ x = self.actv(x)
877
+
878
+ # Manually implement grouped ConvTranspose1d since MLX doesn't support groups
879
+ x = x.swapaxes(2, 1)
880
+ x = self.pool(x, mx.conv_transpose1d) if self.upsample_type != "none" else x
881
+ x = mx.pad(x, ((0, 0), (1, 0), (0, 0))) if self.upsample_type != "none" else x
882
+ x = x.swapaxes(2, 1)
883
+
884
+ x = x.swapaxes(2, 1)
885
+ x = self.conv1(self.dropout(x), mx.conv1d)
886
+ x = x.swapaxes(2, 1)
887
+
888
+ x = self.norm2(x, s)
889
+ x = self.actv(x)
890
+
891
+ x = x.swapaxes(2, 1)
892
+ x = self.conv2(x, mx.conv1d)
893
+ x = x.swapaxes(2, 1)
894
+ return x
895
+
896
+ def __call__(self, x, s):
897
+ out = self._residual(x, s)
898
+ out = (out + self._shortcut(x)) / mx.sqrt(2)
899
+ return out
900
+
901
+
902
+ class Decoder(nn.Module):
903
+ def __init__(
904
+ self,
905
+ dim_in,
906
+ style_dim,
907
+ dim_out,
908
+ resblock_kernel_sizes,
909
+ upsample_rates,
910
+ upsample_initial_channel,
911
+ resblock_dilation_sizes,
912
+ upsample_kernel_sizes,
913
+ gen_istft_n_fft,
914
+ gen_istft_hop_size,
915
+ ):
916
+ super().__init__()
917
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim, conv_type=mx.conv1d)
918
+ self.decode = []
919
+ self.decode.append(
920
+ AdainResBlk1d(1024 + 2 + 64, 1024, style_dim, conv_type=mx.conv1d)
921
+ )
922
+ self.decode.append(
923
+ AdainResBlk1d(1024 + 2 + 64, 1024, style_dim, conv_type=mx.conv1d)
924
+ )
925
+ self.decode.append(
926
+ AdainResBlk1d(1024 + 2 + 64, 1024, style_dim, conv_type=mx.conv1d)
927
+ )
928
+ self.decode.append(
929
+ AdainResBlk1d(
930
+ 1024 + 2 + 64, 512, style_dim, upsample=True, conv_type=mx.conv1d
931
+ )
932
+ )
933
+ self.F0_conv = ConvWeighted(1, 1, kernel_size=3, stride=2, padding=1, groups=1)
934
+ self.N_conv = ConvWeighted(1, 1, kernel_size=3, stride=2, padding=1, groups=1)
935
+ self.asr_res = [ConvWeighted(512, 64, kernel_size=1, padding=0)]
936
+ self.generator = Generator(
937
+ style_dim,
938
+ resblock_kernel_sizes,
939
+ upsample_rates,
940
+ upsample_initial_channel,
941
+ resblock_dilation_sizes,
942
+ upsample_kernel_sizes,
943
+ gen_istft_n_fft,
944
+ gen_istft_hop_size,
945
+ )
946
+
947
+ def __call__(self, asr, F0_curve, N, s):
948
+ s = mx.array(s)
949
+ F0 = self.F0_conv(F0_curve[:, None, :].swapaxes(2, 1), mx.conv1d).swapaxes(2, 1)
950
+ N = self.N_conv(N[:, None, :].swapaxes(2, 1), mx.conv1d).swapaxes(2, 1)
951
+ x = mx.concatenate([asr, F0, N], axis=1)
952
+ x = self.encode(x, s)
953
+ asr_res = self.asr_res[0](asr.swapaxes(2, 1), mx.conv1d).swapaxes(2, 1)
954
+ res = True
955
+ for block in self.decode: # Working in MLX
956
+ if res:
957
+ x = mx.concatenate([x, asr_res, F0, N], axis=1)
958
+ x = block(x, s)
959
+ # Check if this block has upsampling
960
+ if hasattr(block, "upsample_type") and block.upsample_type != "none":
961
+ res = False
962
+ x = self.generator(x, s, F0_curve) # Working in MLX
963
+ return x
964
+
965
+ def sanitize(self, key, weights):
966
+ sanitized_weights = None
967
+ if "noise_convs" in key and key.endswith(".weight"):
968
+ sanitized_weights = weights.transpose(0, 2, 1)
969
+
970
+ elif "weight_v" in key:
971
+ if check_array_shape(weights):
972
+ sanitized_weights = weights
973
+ else:
974
+ sanitized_weights = weights.transpose(0, 2, 1)
975
+
976
+ else:
977
+ sanitized_weights = weights
978
+
979
+ return sanitized_weights