nexaai 1.0.4rc10__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (519) hide show
  1. nexaai/__init__.py +71 -0
  2. nexaai/_version.py +4 -0
  3. nexaai/asr.py +60 -0
  4. nexaai/asr_impl/__init__.py +0 -0
  5. nexaai/asr_impl/mlx_asr_impl.py +91 -0
  6. nexaai/asr_impl/pybind_asr_impl.py +43 -0
  7. nexaai/base.py +39 -0
  8. nexaai/binds/__init__.py +3 -0
  9. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  10. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/libnexa_bridge.dylib +0 -0
  12. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  13. nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
  14. nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
  15. nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
  16. nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
  17. nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
  18. nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
  19. nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
  20. nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
  21. nexaai/binds/nexa_mlx/py-lib/ml.py +842 -0
  22. nexaai/binds/nexa_mlx/py-lib/mlx_audio/__init__.py +0 -0
  23. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/__init__.py +1 -0
  24. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  25. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  26. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  27. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  28. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  29. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  30. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  31. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  32. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  33. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  34. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  35. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  36. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  37. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  38. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  39. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  40. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  41. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  42. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  43. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  44. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  45. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  46. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  47. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  48. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  49. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  50. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  51. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  52. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  53. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  54. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  55. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  56. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  57. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  58. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  59. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  60. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  61. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  62. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  63. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  64. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  65. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  66. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  67. nexaai/binds/nexa_mlx/py-lib/mlx_audio/server.py +525 -0
  68. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/__init__.py +0 -0
  69. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  70. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  71. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/__init__.py +0 -0
  72. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/generate.py +174 -0
  73. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  74. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  75. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  76. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  77. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  78. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  79. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  80. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  81. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  82. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  83. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  84. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  85. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  86. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  87. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  88. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  89. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  90. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  91. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  92. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  93. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/utils.py +195 -0
  94. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/__init__.py +1 -0
  95. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/audio_player.py +120 -0
  96. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/convert.py +71 -0
  97. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/generate.py +449 -0
  98. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  99. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  100. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  101. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  102. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  103. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/base.py +84 -0
  104. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  105. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  106. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  107. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  108. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  109. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  110. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  111. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  112. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  113. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  114. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  115. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  116. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  117. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  118. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  119. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  120. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  121. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  122. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  123. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  124. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  125. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  126. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  127. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  128. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  129. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  130. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  131. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  132. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  133. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  134. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  135. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  136. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  137. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  138. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  139. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  140. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  141. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  142. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  143. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  144. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  145. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  146. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  147. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  148. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  149. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  150. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  151. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  152. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  153. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  154. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  155. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  156. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  157. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  158. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  159. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  160. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  161. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  162. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  163. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  164. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  165. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  166. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  167. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  168. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  169. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/utils.py +337 -0
  170. nexaai/binds/nexa_mlx/py-lib/mlx_audio/utils.py +237 -0
  171. nexaai/binds/nexa_mlx/py-lib/mlx_audio/version.py +1 -0
  172. nexaai/binds/nexa_mlx/py-lib/profiling.py +239 -0
  173. nexaai/common.py +61 -0
  174. nexaai/cv.py +87 -0
  175. nexaai/cv_impl/__init__.py +0 -0
  176. nexaai/cv_impl/mlx_cv_impl.py +88 -0
  177. nexaai/cv_impl/pybind_cv_impl.py +31 -0
  178. nexaai/embedder.py +68 -0
  179. nexaai/embedder_impl/__init__.py +0 -0
  180. nexaai/embedder_impl/mlx_embedder_impl.py +114 -0
  181. nexaai/embedder_impl/pybind_embedder_impl.py +91 -0
  182. nexaai/image_gen.py +136 -0
  183. nexaai/image_gen_impl/__init__.py +0 -0
  184. nexaai/image_gen_impl/mlx_image_gen_impl.py +291 -0
  185. nexaai/image_gen_impl/pybind_image_gen_impl.py +84 -0
  186. nexaai/llm.py +89 -0
  187. nexaai/llm_impl/__init__.py +0 -0
  188. nexaai/llm_impl/mlx_llm_impl.py +249 -0
  189. nexaai/llm_impl/pybind_llm_impl.py +207 -0
  190. nexaai/mlx_backend/asr/__init__.py +12 -0
  191. nexaai/mlx_backend/asr/interface.py +122 -0
  192. nexaai/mlx_backend/common/__init__.py +0 -0
  193. nexaai/mlx_backend/common/utils.py +25 -0
  194. nexaai/mlx_backend/cv/__init__.py +0 -0
  195. nexaai/mlx_backend/cv/generate.py +195 -0
  196. nexaai/mlx_backend/cv/interface.py +151 -0
  197. nexaai/mlx_backend/cv/main.py +81 -0
  198. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  199. nexaai/mlx_backend/embedding/__init__.py +0 -0
  200. nexaai/mlx_backend/embedding/generate.py +130 -0
  201. nexaai/mlx_backend/embedding/interface.py +312 -0
  202. nexaai/mlx_backend/embedding/main.py +82 -0
  203. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  204. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  205. nexaai/mlx_backend/llm/__init__.py +0 -0
  206. nexaai/mlx_backend/llm/generate.py +149 -0
  207. nexaai/mlx_backend/llm/interface.py +764 -0
  208. nexaai/mlx_backend/llm/main.py +68 -0
  209. nexaai/mlx_backend/ml.py +842 -0
  210. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  211. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  212. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  213. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  214. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  215. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  216. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  217. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  218. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  219. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  220. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  221. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  222. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  223. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  224. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  225. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  226. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  227. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  228. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  229. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  230. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  231. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  232. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  233. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  234. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  235. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  236. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  237. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  238. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  239. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  240. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  241. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  242. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  243. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  244. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  245. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  246. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  247. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  249. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  250. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  251. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  252. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  253. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  254. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  255. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  256. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  257. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  258. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  259. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  260. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  261. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  262. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  264. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  265. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  266. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  267. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  268. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  269. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  270. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  271. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  272. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  273. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  274. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  275. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  276. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  277. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  278. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  279. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  280. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  281. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  282. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  283. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  284. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  285. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  286. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  287. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  288. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  289. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  290. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  291. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  292. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  293. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  294. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  295. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  296. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  297. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  298. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  299. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  300. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  301. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  302. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  303. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  304. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  305. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  306. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  307. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  308. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  309. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  310. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  311. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  312. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  313. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  314. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  315. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  316. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  317. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  318. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  319. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  320. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  321. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  322. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  353. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  354. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  355. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  356. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  357. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  358. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  359. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  360. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  361. nexaai/mlx_backend/profiling.py +239 -0
  362. nexaai/mlx_backend/rerank/__init__.py +0 -0
  363. nexaai/mlx_backend/rerank/generate.py +174 -0
  364. nexaai/mlx_backend/rerank/interface.py +287 -0
  365. nexaai/mlx_backend/rerank/main.py +127 -0
  366. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  367. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  368. nexaai/mlx_backend/sd/__init__.py +1 -0
  369. nexaai/mlx_backend/sd/interface.py +362 -0
  370. nexaai/mlx_backend/sd/main.py +286 -0
  371. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  372. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  373. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  374. nexaai/mlx_backend/sd/modeling/model_io.py +330 -0
  375. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  376. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  377. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  378. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  379. nexaai/mlx_backend/tts/__init__.py +12 -0
  380. nexaai/mlx_backend/tts/interface.py +276 -0
  381. nexaai/mlx_backend/vlm/__init__.py +3 -0
  382. nexaai/mlx_backend/vlm/generate.py +572 -0
  383. nexaai/mlx_backend/vlm/interface.py +406 -0
  384. nexaai/mlx_backend/vlm/main.py +157 -0
  385. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  386. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  387. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  388. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  389. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  390. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  391. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  392. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  393. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  394. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  395. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  396. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  397. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  398. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  399. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  400. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  401. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  402. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  403. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  404. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  405. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  406. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  407. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  408. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  409. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  410. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  411. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  412. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  413. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  414. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  415. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  416. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  417. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  418. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  419. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  420. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  421. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  422. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  423. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  424. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  425. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  426. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  427. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  428. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  429. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  430. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  431. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  432. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  433. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  434. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  435. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  436. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  437. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  438. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  439. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  440. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  441. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  442. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  443. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  444. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  445. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  446. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  447. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  448. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  449. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  450. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  451. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  452. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  453. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  454. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  455. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  456. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  457. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  458. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  459. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  460. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  461. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  462. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  463. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  464. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  465. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  466. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  467. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  468. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  470. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  471. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  472. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  473. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  474. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  475. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  476. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  477. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  478. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  479. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  480. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  481. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  482. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  483. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  484. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  485. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  486. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  487. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  488. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  489. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  490. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  491. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  492. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  493. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  494. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  495. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  496. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  497. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  498. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  499. nexaai/rerank.py +51 -0
  500. nexaai/rerank_impl/__init__.py +0 -0
  501. nexaai/rerank_impl/mlx_rerank_impl.py +91 -0
  502. nexaai/rerank_impl/pybind_rerank_impl.py +42 -0
  503. nexaai/runtime.py +64 -0
  504. nexaai/tts.py +70 -0
  505. nexaai/tts_impl/__init__.py +0 -0
  506. nexaai/tts_impl/mlx_tts_impl.py +93 -0
  507. nexaai/tts_impl/pybind_tts_impl.py +42 -0
  508. nexaai/utils/avatar_fetcher.py +104 -0
  509. nexaai/utils/decode.py +18 -0
  510. nexaai/utils/model_manager.py +1195 -0
  511. nexaai/utils/progress_tracker.py +372 -0
  512. nexaai/vlm.py +120 -0
  513. nexaai/vlm_impl/__init__.py +0 -0
  514. nexaai/vlm_impl/mlx_vlm_impl.py +205 -0
  515. nexaai/vlm_impl/pybind_vlm_impl.py +228 -0
  516. nexaai-1.0.4rc10.dist-info/METADATA +26 -0
  517. nexaai-1.0.4rc10.dist-info/RECORD +519 -0
  518. nexaai-1.0.4rc10.dist-info/WHEEL +5 -0
  519. nexaai-1.0.4rc10.dist-info/top_level.txt +1 -0
