nexaai 1.0.16rc13__cp310-cp310-macosx_15_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (557) hide show
  1. nexaai/__init__.py +83 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +64 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +92 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +44 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +4 -0
  10. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/libnexa_bridge.dylib +0 -0
  13. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  14. nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
  15. nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
  16. nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
  17. nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
  18. nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
  19. nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
  20. nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
  21. nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
  22. nexaai/binds/nexa_mlx/py-lib/ml.py +888 -0
  23. nexaai/binds/nexa_mlx/py-lib/mlx_audio/__init__.py +0 -0
  24. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/__init__.py +1 -0
  25. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  26. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  27. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  28. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  29. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  30. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  31. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  32. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  33. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  34. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  35. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  36. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  37. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  38. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  39. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  40. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  41. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  42. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  43. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  44. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  45. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  46. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  47. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  48. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  49. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  50. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  51. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  52. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  53. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  54. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  55. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  56. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  57. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  58. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  59. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  60. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  61. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  62. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  63. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  64. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  65. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  66. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  67. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  68. nexaai/binds/nexa_mlx/py-lib/mlx_audio/server.py +525 -0
  69. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/__init__.py +0 -0
  70. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  71. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  72. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/__init__.py +0 -0
  73. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/generate.py +174 -0
  74. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  75. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  76. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  77. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  78. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  79. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  80. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  81. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  82. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  83. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  84. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  85. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  86. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  87. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  88. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  89. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  90. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  91. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  92. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  93. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  94. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/utils.py +195 -0
  95. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/__init__.py +1 -0
  96. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/audio_player.py +120 -0
  97. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/convert.py +71 -0
  98. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/generate.py +449 -0
  99. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  100. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  101. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  102. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  103. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  104. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/base.py +84 -0
  105. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  106. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  107. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  108. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  109. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  110. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  111. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  112. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  113. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  114. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  115. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  116. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  117. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  118. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  119. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  120. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  121. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  122. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  123. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  124. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  125. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  126. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  127. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  128. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  129. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  130. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  131. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  132. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  133. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  134. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  135. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  136. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  137. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  138. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  139. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  140. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  141. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  142. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  143. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  144. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  145. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  146. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  147. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  148. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  149. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  150. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  151. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  152. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  153. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  154. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  155. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  156. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  157. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  158. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  159. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  160. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  161. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  162. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  163. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  164. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  165. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  166. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  167. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  168. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  169. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  170. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/utils.py +337 -0
  171. nexaai/binds/nexa_mlx/py-lib/mlx_audio/utils.py +237 -0
  172. nexaai/binds/nexa_mlx/py-lib/mlx_audio/version.py +1 -0
  173. nexaai/binds/nexa_mlx/py-lib/profiling.py +239 -0
  174. nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
  175. nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
  176. nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
  177. nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
  178. nexaai/binds/nexa_nexaml/libnexa-mm-process.dylib +0 -0
  179. nexaai/binds/nexa_nexaml/libnexa-sampling.dylib +0 -0
  180. nexaai/binds/nexa_nexaml/libnexa_plugin.dylib +0 -0
  181. nexaai/binds/nexa_nexaml/libnexaproc.dylib +0 -0
  182. nexaai/binds/nexa_nexaml/libqwen3-vl.dylib +0 -0
  183. nexaai/binds/nexa_nexaml/libqwen3vl-vision.dylib +0 -0
  184. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  185. nexaai/common.py +104 -0
  186. nexaai/cv.py +92 -0
  187. nexaai/cv_impl/__init__.py +0 -0
  188. nexaai/cv_impl/mlx_cv_impl.py +89 -0
  189. nexaai/cv_impl/pybind_cv_impl.py +32 -0
  190. nexaai/embedder.py +72 -0
  191. nexaai/embedder_impl/__init__.py +0 -0
  192. nexaai/embedder_impl/mlx_embedder_impl.py +116 -0
  193. nexaai/embedder_impl/pybind_embedder_impl.py +95 -0
  194. nexaai/image_gen.py +140 -0
  195. nexaai/image_gen_impl/__init__.py +0 -0
  196. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  197. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  198. nexaai/llm.py +96 -0
  199. nexaai/llm_impl/__init__.py +0 -0
  200. nexaai/llm_impl/mlx_llm_impl.py +269 -0
  201. nexaai/llm_impl/pybind_llm_impl.py +218 -0
  202. nexaai/log.py +92 -0
  203. nexaai/mlx_backend/asr/__init__.py +12 -0
  204. nexaai/mlx_backend/asr/interface.py +122 -0
  205. nexaai/mlx_backend/common/__init__.py +0 -0
  206. nexaai/mlx_backend/common/utils.py +25 -0
  207. nexaai/mlx_backend/cv/__init__.py +0 -0
  208. nexaai/mlx_backend/cv/generate.py +195 -0
  209. nexaai/mlx_backend/cv/interface.py +151 -0
  210. nexaai/mlx_backend/cv/main.py +81 -0
  211. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  212. nexaai/mlx_backend/embedding/__init__.py +0 -0
  213. nexaai/mlx_backend/embedding/generate.py +333 -0
  214. nexaai/mlx_backend/embedding/interface.py +617 -0
  215. nexaai/mlx_backend/embedding/main.py +173 -0
  216. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  217. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  218. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  219. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  220. nexaai/mlx_backend/image_gen/interface.py +82 -0
  221. nexaai/mlx_backend/image_gen/main.py +281 -0
  222. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  223. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  224. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  225. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  226. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  227. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  228. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  229. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  230. nexaai/mlx_backend/llm/__init__.py +0 -0
  231. nexaai/mlx_backend/llm/generate.py +149 -0
  232. nexaai/mlx_backend/llm/interface.py +764 -0
  233. nexaai/mlx_backend/llm/main.py +68 -0
  234. nexaai/mlx_backend/ml.py +888 -0
  235. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  236. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  237. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  238. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  239. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  240. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  241. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  242. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  243. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  244. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  245. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  246. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  247. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  248. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  272. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  273. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  274. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  275. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  276. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  277. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  278. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  279. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  280. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  281. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  282. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  283. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  284. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  286. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  287. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  288. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  289. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  290. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  291. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  292. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  293. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  294. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  295. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  296. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  297. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  305. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  306. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  307. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  308. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  309. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  310. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  311. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  312. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  313. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  314. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  315. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  316. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  317. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  318. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  319. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  320. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  321. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  322. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  378. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  379. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  380. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  381. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  382. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  383. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  384. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  385. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  386. nexaai/mlx_backend/profiling.py +239 -0
  387. nexaai/mlx_backend/rerank/__init__.py +0 -0
  388. nexaai/mlx_backend/rerank/generate.py +174 -0
  389. nexaai/mlx_backend/rerank/interface.py +287 -0
  390. nexaai/mlx_backend/rerank/main.py +127 -0
  391. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  392. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  393. nexaai/mlx_backend/sd/__init__.py +1 -0
  394. nexaai/mlx_backend/sd/interface.py +362 -0
  395. nexaai/mlx_backend/sd/main.py +286 -0
  396. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  397. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  398. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  399. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  400. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  401. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  402. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  403. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  404. nexaai/mlx_backend/tts/__init__.py +12 -0
  405. nexaai/mlx_backend/tts/interface.py +276 -0
  406. nexaai/mlx_backend/vlm/__init__.py +3 -0
  407. nexaai/mlx_backend/vlm/generate.py +572 -0
  408. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +261 -0
  409. nexaai/mlx_backend/vlm/interface.py +415 -0
  410. nexaai/mlx_backend/vlm/main.py +316 -0
  411. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  412. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  413. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  414. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  415. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  416. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  417. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  418. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  419. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  420. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  421. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  422. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  423. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  424. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  425. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  426. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  427. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  429. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  430. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  431. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  432. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  433. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  434. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  435. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  436. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  437. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  438. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  439. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  440. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  441. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  442. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  443. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  444. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  445. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  446. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  447. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  448. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  449. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  450. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  451. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  452. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  453. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  454. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  455. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  456. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  457. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  458. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  459. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  460. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  461. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  462. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  463. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  464. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  465. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  466. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  467. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  468. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  469. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  470. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  471. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  475. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  476. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  477. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  478. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  479. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  480. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  481. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  482. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  483. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  484. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  485. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  486. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  487. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  488. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  489. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  490. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  492. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  493. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  494. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  496. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  497. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  498. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  499. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  500. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  501. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  502. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  503. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  504. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  505. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  506. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  507. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  508. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  509. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  510. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  511. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  512. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  513. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  514. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  515. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
  522. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  523. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  524. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  525. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  526. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  527. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  528. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  529. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  530. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  531. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  532. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  533. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  534. nexaai/rerank.py +55 -0
  535. nexaai/rerank_impl/__init__.py +0 -0
  536. nexaai/rerank_impl/mlx_rerank_impl.py +92 -0
  537. nexaai/rerank_impl/pybind_rerank_impl.py +43 -0
  538. nexaai/runtime.py +68 -0
  539. nexaai/tts.py +74 -0
  540. nexaai/tts_impl/__init__.py +0 -0
  541. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  542. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  543. nexaai/utils/avatar_fetcher.py +104 -0
  544. nexaai/utils/decode.py +18 -0
  545. nexaai/utils/manifest_utils.py +324 -0
  546. nexaai/utils/model_manager.py +1353 -0
  547. nexaai/utils/model_types.py +47 -0
  548. nexaai/utils/progress_tracker.py +385 -0
  549. nexaai/utils/quantization_utils.py +245 -0
  550. nexaai/vlm.py +128 -0
  551. nexaai/vlm_impl/__init__.py +0 -0
  552. nexaai/vlm_impl/mlx_vlm_impl.py +258 -0
  553. nexaai/vlm_impl/pybind_vlm_impl.py +230 -0
  554. nexaai-1.0.16rc13.dist-info/METADATA +32 -0
  555. nexaai-1.0.16rc13.dist-info/RECORD +557 -0
  556. nexaai-1.0.16rc13.dist-info/WHEEL +5 -0
  557. nexaai-1.0.16rc13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,261 @@
