nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,870 @@
1
+ from typing import Any, List, Optional, Tuple
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ from einops.array_api import repeat
6
+
7
+ from .config import DiaConfig
8
+
9
+
10
+ def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
11
+ return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
12
+
13
+
14
+ def _str_to_dtype(dtype_str: str):
15
+ # Allow None for default behavior
16
+ if dtype_str is None or dtype_str.lower() == "none":
17
+ return None
18
+ if dtype_str == "float32":
19
+ return mx.float32
20
+ elif dtype_str == "float16":
21
+ return mx.float16
22
+ elif dtype_str == "bfloat16":
23
+ return mx.bfloat16
24
+ else:
25
+ raise ValueError(f"Unsupported dtype string: {dtype_str}")
26
+
27
+
28
+ class DenseGeneral(nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_shapes: Tuple[int, ...],
32
+ out_features: Tuple[int, ...],
33
+ axis: Tuple[int, ...] = (-1,),
34
+ dtype: Optional[mx.Dtype] = None,
35
+ weight_dtype: Optional[mx.Dtype] = None,
36
+ ):
37
+ super().__init__()
38
+ self.in_shapes = in_shapes
39
+ self.out_features = out_features
40
+ self.axis = axis
41
+ self.dtype = dtype
42
+ self.kernel_shape = self.in_shapes + self.out_features
43
+
44
+ weight_type = weight_dtype if weight_dtype is not None else dtype
45
+ self.weight = mx.zeros(self.kernel_shape, dtype=weight_type)
46
+
47
+ def __call__(self, inputs: mx.array) -> mx.array:
48
+ norm_axis = _normalize_axes(self.axis, inputs.ndim)
49
+ kernel_contract_axes = tuple(range(len(norm_axis)))
50
+
51
+ output = mx.tensordot(
52
+ inputs,
53
+ self.weight,
54
+ axes=(norm_axis, kernel_contract_axes),
55
+ )
56
+
57
+ if self.dtype is not None and output.dtype != self.dtype:
58
+ output = output.astype(self.dtype)
59
+
60
+ return output
61
+
62
+
63
+ def get_activation_fn(activation_string: str) -> nn.Module:
64
+ if activation_string == "gelu":
65
+ return nn.GELU()
66
+ elif activation_string == "relu":
67
+ return nn.ReLU()
68
+ elif activation_string == "silu" or activation_string == "swish":
69
+ return nn.SiLU()
70
+ elif activation_string == "linear":
71
+ return nn.Identity()
72
+ else:
73
+ raise ValueError(f"Unsupported activation function: {activation_string}")
74
+
75
+
76
+ class MlpBlock(nn.Module):
77
+ def __init__(
78
+ self,
79
+ config: DiaConfig,
80
+ embed_dim: int,
81
+ intermediate_dim: int,
82
+ dropout_rate: float,
83
+ activations: List[str] = ["silu", "linear"],
84
+ use_pre_norm: bool = False,
85
+ ):
86
+ super().__init__()
87
+ self.use_pre_norm = use_pre_norm
88
+ num_activations = len(activations)
89
+
90
+ compute_dtype = _str_to_dtype(config.training.dtype)
91
+ weight_dtype = _str_to_dtype(config.model.weight_dtype)
92
+ self.dtype = compute_dtype
93
+
94
+ if use_pre_norm:
95
+ self.pre_norm = nn.RMSNorm(
96
+ embed_dim,
97
+ eps=config.model.normalization_layer_epsilon,
98
+ )
99
+
100
+ self.wi_fused = DenseGeneral(
101
+ in_shapes=(embed_dim,),
102
+ out_features=(
103
+ num_activations,
104
+ intermediate_dim,
105
+ ),
106
+ axis=(-1,),
107
+ dtype=compute_dtype,
108
+ weight_dtype=weight_dtype,
109
+ )
110
+
111
+ self.activation_fn_0 = get_activation_fn(activations[0]) # silu
112
+ self.activation_fn_1 = get_activation_fn(activations[1]) # linear
113
+
114
+ self.dropout = nn.Dropout(dropout_rate)
115
+
116
+ self.wo = DenseGeneral(
117
+ in_shapes=(intermediate_dim,),
118
+ out_features=(embed_dim,),
119
+ axis=(-1,),
120
+ dtype=compute_dtype,
121
+ weight_dtype=weight_dtype,
122
+ )
123
+
124
+ def __call__(self, x: mx.array, deterministic: bool = False) -> mx.array:
125
+ if self.use_pre_norm and hasattr(self, "pre_norm"):
126
+ x = self.pre_norm(x)
127
+
128
+ fused_x = self.wi_fused(x)
129
+
130
+ gate_input = fused_x[..., 0, :]
131
+ up_input = fused_x[..., 1, :]
132
+
133
+ gate = self.activation_fn_0(gate_input)
134
+ up = self.activation_fn_1(up_input)
135
+ hidden = mx.multiply(gate, up)
136
+
137
+ if self.dtype is not None and self.dtype != hidden.dtype:
138
+ hidden = hidden.astype(self.dtype)
139
+
140
+ if not deterministic:
141
+ hidden = self.dropout(hidden)
142
+
143
+ output = self.wo(hidden)
144
+ return output
145
+
146
+
147
+ class RotaryEmbedding(nn.Module):
148
+ def __init__(
149
+ self,
150
+ embedding_dims: int,
151
+ min_timescale: int = 1,
152
+ max_timescale: int = 10000,
153
+ dtype: mx.Dtype = mx.float32,
154
+ ):
155
+ super().__init__()
156
+ if embedding_dims % 2 != 0:
157
+ raise ValueError("Embedding dim must be even for RoPE.")
158
+ self.embedding_dims = embedding_dims
159
+ self.min_timescale = min_timescale
160
+ self.max_timescale = max_timescale
161
+ self.dtype = dtype
162
+ half_embedding_dim = embedding_dims // 2
163
+ fraction = (2.0 * mx.arange(half_embedding_dim)) / embedding_dims
164
+
165
+ self._timescale = (
166
+ self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
167
+ )
168
+
169
+ def __call__(self, inputs: mx.array, position: mx.array):
170
+ """Applies RoPE."""
171
+ position = mx.expand_dims(mx.expand_dims(position, -1), -1)
172
+
173
+ sinusoid_inp = position / self._timescale
174
+
175
+ sin = mx.sin(sinusoid_inp).astype(inputs.dtype)
176
+ cos = mx.cos(sinusoid_inp).astype(inputs.dtype)
177
+
178
+ first_half = inputs[..., : self.embedding_dims // 2]
179
+ second_half = inputs[..., self.embedding_dims // 2 :]
180
+
181
+ first_part = first_half * cos - second_half * sin
182
+ second_part = second_half * cos + first_half * sin
183
+
184
+ return mx.concatenate([first_part, second_part], axis=-1)
185
+
186
+
187
+ class KVCache:
188
+ def __init__(self, num_heads, max_len, head_dim, k=None, v=None):
189
+ self.k = mx.zeros((2, num_heads, max_len, head_dim)) if k is None else k
190
+ self.v = mx.zeros((2, num_heads, max_len, head_dim)) if v is None else v
191
+ self.current_idx = 0
192
+ self.max_len = max_len
193
+
194
+ def update_and_fetch(self, k, v):
195
+ assert self.current_idx < self.max_len
196
+ self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
197
+ self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
198
+ self.current_idx += 1
199
+ return self.k[:, :, : self.current_idx, :], self.v[:, :, : self.current_idx, :]
200
+
201
+ def prefill_kv(self, k, v):
202
+ prefill_len = k.shape[2]
203
+ assert prefill_len <= self.max_len
204
+ self.k[:, :, :prefill_len, :] = k
205
+ self.v[:, :, :prefill_len, :] = v
206
+ self.current_idx = prefill_len
207
+
208
+
209
+ class Attention(nn.Module):
210
+ def __init__(
211
+ self,
212
+ config: DiaConfig,
213
+ q_embed_dim: int,
214
+ kv_embed_dim: int,
215
+ num_query_heads: int,
216
+ num_kv_heads: int,
217
+ head_dim: int,
218
+ dropout_rate: float,
219
+ is_cross_attn: bool = False,
220
+ out_embed_dim: Optional[int] = None,
221
+ ):
222
+ super().__init__()
223
+ self.num_query_heads = num_query_heads
224
+ self.num_kv_heads = num_kv_heads
225
+ self.head_dim = head_dim
226
+ self.is_cross_attn = is_cross_attn
227
+ self.dropout_rate = dropout_rate
228
+
229
+ compute_dtype = _str_to_dtype(config.training.dtype)
230
+ weight_dtype = _str_to_dtype(config.model.weight_dtype)
231
+
232
+ self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
233
+ self.projected_query_dim = num_query_heads * head_dim
234
+
235
+ if num_query_heads % num_kv_heads != 0:
236
+ raise ValueError(
237
+ f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
238
+ )
239
+
240
+ self.num_gqa_groups = num_query_heads // num_kv_heads
241
+
242
+ # --- Projection Layers using DenseGeneral ---
243
+ self.q_proj = DenseGeneral(
244
+ in_shapes=(q_embed_dim,),
245
+ out_features=(num_query_heads, head_dim),
246
+ axis=(-1,),
247
+ dtype=compute_dtype,
248
+ weight_dtype=weight_dtype,
249
+ )
250
+ self.k_proj = DenseGeneral(
251
+ in_shapes=(kv_embed_dim,),
252
+ out_features=(num_kv_heads, head_dim),
253
+ axis=(-1,),
254
+ dtype=compute_dtype,
255
+ weight_dtype=weight_dtype,
256
+ )
257
+ self.v_proj = DenseGeneral(
258
+ in_shapes=(kv_embed_dim,),
259
+ out_features=(num_kv_heads, head_dim),
260
+ axis=(-1,),
261
+ dtype=compute_dtype,
262
+ weight_dtype=weight_dtype,
263
+ )
264
+ self.o_proj = DenseGeneral(
265
+ in_shapes=(num_query_heads, head_dim),
266
+ out_features=(self.output_dim,),
267
+ axis=(-2, -1),
268
+ dtype=compute_dtype,
269
+ weight_dtype=weight_dtype,
270
+ )
271
+
272
+ # --- Rotary Embedding ---
273
+ self.rotary_emb = RotaryEmbedding(
274
+ embedding_dims=self.head_dim,
275
+ min_timescale=config.model.rope_min_timescale,
276
+ max_timescale=config.model.rope_max_timescale,
277
+ dtype=compute_dtype,
278
+ )
279
+
280
+ def __call__(
281
+ self,
282
+ Xq: mx.array, # (B, T, D) T = 1 in AR generation
283
+ Xkv: mx.array, # (B, S, E) S = 1 in AR generation
284
+ q_positions: mx.array, # (B, T)
285
+ kv_positions: Optional[mx.array] = None, # (B, S)
286
+ deterministic: bool = True,
287
+ attn_mask: Optional[
288
+ mx.array
289
+ ] = None, # None in Decoder Self Attention, Valid mask in Others
290
+ cache: Optional[KVCache] = None, # None in Encoder, KVCache in Decoder
291
+ prefill: bool = False, # True only when prefilling KV Cache
292
+ ) -> Tuple[mx.array, Optional[Tuple[mx.array, mx.array]]]:
293
+ """
294
+ Performs attention calculation with optional KV caching.
295
+
296
+ Args:
297
+ Xq: Query tensor (B, T, D). T=1 during single-step decoding.
298
+ Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
299
+ q_positions: Positions for queries (B, T).
300
+ kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
301
+ deterministic: If True, disable dropout.
302
+ attn_mask: Attention mask.
303
+ cache: KVCache.
304
+ prefill: If True, use prefill mode.
305
+
306
+ Returns:
307
+ A tuple containing:
308
+ - output: The attention output tensor (B, T, output_dim).
309
+ - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)).
310
+ For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
311
+ """
312
+ if kv_positions is None:
313
+ kv_positions = q_positions
314
+ original_dtype = Xq.dtype
315
+
316
+ Xq_BxTxNxH = self.q_proj(Xq)
317
+ Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
318
+ Xq_BxNxTxH = mx.transpose(Xq_BxTxNxH, (0, 2, 1, 3))
319
+
320
+ # Input values into attention calculation
321
+ attn_k = None
322
+ attn_v = None
323
+
324
+ # Decoder Cross Attention
325
+ if self.is_cross_attn:
326
+ # Directly use cache (no need to check index)
327
+ attn_k, attn_v = cache.k, cache.v
328
+ if (
329
+ attn_k.shape[1] != self.num_query_heads
330
+ or attn_v.shape[1] != self.num_query_heads
331
+ ):
332
+ raise ValueError(
333
+ f"Cross-attention cache head dimension ({attn_k.shape[1]}) "
334
+ f"does not match num_query_heads ({self.num_query_heads}). "
335
+ "Cache should be pre-repeated for GQA."
336
+ )
337
+ # Self Attention
338
+ else:
339
+ Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
340
+ Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
341
+ Xk_BxSxKxH = self.rotary_emb(
342
+ Xk_BxSxKxH, position=kv_positions
343
+ ) # (B, S, K, H)
344
+
345
+ Xk_BxKxSxH = mx.transpose(Xk_BxSxKxH, (0, 2, 1, 3)) # (B, K, S, H)
346
+ Xv_BxKxSxH = mx.transpose(Xv_BxSxKxH, (0, 2, 1, 3)) # (B, K, S, H)
347
+ # S=1 for Decode Step
348
+
349
+ if self.num_gqa_groups > 1:
350
+ Xk_BxNxSxH = repeat(
351
+ Xk_BxKxSxH, "b k s h -> b (k g) s h", g=self.num_gqa_groups
352
+ )
353
+ Xv_BxNxSxH = repeat(
354
+ Xv_BxKxSxH, "b k s h -> b (k g) s h", g=self.num_gqa_groups
355
+ )
356
+ else:
357
+ Xk_BxNxSxH = Xk_BxKxSxH
358
+ Xv_BxNxSxH = Xv_BxKxSxH
359
+
360
+ # Encoder Self Attention
361
+ if cache is None:
362
+ attn_k = Xk_BxNxSxH
363
+ attn_v = Xv_BxNxSxH
364
+ # Decoder Self Attention
365
+ else:
366
+ # In prefill mode, we fill in cache until prefill length
367
+ if prefill:
368
+ attn_k, attn_v = Xk_BxNxSxH, Xv_BxNxSxH
369
+ cache.prefill_kv(attn_k, attn_v)
370
+ # In decode step, we add current K/V to cache step by step
371
+ else:
372
+ attn_k, attn_v = cache.update_and_fetch(Xk_BxNxSxH, Xv_BxNxSxH)
373
+
374
+ # Attention Calculation
375
+ attn_scores = mx.matmul(Xq_BxNxTxH, attn_k.swapaxes(2, 3))
376
+
377
+ # Apply Scaling
378
+ scale_factor = 1.0
379
+ attn_scores = attn_scores * scale_factor
380
+
381
+ # Apply Attention Mask
382
+ if attn_mask is not None:
383
+ # Add large negative value where mask is False/0
384
+ attn_scores = mx.where(
385
+ attn_mask, attn_scores, -1e9
386
+ ) # Using -1e9 for numerical stability
387
+
388
+ attn_weights = mx.softmax(attn_scores, axis=-1)
389
+ attn_output = mx.matmul(attn_weights, attn_v)
390
+
391
+ attn_output = mx.transpose(attn_output, (0, 2, 1, 3)) # (B, T, N, H)
392
+ output = self.o_proj(attn_output)
393
+
394
+ if output.dtype != original_dtype:
395
+ output = output.astype(original_dtype)
396
+
397
+ return output
398
+
399
+
400
+ class EncoderLayer(nn.Module):
401
+ def __init__(self, config: DiaConfig):
402
+ super().__init__()
403
+ self.config = config
404
+ model_config = config.model
405
+ enc_config = config.model.encoder
406
+ embed_dim = enc_config.n_embd
407
+
408
+ self.pre_sa_norm = nn.RMSNorm(
409
+ embed_dim,
410
+ eps=model_config.normalization_layer_epsilon,
411
+ )
412
+
413
+ self.self_attention = Attention(
414
+ config=config,
415
+ q_embed_dim=embed_dim,
416
+ kv_embed_dim=embed_dim,
417
+ num_query_heads=enc_config.n_head,
418
+ num_kv_heads=enc_config.n_head,
419
+ head_dim=enc_config.head_dim,
420
+ dropout_rate=model_config.dropout,
421
+ is_cross_attn=False,
422
+ out_embed_dim=embed_dim,
423
+ )
424
+
425
+ self.post_sa_norm = nn.RMSNorm(
426
+ embed_dim,
427
+ eps=model_config.normalization_layer_epsilon,
428
+ )
429
+
430
+ self.mlp = MlpBlock(
431
+ config=config,
432
+ embed_dim=embed_dim,
433
+ intermediate_dim=enc_config.n_hidden,
434
+ activations=enc_config.mlp_activations,
435
+ dropout_rate=model_config.dropout,
436
+ use_pre_norm=enc_config.use_pre_norm,
437
+ )
438
+
439
+ self.dropout = nn.Dropout(model_config.dropout)
440
+
441
+ def __call__(
442
+ self,
443
+ x: mx.array,
444
+ src_positions: Optional[mx.array] = None,
445
+ deterministic: bool = True,
446
+ attn_mask: Optional[mx.array] = None,
447
+ ) -> mx.array:
448
+ residual = x
449
+ x_norm = self.pre_sa_norm(x)
450
+
451
+ sa_out = self.self_attention(
452
+ Xq=x_norm,
453
+ Xkv=x_norm,
454
+ q_positions=src_positions,
455
+ kv_positions=src_positions,
456
+ deterministic=deterministic,
457
+ attn_mask=attn_mask,
458
+ )
459
+ x = residual + sa_out
460
+
461
+ residual = x
462
+ x_norm = self.post_sa_norm(x)
463
+ mlp_out = self.mlp(x_norm, deterministic=deterministic)
464
+ x = residual + mlp_out
465
+
466
+ if not deterministic:
467
+ x = self.dropout(x)
468
+
469
+ return x
470
+
471
+
472
+ class Encoder(nn.Module):
473
+ def __init__(self, config: DiaConfig):
474
+ super().__init__()
475
+ self.config = config
476
+ model_config = config.model
477
+ enc_config = config.model.encoder
478
+
479
+ self.embedding = nn.Embedding(
480
+ model_config.src_vocab_size,
481
+ enc_config.n_embd,
482
+ )
483
+ self.dropout = nn.Dropout(model_config.dropout)
484
+ self.layers = [EncoderLayer(config=config) for _ in range(enc_config.n_layer)]
485
+ self.norm = nn.RMSNorm(
486
+ enc_config.n_embd,
487
+ eps=model_config.normalization_layer_epsilon,
488
+ )
489
+
490
+ def __call__(
491
+ self,
492
+ x_ids: mx.array,
493
+ src_positions: Optional[mx.array] = None,
494
+ deterministic: bool = True,
495
+ attn_mask: Optional[mx.array] = None,
496
+ ) -> mx.array:
497
+ x = self.embedding(x_ids)
498
+
499
+ if not deterministic:
500
+ x = self.dropout(x)
501
+
502
+ for layer_index, layer in enumerate(self.layers):
503
+ x = layer(
504
+ x,
505
+ src_positions=src_positions,
506
+ deterministic=deterministic,
507
+ attn_mask=attn_mask,
508
+ )
509
+
510
+ x = self.norm(x)
511
+
512
+ if not deterministic:
513
+ x = self.dropout(x)
514
+
515
+ return x
516
+
517
+
518
+ class DecoderLayer(nn.Module):
519
+ def __init__(self, config: DiaConfig):
520
+ super().__init__()
521
+ self.config = config
522
+ model_config = config.model
523
+ dec_config = config.model.decoder
524
+ enc_config = config.model.encoder
525
+ dec_embed_dim = dec_config.n_embd
526
+ enc_embed_dim = enc_config.n_embd
527
+
528
+ # Norms
529
+ self.pre_sa_norm = nn.RMSNorm(
530
+ dec_embed_dim,
531
+ eps=model_config.normalization_layer_epsilon,
532
+ )
533
+ self.pre_ca_norm = nn.RMSNorm(
534
+ dec_embed_dim,
535
+ eps=model_config.normalization_layer_epsilon,
536
+ )
537
+ self.pre_mlp_norm = nn.RMSNorm(
538
+ dec_embed_dim,
539
+ eps=model_config.normalization_layer_epsilon,
540
+ )
541
+
542
+ # Self-Attention (GQA) with Causal Masking
543
+ self.self_attention = Attention(
544
+ config=config,
545
+ q_embed_dim=dec_embed_dim,
546
+ kv_embed_dim=dec_embed_dim,
547
+ num_query_heads=dec_config.gqa_query_heads,
548
+ num_kv_heads=dec_config.kv_heads,
549
+ head_dim=dec_config.gqa_head_dim,
550
+ dropout_rate=model_config.dropout,
551
+ is_cross_attn=False,
552
+ out_embed_dim=dec_embed_dim,
553
+ )
554
+
555
+ # Cross-Attention (MHA)
556
+ self.cross_attention = Attention(
557
+ config=config,
558
+ q_embed_dim=dec_embed_dim,
559
+ kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
560
+ num_query_heads=dec_config.cross_query_heads,
561
+ num_kv_heads=dec_config.cross_query_heads,
562
+ head_dim=dec_config.cross_head_dim,
563
+ dropout_rate=model_config.dropout,
564
+ is_cross_attn=True,
565
+ out_embed_dim=dec_embed_dim,
566
+ )
567
+
568
+ # MLP
569
+ self.mlp = MlpBlock(
570
+ config=config,
571
+ embed_dim=dec_embed_dim,
572
+ intermediate_dim=dec_config.n_hidden,
573
+ activations=dec_config.mlp_activations,
574
+ dropout_rate=model_config.dropout,
575
+ use_pre_norm=dec_config.use_pre_norm,
576
+ )
577
+
578
+ def __call__(
579
+ self,
580
+ x: mx.array,
581
+ encoder_out: mx.array,
582
+ tgt_positions: mx.array,
583
+ src_positions: Optional[mx.array],
584
+ deterministic: bool,
585
+ self_attn_mask: mx.array,
586
+ cross_attn_mask: mx.array,
587
+ self_attn_cache: KVCache,
588
+ cross_attn_cache: KVCache,
589
+ prefill: bool = False,
590
+ ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
591
+ # 1. Self-Attention
592
+ residual = x
593
+ x_norm = self.pre_sa_norm(x)
594
+
595
+ sa_out = self.self_attention(
596
+ Xq=x_norm, # (2, 1, D)
597
+ Xkv=x_norm, # (2, 1, D)
598
+ q_positions=tgt_positions, # (2, 1)
599
+ kv_positions=tgt_positions, # (2, 1)
600
+ deterministic=deterministic,
601
+ attn_mask=self_attn_mask, # (2, 1, 1, S_max)
602
+ cache=self_attn_cache,
603
+ prefill=prefill,
604
+ )
605
+ x = residual + sa_out
606
+
607
+ # 2. Cross-Attention
608
+ residual = x
609
+ x_norm = self.pre_ca_norm(x)
610
+ ca_out = self.cross_attention(
611
+ Xq=x_norm,
612
+ Xkv=encoder_out,
613
+ q_positions=tgt_positions,
614
+ kv_positions=src_positions,
615
+ deterministic=deterministic,
616
+ attn_mask=cross_attn_mask,
617
+ cache=cross_attn_cache,
618
+ )
619
+ x = residual + ca_out
620
+
621
+ # 3. MLP
622
+ residual = x
623
+ x_norm = self.pre_mlp_norm(x)
624
+ mlp_out = self.mlp(x_norm, deterministic=deterministic)
625
+ x = residual + mlp_out
626
+
627
+ return x
628
+
629
+
630
+ class Decoder(nn.Module):
631
+ def __init__(self, config: DiaConfig):
632
+ super().__init__()
633
+ self.config = config
634
+ model_config = config.model
635
+ dec_config = config.model.decoder
636
+ train_config = config.training
637
+ data_config = config.data
638
+ weight_dtype = _str_to_dtype(config.model.weight_dtype)
639
+ self.num_channels = data_config.channels
640
+ self.num_layers = dec_config.n_layer
641
+
642
+ self.embeddings = [
643
+ nn.Embedding(model_config.tgt_vocab_size, dec_config.n_embd)
644
+ for _ in range(self.num_channels)
645
+ ]
646
+ self.dropout = nn.Dropout(model_config.dropout)
647
+ self.layers = [DecoderLayer(config=config) for _ in range(self.num_layers)]
648
+ self.norm = nn.RMSNorm(
649
+ dec_config.n_embd,
650
+ eps=model_config.normalization_layer_epsilon,
651
+ )
652
+
653
+ # Final Logits Projection using DenseGeneral
654
+ self.logits_dense = DenseGeneral(
655
+ in_shapes=(dec_config.n_embd,),
656
+ out_features=(self.num_channels, model_config.tgt_vocab_size),
657
+ axis=(-1,),
658
+ dtype=mx.float32,
659
+ weight_dtype=weight_dtype,
660
+ )
661
+ self.logits_in_fp32 = train_config.logits_dot_in_fp32
662
+
663
+ def precompute_cross_attention_kv(
664
+ self,
665
+ max_len: int,
666
+ encoder_out: mx.array, # (B, S, E)
667
+ src_positions: Optional[mx.array], # (B, S)
668
+ ) -> List[KVCache]:
669
+ """
670
+ Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
671
+ """
672
+ per_layer_kv_cache: List[KVCache] = []
673
+
674
+ for layer in self.layers:
675
+ cross_attn_module = layer.cross_attention
676
+ k_proj = cross_attn_module.k_proj(encoder_out)
677
+ v_proj = cross_attn_module.v_proj(encoder_out)
678
+
679
+ k_proj = cross_attn_module.rotary_emb(k_proj, position=src_positions)
680
+ k = mx.transpose(k_proj, (0, 2, 1, 3)) # equivalent to transpose(1, 2)
681
+ v = mx.transpose(v_proj, (0, 2, 1, 3)) # equivalent to transpose(1, 2)
682
+
683
+ # Create KVCache without device parameter
684
+ per_layer_kv_cache.append(
685
+ KVCache(
686
+ cross_attn_module.num_kv_heads,
687
+ max_len,
688
+ cross_attn_module.head_dim,
689
+ k=k,
690
+ v=v,
691
+ )
692
+ )
693
+
694
+ return per_layer_kv_cache
695
+
696
+ def decode_step(
697
+ self,
698
+ tgt_ids_Bx1xC: mx.array, # [B, 1, C]
699
+ tgt_pos_Bx1: mx.array, # [B, 1]
700
+ encoder_out: mx.array, # [B, S, E]
701
+ self_attn_mask: Any, # None
702
+ cross_attn_mask: mx.array, # [B, 1, 1, S]
703
+ self_attention_cache: List[KVCache],
704
+ cross_attention_cache: List[KVCache],
705
+ ) -> mx.array:
706
+ """
707
+ Performs a single decoding step, managing KV caches layer by layer.
708
+
709
+ Returns:
710
+ A tuple containing:
711
+ - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
712
+ - new_cache: The updated KV cache for the next decoding step.
713
+ """
714
+ assert (
715
+ self_attn_mask is None
716
+ ), "Self-attention mask should be None, kept for pattern"
717
+
718
+ x = None
719
+ for i in range(self.num_channels):
720
+ channel_tokens = tgt_ids_Bx1xC[..., i]
721
+ channel_embed = self.embeddings[i](channel_tokens)
722
+ x = channel_embed if x is None else x + channel_embed
723
+
724
+ for i, layer in enumerate(self.layers):
725
+ self_cache = self_attention_cache[i]
726
+ cross_cache = cross_attention_cache[i]
727
+ x = layer(
728
+ x, # (2, 1, D)
729
+ encoder_out, # (2, S, E)
730
+ src_positions=None, # CA KV is already computed
731
+ tgt_positions=tgt_pos_Bx1, # (2, 1)
732
+ deterministic=True,
733
+ self_attn_mask=None,
734
+ cross_attn_mask=cross_attn_mask,
735
+ self_attn_cache=self_cache,
736
+ cross_attn_cache=cross_cache,
737
+ )
738
+
739
+ x = self.norm(x)
740
+ logits_Bx1xCxV = self.logits_dense(x)
741
+
742
+ # Convert to float32 if needed
743
+ if logits_Bx1xCxV.dtype != mx.float32:
744
+ logits_Bx1xCxV = logits_Bx1xCxV.astype(mx.float32)
745
+
746
+ return logits_Bx1xCxV
747
+
748
+ def __call__(
749
+ self,
750
+ tgt_ids_BxTxC: mx.array,
751
+ encoder_out: mx.array,
752
+ tgt_positions: mx.array,
753
+ src_positions: mx.array,
754
+ deterministic: bool,
755
+ self_attn_mask: mx.array,
756
+ cross_attn_mask: mx.array,
757
+ self_attention_cache: List[KVCache],
758
+ cross_attention_cache: List[KVCache],
759
+ ) -> mx.array:
760
+ """
761
+ Forward pass for the Decoder stack, managing KV caches.
762
+
763
+ Args:
764
+ tgt_ids_BxTxC: Target token IDs (B, T, C).
765
+ encoder_out: Output from the encoder (B, S, E).
766
+ tgt_positions: Positions for target sequence (B, T).
767
+ src_positions: Positions for source sequence (B, S).
768
+ deterministic: Disable dropout if True.
769
+ self_attn_mask: Mask for self-attention.
770
+ cross_attn_mask: Mask for cross-attention.
771
+ self_attention_cache: List containing the self-attention KV cache for each layer.
772
+ cross_attention_cache: List containing the cross-attention KV cache for each layer.
773
+
774
+ Returns:
775
+ logits: The final output logits (B, T, C * V), cast to float32.
776
+ """
777
+ _, _, num_channels_in = tgt_ids_BxTxC.shape
778
+ assert num_channels_in == self.num_channels, "Input channels mismatch"
779
+
780
+ # Embeddings
781
+ x = None
782
+ for i in range(self.num_channels):
783
+ channel_tokens = tgt_ids_BxTxC[..., i]
784
+ channel_embed = self.embeddings[i](channel_tokens)
785
+ x = channel_embed if x is None else x + channel_embed
786
+
787
+ # Apply dropout if not deterministic
788
+ if not deterministic:
789
+ x = self.dropout(x)
790
+
791
+ # Process through each decoder layer
792
+ for i, layer in enumerate(self.layers):
793
+ x = layer(
794
+ x,
795
+ encoder_out,
796
+ tgt_positions=tgt_positions,
797
+ src_positions=src_positions,
798
+ deterministic=deterministic,
799
+ self_attn_mask=self_attn_mask,
800
+ cross_attn_mask=cross_attn_mask,
801
+ self_attn_cache=self_attention_cache[i],
802
+ cross_attn_cache=cross_attention_cache[i],
803
+ prefill=True,
804
+ )
805
+
806
+ # Final Norm
807
+ x = self.norm(x)
808
+ logits_BxTxCxV = self.logits_dense(x)
809
+
810
+ # Convert to float32 if needed
811
+ if logits_BxTxCxV.dtype != mx.float32:
812
+ logits_BxTxCxV = logits_BxTxCxV.astype(mx.float32)
813
+
814
+ return logits_BxTxCxV
815
+
816
+
817
+ class DiaModel(nn.Module):
818
+ def __init__(self, config: DiaConfig):
819
+ super().__init__()
820
+ self.config = config
821
+ self.encoder = Encoder(config)
822
+ self.decoder = Decoder(config)
823
+
824
+ def __call__(
825
+ self,
826
+ src_BxS: mx.array,
827
+ tgt_BxTxC: mx.array,
828
+ src_positions: Optional[mx.array] = None,
829
+ tgt_positions: Optional[mx.array] = None,
830
+ enc_self_attn_mask: Optional[mx.array] = None,
831
+ dec_self_attn_mask: Optional[mx.array] = None,
832
+ dec_cross_attn_mask: Optional[mx.array] = None,
833
+ enable_dropout: bool = True,
834
+ ):
835
+ deterministic = not enable_dropout
836
+
837
+ # --- Encoder Pass ---
838
+ encoder_out = self.encoder(
839
+ x_ids=src_BxS,
840
+ src_positions=src_positions,
841
+ deterministic=deterministic,
842
+ attn_mask=enc_self_attn_mask,
843
+ )
844
+
845
+ # --- Decoder Pass ---
846
+ max_len = self.config.model.max_sequence_length
847
+
848
+ self_attention_cache = []
849
+
850
+ for layer in self.decoder.layers:
851
+ self_attn_module = layer.self_attention
852
+ self_attention_cache.append(
853
+ KVCache(
854
+ self_attn_module.num_query_heads, max_len, self_attn_module.head_dim
855
+ )
856
+ )
857
+
858
+ logits = self.decoder(
859
+ tgt_ids_BxTxC=tgt_BxTxC,
860
+ encoder_out=encoder_out,
861
+ tgt_positions=tgt_positions,
862
+ src_positions=src_positions,
863
+ deterministic=deterministic,
864
+ self_attn_mask=dec_self_attn_mask,
865
+ cross_attn_mask=dec_cross_attn_mask,
866
+ self_attention_cache=self_attention_cache,
867
+ cross_attention_cache=None,
868
+ )
869
+
870
+ return logits