nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,49 @@
1
+ """
2
+ Model type mappings for HuggingFace pipeline tags to our internal model types.
3
+
4
+ This module provides centralized model type mapping functionality to avoid
5
+ circular imports between other utility modules.
6
+ """
7
+
8
+ from enum import Enum
9
+ from typing import Dict
10
+
11
+
12
+ class ModelTypeMapping(Enum):
13
+ """Enum for mapping HuggingFace pipeline_tag to our ModelType."""
14
+ TEXT_GENERATION = ("text-generation", "llm")
15
+ IMAGE_TEXT_TO_TEXT = ("image-text-to-text", "vlm")
16
+ ANY_TO_ANY = ("any-to-any", "ata")
17
+ AUTOMATIC_SPEECH_RECOGNITION = ("automatic-speech-recognition", "asr")
18
+
19
+ def __init__(self, pipeline_tag: str, model_type: str):
20
+ self.pipeline_tag = pipeline_tag
21
+ self.model_type = model_type
22
+
23
+
24
+ # Create mapping dictionaries from the enum
25
+ PIPELINE_TO_MODEL_TYPE: Dict[str, str] = {
26
+ mapping.pipeline_tag: mapping.model_type
27
+ for mapping in ModelTypeMapping
28
+ }
29
+
30
+ MODEL_TYPE_TO_PIPELINE: Dict[str, str] = {
31
+ mapping.model_type: mapping.pipeline_tag
32
+ for mapping in ModelTypeMapping
33
+ }
34
+
35
+
36
+ def map_pipeline_tag_to_model_type(pipeline_tag: str) -> str:
37
+ """Map HuggingFace pipeline_tag to our ModelType."""
38
+ if not pipeline_tag:
39
+ return "other"
40
+
41
+ return PIPELINE_TO_MODEL_TYPE.get(pipeline_tag, "other")
42
+
43
+
44
+ def map_model_type_to_pipeline_tag(model_type: str) -> str:
45
+ """Reverse map ModelType back to HuggingFace pipeline_tag."""
46
+ if not model_type:
47
+ return None
48
+
49
+ return MODEL_TYPE_TO_PIPELINE.get(model_type)
@@ -0,0 +1,389 @@
1
+ """
2
+ Progress tracking utilities for downloads with tqdm integration.
3
+
4
+ This module provides custom progress tracking classes that can monitor
5
+ download progress with callback support and customizable display options.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import time
11
+ from typing import Optional, Callable, Dict, Any
12
+ from tqdm.auto import tqdm
13
+
14
+
15
+ class CustomProgressTqdm(tqdm):
16
+ """Custom tqdm that tracks progress but completely hides terminal output."""
17
+
18
+ def __init__(self, *args, **kwargs):
19
+ # Filter out 'name' argument which might be passed by newer huggingface_hub versions
20
+ # but isn't supported by tqdm
21
+ kwargs.pop('name', None)
22
+
23
+ # Redirect output to devnull to completely suppress terminal output
24
+ kwargs['file'] = open(os.devnull, 'w')
25
+ kwargs['disable'] = False # Keep enabled for tracking
26
+ kwargs['leave'] = False # Don't leave progress bar
27
+ super().__init__(*args, **kwargs)
28
+
29
+ def display(self, msg=None, pos=None):
30
+ # Override display to show nothing
31
+ pass
32
+
33
+ def write(self, s, file=None, end="\n", nolock=False):
34
+ # Override write to prevent any output
35
+ pass
36
+
37
+ def close(self):
38
+ # Override close to avoid printing and properly close devnull
39
+ if hasattr(self, 'fp') and self.fp and self.fp != sys.stdout and self.fp != sys.stderr:
40
+ try:
41
+ self.fp.close()
42
+ except:
43
+ pass
44
+ self.disable = True
45
+ super(tqdm, self).close()
46
+
47
+
48
+ class DownloadProgressTracker:
49
+ """Progress tracker for HuggingFace downloads with callback support."""
50
+
51
+ def __init__(self, progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, show_progress: bool = True):
52
+ self.progress_data: Dict[str, Dict[str, Any]] = {}
53
+ self.total_repo_size = 0
54
+ self.repo_file_count = 0
55
+ self.original_tqdm_update = None
56
+ self.original_tqdm_init = None
57
+ self.original_tqdm_display = None
58
+ self.original_tqdm_write = None
59
+ self.is_tracking = False
60
+
61
+ # Callback function
62
+ self.progress_callback = progress_callback
63
+
64
+ # Progress display
65
+ self.show_progress = show_progress
66
+ self.last_display_length = 0
67
+
68
+ # Speed tracking
69
+ self.last_downloaded = None # Use None to indicate no previous measurement
70
+ self.last_time = None # Use None to indicate no previous time measurement
71
+ self.speed_history = []
72
+ self.max_speed_history = 10
73
+
74
+ # Download status
75
+ self.download_status = "idle" # idle, downloading, completed, error
76
+ self.error_message = None
77
+ self.download_start_time = None
78
+
79
+ def set_repo_info(self, total_size: int, file_count: int):
80
+ """Set the total repository size and file count before download."""
81
+ self.total_repo_size = total_size
82
+ self.repo_file_count = file_count
83
+
84
+ def register_tqdm(self, tqdm_instance):
85
+ """Register a tqdm instance for monitoring."""
86
+ tqdm_id = str(id(tqdm_instance))
87
+ self.progress_data[tqdm_id] = {
88
+ 'current': 0,
89
+ 'total': getattr(tqdm_instance, 'total', 0) or 0,
90
+ 'desc': getattr(tqdm_instance, 'desc', 'Unknown'),
91
+ 'tqdm_obj': tqdm_instance
92
+ }
93
+ # Trigger callback when new file is registered
94
+ self._trigger_callback()
95
+
96
+ def update_progress(self, tqdm_instance, n=1):
97
+ """Update progress for a tqdm instance."""
98
+ tqdm_id = str(id(tqdm_instance))
99
+ if tqdm_id in self.progress_data:
100
+ self.progress_data[tqdm_id]['current'] = getattr(tqdm_instance, 'n', 0)
101
+ self.progress_data[tqdm_id]['total'] = getattr(tqdm_instance, 'total', 0) or 0
102
+ # Trigger callback on every progress update
103
+ self._trigger_callback()
104
+
105
+ def calculate_speed(self, current_downloaded: int) -> float:
106
+ """Calculate download speed in bytes per second."""
107
+ current_time = time.time()
108
+
109
+ # Check if we have a previous measurement to compare against
110
+ if self.last_time is not None and self.last_downloaded is not None:
111
+ time_diff = current_time - self.last_time
112
+
113
+ # Only calculate if we have a meaningful time difference (avoid division by very small numbers)
114
+ if time_diff > 0.1: # At least 100ms between measurements
115
+ bytes_diff = current_downloaded - self.last_downloaded
116
+
117
+ # Only calculate speed if bytes actually changed
118
+ if bytes_diff >= 0: # Allow 0 for periods with no progress
119
+ speed = bytes_diff / time_diff
120
+
121
+ # Add to speed history for smoothing
122
+ self.speed_history.append(speed)
123
+ if len(self.speed_history) > self.max_speed_history:
124
+ self.speed_history.pop(0)
125
+
126
+ # Update tracking variables when we actually calculate speed
127
+ self.last_downloaded = current_downloaded
128
+ self.last_time = current_time
129
+ else:
130
+ # First measurement - initialize tracking variables
131
+ self.last_downloaded = current_downloaded
132
+ self.last_time = current_time
133
+
134
+ # Return the average of historical speeds if we have any
135
+ # This ensures we show the last known speed even when skipping updates
136
+ if self.speed_history:
137
+ return sum(self.speed_history) / len(self.speed_history)
138
+
139
+ return 0.0
140
+
141
+ def format_bytes(self, bytes_value: int) -> str:
142
+ """Format bytes to human readable string."""
143
+ for unit in ['B', 'KB', 'MB', 'GB']:
144
+ if bytes_value < 1024.0:
145
+ return f"{bytes_value:.1f} {unit}"
146
+ bytes_value /= 1024.0
147
+ return f"{bytes_value:.1f} TB"
148
+
149
+ def format_speed(self, speed: float) -> str:
150
+ """Format speed to human readable string."""
151
+ if speed == 0:
152
+ return "0 B/s"
153
+
154
+ for unit in ['B/s', 'KB/s', 'MB/s', 'GB/s']:
155
+ if speed < 1024.0:
156
+ return f"{speed:.1f} {unit}"
157
+ speed /= 1024.0
158
+ return f"{speed:.1f} TB/s"
159
+
160
+ def get_progress_data(self) -> Dict[str, Any]:
161
+ """Get current progress data."""
162
+ total_downloaded = 0
163
+ active_file_count = 0
164
+ total_file_sizes = 0
165
+
166
+ for data in self.progress_data.values():
167
+ if data['total'] > 0:
168
+ total_downloaded += data['current']
169
+ total_file_sizes += data['total']
170
+ active_file_count += 1
171
+
172
+ # Calculate speed (tracking variables are updated internally)
173
+ speed = self.calculate_speed(total_downloaded)
174
+
175
+ # Determine total size - prioritize pre-fetched repo size, then aggregate file sizes
176
+ if self.total_repo_size > 0:
177
+ # Use pre-fetched repository info if available
178
+ total_size = self.total_repo_size
179
+ elif total_file_sizes > 0:
180
+ # Use sum of individual file sizes if available
181
+ total_size = total_file_sizes
182
+ else:
183
+ # Last resort - we don't know the total size yet
184
+ total_size = 0
185
+
186
+ file_count = self.repo_file_count if self.repo_file_count > 0 else active_file_count
187
+
188
+ # Calculate percentage - handle unknown total size gracefully
189
+ if total_size > 0:
190
+ percentage = min((total_downloaded / total_size * 100), 100.0)
191
+ else:
192
+ percentage = 0
193
+
194
+ # Calculate ETA
195
+ eta_seconds = None
196
+ if speed > 0 and total_size > total_downloaded:
197
+ eta_seconds = (total_size - total_downloaded) / speed
198
+
199
+ # Calculate elapsed time
200
+ elapsed_seconds = None
201
+ if self.download_start_time:
202
+ elapsed_seconds = time.time() - self.download_start_time
203
+
204
+ return {
205
+ 'status': self.download_status,
206
+ 'error_message': self.error_message,
207
+ 'progress': {
208
+ 'total_downloaded': total_downloaded,
209
+ 'total_size': total_size,
210
+ 'percentage': round(percentage, 2),
211
+ 'files_active': active_file_count,
212
+ 'files_total': file_count,
213
+ 'known_total': total_size > 0
214
+ },
215
+ 'speed': {
216
+ 'bytes_per_second': speed,
217
+ 'formatted': self.format_speed(speed)
218
+ },
219
+ 'formatting': {
220
+ 'downloaded': self.format_bytes(total_downloaded),
221
+ 'total_size': self.format_bytes(total_size)
222
+ },
223
+ 'timing': {
224
+ 'elapsed_seconds': elapsed_seconds,
225
+ 'eta_seconds': eta_seconds,
226
+ 'start_time': self.download_start_time
227
+ }
228
+ }
229
+
230
+ def _display_progress_bar(self, progress_data: Dict[str, Any]):
231
+ """Display a custom unified progress bar."""
232
+ if not self.show_progress:
233
+ return
234
+
235
+ # Clear previous line
236
+ if self.last_display_length > 0:
237
+ print('\r' + ' ' * self.last_display_length, end='\r')
238
+
239
+ progress_info = progress_data.get('progress', {})
240
+ speed_info = progress_data.get('speed', {})
241
+ timing_info = progress_data.get('timing', {})
242
+ formatting_info = progress_data.get('formatting', {})
243
+
244
+ percentage = progress_info.get('percentage', 0)
245
+ downloaded = formatting_info.get('downloaded', '0 B')
246
+ total_size_raw = progress_info.get('total_size', 0)
247
+ total_size = formatting_info.get('total_size', 'Unknown')
248
+ speed = speed_info.get('formatted', '0 B/s')
249
+ known_total = progress_info.get('known_total', False)
250
+
251
+ # Create progress bar
252
+ bar_width = 30
253
+ if known_total and total_size_raw > 0:
254
+ # Known total size - show actual progress
255
+ filled_width = int(bar_width * min(percentage, 100) / 100)
256
+ bar = '#' * filled_width + '-' * (bar_width - filled_width)
257
+ else:
258
+ # Unknown total size - show animated progress
259
+ animation_pos = int(time.time() * 2) % bar_width
260
+ bar = '-' * animation_pos + '#' + '-' * (bar_width - animation_pos - 1)
261
+
262
+ # Format the progress line
263
+ status = progress_data.get('status', 'unknown')
264
+ if status == 'downloading':
265
+ if known_total:
266
+ progress_line = f"[{bar}] {percentage:.1f}% | {downloaded}/{total_size} | {speed}"
267
+ else:
268
+ progress_line = f"[{bar}] {downloaded} | {speed} | Calculating size..."
269
+ elif status == 'completed':
270
+ progress_line = f"[{bar}] 100.0% | {downloaded} | Complete!"
271
+ elif status == 'error':
272
+ progress_line = f"Error: {progress_data.get('error_message', 'Unknown error')}"
273
+ else:
274
+ progress_line = f"Starting download..."
275
+
276
+ # Display and track length for next clear
277
+ print(progress_line, end='', flush=True)
278
+ self.last_display_length = len(progress_line)
279
+
280
+ def _clear_progress_bar(self):
281
+ """Clear the progress bar display."""
282
+ if self.show_progress and self.last_display_length > 0:
283
+ print('\r' + ' ' * self.last_display_length, end='\r')
284
+ print() # Move to next line
285
+ self.last_display_length = 0
286
+
287
+ def _trigger_callback(self):
288
+ """Trigger the progress callback if one is set."""
289
+ progress_data = self.get_progress_data()
290
+
291
+ if self.progress_callback:
292
+ try:
293
+ self.progress_callback(progress_data)
294
+ except Exception as e:
295
+ print(f"Error in progress callback: {e}")
296
+
297
+ # Show custom progress bar only if callback is enabled and show_progress is True
298
+ if self.progress_callback and self.show_progress:
299
+ self._display_progress_bar(progress_data)
300
+
301
+ def start_tracking(self):
302
+ """Start progress tracking (monkey patch tqdm)."""
303
+ if self.is_tracking:
304
+ return
305
+
306
+ # Store original methods
307
+ self.original_tqdm_update = tqdm.update
308
+ self.original_tqdm_init = tqdm.__init__
309
+ self.original_tqdm_display = tqdm.display
310
+ self.original_tqdm_write = tqdm.write
311
+
312
+ # Create references to self for the nested functions
313
+ tracker = self
314
+
315
+ def patched_init(self_tqdm, *args, **kwargs):
316
+ # Suppress tqdm display by redirecting to devnull
317
+ kwargs['file'] = open(os.devnull, 'w')
318
+ kwargs['disable'] = False # Keep enabled for tracking
319
+ kwargs['leave'] = False # Don't leave progress bar
320
+
321
+ result = tracker.original_tqdm_init(self_tqdm, *args, **kwargs)
322
+ tracker.register_tqdm(self_tqdm)
323
+ return result
324
+
325
+ def patched_update(self_tqdm, n=1):
326
+ result = tracker.original_tqdm_update(self_tqdm, n)
327
+ tracker.update_progress(self_tqdm, n)
328
+ return result
329
+
330
+ def patched_display(self_tqdm, msg=None, pos=None):
331
+ # Override display to show nothing
332
+ pass
333
+
334
+ def patched_write(self_tqdm, s, file=None, end="\n", nolock=False):
335
+ # Override write to prevent any output
336
+ pass
337
+
338
+ # Apply patches
339
+ tqdm.__init__ = patched_init
340
+ tqdm.update = patched_update
341
+ tqdm.display = patched_display
342
+ tqdm.write = patched_write
343
+
344
+ self.is_tracking = True
345
+ self.download_status = "downloading"
346
+ self.download_start_time = time.time()
347
+
348
+ # Trigger initial callback
349
+ self._trigger_callback()
350
+
351
+ def stop_tracking(self):
352
+ """Stop progress tracking and restore original tqdm."""
353
+ if not self.is_tracking:
354
+ return
355
+
356
+ # Restore original tqdm methods
357
+ if self.original_tqdm_update:
358
+ tqdm.update = self.original_tqdm_update
359
+ if self.original_tqdm_init:
360
+ tqdm.__init__ = self.original_tqdm_init
361
+ if hasattr(self, 'original_tqdm_display') and self.original_tqdm_display:
362
+ tqdm.display = self.original_tqdm_display
363
+ if hasattr(self, 'original_tqdm_write') and self.original_tqdm_write:
364
+ tqdm.write = self.original_tqdm_write
365
+
366
+ # Clean up any open devnull file handles from tqdm instances
367
+ for data in self.progress_data.values():
368
+ if 'tqdm_obj' in data and hasattr(data['tqdm_obj'], 'fp'):
369
+ try:
370
+ fp = data['tqdm_obj'].fp
371
+ if fp and fp != sys.stdout and fp != sys.stderr and not fp.closed:
372
+ fp.close()
373
+ except:
374
+ pass
375
+
376
+ self.is_tracking = False
377
+ if self.download_status == "downloading":
378
+ self.download_status = "completed"
379
+
380
+ # Trigger final callback and clear progress bar
381
+ self._trigger_callback()
382
+ self._clear_progress_bar()
383
+
384
+ def set_error(self, error_message: str):
385
+ """Set error status and trigger callback."""
386
+ self.download_status = "error"
387
+ self.error_message = error_message
388
+ self._trigger_callback()
389
+ self._clear_progress_bar()
@@ -0,0 +1,245 @@
1
+ """
2
+ Quantization utilities for extracting quantization types from model files and configurations.
3
+
4
+ This module provides utilities to extract quantization information from:
5
+ - GGUF model filenames
6
+ - MLX model repository IDs
7
+ - MLX model config.json files
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import re
13
+ import logging
14
+ from enum import Enum
15
+ from typing import Optional
16
+
17
+ # Set up logger
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class QuantizationType(str, Enum):
22
+ """Enum for GGUF and MLX model quantization types."""
23
+ # GGUF quantization types
24
+ BF16 = "BF16"
25
+ F16 = "F16"
26
+ Q2_K = "Q2_K"
27
+ Q2_K_L = "Q2_K_L"
28
+ Q3_K = "Q3_K"
29
+ Q3_K_M = "Q3_K_M"
30
+ Q3_K_S = "Q3_K_S"
31
+ Q4_0 = "Q4_0"
32
+ Q4_1 = "Q4_1"
33
+ Q4_K = "Q4_K"
34
+ Q4_K_M = "Q4_K_M"
35
+ Q4_K_S = "Q4_K_S"
36
+ Q5_K = "Q5_K"
37
+ Q5_K_M = "Q5_K_M"
38
+ Q5_K_S = "Q5_K_S"
39
+ Q6_K = "Q6_K"
40
+ Q8_0 = "Q8_0"
41
+ MXFP4 = "MXFP4"
42
+ MXFP8 = "MXFP8"
43
+
44
+ # MLX bit-based quantization types
45
+ BIT_1 = "1BIT"
46
+ BIT_2 = "2BIT"
47
+ BIT_3 = "3BIT"
48
+ BIT_4 = "4BIT"
49
+ BIT_5 = "5BIT"
50
+ BIT_6 = "6BIT"
51
+ BIT_7 = "7BIT"
52
+ BIT_8 = "8BIT"
53
+ BIT_16 = "16BIT"
54
+
55
+
56
+ def extract_quantization_from_filename(filename: str) -> Optional[QuantizationType]:
57
+ """
58
+ Extract quantization type from filename.
59
+
60
+ Args:
61
+ filename: The filename to extract quantization from
62
+
63
+ Returns:
64
+ QuantizationType enum value or None if not found
65
+ """
66
+ # Define mapping from lowercase patterns to enum values
67
+ # Include "." to ensure precise matching (e.g., "q4_0." not "q4_0_xl")
68
+ pattern_to_enum = {
69
+ 'bf16.': QuantizationType.BF16,
70
+ 'f16.': QuantizationType.F16, # Add F16 support
71
+ 'q2_k_l.': QuantizationType.Q2_K_L, # Check Q2_K_L before Q2_K to avoid partial match
72
+ 'q2_k.': QuantizationType.Q2_K,
73
+ 'q3_k.': QuantizationType.Q3_K,
74
+ 'q3_k_m.': QuantizationType.Q3_K_M,
75
+ 'q3_k_s.': QuantizationType.Q3_K_S,
76
+ 'q4_k_m.': QuantizationType.Q4_K_M,
77
+ 'q4_k_s.': QuantizationType.Q4_K_S,
78
+ 'q4_0.': QuantizationType.Q4_0,
79
+ 'q4_1.': QuantizationType.Q4_1,
80
+ 'q4_k.': QuantizationType.Q4_K,
81
+ 'q5_k.': QuantizationType.Q5_K,
82
+ 'q5_k_m.': QuantizationType.Q5_K_M,
83
+ 'q5_k_s.': QuantizationType.Q5_K_S,
84
+ 'q6_k.': QuantizationType.Q6_K,
85
+ 'q8_0.': QuantizationType.Q8_0,
86
+ 'mxfp4.': QuantizationType.MXFP4,
87
+ 'mxfp8.': QuantizationType.MXFP8,
88
+ }
89
+
90
+ filename_lower = filename.lower()
91
+
92
+ # Check longer patterns first to avoid partial matches
93
+ # Sort by length descending to check q2_k_l before q2_k, q4_k_m before q4_0, etc.
94
+ for pattern in sorted(pattern_to_enum.keys(), key=len, reverse=True):
95
+ if pattern in filename_lower:
96
+ return pattern_to_enum[pattern]
97
+
98
+ return None
99
+
100
+
101
+ def extract_quantization_from_repo_id(repo_id: str) -> Optional[QuantizationType]:
102
+ """
103
+ Extract quantization type from repo_id for MLX models by looking for bit patterns.
104
+
105
+ Args:
106
+ repo_id: The repository ID to extract quantization from
107
+
108
+ Returns:
109
+ QuantizationType enum value or None if not found
110
+ """
111
+ # Define mapping from bit numbers to enum values
112
+ bit_to_enum = {
113
+ 1: QuantizationType.BIT_1,
114
+ 2: QuantizationType.BIT_2,
115
+ 3: QuantizationType.BIT_3,
116
+ 4: QuantizationType.BIT_4,
117
+ 5: QuantizationType.BIT_5,
118
+ 6: QuantizationType.BIT_6,
119
+ 7: QuantizationType.BIT_7,
120
+ 8: QuantizationType.BIT_8,
121
+ 16: QuantizationType.BIT_16,
122
+ }
123
+
124
+ # First check for patterns like "4bit", "8bit" etc. (case insensitive)
125
+ pattern = r'(\d+)bit'
126
+ matches = re.findall(pattern, repo_id.lower())
127
+
128
+ for match in matches:
129
+ try:
130
+ bit_number = int(match)
131
+ if bit_number in bit_to_enum:
132
+ logger.debug(f"Found {bit_number}bit quantization in repo_id: {repo_id}")
133
+ return bit_to_enum[bit_number]
134
+ except ValueError:
135
+ continue
136
+
137
+ # Also check for patterns like "-q8", "_Q4" etc.
138
+ q_pattern = r'[-_]q(\d+)'
139
+ q_matches = re.findall(q_pattern, repo_id.lower())
140
+
141
+ for match in q_matches:
142
+ try:
143
+ bit_number = int(match)
144
+ if bit_number in bit_to_enum:
145
+ logger.debug(f"Found Q{bit_number} quantization in repo_id: {repo_id}")
146
+ return bit_to_enum[bit_number]
147
+ except ValueError:
148
+ continue
149
+
150
+ return None
151
+
152
+
153
+ def extract_quantization_from_mlx_config(mlx_folder_path: str) -> Optional[QuantizationType]:
154
+ """
155
+ Extract quantization type from MLX model's config.json file.
156
+
157
+ Args:
158
+ mlx_folder_path: Path to the MLX model folder
159
+
160
+ Returns:
161
+ QuantizationType enum value or None if not found
162
+ """
163
+ config_path = os.path.join(mlx_folder_path, "config.json")
164
+
165
+ if not os.path.exists(config_path):
166
+ logger.debug(f"Config file not found: {config_path}")
167
+ return None
168
+
169
+ try:
170
+ with open(config_path, 'r', encoding='utf-8') as f:
171
+ config = json.load(f)
172
+
173
+ # Look for quantization.bits field
174
+ quantization_config = config.get("quantization", {})
175
+ if isinstance(quantization_config, dict):
176
+ bits = quantization_config.get("bits")
177
+ if isinstance(bits, int):
178
+ # Define mapping from bit numbers to enum values
179
+ bit_to_enum = {
180
+ 1: QuantizationType.BIT_1,
181
+ 2: QuantizationType.BIT_2,
182
+ 3: QuantizationType.BIT_3,
183
+ 4: QuantizationType.BIT_4,
184
+ 5: QuantizationType.BIT_5,
185
+ 6: QuantizationType.BIT_6,
186
+ 7: QuantizationType.BIT_7,
187
+ 8: QuantizationType.BIT_8,
188
+ 16: QuantizationType.BIT_16,
189
+ }
190
+
191
+ if bits in bit_to_enum:
192
+ logger.debug(f"Found {bits}bit quantization in config.json: {config_path}")
193
+ return bit_to_enum[bits]
194
+ else:
195
+ logger.debug(f"Unsupported quantization bits value: {bits}")
196
+
197
+ except (json.JSONDecodeError, IOError) as e:
198
+ logger.warning(f"Error reading config.json from {config_path}: {e}")
199
+ except Exception as e:
200
+ logger.warning(f"Unexpected error reading config.json from {config_path}: {e}")
201
+
202
+ return None
203
+
204
+
205
+ def extract_gguf_quantization(filename: str) -> str:
206
+ """
207
+ Extract quantization level from GGUF filename using the enum-based approach.
208
+
209
+ This function provides backward compatibility by returning a string representation
210
+ of the quantization type.
211
+
212
+ Args:
213
+ filename: The GGUF filename
214
+
215
+ Returns:
216
+ String representation of the quantization type or "UNKNOWN" if not found
217
+ """
218
+ quantization_type = extract_quantization_from_filename(filename)
219
+ if quantization_type:
220
+ return quantization_type.value
221
+ return "UNKNOWN"
222
+
223
+
224
+ def detect_quantization_for_mlx(repo_id: str, directory_path: str) -> Optional[QuantizationType]:
225
+ """
226
+ Detect quantization for MLX models using multiple methods in priority order.
227
+
228
+ Args:
229
+ repo_id: The repository ID
230
+ directory_path: Path to the model directory
231
+
232
+ Returns:
233
+ QuantizationType enum value or None if not found
234
+ """
235
+ # Method 1: Extract from repo_id
236
+ quantization_type = extract_quantization_from_repo_id(repo_id)
237
+ if quantization_type:
238
+ return quantization_type
239
+
240
+ # Method 2: Extract from config.json if available
241
+ quantization_type = extract_quantization_from_mlx_config(directory_path)
242
+ if quantization_type:
243
+ return quantization_type
244
+
245
+ return None