@@ -0,0 +1,399 @@
1
+ # Copyright © Nexa AI
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Union
18
+
19
+ import mlx.core as mx
20
+ import mlx.nn as nn
21
+
22
+ import os
23
+ import sys
24
+
25
+ curr_dir = os.path.dirname(os.path.abspath(__file__))
26
+ llm_common_dir = os.path.join(curr_dir, "..", "..")
27
+ sys.path.append(llm_common_dir)
28
+
29
+ from mlx_lm.models.base import (
30
+ BaseModelArgs,
31
+ scaled_dot_product_attention,
32
+ )
33
+ from tokenizers import Tokenizer
34
+
35
+ @dataclass
36
+ class ModelArgs(BaseModelArgs):
37
+ model_type: str = "bert"
38
+ vocab_size: int = 61056 # Updated from config
39
+ hidden_size: int = 768
40
+ num_hidden_layers: int = 12
41
+ num_attention_heads: int = 12
42
+ intermediate_size: int = 3072
43
+ hidden_act: str = "gelu"
44
+ hidden_dropout_prob: float = 0.1
45
+ attention_probs_dropout_prob: float = 0.1
46
+ max_position_embeddings: int = 8192 # Updated from config
47
+ type_vocab_size: int = 2
48
+ initializer_range: float = 0.02
49
+ layer_norm_eps: float = 1e-12
50
+ pad_token_id: int = 0
51
+ position_embedding_type: str = "alibi" # Updated from config
52
+ use_cache: bool = True
53
+ classifier_dropout: Optional[float] = None
54
+ feed_forward_type: str = "geglu" # Updated from config
55
+ emb_pooler: str = "mean" # Updated from config
56
+ attn_implementation: str = "torch"
57
+
58
+
59
+ class JinaBertEmbeddings(nn.Module):
60
+ def __init__(self, config: ModelArgs):
61
+ super().__init__()
62
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
63
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
64
+ # Use PyTorch-style naming for weight loading compatibility
65
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
66
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
67
+
68
+ def __call__(
69
+ self,
70
+ input_ids: Optional[mx.array] = None,
71
+ token_type_ids: Optional[mx.array] = None,
72
+ ) -> mx.array:
73
+ if token_type_ids is None:
74
+ input_shape = input_ids.shape
75
+ token_type_ids = mx.zeros(input_shape, dtype=mx.int64)
76
+
77
+ inputs_embeds = self.word_embeddings(input_ids)
78
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
79
+
80
+ embeddings = inputs_embeds + token_type_embeddings
81
+ embeddings = self.LayerNorm(embeddings)
82
+ return embeddings
83
+
84
+
85
+ class JinaBertSelfAttention(nn.Module):
86
+ def __init__(self, config: ModelArgs, position_embedding_type=None):
87
+ super().__init__()
88
+ if config.hidden_size % config.num_attention_heads != 0:
89
+ raise ValueError(
90
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
91
+ f"heads ({config.num_attention_heads})"
92
+ )
93
+
94
+ self.attn_implementation = config.attn_implementation
95
+ self.num_attention_heads = config.num_attention_heads
96
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
97
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
98
+
99
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
100
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
101
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
102
+
103
+ self.position_embedding_type = position_embedding_type or getattr(
104
+ config, "position_embedding_type", "absolute"
105
+ )
106
+
107
+ def transpose_for_scores(self, x: mx.array) -> mx.array:
108
+ new_x_shape = x.shape[:-1] + (self.num_attention_heads, self.attention_head_size)
109
+ x = x.reshape(new_x_shape)
110
+ return x.transpose(0, 2, 1, 3)
111
+
112
+ def __call__(
113
+ self,
114
+ hidden_states: mx.array,
115
+ attention_mask: Optional[mx.array] = None,
116
+ bias: Optional[mx.array] = None,
117
+ ) -> mx.array:
118
+ mixed_query_layer = self.query(hidden_states)
119
+
120
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
121
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
122
+ query_layer = self.transpose_for_scores(mixed_query_layer)
123
+
124
+ scale = 1.0 / math.sqrt(self.attention_head_size)
125
+
126
+ mask = None
127
+ if attention_mask is not None or bias is not None:
128
+ if attention_mask is not None and bias is not None:
129
+ mask = attention_mask + bias
130
+ elif attention_mask is not None:
131
+ mask = attention_mask
132
+ else:
133
+ mask = bias
134
+
135
+ # Cast mask to same dtype as hidden_states
136
+ if mask is not None:
137
+ mask = mask.astype(hidden_states.dtype)
138
+
139
+ context_layer = scaled_dot_product_attention(
140
+ query_layer, key_layer, value_layer, cache=None, scale=scale, mask=mask
141
+ )
142
+
143
+ context_layer = context_layer.transpose(0, 2, 1, 3)
144
+ new_context_layer_shape = context_layer.shape[:-2] + (self.all_head_size,)
145
+ context_layer = context_layer.reshape(new_context_layer_shape)
146
+ return context_layer
147
+
148
+
149
+ class JinaBertSelfOutput(nn.Module):
150
+ def __init__(self, config):
151
+ super().__init__()
152
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
153
+ # Use PyTorch-style naming for weight loading compatibility
154
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
155
+
156
+ def __call__(self, hidden_states: mx.array, input_tensor: mx.array) -> mx.array:
157
+ hidden_states = self.dense(hidden_states)
158
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
159
+ return hidden_states
160
+
161
+
162
+ class JinaBertAttention(nn.Module):
163
+ def __init__(self, config, position_embedding_type=None):
164
+ super().__init__()
165
+ self.self = JinaBertSelfAttention(config, position_embedding_type=position_embedding_type)
166
+ self.output = JinaBertSelfOutput(config)
167
+
168
+ def __call__(
169
+ self,
170
+ hidden_states: mx.array,
171
+ attention_mask: Optional[mx.array] = None,
172
+ bias: Optional[mx.array] = None,
173
+ ) -> mx.array:
174
+ self_outputs = self.self(hidden_states, attention_mask, bias)
175
+ attention_output = self.output(self_outputs, hidden_states)
176
+ return attention_output
177
+
178
+
179
+ class JinaBertGLUMLP(nn.Module):
180
+ def __init__(self, config: ModelArgs):
181
+ super().__init__()
182
+ self.config = config
183
+ self.gated_layers = nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=False)
184
+ self.wo = nn.Linear(config.intermediate_size, config.hidden_size)
185
+ # Use PyTorch-style naming for weight loading compatibility
186
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
187
+
188
+ def __call__(self, hidden_states: mx.array) -> mx.array:
189
+ residual_connection = hidden_states
190
+ hidden_states = self.gated_layers(hidden_states)
191
+
192
+ if self.config.feed_forward_type == "geglu":
193
+ gated = hidden_states[..., : self.config.intermediate_size]
194
+ non_gated = hidden_states[..., self.config.intermediate_size :]
195
+ hidden_states = nn.gelu(gated) * non_gated
196
+ else:
197
+ # Original GLU
198
+ gated = hidden_states[..., : self.config.intermediate_size]
199
+ non_gated = hidden_states[..., self.config.intermediate_size :]
200
+ hidden_states = nn.gelu(gated) * non_gated
201
+
202
+ hidden_states = self.wo(hidden_states)
203
+ hidden_states = self.layernorm(hidden_states + residual_connection)
204
+ return hidden_states
205
+
206
+
207
+ class JinaBertLayer(nn.Module):
208
+ def __init__(self, config: ModelArgs):
209
+ super().__init__()
210
+ self.attention = JinaBertAttention(config)
211
+ self.feed_forward_type = config.feed_forward_type
212
+ self.mlp = JinaBertGLUMLP(config)
213
+
214
+ def __call__(
215
+ self,
216
+ hidden_states: mx.array,
217
+ attention_mask: Optional[mx.array] = None,
218
+ bias: Optional[mx.array] = None,
219
+ ) -> mx.array:
220
+ attention_output = self.attention(hidden_states, attention_mask, bias=bias)
221
+ layer_output = self.mlp(attention_output)
222
+ return layer_output
223
+
224
+
225
+ class JinaBertEncoder(nn.Module):
226
+ def __init__(self, config: ModelArgs):
227
+ super().__init__()
228
+ self.config = config
229
+ # Use list instead of ModuleList for PyTorch compatibility
230
+ self.layer = [JinaBertLayer(config) for _ in range(config.num_hidden_layers)]
231
+ self.gradient_checkpointing = False
232
+ self.num_attention_heads = config.num_attention_heads
233
+ self._current_alibi_size = config.max_position_embeddings
234
+
235
+ # Build ALiBi tensor
236
+ # self.alibi = self.rebuild_alibi_tensor(size=config.max_position_embeddings)
237
+
238
+ def rebuild_alibi_tensor(self, size: int) -> mx.array:
239
+ """Build ALiBi bias tensor"""
240
+ n_heads = self.num_attention_heads
241
+
242
+ def _get_alibi_head_slopes(n_heads: int) -> List[float]:
243
+ def get_slopes_power_of_2(n):
244
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
245
+ ratio = start
246
+ return [start * ratio**i for i in range(n)]
247
+
248
+ if math.log2(n_heads).is_integer():
249
+ return get_slopes_power_of_2(n_heads)
250
+ else:
251
+ closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
252
+ return (
253
+ get_slopes_power_of_2(closest_power_of_2)
254
+ + _get_alibi_head_slopes(2 * closest_power_of_2)[0::2][
255
+ : n_heads - closest_power_of_2
256
+ ]
257
+ )
258
+
259
+ context_position = mx.arange(size)[:, None]
260
+ memory_position = mx.arange(size)[None, :]
261
+ relative_position = mx.abs(memory_position - context_position)
262
+ relative_position = mx.expand_dims(relative_position, axis=0)
263
+ relative_position = mx.repeat(relative_position, n_heads, axis=0)
264
+
265
+ slopes = mx.array(_get_alibi_head_slopes(n_heads)) * -1
266
+ slopes = mx.expand_dims(mx.expand_dims(slopes, axis=1), axis=2)
267
+ alibi = slopes * relative_position
268
+ alibi = mx.expand_dims(alibi, axis=0)
269
+
270
+ self._current_alibi_size = size
271
+ return alibi
272
+
273
+ def __call__(
274
+ self,
275
+ hidden_states: mx.array,
276
+ attention_mask: Optional[mx.array] = None,
277
+ ) -> mx.array:
278
+ _, seqlen, _ = hidden_states.shape
279
+ alibi_bias = self.rebuild_alibi_tensor(seqlen)
280
+
281
+ for i, layer_module in enumerate(self.layer):
282
+ layer_outputs = layer_module(hidden_states, attention_mask, alibi_bias)
283
+ hidden_states = layer_outputs
284
+
285
+ return hidden_states
286
+
287
+
288
+ class JinaBertPooler(nn.Module):
289
+ def __init__(self, config):
290
+ super().__init__()
291
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
292
+ self.activation = nn.tanh
293
+
294
+ def __call__(self, hidden_states: mx.array) -> mx.array:
295
+ # We "pool" the model by simply taking the hidden state corresponding
296
+ # to the first token.
297
+ first_token_tensor = hidden_states[:, 0]
298
+ pooled_output = self.dense(first_token_tensor)
299
+ pooled_output = self.activation(pooled_output)
300
+ return pooled_output
301
+
302
+
303
+ class JinaBertModel(nn.Module):
304
+ def __init__(self, config: ModelArgs):
305
+ super().__init__()
306
+ self.config = config
307
+ self.embeddings = JinaBertEmbeddings(config)
308
+ self.encoder = JinaBertEncoder(config)
309
+ # Add pooler layer for weight compatibility
310
+ self.pooler = JinaBertPooler(config)
311
+
312
+ def get_extended_attention_mask(self, attention_mask: mx.array, input_shape: tuple) -> mx.array:
313
+ """Convert attention mask to extended format"""
314
+ if attention_mask.ndim == 3:
315
+ extended_attention_mask = attention_mask[:, None, :, :]
316
+ elif attention_mask.ndim == 2:
317
+ extended_attention_mask = attention_mask[:, None, None, :]
318
+ else:
319
+ raise ValueError(f"Wrong shape for attention_mask (shape {attention_mask.shape})")
320
+
321
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
322
+ return extended_attention_mask
323
+
324
+ def mean_pooling(self, token_embeddings: mx.array, attention_mask: mx.array) -> mx.array:
325
+ input_mask_expanded = mx.expand_dims(attention_mask, axis=-1) * mx.ones_like(
326
+ token_embeddings
327
+ )
328
+ return mx.sum(token_embeddings * input_mask_expanded, axis=1) / mx.clip(
329
+ mx.sum(input_mask_expanded, axis=1), 1e-9, None
330
+ )
331
+
332
+ def __call__(
333
+ self,
334
+ input_ids: Optional[mx.array] = None,
335
+ attention_mask: Optional[mx.array] = None,
336
+ token_type_ids: Optional[mx.array] = None,
337
+ ) -> mx.array:
338
+ input_shape = input_ids.shape
339
+
340
+ if attention_mask is not None:
341
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
342
+ else:
343
+ extended_attention_mask = None
344
+
345
+ embedding_output = self.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)
346
+ encoder_outputs = self.encoder(embedding_output, attention_mask=extended_attention_mask)
347
+
348
+ return encoder_outputs
349
+
350
+ def encode(
351
+ self,
352
+ input_ids: mx.array,
353
+ attention_mask: mx.array,
354
+ token_type_ids: Optional[mx.array] = None,
355
+ ) -> mx.array:
356
+ """Encode inputs and return mean-pooled embeddings"""
357
+ token_embs = self(
358
+ input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
359
+ )
360
+ embeddings = self.mean_pooling(token_embs, attention_mask)
361
+ return embeddings
362
+
363
+
364
+ class Model(nn.Module):
365
+ def __init__(self, args: ModelArgs):
366
+ super().__init__()
367
+ self.args = args
368
+ self.model_type = args.model_type
369
+ self.model = JinaBertModel(args)
370
+
371
+ def __call__(
372
+ self,
373
+ input_ids: mx.array,
374
+ attention_mask: Optional[mx.array] = None,
375
+ token_type_ids: Optional[mx.array] = None,
376
+ ) -> mx.array:
377
+ return self.model(
378
+ input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
379
+ )
380
+
381
+ def encode(
382
+ self,
383
+ input_ids: mx.array,
384
+ attention_mask: mx.array,
385
+ token_type_ids: Optional[mx.array] = None,
386
+ ) -> mx.array:
387
+ """Encode inputs and return mean-pooled embeddings"""
388
+ return self.model.encode(
389
+ input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
390
+ )
391
+
392
+ def sanitize(self, weights):
393
+ """Remove parameters that don't exist in our model"""
394
+ # No longer need to remove pooler weights since we now have them
395
+ return weights
396
+
397
+ @property
398
+ def layers(self):
399
+ return self.model.encoder.layer
File without changes
@@ -0,0 +1,149 @@
1
+ import argparse
2
+ from mlx_lm.models.cache import make_prompt_cache
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ from mlx.utils import tree_reduce
6
+ from transformers import PreTrainedTokenizer
7
+ from mlx_lm.models import cache
8
+ from mlx_lm.models.cache import (
9
+ QuantizedKVCache,
10
+ load_prompt_cache,
11
+ )
12
+ from mlx_lm.sample_utils import make_sampler
13
+ from mlx_lm.tokenizer_utils import TokenizerWrapper
14
+ from mlx_lm.utils import does_model_support_input_embeddings, load
15
+ from mlx_lm.generate import stream_generate
16
+
17
+ DEFAULT_TEMP = 0.0
18
+ DEFAULT_TOP_P = 1.0
19
+ DEFAULT_XTC_PROBABILITY = 0.0
20
+ DEFAULT_XTC_THRESHOLD = 0.0
21
+ DEFAULT_SEED = None
22
+ DEFAULT_MAX_TOKENS = 256
23
+ DEFAULT_MODEL = "mlx-community/Qwen3-1.7B-4bit-DWQ"
24
+
25
+
26
+ def str2bool(string):
27
+ return string.lower() not in ["false", "f"]
28
+
29
+
30
+ def setup_arg_parser():
31
+ """Set up and return the argument parser."""
32
+ parser = argparse.ArgumentParser(description="Chat with an LLM")
33
+ parser.add_argument(
34
+ "--model",
35
+ type=str,
36
+ help="The path to the local model directory or Hugging Face repo.",
37
+ default=DEFAULT_MODEL,
38
+ )
39
+ parser.add_argument(
40
+ "--adapter-path",
41
+ type=str,
42
+ help="Optional path for the trained adapter weights and config.",
43
+ )
44
+ parser.add_argument(
45
+ "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
46
+ )
47
+ parser.add_argument(
48
+ "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
49
+ )
50
+ parser.add_argument(
51
+ "--xtc-probability",
52
+ type=float,
53
+ default=DEFAULT_XTC_PROBABILITY,
54
+ help="Probability of XTC sampling to happen each next token",
55
+ )
56
+ parser.add_argument(
57
+ "--xtc-threshold",
58
+ type=float,
59
+ default=0.0,
60
+ help="Thresold the probs of each next token candidate to be sampled by XTC",
61
+ )
62
+ parser.add_argument(
63
+ "--seed",
64
+ type=int,
65
+ default=DEFAULT_SEED,
66
+ help="PRNG seed",
67
+ )
68
+ parser.add_argument(
69
+ "--max-kv-size",
70
+ type=int,
71
+ help="Set the maximum key-value cache size",
72
+ default=None,
73
+ )
74
+ parser.add_argument(
75
+ "--max-tokens",
76
+ "-m",
77
+ type=int,
78
+ default=DEFAULT_MAX_TOKENS,
79
+ help="Maximum number of tokens to generate",
80
+ )
81
+ return parser
82
+
83
+
84
+ def main():
85
+ parser = setup_arg_parser()
86
+ args = parser.parse_args()
87
+
88
+ model, tokenizer = load(
89
+ args.model,
90
+ adapter_path=args.adapter_path,
91
+ tokenizer_config={"trust_remote_code": True},
92
+ )
93
+
94
+ # Initialize chat history
95
+ chat = []
96
+
97
+ while True:
98
+ try:
99
+ user_input = input("User: ").strip()
100
+
101
+ # Exit conditions
102
+ if user_input.lower() in ['exit', 'quit', '']:
103
+ break
104
+
105
+ chat.append({"role": "user", "content": user_input})
106
+
107
+ formatted_prompt = tokenizer.apply_chat_template(chat, add_generation_prompt=True)
108
+
109
+ # Generate response
110
+ response = ""
111
+ print("Assistant: ", end="", flush=True)
112
+
113
+ for chunk in stream_generate(
114
+ model,
115
+ tokenizer,
116
+ formatted_prompt,
117
+ max_tokens=args.max_tokens,
118
+ sampler=make_sampler(
119
+ args.temp,
120
+ args.top_p,
121
+ xtc_threshold=args.xtc_threshold,
122
+ xtc_probability=args.xtc_probability,
123
+ xtc_special_tokens=(
124
+ tokenizer.encode("\n") + list(tokenizer.eos_token_ids)
125
+ ),
126
+ ),
127
+ ):
128
+ response += chunk.text
129
+ print(chunk.text, end="", flush=True)
130
+
131
+ print() # New line after response
132
+
133
+ # Add assistant response to chat history
134
+ chat.append({"role": "assistant", "content": response})
135
+
136
+ except KeyboardInterrupt:
137
+ print("\nConversation interrupted by user.")
138
+ break
139
+ except Exception as e:
140
+ print(f"Error: {e}")
141
+ continue
142
+
143
+
144
+ if __name__ == "__main__":
145
+ print(
146
+ "Calling `python -m mlx_lm.chat...` directly is deprecated."
147
+ " Use `mlx_lm.chat...` or `python -m mlx_lm chat ...` instead."
148
+ )
149
+ main()