nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,633 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ import time
5
+ from dataclasses import dataclass
6
+ from typing import Callable, Dict, List, Optional, Tuple
7
+
8
+ import mlx.core as mx
9
+ import mlx.nn as nn
10
+ import numpy as np
11
+ import soundfile as sf
12
+ from huggingface_hub import hf_hub_download
13
+ from mlx_lm.models.cache import make_prompt_cache
14
+ from mlx_lm.models.llama import LlamaModel
15
+ from mlx_lm.models.llama import ModelArgs as LlamaModelArgs
16
+ from mlx_lm.sample_utils import make_sampler
17
+ from scipy import signal
18
+ from tokenizers.processors import TemplateProcessing
19
+ from tqdm import tqdm
20
+ from transformers import AutoTokenizer
21
+
22
+ from mlx_audio.codec.models.mimi import Mimi, MimiStreamingDecoder
23
+
24
+ from ..base import GenerationResult
25
+ from .attention import Attention
26
+
27
+ try:
28
+ from .watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
29
+ except ImportError:
30
+ print(
31
+ "Watermarking module not found. Please install silentcipher to use watermarking."
32
+ )
33
+
34
+ MIMI_REPO = "kyutai/moshiko-pytorch-bf16"
35
+ TOKENIZER_REPO = "unsloth/Llama-3.2-1B"
36
+
37
+
38
+ def create_causal_mask(seq_len: int) -> mx.array:
39
+ return mx.tril(mx.ones((seq_len, seq_len), dtype=mx.bool_))
40
+
41
+
42
+ def index_causal_mask(mask: mx.array, input_pos: mx.array) -> mx.array:
43
+ mask_indexed = mx.take(mask, input_pos, axis=0)
44
+
45
+ seq_len = input_pos.shape[1]
46
+ mask_indexed = mask_indexed[:, :, :seq_len]
47
+
48
+ # reshape to (batch_size, 1, seq_len, seq_len) for broadcasting across heads
49
+ return mx.expand_dims(mask_indexed, axis=1)
50
+
51
+
52
+ def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
53
+ gcd = np.gcd(orig_sr, target_sr)
54
+ up = target_sr // gcd
55
+ down = orig_sr // gcd
56
+ resampled = signal.resample_poly(audio, up, down, padtype="edge")
57
+ return resampled
58
+
59
+
60
+ @dataclass
61
+ class SesameModelArgs:
62
+ model_type: str
63
+ backbone_flavor: str
64
+ decoder_flavor: str
65
+ text_vocab_size: int
66
+ audio_vocab_size: int
67
+ audio_num_codebooks: int
68
+
69
+ def __init__(
70
+ self,
71
+ model_type,
72
+ backbone_flavor,
73
+ decoder_flavor,
74
+ text_vocab_size,
75
+ audio_vocab_size,
76
+ audio_num_codebooks,
77
+ **kwargs,
78
+ ):
79
+ self.model_type = model_type
80
+ self.backbone_flavor = backbone_flavor
81
+ self.decoder_flavor = decoder_flavor
82
+ self.text_vocab_size = text_vocab_size
83
+ self.audio_vocab_size = audio_vocab_size
84
+ self.audio_num_codebooks = audio_num_codebooks
85
+
86
+
87
+ def create_llama_model_args(flavor: str) -> LlamaModelArgs:
88
+ if flavor == "llama-1B":
89
+ return LlamaModelArgs(
90
+ model_type="llama",
91
+ num_hidden_layers=16,
92
+ num_attention_heads=32,
93
+ num_key_value_heads=8,
94
+ head_dim=64,
95
+ hidden_size=2048,
96
+ intermediate_size=8192,
97
+ rms_norm_eps=1e-5,
98
+ vocab_size=128_256,
99
+ max_position_embeddings=2048,
100
+ attention_bias=False,
101
+ mlp_bias=False,
102
+ rope_theta=500_000,
103
+ rope_scaling={
104
+ "factor": 32.0,
105
+ "low_freq_factor": 1.0,
106
+ "high_freq_factor": 4.0,
107
+ "original_max_position_embeddings": 8192,
108
+ "rope_type": "llama3",
109
+ },
110
+ )
111
+ elif flavor == "llama-100M":
112
+ return LlamaModelArgs(
113
+ model_type="llama",
114
+ num_hidden_layers=4,
115
+ num_attention_heads=8,
116
+ num_key_value_heads=2,
117
+ head_dim=128,
118
+ hidden_size=1024,
119
+ intermediate_size=8192,
120
+ rms_norm_eps=1e-5,
121
+ vocab_size=128_256,
122
+ max_position_embeddings=2048,
123
+ attention_bias=False,
124
+ mlp_bias=False,
125
+ rope_theta=500_000,
126
+ rope_scaling={
127
+ "factor": 32.0,
128
+ "low_freq_factor": 1.0,
129
+ "high_freq_factor": 4.0,
130
+ "original_max_position_embeddings": 8192,
131
+ "rope_type": "llama3",
132
+ },
133
+ )
134
+ else:
135
+ raise ValueError(f"Unknown flavor: {flavor}")
136
+
137
+
138
+ class SesameModel(nn.Module):
139
+ def __init__(self, config):
140
+ super().__init__()
141
+ args = SesameModelArgs(**config)
142
+ self.args = args
143
+
144
+ backbone_args = create_llama_model_args(args.backbone_flavor)
145
+ decoder_args = create_llama_model_args(args.decoder_flavor)
146
+
147
+ self.backbone = LlamaModel(backbone_args)
148
+ self.decoder = LlamaModel(decoder_args)
149
+
150
+ backbone_dim = backbone_args.hidden_size
151
+ decoder_dim = decoder_args.hidden_size
152
+
153
+ self.backbone.embed_tokens = nn.Identity()
154
+ self.decoder.embed_tokens = nn.Identity()
155
+
156
+ for layer in self.backbone.layers:
157
+ layer.self_attn = Attention(backbone_args)
158
+ for layer in self.decoder.layers:
159
+ layer.self_attn = Attention(decoder_args)
160
+
161
+ self.text_embeddings = nn.Embedding(args.text_vocab_size, backbone_dim)
162
+ self.audio_embeddings = nn.Embedding(
163
+ args.audio_vocab_size * args.audio_num_codebooks, backbone_dim
164
+ )
165
+
166
+ self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
167
+ self.codebook0_head = nn.Linear(backbone_dim, args.audio_vocab_size, bias=False)
168
+ self.audio_head = mx.zeros(
169
+ (args.audio_num_codebooks - 1, decoder_dim, args.audio_vocab_size)
170
+ )
171
+
172
+ self._backbone_causal_mask = None
173
+ self._decoder_causal_mask = None
174
+
175
+ self.backbone_cache = None
176
+ self.decoder_cache = None
177
+ self.caches_enabled = False
178
+
179
+ def setup_caches(self, max_batch_size: int):
180
+ backbone_args = create_llama_model_args(self.args.backbone_flavor)
181
+
182
+ self._backbone_causal_mask = create_causal_mask(
183
+ backbone_args.max_position_embeddings
184
+ )
185
+ self._decoder_causal_mask = create_causal_mask(self.args.audio_num_codebooks)
186
+
187
+ self.backbone_cache = make_prompt_cache(self.backbone)
188
+ self.decoder_cache = make_prompt_cache(self.decoder)
189
+ self.caches_enabled = True
190
+
191
+ def caches_are_enabled(self):
192
+ return self.caches_enabled
193
+
194
+ def reset_caches(self):
195
+ if self.backbone_cache is not None:
196
+ self.backbone_cache = make_prompt_cache(self.backbone)
197
+
198
+ if self.decoder_cache is not None:
199
+ self.decoder_cache = make_prompt_cache(self.decoder)
200
+
201
+ def generate_frame(
202
+ self,
203
+ tokens: mx.array,
204
+ tokens_mask: mx.array,
205
+ input_pos: mx.array,
206
+ sampler: Callable[..., mx.array],
207
+ ) -> mx.array:
208
+ assert self.caches_are_enabled(), "backbone caches are not enabled"
209
+
210
+ curr_backbone_mask = index_causal_mask(self._backbone_causal_mask, input_pos)
211
+ embeds = self._embed_tokens(tokens)
212
+ masked_embeds = embeds * mx.expand_dims(tokens_mask, -1)
213
+ h = mx.sum(masked_embeds, axis=2)
214
+ h = self.backbone(h, mask=curr_backbone_mask, cache=self.backbone_cache)
215
+
216
+ last_h = h[:, -1, :]
217
+ c0_logits = self.codebook0_head(last_h)
218
+ c0_sample = mx.expand_dims(sampler(c0_logits), axis=-1)
219
+ c0_embed = self._embed_audio(0, c0_sample)
220
+
221
+ curr_h = mx.concat([mx.expand_dims(last_h, 1), c0_embed], axis=1)
222
+ curr_sample = c0_sample
223
+ curr_pos = mx.arange(curr_h.shape[1], dtype=mx.int32)
224
+ curr_pos = mx.expand_dims(curr_pos, 0)
225
+ curr_pos = mx.broadcast_to(curr_pos, (curr_h.shape[0], curr_h.shape[1]))
226
+
227
+ # reset decoder cache for new frame
228
+
229
+ self.decoder_cache = make_prompt_cache(self.decoder)
230
+
231
+ for i in range(1, self.args.audio_num_codebooks):
232
+ curr_decoder_mask = index_causal_mask(self._decoder_causal_mask, curr_pos)
233
+ decoder_h = self.decoder(
234
+ self.projection(curr_h),
235
+ mask=curr_decoder_mask,
236
+ cache=self.decoder_cache,
237
+ )
238
+
239
+ ci_logits = mx.matmul(decoder_h[:, -1, :], self.audio_head[i - 1])
240
+ ci_sample = mx.expand_dims(sampler(ci_logits), axis=-1)
241
+ ci_embed = self._embed_audio(i, ci_sample)
242
+
243
+ curr_h = ci_embed
244
+ curr_sample = mx.concat([curr_sample, ci_sample], axis=1)
245
+ curr_pos = curr_pos[:, -1:] + 1
246
+
247
+ return curr_sample
248
+
249
+ def _embed_audio(self, codebook: int, tokens: mx.array) -> mx.array:
250
+ return self.audio_embeddings(tokens + codebook * self.args.audio_vocab_size)
251
+
252
+ def _embed_tokens(self, tokens: mx.array) -> mx.array:
253
+ text_embeds = self.text_embeddings(tokens[:, :, -1])
254
+ text_embeds = mx.expand_dims(text_embeds, axis=-2)
255
+
256
+ codebook_indices = mx.arange(self.args.audio_num_codebooks, dtype=mx.int32)
257
+ codebook_offsets = codebook_indices * self.args.audio_vocab_size
258
+
259
+ audio_tokens = tokens[:, :, :-1] + mx.reshape(codebook_offsets, (1, 1, -1))
260
+ audio_embeds_flat = self.audio_embeddings(audio_tokens.flatten())
261
+
262
+ audio_embeds = mx.reshape(
263
+ audio_embeds_flat,
264
+ (tokens.shape[0], tokens.shape[1], self.args.audio_num_codebooks, -1),
265
+ )
266
+
267
+ return mx.concat([audio_embeds, text_embeds], axis=-2)
268
+
269
+
270
+ @dataclass
271
+ class Segment:
272
+ speaker: int
273
+ text: str
274
+ # (num_samples,), sample_rate = 24_000
275
+ audio: mx.array
276
+
277
+
278
+ def load_llama3_tokenizer(path_or_hf_repo: str):
279
+ tokenizer = AutoTokenizer.from_pretrained(path_or_hf_repo)
280
+ bos = tokenizer.bos_token
281
+ eos = tokenizer.eos_token
282
+ tokenizer._tokenizer.post_processor = TemplateProcessing(
283
+ single=f"{bos}:0 $A:0 {eos}:0",
284
+ pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
285
+ special_tokens=[
286
+ (f"{bos}", tokenizer.bos_token_id),
287
+ (f"{eos}", tokenizer.eos_token_id),
288
+ ],
289
+ )
290
+ return tokenizer
291
+
292
+
293
+ class Model(nn.Module):
294
+ def __init__(
295
+ self,
296
+ config: Dict,
297
+ ):
298
+ super().__init__()
299
+ self.model = SesameModel(config)
300
+ self.model.setup_caches(1)
301
+
302
+ self._text_tokenizer = load_llama3_tokenizer(TOKENIZER_REPO)
303
+ mimi = Mimi.from_pretrained(MIMI_REPO)
304
+ self._audio_tokenizer = mimi
305
+ self._streaming_decoder = MimiStreamingDecoder(mimi)
306
+
307
+ try:
308
+ self._watermarker = load_watermarker()
309
+ except Exception:
310
+ self._watermarker = None
311
+
312
+ self._sample_rate = mimi.cfg.sample_rate
313
+
314
+ def model_quant_predicate(self, p, m, config):
315
+ """
316
+ Model modules to skip during quantization
317
+ """
318
+ return not p.startswith("_audio_tokenizer")
319
+
320
+ @property
321
+ def layers(self):
322
+ """Return the backbone layers of the model."""
323
+ return self.model.backbone.layers
324
+
325
+ @property
326
+ def sample_rate(self):
327
+ return self._sample_rate
328
+
329
+ def _tokenize_text_segment(
330
+ self, text: str, speaker: int
331
+ ) -> Tuple[mx.array, mx.array]:
332
+ frame_tokens = []
333
+ frame_masks = []
334
+
335
+ text_tokens = self._text_tokenizer.encode(
336
+ f"[{speaker}]{text}", return_tensors="mlx"
337
+ ).squeeze(0)
338
+ text_frame = mx.zeros((len(text_tokens), 33)).astype(mx.int32)
339
+ text_frame_mask = mx.zeros((len(text_tokens), 33)).astype(mx.bool_)
340
+ text_frame[:, -1] = text_tokens
341
+ text_frame_mask[:, -1] = True
342
+
343
+ frame_tokens.append(text_frame)
344
+ frame_masks.append(text_frame_mask)
345
+
346
+ return mx.concat(frame_tokens, axis=0), mx.concat(frame_masks, axis=0)
347
+
348
+ def _tokenize_audio(self, audio: mx.array) -> Tuple[mx.array, mx.array]:
349
+ frame_tokens = []
350
+ frame_masks = []
351
+
352
+ # (K, T)
353
+ audio_tokens = self._audio_tokenizer.encode(
354
+ mx.expand_dims(mx.expand_dims(audio, 0), 0)
355
+ )[0]
356
+
357
+ # add EOS frame
358
+ eos_frame = mx.zeros((audio_tokens.shape[0], 1))
359
+ audio_tokens = mx.concat([audio_tokens, eos_frame], axis=1)
360
+
361
+ audio_frame = mx.zeros((audio_tokens.shape[1], 33)).astype(mx.int32)
362
+ audio_frame_mask = mx.zeros((audio_tokens.shape[1], 33)).astype(mx.bool_)
363
+ audio_frame[:, :-1] = audio_tokens.swapaxes(0, 1)
364
+ audio_frame_mask[:, :-1] = True
365
+
366
+ frame_tokens.append(audio_frame)
367
+ frame_masks.append(audio_frame_mask)
368
+
369
+ return mx.concat(frame_tokens, axis=0), mx.concat(frame_masks, axis=0)
370
+
371
+ def _tokenize_segment(self, segment: Segment) -> Tuple[mx.array, mx.array]:
372
+ """
373
+ Returns:
374
+ (seq_len, 33), (seq_len, 33)
375
+ """
376
+ text_tokens, text_masks = self._tokenize_text_segment(
377
+ segment.text, segment.speaker
378
+ )
379
+ audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
380
+
381
+ return mx.concat([text_tokens, audio_tokens], axis=0), mx.concat(
382
+ [text_masks, audio_masks], axis=0
383
+ )
384
+
385
+ def sanitize(self, weights):
386
+ sanitized_weights = {}
387
+
388
+ for k, v in weights.items():
389
+ if not k.startswith("model."):
390
+ k = "model." + k
391
+
392
+ if "attn" in k and not "self_attn" in k:
393
+ k = k.replace("attn", "self_attn")
394
+ k = k.replace("output_proj", "o_proj")
395
+
396
+ if "mlp" in k:
397
+ k = k.replace("w1", "gate_proj")
398
+ k = k.replace("w2", "down_proj")
399
+ k = k.replace("w3", "up_proj")
400
+
401
+ if "sa_norm" in k or "mlp_norm" in k:
402
+ k = k.replace("sa_norm", "input_layernorm").replace("scale", "weight")
403
+ k = k.replace("mlp_norm", "post_attention_layernorm").replace(
404
+ "scale", "weight"
405
+ )
406
+
407
+ if "decoder.norm" in k or "backbone.norm" in k:
408
+ k = k.replace("scale", "weight")
409
+
410
+ sanitized_weights[k] = v
411
+
412
+ return sanitized_weights
413
+
414
+ def prepare_prompt(
415
+ self, text: str, speaker: int, audio_path: str, sample_rate: int
416
+ ) -> Segment:
417
+ audio, sr = sf.read(audio_path)
418
+ if sr != sample_rate:
419
+ audio = resample_audio(audio, sr, sample_rate)
420
+ return Segment(text=text, speaker=speaker, audio=mx.array(audio))
421
+
422
+ def default_speaker_prompt(self, voice: str) -> List[Segment]:
423
+ SPEAKER_PROMPTS = {
424
+ "conversational_a": {
425
+ "text": (
426
+ "like revising for an exam I'd have to try and like keep up the momentum because I'd "
427
+ "start really early I'd be like okay I'm gonna start revising now and then like "
428
+ "you're revising for ages and then I just like start losing steam I didn't do that "
429
+ "for the exam we had recently to be fair that was a more of a last minute scenario "
430
+ "but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
431
+ "sort of start the day with this not like a panic but like a"
432
+ ),
433
+ },
434
+ "conversational_b": {
435
+ "text": (
436
+ "like a super Mario level. Like it's very like high detail. And like, once you get "
437
+ "into the park, it just like, everything looks like a computer game and they have all "
438
+ "these, like, you know, if, if there's like a, you know, like in a Mario game, they "
439
+ "will have like a question block. And if you like, you know, punch it, a coin will "
440
+ "come out. So like everyone, when they come into the park, they get like this little "
441
+ "bracelet and then you can go punching question blocks around."
442
+ ),
443
+ },
444
+ }
445
+
446
+ prompt_path = hf_hub_download(
447
+ repo_id="sesame/csm-1b", filename=f"prompts/{voice}.wav"
448
+ )
449
+ prompt = self.prepare_prompt(
450
+ SPEAKER_PROMPTS[voice]["text"], 0, prompt_path, 24_000
451
+ )
452
+ return [prompt]
453
+
454
+ def generate_result(
455
+ self, samples, start_time: float, stream: bool = False
456
+ ) -> GenerationResult:
457
+ token_count = len(samples)
458
+ transposed = mx.transpose(mx.stack(samples), axes=[1, 2, 0])
459
+ if stream:
460
+ audio = (
461
+ self._streaming_decoder.decode_frames(transposed).squeeze(0).squeeze(0)
462
+ )
463
+ else:
464
+ audio = self._audio_tokenizer.decode(transposed).squeeze(0).squeeze(0)
465
+
466
+ # This applies an imperceptible watermark to identify audio as AI-generated.
467
+ # Watermarking ensures transparency, dissuades misuse, and enables traceability.
468
+ # Please be a responsible AI citizen and keep the watermarking in place.
469
+ # If using CSM 1B in another application, use your own private key and keep it secret.
470
+ if self._watermarker is not None:
471
+ audio = watermark(
472
+ self._watermarker,
473
+ audio,
474
+ self._sample_rate,
475
+ CSM_1B_GH_WATERMARK,
476
+ )
477
+ audio = mx.array(audio, dtype=mx.float32)
478
+
479
+ mx.eval(audio)
480
+
481
+ segment_time = time.perf_counter() - start_time
482
+
483
+ samples = audio.shape[0] if audio is not None else 0
484
+ assert samples > 0, "No audio generated"
485
+
486
+ # Calculate audio duration in seconds
487
+ sample_rate = 24000
488
+ audio_duration_seconds = samples / sample_rate
489
+
490
+ # Calculate real-time factor (RTF)
491
+ rtf = segment_time / audio_duration_seconds if audio_duration_seconds > 0 else 0
492
+
493
+ # Format duration as HH:MM:SS.mmm
494
+ duration_mins = int(audio_duration_seconds // 60)
495
+ duration_secs = int(audio_duration_seconds % 60)
496
+ duration_ms = int((audio_duration_seconds % 1) * 1000)
497
+ duration_hours = int(audio_duration_seconds // 3600)
498
+ duration_str = f"{duration_hours:02d}:{duration_mins:02d}:{duration_secs:02d}.{duration_ms:03d}"
499
+
500
+ return GenerationResult(
501
+ audio=audio,
502
+ samples=samples,
503
+ sample_rate=sample_rate,
504
+ segment_idx=0,
505
+ token_count=token_count,
506
+ audio_duration=duration_str,
507
+ real_time_factor=round(rtf, 2),
508
+ prompt={
509
+ "tokens": token_count,
510
+ "tokens-per-sec": (
511
+ round(token_count / segment_time, 2) if segment_time > 0 else 0
512
+ ),
513
+ },
514
+ audio_samples={
515
+ "samples": samples,
516
+ "samples-per-sec": (
517
+ round(samples / segment_time, 2) if segment_time > 0 else 0
518
+ ),
519
+ },
520
+ processing_time_seconds=segment_time,
521
+ peak_memory_usage=mx.get_peak_memory() / 1e9,
522
+ )
523
+
524
+ def generate(
525
+ self,
526
+ text: List[str] | str,
527
+ voice: Optional[str] = None,
528
+ speaker: int = 0,
529
+ context: List[Segment] = [],
530
+ split_pattern: Optional[str] = r"\n+",
531
+ sampler: Callable[..., mx.array] = None,
532
+ max_audio_length_ms: float = 90_000,
533
+ ref_audio: mx.array = None,
534
+ ref_text: str = None,
535
+ stream: bool = False,
536
+ streaming_interval: float = 2.0,
537
+ **kwargs,
538
+ ):
539
+ # if reference audio is provided, use it as the first segment
540
+ if len(context) == 0 and ref_audio is not None and ref_text is not None:
541
+ context = [Segment(speaker=speaker, text=ref_text, audio=ref_audio)]
542
+ elif ref_audio is None:
543
+ # otherwise, use the provided or default voice
544
+ if voice is None:
545
+ voice = "conversational_a"
546
+ context = self.default_speaker_prompt(voice)
547
+
548
+ sampler = sampler or make_sampler(temp=0.9, top_k=50)
549
+ max_audio_frames = int(max_audio_length_ms / 80)
550
+ streaming_interval_tokens = int(streaming_interval * 12.5)
551
+
552
+ if isinstance(text, str):
553
+ text = re.split(split_pattern, text.strip()) if split_pattern else [text]
554
+
555
+ for prompt in text:
556
+ start_time = time.perf_counter()
557
+
558
+ self.model.reset_caches()
559
+ if stream:
560
+ self._streaming_decoder.reset()
561
+
562
+ tokens, tokens_mask = [], []
563
+ for segment in context:
564
+ segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
565
+ tokens.append(segment_tokens)
566
+ tokens_mask.append(segment_tokens_mask)
567
+
568
+ gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(
569
+ prompt, speaker
570
+ )
571
+ tokens.append(gen_segment_tokens)
572
+ tokens_mask.append(gen_segment_tokens_mask)
573
+
574
+ prompt_tokens = mx.concat(tokens, axis=0).astype(mx.int32)
575
+ prompt_tokens_mask = mx.concat(tokens_mask, axis=0).astype(mx.bool_)
576
+
577
+ samples = []
578
+ curr_tokens = mx.expand_dims(prompt_tokens, axis=0)
579
+ curr_tokens_mask = mx.expand_dims(prompt_tokens_mask, axis=0)
580
+ curr_pos = mx.expand_dims(
581
+ mx.arange(0, prompt_tokens.shape[0]), axis=0
582
+ ).astype(mx.int32)
583
+ generated_frame_count = 0
584
+ yielded_frame_count = 0
585
+
586
+ max_seq_len = 2048 - max_audio_frames
587
+ if curr_tokens.shape[1] >= max_seq_len:
588
+ raise ValueError(
589
+ f"Inputs too long, must be below max_seq_len - max_audio_frames: {max_seq_len}"
590
+ )
591
+
592
+ for _ in tqdm(range(max_audio_frames)):
593
+ sample = self.model.generate_frame(
594
+ curr_tokens, curr_tokens_mask, curr_pos, sampler
595
+ )
596
+ if mx.all(sample == 0):
597
+ break # eos
598
+
599
+ samples.append(sample)
600
+
601
+ curr_tokens = mx.expand_dims(
602
+ mx.concat([sample, mx.zeros((1, 1)).astype(mx.int32)], axis=1),
603
+ axis=1,
604
+ )
605
+ curr_tokens_mask = mx.expand_dims(
606
+ mx.concat(
607
+ [
608
+ mx.ones_like(sample).astype(mx.bool_),
609
+ mx.zeros((1, 1)).astype(mx.bool_),
610
+ ],
611
+ axis=1,
612
+ ),
613
+ axis=1,
614
+ )
615
+ curr_pos = curr_pos[:, -1:] + 1
616
+ generated_frame_count += 1
617
+
618
+ # send a partial result in streaming mode
619
+ if (
620
+ stream
621
+ and (generated_frame_count - yielded_frame_count)
622
+ >= streaming_interval_tokens
623
+ ):
624
+ yielded_frame_count = generated_frame_count
625
+ yield self.generate_result(samples, start_time, stream=True)
626
+ samples = []
627
+ start_time = time.perf_counter()
628
+
629
+ if len(samples) > 0:
630
+ yield self.generate_result(samples, start_time, stream=stream)
631
+
632
+ # Clear cache after each segment to avoid memory leaks
633
+ mx.clear_cache()