nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,215 @@
1
+ # Copied from transformers. Removed video-related code.
2
+ """
3
+ Processor class for Qwen2-VL.
4
+ """
5
+
6
+ from typing import Optional, Union
7
+
8
+ import numpy as np
9
+
10
+ from transformers.feature_extraction_utils import BatchFeature
11
+ from transformers.image_utils import ImageInput
12
+ from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
13
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
14
+ from transformers.utils import logging
15
+
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ class Qwen2VLImagesKwargs(ImagesKwargs):
21
+ min_pixels: Optional[int]
22
+ max_pixels: Optional[int]
23
+ patch_size: Optional[int]
24
+ temporal_patch_size: Optional[int]
25
+ merge_size: Optional[int]
26
+
27
+
28
+ class Qwen2VLProcessorKwargs(ProcessingKwargs, total=False):
29
+ images_kwargs: Qwen2VLImagesKwargs
30
+ _defaults = {
31
+ "text_kwargs": {
32
+ "padding": False,
33
+ "return_mm_token_type_ids": False,
34
+ },
35
+ }
36
+
37
+
38
+ class Qwen2VLProcessor(ProcessorMixin):
39
+ r"""
40
+ Constructs a Qwen2-VL processor which wraps a Qwen2-VL image processor and a Qwen2 tokenizer into a single processor.
41
+ [`Qwen2VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
42
+ [`~Qwen2VLProcessor.__call__`] and [`~Qwen2VLProcessor.decode`] for more information.
43
+ Args:
44
+ image_processor ([`Qwen2VLImageProcessor`], *optional*):
45
+ The image processor is a required input.
46
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
47
+ The tokenizer is a required input.
48
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
49
+ in a chat into a tokenizable string.
50
+ """
51
+
52
+ attributes = ["image_processor", "tokenizer"]
53
+ image_processor_class = "AutoImageProcessor"
54
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
55
+
56
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
57
+ self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
58
+ self.image_token_id = (
59
+ tokenizer.image_token_id
60
+ if getattr(tokenizer, "image_token_id", None)
61
+ else tokenizer.convert_tokens_to_ids(self.image_token)
62
+ )
63
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
64
+
65
+ def __call__(
66
+ self,
67
+ images: ImageInput = None,
68
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
69
+ **kwargs: Unpack[Qwen2VLProcessorKwargs],
70
+ ) -> BatchFeature:
71
+ """
72
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
73
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
74
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
75
+ Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
76
+
77
+ Args:
78
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
79
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
80
+ tensor. Both channels-first and channels-last formats are supported.
81
+ text (`str`, `list[str]`, `list[list[str]]`):
82
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
83
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
84
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
85
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
86
+ If set, will return tensors of a particular framework. Acceptable values are:
87
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
88
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
89
+ - `'np'`: Return NumPy `np.ndarray` objects.
90
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
91
+
92
+ Returns:
93
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
94
+
95
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
96
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
97
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
98
+ `None`).
99
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
100
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
101
+ """
102
+ output_kwargs = self._merge_kwargs(
103
+ Qwen2VLProcessorKwargs,
104
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
105
+ **kwargs,
106
+ )
107
+
108
+ image_inputs = {}
109
+ if images is not None:
110
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
111
+ image_grid_thw = image_inputs["image_grid_thw"]
112
+
113
+ if not isinstance(text, list):
114
+ text = [text]
115
+
116
+ text = text.copy() # below lines change text in-place
117
+
118
+ if images is not None:
119
+ merge_length = self.image_processor.merge_size**2
120
+ index = 0
121
+ for i in range(len(text)):
122
+ while self.image_token in text[i]:
123
+ num_image_tokens = image_grid_thw[index].prod() // merge_length
124
+ text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
125
+ index += 1
126
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
127
+
128
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
129
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
130
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
131
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
132
+
133
+ if return_mm_token_type_ids:
134
+ array_ids = np.array(text_inputs["input_ids"])
135
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
136
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
137
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
138
+
139
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
140
+
141
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
142
+ """
143
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
144
+ Args:
145
+ image_sizes (`list[list[int]]`, *optional*):
146
+ The input sizes formatted as (height, width) per each image.
147
+ Returns:
148
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
149
+ input modalities, along with other useful data.
150
+ """
151
+
152
+ vision_data = {}
153
+ if image_sizes is not None:
154
+ images_kwargs = Qwen2VLProcessorKwargs._defaults.get("images_kwargs", {})
155
+ images_kwargs.update(kwargs)
156
+ merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
157
+
158
+ num_image_patches = [
159
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
160
+ for image_size in image_sizes
161
+ ]
162
+ num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
163
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
164
+
165
+ return MultiModalData(**vision_data)
166
+
167
+ def batch_decode(self, *args, **kwargs):
168
+ """
169
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
170
+ refer to the docstring of this method for more information.
171
+ """
172
+ return self.tokenizer.batch_decode(*args, **kwargs)
173
+
174
+ def decode(self, *args, **kwargs):
175
+ """
176
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
177
+ the docstring of this method for more information.
178
+ """
179
+ return self.tokenizer.decode(*args, **kwargs)
180
+
181
+ def post_process_image_text_to_text(
182
+ self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
183
+ ):
184
+ """
185
+ Post-process the output of the model to decode the text.
186
+
187
+ Args:
188
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
189
+ The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
190
+ or `(sequence_length,)`.
191
+ skip_special_tokens (`bool`, *optional*, defaults to `True`):
192
+ Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
193
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
194
+ Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
195
+ **kwargs:
196
+ Additional arguments to be passed to the tokenizer's `batch_decode method`.
197
+
198
+ Returns:
199
+ `list[str]`: The decoded text.
200
+ """
201
+ return self.tokenizer.batch_decode(
202
+ generated_outputs,
203
+ skip_special_tokens=skip_special_tokens,
204
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
205
+ **kwargs,
206
+ )
207
+
208
+ @property
209
+ def model_input_names(self):
210
+ tokenizer_input_names = self.tokenizer.model_input_names
211
+ image_processor_input_names = self.image_processor.model_input_names
212
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
213
+
214
+
215
+ __all__ = ["Qwen2VLProcessor"]
@@ -0,0 +1,474 @@
1
+ from enum import Enum
2
+ from functools import partial
3
+ from typing import Any, Dict, List, Optional, Union
4
+
5
+
6
+ class MessageFormat(Enum):
7
+ """Enum for different message format types."""
8
+
9
+ LIST_WITH_IMAGE = "list_with_image"
10
+ LIST_WITH_IMAGE_FIRST = "list_with_image_first"
11
+ LIST_WITH_IMAGE_TYPE = "list_with_image_type"
12
+ LIST_WITH_IMAGE_TYPE_TEXT = "list_with_image_type_text"
13
+ LIST_WITH_IMAGE_TYPE_TEXT_IMAGE_LAST = "list_with_image_type_text_image_last"
14
+ IMAGE_TOKEN = "image_token"
15
+ IMAGE_TOKEN_PIPE = "image_token_pipe"
16
+ START_IMAGE_TOKEN = "start_image_token"
17
+ IMAGE_TOKEN_NEWLINE = "image_token_newline"
18
+ NUMBERED_IMAGE_TOKENS = "numbered_image_tokens"
19
+ PROMPT_ONLY = "prompt_only"
20
+ PROMPT_WITH_IMAGE_TOKEN = "prompt_with_image_token"
21
+ PROMPT_WITH_START_IMAGE_TOKEN = "prompt_with_start_image_token"
22
+ VIDEO_WITH_TEXT = "video_with_text"
23
+
24
+
25
+ # Model configuration mapping
26
+ MODEL_CONFIG = {
27
+ # List with image format models
28
+ "idefics2": MessageFormat.LIST_WITH_IMAGE,
29
+ "idefics3": MessageFormat.LIST_WITH_IMAGE_FIRST,
30
+ "aya_vision": MessageFormat.LIST_WITH_IMAGE,
31
+ "qwen2_vl": MessageFormat.LIST_WITH_IMAGE,
32
+ "qwen2_5_vl": MessageFormat.LIST_WITH_IMAGE_FIRST,
33
+ "mistral3": MessageFormat.LIST_WITH_IMAGE_FIRST,
34
+ "internvl_chat": MessageFormat.LIST_WITH_IMAGE_TYPE,
35
+ "kimi_vl": MessageFormat.LIST_WITH_IMAGE,
36
+ "gemma3": MessageFormat.START_IMAGE_TOKEN,
37
+ "gemma3n": MessageFormat.LIST_WITH_IMAGE_TYPE_TEXT_IMAGE_LAST,
38
+ "llama4": MessageFormat.LIST_WITH_IMAGE,
39
+ "smolvlm": MessageFormat.LIST_WITH_IMAGE_FIRST,
40
+ "llava": MessageFormat.LIST_WITH_IMAGE,
41
+ "llava_next": MessageFormat.LIST_WITH_IMAGE,
42
+ "mllama": MessageFormat.LIST_WITH_IMAGE,
43
+ "pixtral": MessageFormat.LIST_WITH_IMAGE_TYPE,
44
+ # Token-based models
45
+ "llava-qwen2": MessageFormat.IMAGE_TOKEN_NEWLINE,
46
+ "bunny-llama": MessageFormat.IMAGE_TOKEN_NEWLINE,
47
+ "phi3_v": MessageFormat.NUMBERED_IMAGE_TOKENS,
48
+ "multi_modality": MessageFormat.IMAGE_TOKEN,
49
+ "deepseek_vl_v2": MessageFormat.IMAGE_TOKEN_NEWLINE,
50
+ # Prompt-only models
51
+ "florence2": MessageFormat.PROMPT_ONLY,
52
+ "molmo": MessageFormat.PROMPT_ONLY,
53
+ "paligemma": MessageFormat.PROMPT_WITH_IMAGE_TOKEN,
54
+ }
55
+
56
+ # Models that don't support multi-image
57
+ SINGLE_IMAGE_ONLY_MODELS = {
58
+ "llava_next",
59
+ "llava-qwen2",
60
+ "bunny-llama",
61
+ "paligemma",
62
+ "multi_modality",
63
+ "mllama",
64
+ }
65
+
66
+
67
+ class MessageBuilder:
68
+ """Builder for creating messages in various formats."""
69
+
70
+ @staticmethod
71
+ def text_message(text: str) -> Dict[str, str]:
72
+ """Create a simple text message."""
73
+ return {"type": "text", "text": text}
74
+
75
+ @staticmethod
76
+ def content_message(content: str) -> Dict[str, str]:
77
+ """Create a content-type text message."""
78
+ return {"type": "text", "content": content}
79
+
80
+ @staticmethod
81
+ def image_message() -> Dict[str, str]:
82
+ """Create an image message."""
83
+ return {"type": "image"}
84
+
85
+ @staticmethod
86
+ def audio_message() -> Dict[str, str]:
87
+ """Create an audio message."""
88
+ return {"type": "audio"}
89
+
90
+ @staticmethod
91
+ def video_message(
92
+ video_path: str, max_pixels: int = 224 * 224, fps: int = 1
93
+ ) -> Dict[str, Any]:
94
+ """Create a video message."""
95
+ return {
96
+ "type": "video",
97
+ "video": video_path,
98
+ "max_pixels": max_pixels,
99
+ "fps": fps,
100
+ }
101
+
102
+
103
+ class MessageFormatter:
104
+ """Handles formatting messages for different model types."""
105
+
106
+ def __init__(self, model_name: str):
107
+ self.model_name = model_name.lower()
108
+ self.format_type = MODEL_CONFIG.get(self.model_name)
109
+ if not self.format_type:
110
+ raise ValueError(f"Unsupported model: {model_name}")
111
+
112
+ def format_message(
113
+ self,
114
+ prompt: str,
115
+ role: str = "user",
116
+ skip_image_token: bool = False,
117
+ skip_audio_token: bool = False,
118
+ num_images: int = 1,
119
+ num_audios: int = 1,
120
+ **kwargs,
121
+ ) -> Union[str, Dict[str, Any]]:
122
+ """Format a message based on the model type."""
123
+
124
+ # Check multi-image support
125
+ if num_images > 1 and self.model_name in SINGLE_IMAGE_ONLY_MODELS:
126
+ raise ValueError(
127
+ f"Model {self.model_name} does not support multi-image chat. "
128
+ f"Please only use 1 image."
129
+ )
130
+
131
+ # Handle video format for specific models
132
+ if self.model_name in ["qwen2_vl", "qwen2_5_vl"] and kwargs.get("video"):
133
+ return self._format_video_message(prompt, kwargs)
134
+
135
+ # Route to appropriate formatter
136
+ formatter_map = {
137
+ MessageFormat.LIST_WITH_IMAGE: self._format_list_with_image,
138
+ MessageFormat.LIST_WITH_IMAGE_FIRST: partial(
139
+ self._format_list_with_image, image_first=True
140
+ ),
141
+ MessageFormat.LIST_WITH_IMAGE_TYPE: self._format_list_with_image_type,
142
+ MessageFormat.LIST_WITH_IMAGE_TYPE_TEXT: partial(
143
+ self._format_list_with_image_type, message_type="text"
144
+ ),
145
+ MessageFormat.LIST_WITH_IMAGE_TYPE_TEXT_IMAGE_LAST: partial(
146
+ self._format_list_with_image_type,
147
+ message_type="text",
148
+ image_first=False,
149
+ ),
150
+ MessageFormat.IMAGE_TOKEN: partial(
151
+ self._format_with_token, token="<image>"
152
+ ),
153
+ MessageFormat.IMAGE_TOKEN_PIPE: partial(
154
+ self._format_with_token, token="<|image|>"
155
+ ),
156
+ MessageFormat.START_IMAGE_TOKEN: partial(
157
+ self._format_with_token, token="<start_of_image>", image_first=False
158
+ ),
159
+ MessageFormat.IMAGE_TOKEN_NEWLINE: partial(
160
+ self._format_with_token, token="<image>\n"
161
+ ),
162
+ MessageFormat.NUMBERED_IMAGE_TOKENS: self._format_numbered_tokens,
163
+ MessageFormat.PROMPT_ONLY: lambda *args, **kw: prompt,
164
+ MessageFormat.PROMPT_WITH_IMAGE_TOKEN: lambda *args, **kw: "<image>"
165
+ * num_images
166
+ + prompt,
167
+ MessageFormat.PROMPT_WITH_START_IMAGE_TOKEN: lambda *args, **kw: prompt
168
+ + "<start_of_image>" * num_images,
169
+ MessageFormat.VIDEO_WITH_TEXT: self._format_video_message,
170
+ }
171
+
172
+ formatter = formatter_map.get(self.format_type)
173
+ return formatter(
174
+ prompt,
175
+ role,
176
+ skip_image_token,
177
+ skip_audio_token,
178
+ num_images,
179
+ num_audios,
180
+ **kwargs,
181
+ )
182
+
183
+ def _format_list_with_image(
184
+ self,
185
+ prompt: str,
186
+ role: str,
187
+ skip_image_token: bool,
188
+ skip_audio_token: bool,
189
+ num_images: int,
190
+ num_audios: int,
191
+ image_first: bool = False,
192
+ **kwargs,
193
+ ) -> Dict[str, Any]:
194
+ """Format as a list with image tokens."""
195
+ content = [MessageBuilder.text_message(prompt)]
196
+
197
+ if role == "user" and not skip_image_token:
198
+ image_tokens = [MessageBuilder.image_message()] * num_images
199
+ content = image_tokens + content if image_first else content + image_tokens
200
+
201
+ return {"role": role, "content": content}
202
+
203
+ def _format_list_with_image_type(
204
+ self,
205
+ prompt: str,
206
+ role: str,
207
+ skip_image_token: bool,
208
+ skip_audio_token: bool,
209
+ num_images: int,
210
+ num_audios: int,
211
+ message_type: str = "content",
212
+ image_first: bool = True,
213
+ **kwargs,
214
+ ) -> Dict[str, Any]:
215
+ """Format as a list with typed messages."""
216
+ msg_func = (
217
+ MessageBuilder.content_message
218
+ if message_type == "content"
219
+ else MessageBuilder.text_message
220
+ )
221
+ message = {"role": role, "content": [msg_func(prompt)]}
222
+
223
+ if role == "user":
224
+ if not skip_image_token:
225
+ message["content"] = (
226
+ [MessageBuilder.image_message()] * num_images + message["content"]
227
+ if image_first
228
+ else message["content"]
229
+ + [MessageBuilder.image_message()] * num_images
230
+ )
231
+ if not skip_audio_token:
232
+ message["content"] = (
233
+ message["content"] + [MessageBuilder.audio_message()] * num_audios
234
+ )
235
+
236
+ if role == "assistant":
237
+ message["content"] = message["content"][0].get(
238
+ "content", message["content"][0].get("text")
239
+ )
240
+
241
+ return message
242
+
243
+ def _format_with_token(
244
+ self,
245
+ prompt: str,
246
+ role: str,
247
+ skip_image_token: bool,
248
+ skip_audio_token: bool,
249
+ num_images: int,
250
+ num_audios: int,
251
+ token: str,
252
+ image_first: bool = True,
253
+ **kwargs,
254
+ ) -> Dict[str, Any]:
255
+ """Format with image tokens in the text."""
256
+ content = prompt
257
+
258
+ if role == "user" and not skip_image_token:
259
+ prefix = token * num_images
260
+ content = f"{prefix}{content}" if image_first else f"{content}{prefix}"
261
+
262
+ return {"role": role, "content": content}
263
+
264
+ def _format_numbered_tokens(
265
+ self,
266
+ prompt: str,
267
+ role: str,
268
+ skip_image_token: bool,
269
+ skip_audio_token: bool,
270
+ num_images: int,
271
+ num_audios: int,
272
+ **kwargs,
273
+ ) -> Dict[str, Any]:
274
+ """Format with numbered image tokens."""
275
+ content = prompt
276
+
277
+ if role == "user" and not skip_image_token:
278
+ # phi3_v uses single token regardless of num_images
279
+ prefix = (
280
+ "<|image_1|>"
281
+ if self.model_name == "phi3_v"
282
+ else " ".join([f"<|image_{i+1}|>" for i in range(num_images)])
283
+ )
284
+ content = f"{prefix}{content}"
285
+
286
+ return {"role": role, "content": content}
287
+
288
+ def _format_video_message(
289
+ self,
290
+ prompt: str,
291
+ role: str = "user",
292
+ skip_image_token: bool = False,
293
+ skip_audio_token: bool = False,
294
+ num_images: int = 0,
295
+ num_audios: int = 0,
296
+ **kwargs,
297
+ ) -> Dict[str, Any]:
298
+ """Format a video message with text."""
299
+ return {
300
+ "role": role,
301
+ "content": [
302
+ MessageBuilder.video_message(
303
+ kwargs["video"],
304
+ kwargs.get("max_pixels", 224 * 224),
305
+ kwargs.get("fps", 1),
306
+ ),
307
+ MessageBuilder.text_message(prompt),
308
+ ],
309
+ }
310
+
311
+
312
+ def get_message_json(
313
+ model_name: str,
314
+ prompt: str,
315
+ role: str = "user",
316
+ skip_image_token: bool = False,
317
+ skip_audio_token: bool = False,
318
+ num_images: int = 0,
319
+ num_audios: int = 0,
320
+ **kwargs,
321
+ ) -> Union[str, Dict[str, Any]]:
322
+ """
323
+ Get the appropriate JSON message based on the specified model.
324
+
325
+ Args:
326
+ model_name: The model for which to generate the message
327
+ prompt: The text prompt to be included in the message
328
+ role: The role of the message (default: "user")
329
+ skip_image_token: Whether to skip adding image tokens
330
+ skip_audio_token: Whether to skip adding audio tokens
331
+ num_images: Number of image tokens to add
332
+ num_audios: Number of audio tokens to add
333
+ **kwargs: Additional arguments (e.g., video path, max_pixels, fps)
334
+
335
+ Returns:
336
+ A dictionary or string representing the message for the specified model
337
+ """
338
+ formatter = MessageFormatter(model_name)
339
+
340
+ return formatter.format_message(
341
+ prompt,
342
+ role,
343
+ skip_image_token,
344
+ skip_audio_token,
345
+ num_images,
346
+ num_audios,
347
+ **kwargs,
348
+ )
349
+
350
+
351
+ def get_chat_template(
352
+ processor,
353
+ messages: List[Dict[str, Any]],
354
+ add_generation_prompt: bool,
355
+ tokenize: bool = False,
356
+ **kwargs,
357
+ ) -> Any:
358
+ """Apply chat template using processor's tokenizer."""
359
+ try:
360
+ processor = (
361
+ processor
362
+ if "chat_template" in processor.__dict__.keys()
363
+ else processor.tokenizer
364
+ )
365
+
366
+ return processor.apply_chat_template(
367
+ messages,
368
+ tokenize=tokenize,
369
+ add_generation_prompt=add_generation_prompt,
370
+ **kwargs,
371
+ )
372
+ except AttributeError:
373
+ raise ValueError(
374
+ "Error: processor does not have 'chat_template' or 'tokenizer' attribute."
375
+ )
376
+
377
+
378
+ def apply_chat_template(
379
+ processor,
380
+ config: Union[Dict[str, Any], Any],
381
+ prompt: Union[str, Dict[str, Any], List[Any]],
382
+ add_generation_prompt: bool = True,
383
+ return_messages: bool = False,
384
+ num_images: int = 0,
385
+ num_audios: int = 0,
386
+ **kwargs,
387
+ ) -> Union[List[Dict[str, Any]], str, Any]:
388
+ """
389
+ Apply chat template to prompts.
390
+
391
+ Args:
392
+ processor: The processor with chat template functionality
393
+ config: Model configuration
394
+ prompt: Single prompt string, dict, or list of prompts
395
+ add_generation_prompt: Whether to add generation prompt
396
+ return_messages: Whether to return messages list instead of template
397
+ num_images: Number of images in the input
398
+ num_audios: Number of audio files in the input
399
+ **kwargs: Additional arguments for message formatting
400
+
401
+ Returns:
402
+ Formatted messages or chat template
403
+ """
404
+ config = config if isinstance(config, dict) else config.__dict__
405
+ model_type = config["model_type"]
406
+
407
+ # Build messages from prompts
408
+ messages = []
409
+
410
+ if isinstance(prompt, str):
411
+ # Single string prompt
412
+ messages.append(
413
+ get_message_json(
414
+ model_type,
415
+ prompt,
416
+ num_images=num_images,
417
+ num_audios=num_audios,
418
+ **kwargs,
419
+ )
420
+ )
421
+ elif isinstance(prompt, dict):
422
+ # Single dict prompt
423
+ messages.append(
424
+ get_message_json(
425
+ model_type,
426
+ prompt["content"],
427
+ prompt["role"],
428
+ num_images=num_images,
429
+ num_audios=num_audios,
430
+ **kwargs,
431
+ )
432
+ )
433
+ elif isinstance(prompt, list):
434
+ # List of prompts
435
+ for i, p in enumerate(prompt):
436
+ if isinstance(p, str):
437
+ is_first = i == 0
438
+ messages.append(
439
+ get_message_json(
440
+ model_type,
441
+ p,
442
+ skip_image_token=not is_first,
443
+ skip_audio_token=not is_first,
444
+ num_images=num_images,
445
+ num_audios=num_audios,
446
+ **kwargs,
447
+ )
448
+ )
449
+ elif isinstance(p, dict):
450
+ role = p.get("role", "user")
451
+ is_first = i == 0 or (i == 1 and role not in ["system", "assistant"])
452
+ messages.append(
453
+ get_message_json(
454
+ model_type,
455
+ p["content"],
456
+ role,
457
+ skip_image_token=not is_first
458
+ or role in ["system", "assistant"],
459
+ skip_audio_token=not is_first
460
+ or role in ["system", "assistant"],
461
+ num_images=num_images,
462
+ num_audios=num_audios,
463
+ **kwargs,
464
+ )
465
+ )
466
+
467
+ if return_messages:
468
+ return messages
469
+
470
+ # Some models only need the last message
471
+ if model_type in ["paligemma", "molmo", "florence2"]:
472
+ return messages[-1]
473
+
474
+ return get_chat_template(processor, messages, add_generation_prompt)