nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,764 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from typing import Any, List, Optional, Sequence, Tuple, Union
6
+ import mlx.core as mx
7
+ import os
8
+ import time
9
+
10
+ # Import necessary modules from mlx_lm
11
+ from mlx_lm import generate, stream_generate, load
12
+ from mlx_lm.sample_utils import make_sampler, make_logits_processors
13
+ from mlx_lm.models.cache import make_prompt_cache, save_prompt_cache, load_prompt_cache
14
+ from mlx_lm.tokenizer_utils import TokenizerWrapper
15
+ from mlx_lm.generate import generate_step
16
+ from mlx_lm.tuner.utils import load_adapters
17
+ import mlx.core as mx
18
+
19
+ # Import configs and callback types from ml.py for API alignment
20
+ from ml import (
21
+ LLM as BaseLLM,
22
+ ModelConfig,
23
+ SamplerConfig,
24
+ GenerationConfig,
25
+ ChatMessage,
26
+ EmbeddingConfig,
27
+ TokenCallback,
28
+ Path,
29
+ Tool
30
+ )
31
+
32
+ # Import profiling module
33
+ from profiling import ProfilingMixin, ProfilingData, StopReason
34
+
35
+ class LLM(BaseLLM, ProfilingMixin):
36
+ """
37
+ LLM interface for mlx-lm.
38
+ API aligned with ml.py LLM abstract base class.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ model_path: Path,
44
+ tokenizer_path: Path,
45
+ config: ModelConfig,
46
+ device: Optional[str] = None,
47
+ ) -> None:
48
+ """
49
+ Initialize the LLM model.
50
+ """
51
+ # Initialize profiling mixin
52
+ ProfilingMixin.__init__(self)
53
+
54
+ # Check if model_path is a file, if so use its parent directory, since MLX requires loading from a directory
55
+ if os.path.isfile(model_path):
56
+ model_path = os.path.dirname(model_path)
57
+
58
+ # Call parent constructor
59
+ super().__init__(model_path, tokenizer_path, config, device)
60
+
61
+ # For MLX, we ignore ModelConfig parameters as requested
62
+ # Store the basic parameters
63
+ self.model_path = model_path
64
+ self.tokenizer_path = tokenizer_path
65
+ self.config = config # Store but ignore the values
66
+ self.device = device if device is not None else "cpu"
67
+
68
+ # Simulate C handle (would be pointer in C, here just store info)
69
+ self.handle = {
70
+ "model_path": model_path,
71
+ "tokenizer_path": tokenizer_path,
72
+ "device": self.device,
73
+ }
74
+
75
+ # Load model and tokenizer using mlx-lm
76
+ self.model, self.tokenizer = load(model_path)
77
+ self.sampler_config = SamplerConfig()
78
+ self.default_generation_config = GenerationConfig()
79
+ self.kv_cache = None
80
+ # Initialize cache and global tracking (similar to reset logic)
81
+ self._reset_cache()
82
+ self.token_generator = None
83
+ self.loras = {}
84
+ self.current_lora_id = -1
85
+ self._next_lora_id = 0
86
+ # Track whether KV cache has been used for generation
87
+ self.kv_cache_used = False
88
+ # Track total tokens processed (prompts + responses) for prompt cache functionality
89
+ self.global_n_past = 0
90
+
91
+ def destroy(self) -> None:
92
+ """Destroy LLM instance and free associated resources (ml_llm_destroy)."""
93
+ self.model = None
94
+ self.tokenizer = None
95
+ self.kv_cache = None
96
+ self.token_generator = None
97
+ self.sampler_config = SamplerConfig()
98
+ self.default_generation_config = GenerationConfig()
99
+ self.loras.clear()
100
+ self.current_lora_id = -1
101
+ self._next_lora_id = 0
102
+ self.kv_cache_used = False
103
+ self.global_n_past = 0
104
+ self.reset_profiling()
105
+
106
+ def reset(self) -> None:
107
+ """Reset LLM internal state (ml_llm_reset)."""
108
+ mx.clear_cache()
109
+ self._reset_cache()
110
+ self.reset_profiling()
111
+
112
+ def _reset_cache(self) -> None:
113
+ """Reset the KV cache."""
114
+ if self.model is not None:
115
+ # For MLX, let mlx-lm handle cache size automatically since we ignore ModelConfig
116
+ # Use n_ctx if provided and > 0, otherwise let mlx-lm decide
117
+ max_kv_size = self.config.n_ctx if self.config.n_ctx > 0 else None
118
+ if max_kv_size:
119
+ self.kv_cache = make_prompt_cache(self.model, max_kv_size=max_kv_size)
120
+ else:
121
+ self.kv_cache = make_prompt_cache(self.model)
122
+ self.token_generator = None # Reset generator for new conversation
123
+ self.kv_cache_used = False # Reset cache usage flag
124
+ self.global_n_past = 0 # Reset prompt cache tracking
125
+
126
+ # Tokenization methods
127
+ def encode(self, text: str) -> List[int]:
128
+ """Encode UTF-8 text to token IDs (ml_llm_encode)."""
129
+ if not isinstance(self.tokenizer, TokenizerWrapper):
130
+ wrapper = TokenizerWrapper(self.tokenizer)
131
+ return wrapper.encode(text, add_special_tokens=True)
132
+ return self.tokenizer.encode(text, add_special_tokens=True)
133
+
134
+ def decode(self, token_ids: Sequence[int]) -> str:
135
+ """Decode token IDs to UTF-8 text (ml_llm_decode)."""
136
+ if not isinstance(self.tokenizer, TokenizerWrapper):
137
+ wrapper = TokenizerWrapper(self.tokenizer)
138
+ return wrapper.decode(list(token_ids))
139
+ return self.tokenizer.decode(list(token_ids))
140
+
141
+ # KV-cache methods
142
+ def save_kv_cache(self, path: Path) -> bool:
143
+ """Save KV cache to file. Returns True on success, False on error."""
144
+ try:
145
+ if self.kv_cache is not None:
146
+ if not path.endswith('.safetensors'):
147
+ path = path + '.safetensors'
148
+ save_prompt_cache(path, self.kv_cache)
149
+ return True
150
+ return False
151
+ except Exception as e:
152
+ print(f"Error saving KV cache: {e}")
153
+ return False
154
+
155
+ def load_kv_cache(self, path: Path) -> bool:
156
+ """Load KV cache from file. Returns True on success, False on error."""
157
+ try:
158
+ if not path.endswith('.safetensors'):
159
+ path = path + '.safetensors'
160
+ self.kv_cache = load_prompt_cache(path)
161
+ return True
162
+ except Exception as e:
163
+ print(f"Error loading KV cache: {e}")
164
+ return False
165
+
166
+ # LoRA methods
167
+ #
168
+ # LoRA (Low-Rank Adaptation) support for fine-tuned model variants.
169
+ # This implementation supports dynamic switching between different LoRA adapters
170
+ # by reloading the model with the appropriate adapter weights.
171
+ #
172
+ # Usage:
173
+ # 1. Add LoRA adapter: lora_id = model.add_lora("/path/to/adapter")
174
+ # 2. Activate LoRA: model.set_lora(lora_id)
175
+ # 3. Switch back to base model: model.set_lora(-1)
176
+ # 4. Or combine steps 1-2: lora_id = model.load_and_activate_lora("/path/to/adapter")
177
+
178
+ def set_lora(self, lora_id: int) -> None:
179
+ """Set active LoRA adapter by ID (ml_llm_set_lora)."""
180
+ if lora_id == -1:
181
+ if self.current_lora_id != -1:
182
+ self._switch_to_base_model()
183
+ return
184
+ if lora_id not in self.loras:
185
+ raise ValueError(f"LoRA adapter with ID {lora_id} not found")
186
+ if self.current_lora_id != lora_id:
187
+ self._switch_to_lora(lora_id)
188
+
189
+ def add_lora(self, lora_path: Path) -> int:
190
+ """Add LoRA adapter from file (ml_llm_add_lora). Returns LoRA ID on success, negative on error."""
191
+ if not lora_path or not os.path.exists(lora_path):
192
+ return -1
193
+ if not self._validate_lora_adapter(lora_path):
194
+ return -2
195
+ for lora_id, (path, _) in self.loras.items():
196
+ if os.path.abspath(path) == os.path.abspath(lora_path):
197
+ return lora_id
198
+ lora_id = self._next_lora_id
199
+ self._next_lora_id += 1
200
+ try:
201
+ adapters = load_adapters(lora_path)
202
+ self.loras[lora_id] = (lora_path, adapters)
203
+ return lora_id
204
+ except Exception:
205
+ return -99
206
+
207
+ def _validate_lora_adapter(self, lora_path: Path) -> bool:
208
+ """Validate that a path contains a valid LoRA adapter."""
209
+ if not os.path.isdir(lora_path):
210
+ return False
211
+
212
+ # Check for required LoRA files
213
+ required_files = ["adapter_config.json"]
214
+ optional_files = [
215
+ "adapters.safetensors",
216
+ "adapter_model.safetensors",
217
+ "pytorch_model.bin", # PyTorch format
218
+ "adapter_model.bin", # Alternative PyTorch format
219
+ ]
220
+
221
+ # At least adapter_config.json should exist
222
+ config_exists = any(os.path.exists(os.path.join(lora_path, f)) for f in required_files)
223
+ if not config_exists:
224
+ return False
225
+
226
+ # At least one weight file should exist
227
+ weights_exist = any(os.path.exists(os.path.join(lora_path, f)) for f in optional_files)
228
+
229
+ return weights_exist
230
+
231
+ def remove_lora(self, lora_id: int) -> None:
232
+ """Remove LoRA adapter by ID (ml_llm_remove_lora)."""
233
+ if lora_id not in self.loras:
234
+ return
235
+ if self.current_lora_id == lora_id:
236
+ self._switch_to_base_model()
237
+ self.loras.pop(lora_id, None)
238
+
239
+ def list_loras(self) -> List[int]:
240
+ """List all loaded LoRA adapter IDs (ml_llm_list_loras)."""
241
+ return list(self.loras.keys())
242
+
243
+ def _switch_to_base_model(self) -> None:
244
+ """Switch to the base model (no LoRA)."""
245
+ try:
246
+ # Reload the base model
247
+ self.model, self.tokenizer = load(self.model_path)
248
+ self.current_lora_id = -1
249
+ self._reset_cache() # Reset cache when switching models
250
+ except Exception as e:
251
+ raise RuntimeError(f"Failed to switch to base model: {str(e)}")
252
+
253
+ def _switch_to_lora(self, lora_id: int) -> None:
254
+ """Switch to a specific LoRA adapter."""
255
+ if lora_id not in self.loras:
256
+ raise ValueError(f"LoRA adapter with ID {lora_id} not found")
257
+
258
+ try:
259
+ lora_path, adapters = self.loras[lora_id]
260
+
261
+ # Load model with LoRA adapter
262
+ self.model, self.tokenizer = load(self.model_path, adapter_path=lora_path)
263
+ self.current_lora_id = lora_id
264
+ self._reset_cache() # Reset cache when switching models
265
+ except Exception as e:
266
+ raise RuntimeError(f"Failed to switch to LoRA adapter {lora_id} (path: {lora_path}): {str(e)}")
267
+
268
+ def get_current_lora_id(self) -> int:
269
+ """Get the currently active LoRA adapter ID."""
270
+ return self.current_lora_id
271
+
272
+ def get_lora_info(self, lora_id: int) -> dict:
273
+ """Get information about a specific LoRA adapter."""
274
+ if lora_id not in self.loras:
275
+ raise ValueError(f"LoRA adapter with ID {lora_id} not found")
276
+
277
+ lora_path, adapters = self.loras[lora_id]
278
+ return {
279
+ "id": lora_id,
280
+ "path": lora_path,
281
+ "is_active": lora_id == self.current_lora_id,
282
+ "config": getattr(adapters, "config", None) if hasattr(adapters, "config") else None
283
+ }
284
+
285
+ def load_and_activate_lora(self, lora_path: Path) -> int:
286
+ """Load a LoRA adapter and immediately activate it."""
287
+ lora_id = self.add_lora(lora_path)
288
+ self.set_lora(lora_id)
289
+ return lora_id
290
+
291
+ # Sampler methods
292
+ def set_sampler(self, config: SamplerConfig) -> None:
293
+ """Configure text generation sampling parameters (ml_llm_set_sampler)."""
294
+ self.sampler_config = config
295
+
296
+ def reset_sampler(self) -> None:
297
+ """Reset sampling parameters to defaults (ml_llm_reset_sampler)."""
298
+ self.sampler_config = SamplerConfig()
299
+
300
+ # Generation config methods
301
+ def set_generation_config(self, config: GenerationConfig) -> None:
302
+ """Set default generation configuration for token-level generation."""
303
+ self.default_generation_config = config
304
+
305
+ def _make_mlx_sampler_from_config(self, sampler_config: SamplerConfig):
306
+ """Create mlx-lm sampler from specific config."""
307
+ # Set seed if specified
308
+ if sampler_config.seed != -1:
309
+ mx.random.seed(sampler_config.seed)
310
+
311
+ return make_sampler(
312
+ temp=sampler_config.temperature,
313
+ top_p=sampler_config.top_p,
314
+ top_k=sampler_config.top_k,
315
+ )
316
+
317
+ def _make_logits_processors_from_config(self, sampler_config: SamplerConfig):
318
+ """Create logits processors from specific config."""
319
+ # Only use repetition penalty which is natively supported by mlx-lm
320
+ if sampler_config.repetition_penalty != 1.0:
321
+ return make_logits_processors(
322
+ repetition_penalty=sampler_config.repetition_penalty,
323
+ )
324
+ return None
325
+
326
+ def _make_mlx_sampler(self):
327
+ """Create mlx-lm sampler from class config."""
328
+ return self._make_mlx_sampler_from_config(self.sampler_config)
329
+
330
+ def _make_logits_processors(self):
331
+ """Create logits processors from class config."""
332
+ return self._make_logits_processors_from_config(self.sampler_config)
333
+
334
+ def generate_stream(
335
+ self,
336
+ prompt: str,
337
+ config: Optional[GenerationConfig],
338
+ on_token: TokenCallback,
339
+ user_data: Any = None,
340
+ ) -> str:
341
+ """
342
+ Generate text with streaming callback and profiling.
343
+
344
+ The prompt should be the incremental part after applying chat template.
345
+ apply_chat_template now returns only the incremental prompt based on global_n_past:
346
+ - First round (global_n_past = 0): Last user message + last system message (if exists)
347
+ - Subsequent rounds (global_n_past > 0): Only last user message
348
+
349
+ Prompt Cache Behavior:
350
+ - Tracks global_n_past to know how many tokens (prompts + responses) have been processed
351
+ - Passes incremental token arrays directly to stream_generate as prompt cache already contains the past history
352
+ - KV cache retains the conversation context until reset() is called
353
+ """
354
+ # Start profiling
355
+ self._start_profiling()
356
+
357
+ if config is None:
358
+ config = GenerationConfig()
359
+
360
+ # Use sampler config from GenerationConfig if provided, otherwise use class config
361
+ effective_sampler_config = config.sampler_config if config.sampler_config else self.sampler_config
362
+
363
+ # Create sampler from effective config
364
+ sampler = self._make_mlx_sampler_from_config(effective_sampler_config)
365
+ logits_processors = self._make_logits_processors_from_config(effective_sampler_config)
366
+
367
+ is_first_round = self.global_n_past <= 0
368
+
369
+ # Encode prompt to get tokens
370
+ incremental_tokens = self.encode(prompt)
371
+ cached_tokens = 0
372
+
373
+ # Only offset prefix kv-cache at first round
374
+ # if is_first_round:
375
+
376
+ # # Handle KV cache prefix offset if available
377
+ # if self.kv_cache is not None and len(self.kv_cache) > 0:
378
+ # # Get the offset from the first cache layer
379
+ # if hasattr(self.kv_cache[0], 'offset'):
380
+ # cached_tokens = self.kv_cache[0].offset - 1
381
+
382
+ # # Process only the non-cached tokens
383
+ # incremental_tokens = incremental_tokens[cached_tokens:] if cached_tokens > 0 else incremental_tokens
384
+
385
+ # if len(incremental_tokens) == 0:
386
+ # raise ValueError("No tokens to process, KV cache is too long.")
387
+
388
+ # Since apply_chat_template now returns incremental prompts, we can use the prompt directly
389
+ # The prompt is already the incremental part based on global_n_past
390
+ incremental_length = len(incremental_tokens)
391
+
392
+ # Record prompt tokens for profiling (use incremental length for this call)
393
+ self._update_prompt_tokens(incremental_length)
394
+
395
+ generated_tokens = 0
396
+ full_text = ""
397
+ last_response = None
398
+ first_token = True
399
+
400
+ try:
401
+ # End prompt processing, start decode
402
+ self._prompt_end()
403
+ self._decode_start()
404
+
405
+ for response in stream_generate(
406
+ model=self.model,
407
+ tokenizer=self.tokenizer,
408
+ prompt=incremental_tokens,
409
+ max_tokens=config.max_tokens,
410
+ sampler=sampler,
411
+ logits_processors=logits_processors if logits_processors else None,
412
+ prompt_cache=self.kv_cache,
413
+ ):
414
+ # Record TTFT on first token
415
+ if first_token:
416
+ self._record_ttft()
417
+ first_token = False
418
+
419
+ token_text = response.text
420
+ generated_tokens += 1
421
+
422
+ # Call the token callback - if it returns False, stop generation
423
+ if not on_token(token_text, user_data):
424
+ self._set_stop_reason(StopReason.ML_STOP_REASON_USER)
425
+ break
426
+ full_text += token_text
427
+ last_response = response
428
+
429
+ # Set stop reason based on how generation ended
430
+ if generated_tokens >= config.max_tokens:
431
+ self._set_stop_reason(StopReason.ML_STOP_REASON_LENGTH)
432
+ elif self._profiling_context.stop_reason != StopReason.ML_STOP_REASON_USER: # Don't override user stop
433
+ # Check if the last response indicates EOS stop
434
+ if last_response:
435
+ if hasattr(last_response, 'finish_reason') and last_response.finish_reason == "stop":
436
+ self._set_stop_reason(StopReason.ML_STOP_REASON_EOS)
437
+ else:
438
+ self._set_stop_reason(StopReason.ML_STOP_REASON_COMPLETED)
439
+ else:
440
+ # Fallback: generation loop ended naturally, likely due to EOS
441
+ self._set_stop_reason(StopReason.ML_STOP_REASON_EOS)
442
+
443
+ # Update global_n_past to reflect the new tokens processed (incremental prompt + response)
444
+ # Use the response metadata to get accurate token counts
445
+ self.global_n_past += cached_tokens + incremental_length + last_response.generation_tokens
446
+
447
+ # Mark cache as used after successful generation
448
+ self.kv_cache_used = True
449
+
450
+ # Update generated tokens and end profiling
451
+ self._update_generated_tokens(generated_tokens)
452
+ self._decode_end()
453
+ self._end_profiling()
454
+
455
+ return full_text
456
+ except Exception as e:
457
+ import traceback
458
+ self._set_stop_reason(StopReason.ML_STOP_REASON_UNKNOWN)
459
+ self._decode_end()
460
+ self._end_profiling()
461
+ return f"Streaming generation error: {str(e)}\n{traceback.format_exc()}"
462
+
463
+ # Chat template methods
464
+ def get_chat_template(self, template_name: str) -> str:
465
+ """Get chat template by name."""
466
+ # The header expects a template_name argument, but mlx-lm only supports one template.
467
+ # We'll ignore the argument for now.
468
+ return self.tokenizer.chat_template
469
+
470
+ def apply_chat_template(self, messages: Sequence[ChatMessage], tools: Optional[str] = None, enable_thinking: bool = True, add_generation_prompt: bool = True) -> str:
471
+ """
472
+ Apply chat template to messages with incremental prompt support and optional tools.
473
+
474
+ This method now returns only the incremental prompt based on global_n_past:
475
+ - When global_n_past = 0 (first conversation): Last user message + last system message (if exists)
476
+ - When global_n_past > 0 (subsequent rounds): Only last user message
477
+ """
478
+ # TODO: this is temporary solution to account for the no-thinking requirement of GPT-OSS. In the long term we need to revisit the API design of apply_chat_template.
479
+ try:
480
+ # Check global_n_past > 0 to determine if this is the first round of conversation
481
+ is_first_round = self.global_n_past <= 0
482
+
483
+ # Find last user message and last system message
484
+ last_user_msg = None
485
+ last_system_msg = None
486
+
487
+ for msg in messages:
488
+ if msg.role == "user":
489
+ last_user_msg = msg
490
+ elif msg.role == "system":
491
+ last_system_msg = msg
492
+
493
+ # Build incremental message list based on conversation round
494
+ if is_first_round:
495
+ # First round: include system message (if exists) + last user message
496
+ incremental_messages = []
497
+ if last_system_msg:
498
+ incremental_messages.append({
499
+ "role": last_system_msg.role,
500
+ "content": last_system_msg.content
501
+ })
502
+
503
+ if last_user_msg:
504
+ incremental_messages.append({
505
+ "role": last_user_msg.role,
506
+ "content": last_user_msg.content
507
+ })
508
+ else:
509
+ raise ValueError("No user message found for first conversation round")
510
+
511
+ else:
512
+ # Subsequent rounds: only last user message
513
+ if last_user_msg:
514
+ incremental_messages = [{
515
+ "role": last_user_msg.role,
516
+ "content": last_user_msg.content
517
+ }]
518
+ else:
519
+ raise ValueError("No user message found for subsequent conversation round")
520
+
521
+ parsed_tools = None
522
+ if tools is not None:
523
+ parsed_tools = json.loads(tools)
524
+
525
+ return self.tokenizer.apply_chat_template(
526
+ incremental_messages,
527
+ tokenize=False,
528
+ enable_thinking=enable_thinking,
529
+ add_generation_prompt=add_generation_prompt,
530
+ tools=parsed_tools
531
+ )
532
+ except Exception as e:
533
+ import traceback
534
+ raise RuntimeError(f"Error applying chat template: {str(e)}\n{traceback.format_exc()}")
535
+
536
+ # Embeddings - using the model's embedding layer directly
537
+ def embed(
538
+ self,
539
+ texts: Sequence[str],
540
+ config: Optional[EmbeddingConfig] = None,
541
+ ) -> List[List[float]]:
542
+ """Generate embeddings for texts with profiling."""
543
+ # Start profiling
544
+ self._start_profiling()
545
+
546
+ # Calculate total tokens for all texts
547
+ total_tokens = sum(len(self.encode(text)) for text in texts)
548
+ self._update_prompt_tokens(total_tokens)
549
+
550
+ # End prompt processing, start decode
551
+ self._prompt_end()
552
+ self._decode_start()
553
+
554
+ try:
555
+ embeddings = []
556
+
557
+ for text in texts:
558
+ # Tokenize the text
559
+ tokens = self.encode(text)
560
+
561
+ # Convert to mlx array
562
+ token_array = mx.array(tokens)
563
+
564
+ # Get embeddings directly from the model's embedding layer
565
+ embedding_tensor = self.model.model.embed_tokens(token_array)
566
+
567
+ # Average pool across sequence dimension to get a single embedding per text
568
+ # Shape: [seq_len, hidden_size] -> [hidden_size]
569
+ pooled_embedding = mx.mean(embedding_tensor, axis=0)
570
+
571
+ # Convert to Python list of floats
572
+ embedding_list = pooled_embedding.tolist()
573
+ embeddings.append(embedding_list)
574
+
575
+ # End timing and finalize profiling data
576
+ self._update_generated_tokens(0) # No generation in embedding
577
+ self._set_stop_reason(StopReason.ML_STOP_REASON_COMPLETED)
578
+ self._decode_end()
579
+ self._end_profiling()
580
+
581
+ return embeddings
582
+
583
+ except Exception as e:
584
+ self._set_stop_reason(StopReason.ML_STOP_REASON_UNKNOWN)
585
+ self._decode_end()
586
+ self._end_profiling()
587
+ raise RuntimeError(f"Error generating embeddings: {str(e)}")
588
+
589
+ # =============================================================================
590
+ # Test functions
591
+ # =============================================================================
592
+ # Add test functions at the bottom before the main conversation test
593
+ def test_kv_cache_save_load():
594
+ """Test KV cache save and load functionality"""
595
+ print("Testing KV cache save and load...")
596
+
597
+ # Initialize model
598
+ model_path = "mlx-community/Qwen3-1.7B-4bit-DWQ"
599
+ config = ModelConfig()
600
+ config.n_ctx = 512
601
+
602
+ llm = LLM(model_path, model_path, config)
603
+
604
+ def stream_callback(token, user_data):
605
+ print(token, end="", flush=True)
606
+ return True
607
+
608
+ # Test prompt
609
+ test_prompt = "🥳 🎂 Once upon a time"
610
+
611
+ # Test save
612
+ print("Testing KV cache save...")
613
+ gen_config = GenerationConfig()
614
+ gen_config.max_tokens = 20 # Generate enough tokens to populate cache
615
+
616
+ print("Generating text to populate cache:")
617
+ response = llm.generate_stream(test_prompt, gen_config, stream_callback)
618
+ print(f"\nGenerated: {response}")
619
+
620
+ cache_path = "./test_kvcache_save.safetensors"
621
+ save_result = llm.save_kv_cache(cache_path)
622
+ print(f"Save result: {save_result}")
623
+ assert save_result == True, "KV cache save should succeed"
624
+
625
+ # Reset cache
626
+ llm.reset()
627
+
628
+ # Test load
629
+ print("Testing KV cache load...")
630
+ cache_path = "./test_kvcache_load.safetensors"
631
+
632
+ # First generate and save
633
+ response = llm.generate_stream(test_prompt, gen_config, stream_callback)
634
+ save_result = llm.save_kv_cache(cache_path)
635
+ assert save_result == True, "KV cache save should succeed"
636
+
637
+ # Reset and load
638
+ llm.reset()
639
+ load_result = llm.load_kv_cache(cache_path)
640
+ print(f"Load result: {load_result}")
641
+ assert load_result == True, "KV cache load should succeed"
642
+
643
+ print("KV cache save/load tests passed!")
644
+
645
+ def test_tokenization():
646
+ """Test encode and decode functionality"""
647
+ print("Testing tokenization...")
648
+
649
+ model_path = "mlx-community/Qwen3-1.7B-4bit-DWQ"
650
+ config = ModelConfig()
651
+
652
+ llm = LLM(model_path, model_path, config)
653
+
654
+ test_text = "🥳 🎂 Once upon a time"
655
+
656
+ # Test encode
657
+ token_ids = llm.encode(test_text)
658
+ print(f"Encoded '{test_text}' to {len(token_ids)} tokens")
659
+ assert len(token_ids) > 0, "Encoding should produce tokens"
660
+
661
+ # Test decode
662
+ decoded_text = llm.decode(token_ids)
663
+ print(f"Decoded back to: '{decoded_text}'")
664
+ assert len(decoded_text) > 0, "Decoding should produce text"
665
+
666
+ print("Tokenization tests passed!")
667
+
668
+ def test_generation():
669
+ """Test basic text generation"""
670
+ print("Testing generation...")
671
+
672
+ model_path = "mlx-community/Qwen3-1.7B-4bit-DWQ"
673
+ config = ModelConfig()
674
+
675
+ llm = LLM(model_path, model_path, config)
676
+
677
+ def stream_callback(token, user_data):
678
+ print(token, end="", flush=True)
679
+ return True
680
+
681
+ test_prompt = "🥳 🎂 Once upon a time"
682
+ gen_config = GenerationConfig()
683
+ gen_config.max_tokens = 10
684
+
685
+ print("Generating text:")
686
+ response = llm.generate_stream(test_prompt, gen_config, stream_callback)
687
+ print(f"\nGenerated response length: {len(response)}")
688
+ assert len(response) > 0, "Generation should produce text"
689
+
690
+ print("Generation test passed!")
691
+
692
+ def run_tests():
693
+ """Run all test cases"""
694
+ try:
695
+ test_tokenization()
696
+ print()
697
+ test_generation()
698
+ print()
699
+ test_kv_cache_save_load()
700
+ print()
701
+ print("All tests passed! ✅")
702
+ except Exception as e:
703
+ print(f"Test failed: {e}")
704
+ import traceback
705
+ traceback.print_exc()
706
+
707
+ # For testing
708
+ if __name__ == "__main__":
709
+ import sys
710
+
711
+ # Check if running tests
712
+ if len(sys.argv) > 1 and sys.argv[1] == "test":
713
+ run_tests()
714
+ sys.exit(0)
715
+
716
+ def on_token(token_text, user_data):
717
+ """Token callback that prints each token as it's generated"""
718
+ print(token_text, end="", flush=True)
719
+ return True # Continue generation
720
+
721
+ # Multi-round conversation test case
722
+ model_path = "mlx-community/Qwen3-1.7B-4bit-DWQ"
723
+ tokenizer_path = "mlx-community/Qwen3-1.7B-4bit-DWQ"
724
+ config = ModelConfig()
725
+
726
+ llm = LLM(model_path, tokenizer_path, config)
727
+
728
+ # Run tests
729
+ print("================================================")
730
+ print("Running tests")
731
+ run_tests()
732
+ print("================================================")
733
+
734
+ # Multi-round conversation test case
735
+ chat = []
736
+ print("Multi-round conversation test. Type 'exit' to quit.")
737
+
738
+ while True:
739
+ try:
740
+ user_input = input("User: ").strip()
741
+
742
+ # Exit conditions
743
+ if user_input.lower() in ['exit', 'quit', '']:
744
+ break
745
+
746
+ # Add user message to chat history
747
+ chat.append(ChatMessage(role="user", content=user_input))
748
+
749
+ # Apply chat template to get full conversation history as formatted prompt
750
+ formatted_prompt = llm.apply_chat_template(chat)
751
+ # Generate response using streaming with on_token callback
752
+ print("Assistant: ", end="", flush=True) # Following generate.py pattern
753
+ response = llm.generate_stream(formatted_prompt, None, on_token)
754
+
755
+ # Add assistant response to chat history for next round
756
+ chat.append(ChatMessage(role="assistant", content=response))
757
+ print() # New line after response
758
+
759
+ except KeyboardInterrupt:
760
+ print("\nConversation interrupted by user.")
761
+ break
762
+ except Exception as e:
763
+ print(f"Error: {e}")
764
+ continue