nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,974 @@
1
+ import importlib.resources
2
+ import unittest
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+ import numpy as np
8
+ from misaki import en
9
+
10
+
11
+ # Create a patch for the deprecated open_text function
12
+ def patched_open_text(package, resource):
13
+ """Replacement for deprecated open_text using files() API"""
14
+ return importlib.resources.files(package).joinpath(resource).open("r")
15
+
16
+
17
+ # Apply the patch at the module level
18
+ @patch("importlib.resources.open_text", patched_open_text)
19
+ class TestSanitizeLSTMWeights(unittest.TestCase):
20
+ def test_sanitize_lstm_weights(self):
21
+ """Test sanitize_lstm_weights function."""
22
+ # Import inside the test method
23
+ from mlx_audio.tts.models.kokoro.kokoro import sanitize_lstm_weights
24
+
25
+ # Test weight_ih_l0_reverse
26
+ key = "lstm.weight_ih_l0_reverse"
27
+ weights = mx.array(np.zeros((10, 10)))
28
+ result = sanitize_lstm_weights(key, weights)
29
+ self.assertEqual(list(result.keys())[0], "lstm.Wx_backward")
30
+
31
+ # Test weight_hh_l0_reverse
32
+ key = "lstm.weight_hh_l0_reverse"
33
+ weights = mx.array(np.zeros((10, 10)))
34
+ result = sanitize_lstm_weights(key, weights)
35
+ self.assertEqual(list(result.keys())[0], "lstm.Wh_backward")
36
+
37
+ # Test bias_ih_l0_reverse
38
+ key = "lstm.bias_ih_l0_reverse"
39
+ weights = mx.array(np.zeros(10))
40
+ result = sanitize_lstm_weights(key, weights)
41
+ self.assertEqual(list(result.keys())[0], "lstm.bias_ih_backward")
42
+
43
+ # Test bias_hh_l0_reverse
44
+ key = "lstm.bias_hh_l0_reverse"
45
+ weights = mx.array(np.zeros(10))
46
+ result = sanitize_lstm_weights(key, weights)
47
+ self.assertEqual(list(result.keys())[0], "lstm.bias_hh_backward")
48
+
49
+ # Test weight_ih_l0
50
+ key = "lstm.weight_ih_l0"
51
+ weights = mx.array(np.zeros((10, 10)))
52
+ result = sanitize_lstm_weights(key, weights)
53
+ self.assertEqual(list(result.keys())[0], "lstm.Wx_forward")
54
+
55
+ # Test weight_hh_l0
56
+ key = "lstm.weight_hh_l0"
57
+ weights = mx.array(np.zeros((10, 10)))
58
+ result = sanitize_lstm_weights(key, weights)
59
+ self.assertEqual(list(result.keys())[0], "lstm.Wh_forward")
60
+
61
+ # Test bias_ih_l0
62
+ key = "lstm.bias_ih_l0"
63
+ weights = mx.array(np.zeros(10))
64
+ result = sanitize_lstm_weights(key, weights)
65
+ self.assertEqual(list(result.keys())[0], "lstm.bias_ih_forward")
66
+
67
+ # Test bias_hh_l0
68
+ key = "lstm.bias_hh_l0"
69
+ weights = mx.array(np.zeros(10))
70
+ result = sanitize_lstm_weights(key, weights)
71
+ self.assertEqual(list(result.keys())[0], "lstm.bias_hh_forward")
72
+
73
+ # Test unknown key
74
+ key = "unknown.key"
75
+ weights = mx.array(np.zeros(10))
76
+ result = sanitize_lstm_weights(key, weights)
77
+ self.assertEqual(list(result.keys())[0], "unknown.key")
78
+
79
+
80
+ @patch("importlib.resources.open_text", patched_open_text)
81
+ class TestKokoroModel(unittest.TestCase):
82
+ @patch("mlx_audio.tts.models.kokoro.kokoro.json.load")
83
+ @patch("builtins.open", new_callable=MagicMock)
84
+ @patch("mlx_audio.tts.models.kokoro.kokoro.mx.load")
85
+ @patch.object(nn.Module, "load_weights")
86
+ def test_init(self, mock_load_weights, mock_mx_load, mock_open, mock_json_load):
87
+ """Test KokoroModel initialization."""
88
+ # Import inside the test method
89
+ from mlx_audio.tts.models.kokoro.kokoro import Model, ModelConfig
90
+
91
+ # Mock the config loading
92
+ config = {
93
+ "istftnet": {
94
+ "upsample_kernel_sizes": [20, 12],
95
+ "upsample_rates": [10, 6],
96
+ "gen_istft_hop_size": 5,
97
+ "gen_istft_n_fft": 20,
98
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
99
+ "resblock_kernel_sizes": [3, 7, 11],
100
+ "upsample_initial_channel": 512,
101
+ },
102
+ "dim_in": 64,
103
+ "dropout": 0.2,
104
+ "hidden_dim": 512,
105
+ "max_conv_dim": 512,
106
+ "max_dur": 50,
107
+ "multispeaker": True,
108
+ "n_layer": 3,
109
+ "n_mels": 80,
110
+ "n_token": 178,
111
+ "style_dim": 128,
112
+ "text_encoder_kernel_size": 5,
113
+ "plbert": {
114
+ "hidden_size": 768,
115
+ "num_attention_heads": 12,
116
+ "intermediate_size": 2048,
117
+ "max_position_embeddings": 512,
118
+ "num_hidden_layers": 12,
119
+ "dropout": 0.1,
120
+ },
121
+ "vocab": {"a": 1, "b": 2},
122
+ }
123
+ mock_json_load.return_value = config
124
+
125
+ # Mock the weights loading
126
+ mock_mx_load.return_value = {"key": mx.array(np.zeros(10))}
127
+
128
+ # Make load_weights return the module
129
+ mock_load_weights.return_value = None
130
+
131
+ # Initialize the model with the config parameter
132
+ model = Model(ModelConfig.from_dict(config))
133
+
134
+ # Check that the model was initialized correctly
135
+ self.assertIsInstance(model, nn.Module)
136
+ self.assertEqual(model.vocab, {"a": 1, "b": 2})
137
+
138
+ def test_output_dataclass(self):
139
+ """Test KokoroModel.Output dataclass."""
140
+ # Import inside the test method
141
+ from mlx_audio.tts.models.kokoro.kokoro import Model
142
+
143
+ # Create a mock output
144
+ audio = mx.array(np.zeros((1, 1000)))
145
+ pred_dur = mx.array(np.zeros((1, 100)))
146
+
147
+ # Mock __init__ to return None
148
+ with patch.object(Model, "__init__", return_value=None):
149
+ output = Model.Output(audio=audio, pred_dur=pred_dur)
150
+
151
+ # Check that the output was created correctly
152
+ self.assertIs(output.audio, audio)
153
+ self.assertIs(output.pred_dur, pred_dur)
154
+
155
+
156
+ @patch("importlib.resources.open_text", patched_open_text)
157
+ class TestKokoroPipeline(unittest.TestCase):
158
+ def test_aliases_and_lang_codes(self):
159
+ """Test ALIASES and LANG_CODES constants."""
160
+ # Import inside the test method
161
+ from mlx_audio.tts.models.kokoro.pipeline import ALIASES, LANG_CODES
162
+
163
+ # Check that all aliases map to valid language codes
164
+ for alias_key, alias_value in ALIASES.items():
165
+ self.assertIn(alias_value, LANG_CODES)
166
+
167
+ # Check specific mappings
168
+ self.assertEqual(ALIASES["en-us"], "a")
169
+ self.assertEqual(ALIASES["ja"], "j")
170
+ self.assertEqual(LANG_CODES["a"], "American English")
171
+ self.assertEqual(LANG_CODES["j"], "Japanese")
172
+
173
+ def test_init(self):
174
+ """Test KokoroPipeline initialization."""
175
+ # Import inside the test method
176
+ from mlx_audio.tts.models.kokoro.pipeline import LANG_CODES, KokoroPipeline
177
+
178
+ # Mock the KokoroModel - fix the import path
179
+ with patch("mlx_audio.tts.models.kokoro.kokoro.Model") as mock_kokoro_model:
180
+ with patch(
181
+ "mlx_audio.tts.models.kokoro.pipeline.isinstance"
182
+ ) as mock_isinstance:
183
+ mock_model = MagicMock()
184
+ mock_kokoro_model.return_value = mock_model
185
+
186
+ # Simply make isinstance always return True when checking for KokoroModel
187
+ mock_isinstance.return_value = True
188
+
189
+ # Initialize with default model
190
+ pipeline = KokoroPipeline(
191
+ lang_code="a", model=mock_model, repo_id="mock"
192
+ )
193
+ self.assertEqual(pipeline.lang_code, "a")
194
+ self.assertEqual(LANG_CODES[pipeline.lang_code], "American English")
195
+
196
+ # Initialize with provided model
197
+ model = mock_model
198
+ pipeline = KokoroPipeline(lang_code="a", model=model, repo_id="mock")
199
+ self.assertEqual(pipeline.model, model)
200
+
201
+ # Initialize with no model
202
+ pipeline = KokoroPipeline(lang_code="a", model=False, repo_id="mock")
203
+ self.assertIs(pipeline.model, False)
204
+
205
+ def test_load_voice(self):
206
+ """Test load_voice method."""
207
+ # Import inside the test method
208
+ from mlx_audio.tts.models.kokoro.pipeline import KokoroPipeline
209
+
210
+ # Setup the pipeline
211
+ with patch.object(KokoroPipeline, "__init__", return_value=None):
212
+ with patch(
213
+ "mlx_audio.tts.models.kokoro.pipeline.load_voice_tensor"
214
+ ) as load_voice_tensor:
215
+ with patch(
216
+ "mlx_audio.tts.models.kokoro.pipeline.hf_hub_download"
217
+ ) as mock_hf_hub_download:
218
+ pipeline = KokoroPipeline.__new__(KokoroPipeline)
219
+ pipeline.lang_code = "a"
220
+ pipeline.voices = {}
221
+ # Add the missing repo_id attribute
222
+ pipeline.repo_id = "mlx-community/kokoro-tts"
223
+
224
+ # Mock the load voice return value
225
+ load_voice_tensor.return_value = mx.zeros((512, 1, 256))
226
+
227
+ # Test loading a single voice
228
+ pipeline.load_single_voice("voice1")
229
+ mock_hf_hub_download.assert_called_once()
230
+ self.assertIn("voice1", pipeline.voices)
231
+
232
+ # Test loading multiple voices
233
+ mock_hf_hub_download.reset_mock()
234
+ pipeline.voices = {} # Reset voices
235
+ result = pipeline.load_voice("voice1,voice2")
236
+ self.assertEqual(mock_hf_hub_download.call_count, 2)
237
+ self.assertIn("voice1", pipeline.voices)
238
+ self.assertIn("voice2", pipeline.voices)
239
+
240
+ def test_tokens_to_ps(self):
241
+ """Test tokens_to_ps method."""
242
+ # Import inside the test method
243
+ from mlx_audio.tts.models.kokoro.pipeline import KokoroPipeline
244
+
245
+ # Create mock tokens with whitespace attribute
246
+ token1 = MagicMock(spec=en.MToken)
247
+ token1.ps = "p1"
248
+ token1.whitespace = " "
249
+ token1.phonemes = "p1"
250
+
251
+ token2 = MagicMock(spec=en.MToken)
252
+ token2.ps = "p2"
253
+ token2.whitespace = ""
254
+ token2.phonemes = "p2"
255
+
256
+ tokens = [token1, token2]
257
+
258
+ # Test the method
259
+ with patch.object(KokoroPipeline, "__init__", return_value=None):
260
+ with patch.object(KokoroPipeline, "tokens_to_ps", return_value="p1 p2"):
261
+ result = KokoroPipeline.tokens_to_ps(tokens)
262
+ self.assertEqual(result, "p1 p2")
263
+
264
+ def test_tokens_to_text(self):
265
+ """Test tokens_to_text method."""
266
+ # Import inside the test method
267
+ from mlx_audio.tts.models.kokoro.pipeline import KokoroPipeline
268
+
269
+ # Create mock tokens with whitespace attribute
270
+ token1 = MagicMock(spec=en.MToken)
271
+ token1.text = "Hello"
272
+ token1.whitespace = " "
273
+
274
+ token2 = MagicMock(spec=en.MToken)
275
+ token2.text = "world"
276
+ token2.whitespace = ""
277
+
278
+ tokens = [token1, token2]
279
+
280
+ # Test the method
281
+ with patch.object(KokoroPipeline, "__init__", return_value=None):
282
+ with patch.object(
283
+ KokoroPipeline, "tokens_to_text", return_value="Hello world"
284
+ ):
285
+ result = KokoroPipeline.tokens_to_text(tokens)
286
+ self.assertEqual(result, "Hello world")
287
+
288
+ def test_result_dataclass(self):
289
+ """Test KokoroPipeline.Result dataclass."""
290
+ # Import inside the test methods
291
+ from mlx_audio.tts.models.kokoro.kokoro import Model
292
+ from mlx_audio.tts.models.kokoro.pipeline import KokoroPipeline
293
+
294
+ # Create a mock output
295
+ audio = mx.array(np.zeros((1, 1000)))
296
+ pred_dur = mx.array(np.zeros((1, 100)))
297
+ model_output = Model.Output(audio=audio, pred_dur=pred_dur)
298
+
299
+ # Create a Result instance
300
+ result = KokoroPipeline.Result(
301
+ graphemes="Hello",
302
+ phonemes="HH EH L OW",
303
+ tokens=[MagicMock()],
304
+ output=model_output,
305
+ text_index=0,
306
+ )
307
+
308
+ # Check properties
309
+ self.assertEqual(result.graphemes, "Hello")
310
+ self.assertEqual(result.phonemes, "HH EH L OW")
311
+ self.assertIs(result.audio, audio)
312
+ self.assertIs(result.pred_dur, pred_dur)
313
+
314
+ # Test backward compatibility
315
+ self.assertEqual(len(result), 3)
316
+ self.assertEqual(result[0], "Hello")
317
+ self.assertEqual(result[1], "HH EH L OW")
318
+ self.assertIs(result[2], audio)
319
+
320
+ # Test iteration
321
+ items = list(result)
322
+ self.assertEqual(items[0], "Hello")
323
+ self.assertEqual(items[1], "HH EH L OW")
324
+ self.assertIs(items[2], audio)
325
+
326
+
327
+ @patch("importlib.resources.open_text", patched_open_text)
328
+ class TestBarkModel(unittest.TestCase):
329
+ @patch("mlx_audio.tts.models.bark.bark.BertTokenizer")
330
+ def test_init(self, mock_tokenizer):
331
+ """Test BarkModel initialization."""
332
+ from mlx_audio.tts.models.bark.bark import (
333
+ CoarseAcousticsConfig,
334
+ CodecConfig,
335
+ FineAcousticsConfig,
336
+ Model,
337
+ ModelConfig,
338
+ SemanticConfig,
339
+ )
340
+
341
+ # Create mock configs
342
+ semantic_config = SemanticConfig()
343
+ coarse_config = CoarseAcousticsConfig()
344
+ fine_config = FineAcousticsConfig()
345
+ codec_config = CodecConfig()
346
+
347
+ config = ModelConfig(
348
+ semantic_config=semantic_config,
349
+ coarse_acoustics_config=coarse_config,
350
+ fine_acoustics_config=fine_config,
351
+ codec_config=codec_config,
352
+ )
353
+
354
+ # Initialize model
355
+ model = Model(config)
356
+
357
+ # Check that components were initialized correctly
358
+ self.assertIsNotNone(model.semantic)
359
+ self.assertIsNotNone(model.coarse_acoustics)
360
+ self.assertIsNotNone(model.fine_acoustics)
361
+ self.assertIsNotNone(model.tokenizer)
362
+
363
+ def test_sanitize_weights(self):
364
+ """Test weight sanitization."""
365
+ from mlx_audio.tts.models.bark.bark import Model, ModelConfig
366
+
367
+ # Create a minimal config
368
+ config = ModelConfig(
369
+ semantic_config={},
370
+ coarse_acoustics_config={},
371
+ fine_acoustics_config={},
372
+ codec_config={},
373
+ )
374
+
375
+ model = Model(config)
376
+
377
+ # Test with transformer weights
378
+ weights = {
379
+ "_orig_mod.transformer.h.0.mlp.weight": mx.zeros((10, 10)),
380
+ "_orig_mod.transformer.h.1.mlp.weight": mx.zeros((10, 10)),
381
+ "lm_head.weight": mx.zeros((10, 10)),
382
+ }
383
+
384
+ sanitized = model.sanitize(weights)
385
+
386
+ # Check that weights were properly renamed
387
+ self.assertIn("layers.0.mlp.weight", sanitized)
388
+ self.assertIn("layers.1.mlp.weight", sanitized)
389
+ self.assertIn("lm_head.weight", sanitized)
390
+
391
+
392
+ @patch("importlib.resources.open_text", patched_open_text)
393
+ class TestBarkPipeline(unittest.TestCase):
394
+ def setUp(self):
395
+ """Set up test fixtures."""
396
+ from mlx_audio.tts.models.bark.bark import (
397
+ CoarseAcousticsConfig,
398
+ CodecConfig,
399
+ FineAcousticsConfig,
400
+ Model,
401
+ ModelConfig,
402
+ SemanticConfig,
403
+ )
404
+ from mlx_audio.tts.models.bark.pipeline import Pipeline
405
+
406
+ # Create mock model with required attributes
407
+ self.mock_model = MagicMock(spec=Model)
408
+
409
+ # Add the required mock attributes/methods
410
+ self.mock_model.semantic = MagicMock()
411
+ self.mock_model.coarse_acoustics = MagicMock()
412
+ self.mock_model.fine_acoustics = MagicMock()
413
+ self.mock_model.codec_model = MagicMock()
414
+
415
+ self.mock_tokenizer = MagicMock()
416
+
417
+ # Initialize pipeline
418
+ self.pipeline = Pipeline(
419
+ model=self.mock_model,
420
+ tokenizer=self.mock_tokenizer,
421
+ config=ModelConfig(
422
+ semantic_config=SemanticConfig(),
423
+ coarse_acoustics_config=CoarseAcousticsConfig(),
424
+ fine_acoustics_config=FineAcousticsConfig(),
425
+ codec_config=CodecConfig(),
426
+ ),
427
+ )
428
+
429
+ def test_generate_text_semantic(self):
430
+ """Test semantic token generation."""
431
+ # Mock tokenizer output
432
+ self.mock_tokenizer.encode.return_value = [1, 2, 3]
433
+
434
+ # Create logits with proper shape including SEMANTIC_PAD_TOKEN
435
+ logits = mx.zeros((1, 1, 129596)) # Large enough to include SEMANTIC_PAD_TOKEN
436
+ # Mock model output
437
+ self.mock_model.semantic.return_value = (
438
+ logits, # logits with correct shape
439
+ None, # kv_cache
440
+ )
441
+
442
+ # Test generation
443
+ semantic_tokens, text_tokens = self.pipeline.generate_text_semantic(
444
+ "test text",
445
+ temperature=0.7,
446
+ use_kv_caching=True,
447
+ voice=None,
448
+ )
449
+
450
+ # Verify tokenizer was called
451
+ self.mock_tokenizer.encode.assert_called_once_with(
452
+ "test text", add_special_tokens=False
453
+ )
454
+
455
+ # Verify model was called
456
+ self.mock_model.semantic.assert_called()
457
+
458
+ # Check output types
459
+ self.assertIsInstance(semantic_tokens, mx.array)
460
+ self.assertIsInstance(text_tokens, mx.array)
461
+
462
+ @patch("mlx.core.random.categorical") # Add this patch since we use mx alias
463
+ def test_generate_coarse(self, mock_mlx_categorical):
464
+ """Test coarse token generation."""
465
+ # Create mock semantic tokens
466
+ semantic_tokens = mx.array([1, 2, 3])
467
+
468
+ # Create logits with proper shape
469
+ logits = mx.zeros((1, 1, 12096))
470
+
471
+ # Mock both categorical functions to return predictable values
472
+ mock_mlx_categorical.return_value = mx.array([10000]) # Return token index
473
+
474
+ # Set up the mock to return proper values for each call
475
+ self.mock_model.coarse_acoustics.return_value = (logits, None)
476
+
477
+ # Test generation with minimal parameters to reduce test time
478
+ coarse_tokens = self.pipeline.generate_coarse(
479
+ semantic_tokens,
480
+ temperature=0.7,
481
+ use_kv_caching=True,
482
+ voice=None,
483
+ max_coarse_history=60,
484
+ sliding_window_len=2, # Reduce this to minimum
485
+ )
486
+
487
+ # Verify model was called at least once
488
+ self.mock_model.coarse_acoustics.assert_called()
489
+
490
+ # Check output type and shape
491
+ self.assertIsInstance(coarse_tokens, mx.array)
492
+ self.assertEqual(coarse_tokens.shape[0], 2) # N_COARSE_CODEBOOKS
493
+
494
+ def test_generate_fine(self):
495
+ """Test fine token generation."""
496
+ # Create mock coarse tokens
497
+ coarse_tokens = mx.zeros((2, 100)) # N_COARSE_CODEBOOKS x sequence_length
498
+
499
+ # Mock model output with proper shape
500
+ self.mock_model.fine_acoustics.return_value = mx.zeros((1, 1024, 1024))
501
+
502
+ # Test generation
503
+ fine_tokens = self.pipeline.generate_fine(coarse_tokens, temperature=0.7)
504
+
505
+ # Verify model was called
506
+ self.mock_model.fine_acoustics.assert_called()
507
+
508
+ # Check output type and shape
509
+ self.assertIsInstance(fine_tokens, mx.array)
510
+ self.assertEqual(
511
+ fine_tokens.shape[0], 8
512
+ ) # N_FINE_CODEBOOKS (corrected from 10 to 8)
513
+ self.assertEqual(fine_tokens.shape[1], 100) # sequence_length
514
+
515
+
516
+ class TestLlamaModel(unittest.TestCase):
517
+ @property
518
+ def _default_config(self):
519
+ return {
520
+ "attention_bias": False,
521
+ "head_dim": 128,
522
+ "hidden_size": 3072,
523
+ "intermediate_size": 8192,
524
+ "max_position_embeddings": 131072,
525
+ "mlp_bias": False,
526
+ "model_type": "llama",
527
+ "num_attention_heads": 24,
528
+ "num_hidden_layers": 28,
529
+ "num_key_value_heads": 8,
530
+ "rms_norm_eps": 1e-05,
531
+ "rope_scaling": {
532
+ "factor": 32.0,
533
+ "high_freq_factor": 4.0,
534
+ "low_freq_factor": 1.0,
535
+ "original_max_position_embeddings": 8192,
536
+ "rope_type": "llama3",
537
+ },
538
+ "rope_theta": 500000.0,
539
+ "tie_word_embeddings": True,
540
+ "vocab_size": 156940,
541
+ }
542
+
543
+ @patch("transformers.LlamaTokenizer")
544
+ def test_init(self, mock_tokenizer):
545
+ """Test LlamaModel initialization."""
546
+ from mlx_audio.tts.models.llama.llama import Model, ModelConfig
547
+
548
+ # Mock the tokenizer instance
549
+ mock_tokenizer_instance = MagicMock()
550
+ mock_tokenizer.return_value = mock_tokenizer_instance
551
+
552
+ # Create a minimal config
553
+ config = ModelConfig(**self._default_config)
554
+
555
+ # Initialize model
556
+ model = Model(config)
557
+
558
+ # Check that model was created
559
+ self.assertIsInstance(model, Model)
560
+
561
+ @patch("transformers.LlamaTokenizer")
562
+ def test_generate(self, mock_tokenizer):
563
+ """Test generate method."""
564
+ from mlx_audio.tts.models.llama.llama import Model, ModelConfig
565
+
566
+ # Mock tokenizer instance
567
+ mock_tokenizer_instance = MagicMock()
568
+ mock_tokenizer.return_value = mock_tokenizer_instance
569
+
570
+ config = ModelConfig(**self._default_config)
571
+ model = Model(config)
572
+
573
+ # Verify batched input creation with a voice
574
+ input_ids, input_mask = model.prepare_input_ids(["Foo", "Bar Baz"], voice="zoe")
575
+ self.assertEqual(input_ids.shape[0], 2)
576
+ self.assertEqual(input_mask.shape[0], 2)
577
+
578
+ logits = model(input_ids)
579
+ self.assertEqual(logits.shape, (2, 9, config.vocab_size))
580
+
581
+ # Verify batched input creation with reference audio
582
+ input_ids, input_mask = model.prepare_input_ids(
583
+ ["Foo", "Bar Baz"], ref_audio=mx.zeros((100,)), ref_text="Caption"
584
+ )
585
+ self.assertEqual(input_ids.shape[0], 2)
586
+ self.assertEqual(input_mask.shape[0], 2)
587
+
588
+ logits = model(input_ids)
589
+ self.assertEqual(logits.shape, (2, 22, config.vocab_size))
590
+
591
+ @patch("transformers.LlamaTokenizer")
592
+ def test_sanitize(self, mock_tokenizer):
593
+ """Test sanitize method."""
594
+ from mlx_audio.tts.models.llama.llama import Model, ModelConfig
595
+
596
+ # Mock tokenizer instance
597
+ mock_tokenizer_instance = MagicMock()
598
+ mock_tokenizer.return_value = mock_tokenizer_instance
599
+
600
+ # Create a config with tie_word_embeddings=True
601
+ config = ModelConfig(
602
+ model_type="llama",
603
+ hidden_size=4096,
604
+ num_hidden_layers=32,
605
+ intermediate_size=16384,
606
+ num_attention_heads=32,
607
+ rms_norm_eps=1e-5,
608
+ vocab_size=32000,
609
+ head_dim=128,
610
+ max_position_embeddings=1024,
611
+ num_key_value_heads=32,
612
+ attention_bias=True,
613
+ mlp_bias=True,
614
+ rope_theta=500000.0,
615
+ rope_traditional=False,
616
+ rope_scaling=None,
617
+ tie_word_embeddings=True,
618
+ )
619
+
620
+ # Initialize the model with a patched __init__
621
+ with patch.object(Model, "__init__", return_value=None):
622
+ model = Model.__new__(Model)
623
+ model.config = config
624
+
625
+ # Add the sanitize method from actual implementation
626
+ def mock_sanitize(weights):
627
+ result = {}
628
+ for k, v in weights.items():
629
+ if "rotary_emb" in k:
630
+ continue
631
+ if "lm_head.weight" in k and config.tie_word_embeddings:
632
+ continue
633
+ result[k] = v
634
+ return result
635
+
636
+ model.sanitize = mock_sanitize
637
+
638
+ # Create test weights with rotary embeddings and lm_head
639
+ weights = {
640
+ "self_attn.rotary_emb.inv_freq": mx.zeros(10),
641
+ "lm_head.weight": mx.zeros((32000, 4096)),
642
+ "model.layers.0.input_layernorm.weight": mx.zeros(4096),
643
+ }
644
+
645
+ # Test sanitize method
646
+ sanitized = model.sanitize(weights)
647
+
648
+ # Assert rotary embeddings are removed
649
+ self.assertNotIn("self_attn.rotary_emb.inv_freq", sanitized)
650
+
651
+ # Assert lm_head weights are removed with tie_word_embeddings=True
652
+ self.assertNotIn("lm_head.weight", sanitized)
653
+
654
+ # Assert other weights remain
655
+ self.assertIn("model.layers.0.input_layernorm.weight", sanitized)
656
+
657
+ # Now test with tie_word_embeddings=False
658
+ config.tie_word_embeddings = False
659
+
660
+ # Test sanitize again
661
+ sanitized2 = model.sanitize(weights)
662
+
663
+ # lm_head should be kept with tie_word_embeddings=False
664
+ self.assertIn("lm_head.weight", sanitized2)
665
+
666
+
667
+ class TestOuteTTSModel(unittest.TestCase):
668
+ @property
669
+ def _default_config(self):
670
+ return {
671
+ "attention_bias": False,
672
+ "head_dim": 64,
673
+ "hidden_size": 2048,
674
+ "intermediate_size": 8192,
675
+ "max_position_embeddings": 131072,
676
+ "mlp_bias": False,
677
+ "model_type": "llama",
678
+ "num_attention_heads": 32,
679
+ "num_hidden_layers": 16,
680
+ "num_key_value_heads": 8,
681
+ "rms_norm_eps": 1e-05,
682
+ "rope_scaling": {
683
+ "factor": 32.0,
684
+ "high_freq_factor": 4.0,
685
+ "low_freq_factor": 1.0,
686
+ "original_max_position_embeddings": 8192,
687
+ "rope_type": "llama3",
688
+ },
689
+ "rope_theta": 500000.0,
690
+ "tie_word_embeddings": True,
691
+ "vocab_size": 134400,
692
+ }
693
+
694
+ @patch("transformers.LlamaTokenizer")
695
+ def test_init(self, mock_tokenizer):
696
+ """Test initialization."""
697
+ from mlx_audio.tts.models.outetts.outetts import Model, ModelConfig
698
+
699
+ # Mock the tokenizer instance
700
+ mock_tokenizer_instance = MagicMock()
701
+ mock_tokenizer.return_value = mock_tokenizer_instance
702
+
703
+ # Create a minimal config
704
+ config = ModelConfig(**self._default_config)
705
+
706
+ # Initialize model
707
+ model = Model(config)
708
+
709
+ # Check that model was created
710
+ self.assertIsInstance(model, Model)
711
+
712
+ @patch("transformers.LlamaTokenizer")
713
+ def test_generate(self, mock_tokenizer):
714
+ """Test generate method."""
715
+ from mlx_audio.tts.models.outetts.outetts import Model, ModelConfig
716
+
717
+ # Mock tokenizer instance
718
+ mock_tokenizer_instance = MagicMock()
719
+ mock_tokenizer.return_value = mock_tokenizer_instance
720
+
721
+ config = ModelConfig(**self._default_config)
722
+ model = Model(config)
723
+
724
+ input_ids = mx.random.randint(0, config.vocab_size, (2, 9))
725
+ logits = model(input_ids)
726
+ self.assertEqual(logits.shape, (2, 9, config.vocab_size))
727
+
728
+
729
+ class TestDiaModel(unittest.TestCase):
730
+ @property
731
+ def _default_config(self):
732
+ return {
733
+ "version": "0.1",
734
+ "model": {
735
+ "encoder": {
736
+ "n_layer": 12,
737
+ "n_embd": 1024,
738
+ "n_hidden": 4096,
739
+ "n_head": 16,
740
+ "head_dim": 128,
741
+ },
742
+ "decoder": {
743
+ "n_layer": 18,
744
+ "n_embd": 2048,
745
+ "n_hidden": 8192,
746
+ "gqa_query_heads": 16,
747
+ "cross_query_heads": 16,
748
+ "kv_heads": 4,
749
+ "gqa_head_dim": 128,
750
+ "cross_head_dim": 128,
751
+ },
752
+ "src_vocab_size": 256,
753
+ "tgt_vocab_size": 1028,
754
+ "dropout": 0.0,
755
+ },
756
+ "training": {},
757
+ "data": {
758
+ "text_length": 1024,
759
+ "audio_length": 3072,
760
+ "channels": 9,
761
+ "text_pad_value": 0,
762
+ "audio_eos_value": 1024,
763
+ "audio_pad_value": 1025,
764
+ "audio_bos_value": 1026,
765
+ "delay_pattern": [0, 8, 9, 10, 11, 12, 13, 14, 15],
766
+ },
767
+ }
768
+
769
+ def test_init(self):
770
+ """Test DiaModel initialization."""
771
+ from mlx_audio.tts.models.dia.dia import Model
772
+
773
+ # Initialize model
774
+ config = self._default_config
775
+ model = Model(config)
776
+
777
+ # Check that model was created
778
+ self.assertIsInstance(model, Model)
779
+
780
+
781
+ class TestSparkTTSModel(unittest.TestCase):
782
+ @property
783
+ def _default_config(self):
784
+ return {
785
+ "model_path": "/fake/model/path",
786
+ "sample_rate": 16000,
787
+ "bos_token_id": 151643,
788
+ "eos_token_id": 151645,
789
+ "hidden_act": "silu",
790
+ "hidden_size": 896,
791
+ "initializer_range": 0.02,
792
+ "intermediate_size": 4864,
793
+ "max_position_embeddings": 32768,
794
+ "max_window_layers": 21,
795
+ "model_type": "qwen2",
796
+ "num_attention_heads": 14,
797
+ "num_hidden_layers": 24,
798
+ "num_key_value_heads": 2,
799
+ "rms_norm_eps": 1e-06,
800
+ "rope_theta": 1000000.0,
801
+ "sliding_window": 32768,
802
+ "tie_word_embeddings": True,
803
+ "torch_dtype": "bfloat16",
804
+ "transformers_version": "4.43.1",
805
+ "use_sliding_window": False,
806
+ "vocab_size": 166000,
807
+ "rope_traditional": False,
808
+ "rope_scaling": None,
809
+ }
810
+
811
+ @patch("mlx_audio.tts.models.spark.spark.load_tokenizer")
812
+ @patch("mlx_audio.tts.models.spark.spark.BiCodecTokenizer")
813
+ @patch("mlx_audio.tts.models.spark.spark.Qwen2Model")
814
+ def test_init(
815
+ self,
816
+ mock_qwen2_model,
817
+ mock_bicodec_tokenizer,
818
+ mock_load_tokenizer,
819
+ ):
820
+ """Test SparkTTSModel initialization."""
821
+ from pathlib import Path
822
+
823
+ from mlx_audio.tts.models.spark.spark import Model, ModelConfig
824
+
825
+ # Mock return values for patched functions
826
+ mock_load_tokenizer.return_value = MagicMock()
827
+ mock_bicodec_tokenizer.return_value = MagicMock()
828
+ mock_qwen2_model.return_value = MagicMock()
829
+
830
+ # Create a config instance
831
+ config = ModelConfig(**self._default_config)
832
+ config.model_path = Path("/fake/model/path")
833
+
834
+ # Initialize the model
835
+ model = Model(config)
836
+
837
+ # Check that the model was initialized correctly
838
+ self.assertIsInstance(model, Model)
839
+
840
+ # Verify the tokenizer was loaded correctly
841
+ mock_load_tokenizer.assert_called_once_with(
842
+ config.model_path, eos_token_ids=config.eos_token_id
843
+ )
844
+ mock_bicodec_tokenizer.assert_called_once_with(config.model_path)
845
+
846
+ # Verify the model was initialized correctly
847
+ mock_qwen2_model.assert_called_once_with(config)
848
+
849
+
850
+ class TestIndexTTS(unittest.TestCase):
851
+ @property
852
+ def _default_config(self):
853
+ return {
854
+ "tokenizer_name": "mlx-community/IndexTTS",
855
+ "bigvgan": {
856
+ "adam_b1": 0.8,
857
+ "adam_b2": 0.99,
858
+ "lr_decay": 0.999998,
859
+ "seed": 1234,
860
+ "resblock": "1",
861
+ "upsample_rates": [4, 4, 4, 4, 2, 2],
862
+ "upsample_kernel_sizes": [8, 8, 4, 4, 4, 4],
863
+ "upsample_initial_channel": 1536,
864
+ "resblock_kernel_sizes": [3, 7, 11],
865
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
866
+ "feat_upsample": False,
867
+ "speaker_embedding_dim": 512,
868
+ "cond_d_vector_in_each_upsampling_layer": True,
869
+ "gpt_dim": 1024,
870
+ "activation": "snakebeta",
871
+ "snake_logscale": True,
872
+ "use_cqtd_instead_of_mrd": True,
873
+ "cqtd_filters": 128,
874
+ "cqtd_max_filters": 1024,
875
+ "cqtd_filters_scale": 1,
876
+ "cqtd_dilations": [1, 2, 4],
877
+ "cqtd_hop_lengths": [512, 256, 256],
878
+ "cqtd_n_octaves": [9, 9, 9],
879
+ "cqtd_bins_per_octaves": [24, 36, 48],
880
+ "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
881
+ "mpd_reshapes": [2, 3, 5, 7, 11],
882
+ "use_spectral_norm": False,
883
+ "discriminator_channel_mult": 1,
884
+ "use_multiscale_melloss": True,
885
+ "lambda_melloss": 15,
886
+ "clip_grad_norm": 1000,
887
+ "segment_size": 16384,
888
+ "num_mels": 100,
889
+ "num_freq": 1025,
890
+ "n_fft": 1024,
891
+ "hop_size": 256,
892
+ "win_size": 1024,
893
+ "sampling_rate": 24000,
894
+ "fmin": 0,
895
+ "fmax": None,
896
+ "fmax_for_loss": None,
897
+ "mel_type": "pytorch",
898
+ "num_workers": 2,
899
+ "dist_config": {
900
+ "dist_backend": "nccl",
901
+ "dist_url": "tcp://localhost:54321",
902
+ "world_size": 1,
903
+ },
904
+ },
905
+ "bigvgan_checkpoint": "bigvgan_generator.pth",
906
+ "dataset": {
907
+ "bpe_model": "checkpoints/bpe.model",
908
+ "sample_rate": 24000,
909
+ "squeeze": False,
910
+ "mel": {
911
+ "sample_rate": 24000,
912
+ "n_fft": 1024,
913
+ "hop_length": 256,
914
+ "win_length": 1024,
915
+ "n_mels": 100,
916
+ "mel_fmin": 0,
917
+ "normalize": False,
918
+ },
919
+ },
920
+ "dvae_checkpoint": "dvae.pth",
921
+ "gpt": {
922
+ "model_dim": 1024,
923
+ "max_mel_tokens": 605,
924
+ "max_text_tokens": 402,
925
+ "heads": 16,
926
+ "use_mel_codes_as_input": True,
927
+ "mel_length_compression": 1024,
928
+ "layers": 20,
929
+ "number_text_tokens": 12000,
930
+ "number_mel_codes": 8194,
931
+ "start_mel_token": 8192,
932
+ "stop_mel_token": 8193,
933
+ "start_text_token": 0,
934
+ "stop_text_token": 1,
935
+ "train_solo_embeddings": False,
936
+ "condition_type": "conformer_perceiver",
937
+ "condition_module": {
938
+ "output_size": 512,
939
+ "linear_units": 2048,
940
+ "attention_heads": 8,
941
+ "num_blocks": 6,
942
+ "input_layer": "conv2d2",
943
+ "perceiver_mult": 2,
944
+ },
945
+ },
946
+ "gpt_checkpoint": "gpt.pth",
947
+ "vqvae": {
948
+ "channels": 100,
949
+ "num_tokens": 8192,
950
+ "hidden_dim": 512,
951
+ "num_resnet_blocks": 3,
952
+ "codebook_dim": 512,
953
+ "num_layers": 2,
954
+ "positional_dims": 1,
955
+ "kernel_size": 3,
956
+ "smooth_l1_loss": True,
957
+ "use_transposed_convs": False,
958
+ },
959
+ }
960
+
961
+ def test_init(self):
962
+ """Test IndexTTS initialization."""
963
+ from mlx_audio.tts.models.indextts.indextts import Model
964
+
965
+ # Initialize model
966
+ config = self._default_config
967
+ model = Model(config) # type: ignore
968
+
969
+ # Check that model was created
970
+ self.assertIsInstance(model, Model)
971
+
972
+
973
+ if __name__ == "__main__":
974
+ unittest.main()