1
+ import argparse
2
+ import json
3
+ import sys
4
+ import os
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+ import time
8
+ from PIL import Image
9
+ import requests
10
+ import numpy as np
11
+ from pathlib import Path
12
+ from huggingface_hub import snapshot_download
13
+
14
+ # Add current directory to path for imports
15
+ curr_dir = os.path.dirname(os.path.abspath(__file__))
16
+ sys.path.append(curr_dir)
17
+ sys.path.append(os.path.dirname(curr_dir))
18
+
19
+ # Add the qwen3vl model directory to path
20
+ qwen3vl_dir = os.path.join(curr_dir, "modeling", "models", "qwen3_vl")
21
+ sys.path.append(qwen3vl_dir)
22
+
23
+ # Import required modules for quantized loading
24
+ from transformers import AutoTokenizer
25
+
26
+ # Try relative imports first, fallback to sys.path approach for Nuitka compatibility
27
+ try:
28
+ from .modeling.models.qwen3_vl.llm_common.generate import nexa_generate_step
29
+ from .modeling.models.qwen3_vl.llm_common.cache import make_prompt_cache
30
+ from .modeling.models.qwen3_vl.qwen3vl import (
31
+ VEGModel, LLMModel, ModelArgs, VisionConfig, TextConfig, handle_multimodal_embeds
32
+ )
33
+ from .modeling.models.qwen3_vl.processor import Qwen3VLProcessor
34
+ except ImportError:
35
+ # Fallback for Nuitka compiled environment - use sys.path approach
36
+ from llm_common.generate import nexa_generate_step
37
+ from llm_common.cache import make_prompt_cache
38
+ from qwen3vl import VEGModel, LLMModel, ModelArgs, VisionConfig, TextConfig, handle_multimodal_embeds
39
+ from processor import Qwen3VLProcessor
40
+
41
+ from ml import ChatMessage
42
+ from dataclasses import dataclass
43
+ from typing import Any, Generator, List, Optional, Sequence, Tuple, Union
44
+ from .generate import GenerationResult
45
+
46
+ @dataclass
47
+ class Qwen3VLBundledModel:
48
+ """Container for Qwen3-VL vision and language models."""
49
+ vision_model: VEGModel
50
+ llm_model: LLMModel
51
+
52
+
53
+ def _ensure_list(x: Union[str, List[str], None]) -> Optional[List[str]]:
54
+ if x is None:
55
+ return None
56
+ return x if isinstance(x, list) else [x]
57
+
58
+
59
+ def load_qwen3_vl(
60
+ path_or_repo: str,
61
+ adapter_path: Optional[str] = None,
62
+ lazy: bool = False,
63
+ revision: Optional[str] = None,
64
+ **kwargs,
65
+ ) -> Tuple[Qwen3VLBundledModel, Qwen3VLProcessor]:
66
+ """Load Qwen3-VL quantized models and processor.
67
+
68
+ Parameters are aligned with .generate.load for compatibility.
69
+ """
70
+ model_path = Path(path_or_repo)
71
+ if not model_path.exists():
72
+ if "/" in path_or_repo:
73
+ model_path = Path(snapshot_download(
74
+ repo_id=path_or_repo, repo_type="model", revision=revision))
75
+ else:
76
+ # Fallback to local modelfiles directory
77
+ model_path = Path(qwen3vl_dir) / "modelfiles"
78
+ if not model_path.exists():
79
+ model_path = Path(curr_dir) / "modelfiles"
80
+
81
+ # Model configs (kept identical to main)
82
+ vision_config = VisionConfig(
83
+ hidden_size=1024,
84
+ intermediate_size=4096,
85
+ num_heads=16,
86
+ num_hidden_layers=24,
87
+ patch_size=16,
88
+ temporal_patch_size=2,
89
+ in_channels=3,
90
+ hidden_act="gelu",
91
+ spatial_merge_size=2,
92
+ out_hidden_size=2560,
93
+ num_position_embeddings=2304,
94
+ deepstack_visual_indexes=[5, 11, 17],
95
+ )
96
+
97
+ text_config = TextConfig(
98
+ model_type="qwen3vl",
99
+ hidden_size=2560,
100
+ num_hidden_layers=36,
101
+ intermediate_size=9728,
102
+ num_attention_heads=32,
103
+ num_key_value_heads=8,
104
+ rms_norm_eps=1e-6,
105
+ vocab_size=151936,
106
+ max_position_embeddings=32768,
107
+ rope_theta=5000000.0,
108
+ head_dim=128,
109
+ tie_word_embeddings=True,
110
+ attention_bias=False,
111
+ attention_dropout=0.0,
112
+ rope_scaling={"mrope_section": [24, 20, 20],
113
+ "rope_type": "default", "type": "default"},
114
+ )
115
+
116
+ vision_model = VEGModel(vision_config)
117
+ llm_model = LLMModel(text_config)
118
+
119
+ # Try to load LLM model from available files in order of preference
120
+ preferred_order = [
121
+ ("qwen3vl-llm-4B-q4_0.safetensors", 4),
122
+ ("qwen3vl-llm-4B-q8_0.safetensors", 8),
123
+ ("qwen3vl-llm-4B-f32.safetensors", 32)
124
+ ]
125
+
126
+ llm_weights_path = None
127
+ quantization_bits = None
128
+
129
+ # Try loading in order of preference
130
+ for filename, bits in preferred_order:
131
+ candidate_path = model_path / filename
132
+ if candidate_path.exists():
133
+ llm_weights_path = candidate_path
134
+ quantization_bits = bits
135
+ break
136
+
137
+ if llm_weights_path is None:
138
+ # Fallback to original hardcoded path for backward compatibility
139
+ llm_weights_path = model_path / "qwen3vl-llm-4B-q4_0.safetensors"
140
+ quantization_bits = 4
141
+
142
+ vision_weights_path = model_path / "qwen3vl-vision-4B-f16.safetensors"
143
+
144
+ if not vision_weights_path.exists() or not llm_weights_path.exists():
145
+ raise FileNotFoundError(
146
+ f"Missing safetensors. Vision: {vision_weights_path}, LLM: {llm_weights_path}"
147
+ )
148
+
149
+ # Load weights (vision fp16, llm with detected quantization)
150
+ vision_model.set_dtype(mx.float16)
151
+ vision_model.load_weights(str(vision_weights_path), strict=True)
152
+
153
+ # Apply quantization if needed and load LLM weights
154
+ if quantization_bits in [4, 8]:
155
+ nn.quantize(llm_model, bits=quantization_bits, group_size=64,
156
+ class_predicate=quant_predicate)
157
+ # For f32 (32-bit), no quantization needed
158
+
159
+ llm_model.load_weights(str(llm_weights_path), strict=True)
160
+
161
+ # Tokenizer and processor
162
+ tokenizer = AutoTokenizer.from_pretrained(path_or_repo)
163
+ processor = Qwen3VLProcessor(tokenizer=tokenizer)
164
+
165
+ return Qwen3VLBundledModel(vision_model=vision_model, llm_model=llm_model), processor
166
+
167
+ def apply_chat_template_qwen3_vl(messages: Sequence[ChatMessage], num_images: int = 0, num_audios: int = 0, tools: Optional[str] = None, enable_thinking: bool = False) -> str:
168
+ """Apply chat template: serialize messages with content as a list of typed items."""
169
+ messages_dict = []
170
+ for msg in messages:
171
+ content_items = [{"type": "text", "text": msg.content}]
172
+ messages_dict.append({"role": msg.role, "content": content_items})
173
+ return json.dumps(messages_dict)
174
+
175
+
176
+ def stream_generate_qwen3_vl(
177
+ model: Qwen3VLBundledModel,
178
+ processor: Qwen3VLProcessor,
179
+ prompt: str,
180
+ image: Union[str, List[str]] = None,
181
+ audio: Union[str, List[str]] = None,
182
+ max_tokens: int = 512,
183
+ **kwargs,
184
+
185
+ ) -> Generator[Any, None, None]:
186
+ """Stream generation yielding .generate.GenerationResult-compatible chunks."""
187
+ messages = json.loads(prompt)
188
+ if image is not None:
189
+ image_list = image if isinstance(image, list) else [image]
190
+ pil_images = []
191
+ for p in image_list:
192
+ try:
193
+ pil_images.append(Image.open(p))
194
+ except Exception:
195
+ continue
196
+ contents = [{"type": "image", "image": img} for img in pil_images]
197
+ if messages:
198
+ if "content" not in messages[-1] or not isinstance(messages[-1]["content"], list):
199
+ messages[-1]["content"] = []
200
+ messages[-1]["content"].extend(contents)
201
+
202
+ raw_text, processed_images = processor.messages_to_text(
203
+ messages, add_generation_prompt=True)
204
+
205
+ inputs = processor.text_to_input_ids(
206
+ raw_text, images=processed_images, return_tensors="mlx")
207
+
208
+ input_ids = inputs["input_ids"]
209
+ pixel_values = inputs.get("pixel_values")
210
+ image_grid_thw = inputs.get("image_grid_thw")
211
+
212
+ inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas = handle_multimodal_embeds(
213
+ model.vision_model, model.llm_model, input_ids, pixel_values, image_grid_thw
214
+ )
215
+
216
+ prompt_cache = make_prompt_cache(model.llm_model, max_kv_size=4096)
217
+ tokenizer = processor.tokenizer
218
+
219
+ # Rough prompt TPS estimation based on input size
220
+ prompt_start = time.perf_counter()
221
+ prompt_tps = input_ids.size / max(1e-6, (time.perf_counter() - prompt_start))
222
+
223
+ gen_count = 0
224
+ tic = time.perf_counter()
225
+
226
+ for token, logprobs in nexa_generate_step(
227
+ model=model.llm_model,
228
+ prompt=None,
229
+ input_embeddings=inputs_embeds,
230
+ max_tokens=max_tokens,
231
+ max_kv_size=4096,
232
+ prompt_cache=prompt_cache,
233
+ visual_pos_masks=visual_pos_masks,
234
+ deepstack_visual_embeds=deepstack_visual_embeds,
235
+ cos=cos,
236
+ sin=sin,
237
+ rope_deltas=rope_deltas,
238
+ ):
239
+ if token == tokenizer.eos_token_id:
240
+ break
241
+
242
+ text_piece = tokenizer.decode([token])
243
+ gen_count += 1
244
+
245
+ yield GenerationResult(
246
+ text=text_piece,
247
+ token=token,
248
+ logprobs=logprobs,
249
+ prompt_tokens=int(input_ids.size),
250
+ generation_tokens=gen_count,
251
+ prompt_tps=float(prompt_tps),
252
+ generation_tps=float(
253
+ gen_count / max(1e-6, (time.perf_counter() - tic))),
254
+ peak_memory=float(mx.get_peak_memory() / 1e9),
255
+ )
256
+
257
+ def quant_predicate(path: str, mod: nn.Module) -> bool:
258
+ """Quantization predicate to exclude certain layers from quantization."""
259
+ if path.endswith("lm_head") or "norm" in path.lower() or "embed" in path.lower():
260
+ return False
261
+ return isinstance(mod, (nn.Linear, nn.Embedding))
@@ -0,0 +1,415 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import time
6
+ from typing import Any, List, Optional, Sequence, Tuple, Union
7
+ import mlx.core as mx
8
+ import codecs
9
+ from dataclasses import dataclass
10
+
11
+ # Import configs and callback types from ml.py for API alignment
12
+ from ml import (
13
+ VLM as BaseVLM,
14
+ SamplerConfig,
15
+ GenerationConfig,
16
+ ChatMessage,
17
+ EmbeddingConfig,
18
+ TokenCallback,
19
+ Path,
20
+ Tool, # Add Path alias for type hints
21
+ )
22
+
23
+ # Import profiling module
24
+ from profiling import ProfilingMixin, ProfilingData, StopReason
25
+
26
+ # Import from the actual mlx_vlm structure
27
+ from .generate import generate, stream_generate, load
28
+ from .generate_qwen3_vl import apply_chat_template_qwen3_vl, stream_generate_qwen3_vl, load_qwen3_vl
29
+
30
+ from .modeling.prompt_utils import apply_chat_template
31
+
32
+ # --------------------------------------------------------------------------------------
33
+ # Updated GenerationResult to match the new structure
34
+ # --------------------------------------------------------------------------------------
35
+
36
+ @dataclass
37
+ class GenerationResult:
38
+ text: str = ""
39
+ token: Optional[int] = None
40
+ logprobs: Optional[List[float]] = None
41
+ prompt_tokens: int = 0
42
+ generation_tokens: int = 0
43
+ total_tokens: int = 0
44
+ prompt_tps: float = 0.0
45
+ generation_tps: float = 0.0
46
+ peak_memory: float = 0.0
47
+ # --------------------------------------------------------------------------------------
48
+ # VLM (Vision-Language Model)
49
+ # --------------------------------------------------------------------------------------
50
+
51
+ class VLM(ProfilingMixin):
52
+ """
53
+ Vision-Language Models for mlx-vlm
54
+ API aligned with ml.py VLM abstract base class.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ model_name: Optional[str],
60
+ model_path: Path,
61
+ mmproj_path: Path,
62
+ context_length: int,
63
+ device: Optional[str] = None,
64
+ ) -> None:
65
+ # Initialize profiling mixin
66
+ ProfilingMixin.__init__(self)
67
+
68
+ # Check if model_path is a file, if so use its parent directory
69
+ if os.path.isfile(model_path):
70
+ model_path = os.path.dirname(model_path)
71
+
72
+ self.model_path = model_path
73
+ self.model_name = model_name
74
+ self.mmproj_path = mmproj_path
75
+ self.context_length = context_length
76
+ self.device = device
77
+
78
+ load_impl = load_qwen3_vl if model_name == "qwen3vl" else load
79
+ self.model, self.processor = load_impl(str(model_path))
80
+
81
+ # Init deafutl sampler config with defualt.
82
+ self.sampler_config = SamplerConfig()
83
+
84
+ def destroy(self) -> None:
85
+ """Destroy the model and free resources."""
86
+ self.model = None
87
+ self.processor = None
88
+
89
+ def reset(self) -> None:
90
+ """Reset the model state."""
91
+ self._reset_cache()
92
+
93
+ def _reset_cache(self) -> None:
94
+ """Reset the KV cache."""
95
+ # If the model has a cache, reset it
96
+ if hasattr(self.model, "cache"):
97
+ self.model.cache = None
98
+
99
+ # Tokenization
100
+ def encode(self, text: str) -> List[int]:
101
+ """Encode text to token IDs."""
102
+ return self.processor.encode(text)
103
+
104
+ def decode(self, token_ids: Sequence[int]) -> str:
105
+ """Decode token IDs to text."""
106
+ return self.processor.decode(token_ids)
107
+
108
+ # Sampler
109
+ def set_sampler(self, config: SamplerConfig) -> None:
110
+ """Set sampler configuration."""
111
+ self.sampler_config = config
112
+
113
+ def reset_sampler(self) -> None:
114
+ """Reset sampler to default configuration."""
115
+ self.sampler_config = None
116
+
117
+ # Generation
118
+ def generate(
119
+ self,
120
+ prompt: str,
121
+ config: Optional[GenerationConfig] = None,
122
+ ) -> GenerationResult:
123
+ """Generate text from prompt."""
124
+ # Start profiling
125
+ self._start_profiling()
126
+
127
+ gen_kwargs = {}
128
+ if config is not None:
129
+ gen_kwargs = config.__dict__.copy()
130
+ # Remove image_paths and audio_paths from config as they'll be handled separately
131
+ gen_kwargs.pop('image_paths', None)
132
+ gen_kwargs.pop('audio_paths', None)
133
+ if self.sampler_config is not None:
134
+ gen_kwargs.update(self.sampler_config.__dict__)
135
+
136
+ # Get image and audio paths from config
137
+ image_paths = config.image_paths if config else None
138
+ audio_paths = config.audio_paths if config else None
139
+
140
+ # Convert paths to strings for generate function
141
+ image_list = [str(path) for path in image_paths] if image_paths else None
142
+ audio_list = [str(path) for path in audio_paths] if audio_paths else None
143
+
144
+ # End prompt processing, start decode
145
+ self._prompt_end()
146
+ self._decode_start()
147
+
148
+ try:
149
+ # Start timing for generation
150
+ generation_start_time = time.perf_counter()
151
+
152
+ text, stats = generate(
153
+ self.model,
154
+ self.processor,
155
+ prompt,
156
+ image=image_list,
157
+ audio=audio_list,
158
+ **gen_kwargs,
159
+ )
160
+
161
+ # End timing for generation
162
+ generation_end_time = time.perf_counter()
163
+
164
+ # Calculate average time per token and estimate TTFT
165
+ generated_tokens = stats.get("output_tokens", 0)
166
+ if generated_tokens > 0:
167
+ total_generation_time = generation_end_time - generation_start_time
168
+ avg_time_per_token = total_generation_time / generated_tokens
169
+ # TTFT = prompt processing time + first token generation time
170
+ # This provides a more accurate estimate than the previous approximation
171
+ estimated_ttft = (self._profiling_context.prompt_end_time - self._profiling_context.prompt_start_time) + avg_time_per_token
172
+ # Update the profiling context with estimated TTFT
173
+ self._profiling_context.first_token_time = self._profiling_context.prompt_start_time + estimated_ttft
174
+ self._profiling_context.ttft_recorded = True
175
+ else:
176
+ # If no tokens generated, use total generation time as TTFT
177
+ self._record_ttft()
178
+
179
+ # Update profiling data
180
+ prompt_tokens = stats.get("input_tokens", 0)
181
+ self._update_prompt_tokens(prompt_tokens)
182
+ self._update_generated_tokens(generated_tokens)
183
+ self._set_stop_reason(StopReason.ML_STOP_REASON_COMPLETED)
184
+ self._decode_end()
185
+ self._end_profiling()
186
+
187
+ return GenerationResult(
188
+ text=text,
189
+ prompt_tokens=prompt_tokens,
190
+ generation_tokens=generated_tokens,
191
+ total_tokens=stats.get("total_tokens", 0),
192
+ prompt_tps=stats.get("prompt_tps", 0.0),
193
+ generation_tps=stats.get("generation_tps", 0.0),
194
+ peak_memory=stats.get("peak_memory", 0.0),
195
+ )
196
+ except Exception as e:
197
+ self._set_stop_reason(StopReason.ML_STOP_REASON_UNKNOWN)
198
+ self._decode_end()
199
+ self._end_profiling()
200
+ raise RuntimeError(f"Generation error: {str(e)}")
201
+
202
+ def generate_stream(
203
+ self,
204
+ prompt: str,
205
+ config: Optional[GenerationConfig],
206
+ on_token: Optional[TokenCallback],
207
+ ) -> GenerationResult:
208
+ """Generate text with streaming callback. Unified method for both text and multimodal generation."""
209
+ # Start profiling
210
+ self._start_profiling()
211
+
212
+ gen_kwargs = {}
213
+ if config is not None:
214
+ gen_kwargs = config.__dict__.copy()
215
+ # Remove image_paths and audio_paths from config as they'll be handled separately
216
+ gen_kwargs.pop('image_paths', None)
217
+ gen_kwargs.pop('audio_paths', None)
218
+ if self.sampler_config is not None:
219
+ gen_kwargs.update(self.sampler_config.__dict__)
220
+
221
+ # Get image and audio paths from config
222
+ image_paths = config.image_paths if config else None
223
+ audio_paths = config.audio_paths if config else None
224
+
225
+ # Convert paths to strings for stream_generate function
226
+ image_list = [str(path) for path in image_paths] if image_paths else None
227
+ audio_list = [str(path) for path in audio_paths] if audio_paths else None
228
+
229
+ # End prompt processing, start decode
230
+ self._prompt_end()
231
+ self._decode_start()
232
+
233
+ text = ""
234
+ last_result = None
235
+ first_token = True
236
+ stream_generate_impl = stream_generate_qwen3_vl if self.model_name == "qwen3vl" else stream_generate
237
+
238
+ try:
239
+ for result in stream_generate_impl(
240
+ self.model,
241
+ self.processor,
242
+ prompt,
243
+ image=image_list,
244
+ audio=audio_list,
245
+ **gen_kwargs,
246
+ ):
247
+ # Record TTFT on first token
248
+ if first_token:
249
+ self._record_ttft()
250
+ first_token = False
251
+
252
+ # Call the token callback if provided
253
+ if on_token is not None:
254
+ if not on_token(result.text):
255
+ self._set_stop_reason(StopReason.ML_STOP_REASON_USER)
256
+ break
257
+ text += result.text
258
+ last_result = result
259
+
260
+ # Set stop reason if not user stop
261
+ if self._profiling_context.stop_reason != StopReason.ML_STOP_REASON_USER:
262
+ self._set_stop_reason(StopReason.ML_STOP_REASON_EOS)
263
+
264
+ # Update profiling data
265
+ if last_result:
266
+ self._update_prompt_tokens(last_result.prompt_tokens)
267
+ self._update_generated_tokens(last_result.generation_tokens)
268
+
269
+ self._decode_end()
270
+ self._end_profiling()
271
+
272
+ return GenerationResult(
273
+ text=text,
274
+ token=last_result.token if last_result else None,
275
+ logprobs=last_result.logprobs if last_result else None,
276
+ prompt_tokens=last_result.prompt_tokens if last_result else 0,
277
+ generation_tokens=last_result.generation_tokens if last_result else 0,
278
+ total_tokens=(last_result.prompt_tokens + last_result.generation_tokens) if last_result else 0,
279
+ prompt_tps=last_result.prompt_tps if last_result else 0.0,
280
+ generation_tps=last_result.generation_tps if last_result else 0.0,
281
+ peak_memory=last_result.peak_memory if last_result else 0.0,
282
+ )
283
+ except Exception as e:
284
+ self._set_stop_reason(StopReason.ML_STOP_REASON_UNKNOWN)
285
+ self._decode_end()
286
+ self._end_profiling()
287
+ raise RuntimeError(f"Streaming generation error: {str(e)}")
288
+
289
+ # Legacy multimodal methods - kept for backward compatibility but delegate to unified method
290
+ def generate_multimodal(
291
+ self,
292
+ prompt: str,
293
+ image_paths: Optional[Sequence[Path]] = None,
294
+ audio_paths: Optional[Sequence[Path]] = None,
295
+ config: Optional[GenerationConfig] = None,
296
+ ) -> str:
297
+ """Generate text from prompt with multiple images and audio."""
298
+ # Create config with media paths if not provided
299
+ if config is None:
300
+ config = GenerationConfig()
301
+
302
+ # Update config with provided paths
303
+ if image_paths is not None:
304
+ config.image_paths = image_paths
305
+ if audio_paths is not None:
306
+ config.audio_paths = audio_paths
307
+
308
+ # Delegate to unified generate method and extract text
309
+ result = self.generate(prompt, config)
310
+ return result.text
311
+
312
+ def generate_stream_multimodal(
313
+ self,
314
+ prompt: str,
315
+ image_paths: Optional[Sequence[Path]] = None,
316
+ audio_paths: Optional[Sequence[Path]] = None,
317
+ config: Optional[GenerationConfig] = None,
318
+ on_token: Optional[TokenCallback] = None,
319
+ ) -> str:
320
+ """Generate text from prompt with multiple images and audio using streaming callback."""
321
+ # Create config with media paths if not provided
322
+ if config is None:
323
+ config = GenerationConfig()
324
+
325
+ # Update config with provided paths
326
+ if image_paths is not None:
327
+ config.image_paths = image_paths
328
+ if audio_paths is not None:
329
+ config.audio_paths = audio_paths
330
+
331
+ # Delegate to unified generate_stream method and extract text
332
+ result = self.generate_stream(prompt, config, on_token)
333
+ return result.text
334
+
335
+ def get_chat_template(self, template_name: str) -> str:
336
+ """Get chat template by name."""
337
+ # This is a stub; actual implementation depends on processor internals
338
+ if hasattr(self.processor, "get_chat_template"):
339
+ return self.processor.get_chat_template(template_name)
340
+ return ""
341
+
342
+ def apply_chat_template(self, messages: Sequence[ChatMessage], tools: Optional[str] = None, enable_thinking: bool = True) -> str:
343
+ """Apply chat template to messages with optional tools support."""
344
+ if hasattr(self.processor, "apply_chat_template"):
345
+ # Convert ChatMessage objects to dictionaries for the processor
346
+ messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages]
347
+
348
+ parsed_tools = None
349
+ if tools is not None and tools.strip():
350
+ parsed_tools = json.loads(tools)
351
+
352
+ result = apply_chat_template(self.processor, self.model.config, messages_dict, add_generation_prompt=True, enable_thinking=enable_thinking, tools=parsed_tools)
353
+ return result
354
+ # Fallback: join messages
355
+ return "\n".join([f"{m.role}: {m.content}" for m in messages])
356
+
357
+ def apply_chat_template_with_media(self, messages: Sequence[ChatMessage], num_images: int = 0, num_audios: int = 0, tools: Optional[str] = None, enable_thinking: bool = True) -> str:
358
+ """Apply chat template to messages with proper image/audio token insertion and optional tools support."""
359
+ if self.model_name == "qwen3vl":
360
+ return apply_chat_template_qwen3_vl(messages, num_images=num_images, num_audios=num_audios, tools=tools, enable_thinking=enable_thinking)
361
+
362
+ # Convert ChatMessage objects to dictionaries for the processor
363
+ messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages]
364
+
365
+ parsed_tools = None
366
+ if tools is not None and tools.strip():
367
+ parsed_tools = json.loads(tools)
368
+
369
+ # Use the same logic as generate.py
370
+ return apply_chat_template(
371
+ self.processor,
372
+ self.model.config,
373
+ messages_dict,
374
+ num_images=num_images,
375
+ num_audios=num_audios,
376
+ enable_thinking=enable_thinking,
377
+ tools=parsed_tools
378
+ )
379
+
380
+ # Embeddings
381
+ def embed(
382
+ self,
383
+ texts: Sequence[str],
384
+ config: Optional[EmbeddingConfig] = None,
385
+ ) -> List[List[float]]:
386
+ """Generate embeddings for texts with profiling."""
387
+ # Start profiling
388
+ self._start_profiling()
389
+
390
+ try:
391
+ # If processor/model supports embeddings, use it; otherwise, stub
392
+ if hasattr(self.model, "embed"):
393
+ embed_kwargs = config.__dict__ if config else {}
394
+
395
+ # End prompt processing, start decode
396
+ self._prompt_end()
397
+ self._decode_start()
398
+
399
+ result = self.model.embed(texts, **embed_kwargs)
400
+
401
+ # End timing and finalize profiling data
402
+ self._update_generated_tokens(0) # No generation in embedding
403
+ self._set_stop_reason(StopReason.ML_STOP_REASON_COMPLETED)
404
+ self._decode_end()
405
+ self._end_profiling()
406
+
407
+ return result
408
+ else:
409
+ raise NotImplementedError("Embedding not supported for this model.")
410
+
411
+ except Exception as e:
412
+ self._set_stop_reason(StopReason.ML_STOP_REASON_UNKNOWN)
413
+ self._decode_end()
414
+ self._end_profiling()
415
+ raise RuntimeError(f"Error generating embeddings: {str(e)}")