nexaai 1.0.16rc13__cp310-cp310-macosx_15_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (557) hide show
  1. nexaai/__init__.py +83 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +64 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +92 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +44 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +4 -0
  10. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/libnexa_bridge.dylib +0 -0
  13. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  14. nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
  15. nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
  16. nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
  17. nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
  18. nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
  19. nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
  20. nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
  21. nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
  22. nexaai/binds/nexa_mlx/py-lib/ml.py +888 -0
  23. nexaai/binds/nexa_mlx/py-lib/mlx_audio/__init__.py +0 -0
  24. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/__init__.py +1 -0
  25. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  26. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  27. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  28. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  29. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  30. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  31. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  32. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  33. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  34. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  35. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  36. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  37. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  38. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  39. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  40. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  41. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  42. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  43. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  44. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  45. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  46. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  47. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  48. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  49. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  50. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  51. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  52. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  53. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  54. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  55. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  56. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  57. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  58. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  59. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  60. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  61. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  62. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  63. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  64. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  65. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  66. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  67. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  68. nexaai/binds/nexa_mlx/py-lib/mlx_audio/server.py +525 -0
  69. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/__init__.py +0 -0
  70. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  71. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  72. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/__init__.py +0 -0
  73. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/generate.py +174 -0
  74. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  75. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  76. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  77. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  78. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  79. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  80. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  81. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  82. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  83. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  84. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  85. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  86. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  87. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  88. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  89. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  90. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  91. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  92. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  93. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  94. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/utils.py +195 -0
  95. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/__init__.py +1 -0
  96. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/audio_player.py +120 -0
  97. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/convert.py +71 -0
  98. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/generate.py +449 -0
  99. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  100. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  101. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  102. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  103. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  104. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/base.py +84 -0
  105. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  106. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  107. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  108. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  109. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  110. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  111. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  112. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  113. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  114. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  115. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  116. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  117. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  118. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  119. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  120. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  121. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  122. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  123. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  124. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  125. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  126. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  127. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  128. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  129. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  130. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  131. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  132. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  133. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  134. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  135. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  136. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  137. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  138. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  139. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  140. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  141. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  142. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  143. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  144. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  145. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  146. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  147. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  148. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  149. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  150. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  151. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  152. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  153. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  154. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  155. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  156. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  157. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  158. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  159. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  160. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  161. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  162. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  163. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  164. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  165. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  166. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  167. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  168. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  169. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  170. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/utils.py +337 -0
  171. nexaai/binds/nexa_mlx/py-lib/mlx_audio/utils.py +237 -0
  172. nexaai/binds/nexa_mlx/py-lib/mlx_audio/version.py +1 -0
  173. nexaai/binds/nexa_mlx/py-lib/profiling.py +239 -0
  174. nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
  175. nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
  176. nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
  177. nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
  178. nexaai/binds/nexa_nexaml/libnexa-mm-process.dylib +0 -0
  179. nexaai/binds/nexa_nexaml/libnexa-sampling.dylib +0 -0
  180. nexaai/binds/nexa_nexaml/libnexa_plugin.dylib +0 -0
  181. nexaai/binds/nexa_nexaml/libnexaproc.dylib +0 -0
  182. nexaai/binds/nexa_nexaml/libqwen3-vl.dylib +0 -0
  183. nexaai/binds/nexa_nexaml/libqwen3vl-vision.dylib +0 -0
  184. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  185. nexaai/common.py +104 -0
  186. nexaai/cv.py +92 -0
  187. nexaai/cv_impl/__init__.py +0 -0
  188. nexaai/cv_impl/mlx_cv_impl.py +89 -0
  189. nexaai/cv_impl/pybind_cv_impl.py +32 -0
  190. nexaai/embedder.py +72 -0
  191. nexaai/embedder_impl/__init__.py +0 -0
  192. nexaai/embedder_impl/mlx_embedder_impl.py +116 -0
  193. nexaai/embedder_impl/pybind_embedder_impl.py +95 -0
  194. nexaai/image_gen.py +140 -0
  195. nexaai/image_gen_impl/__init__.py +0 -0
  196. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  197. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  198. nexaai/llm.py +96 -0
  199. nexaai/llm_impl/__init__.py +0 -0
  200. nexaai/llm_impl/mlx_llm_impl.py +269 -0
  201. nexaai/llm_impl/pybind_llm_impl.py +218 -0
  202. nexaai/log.py +92 -0
  203. nexaai/mlx_backend/asr/__init__.py +12 -0
  204. nexaai/mlx_backend/asr/interface.py +122 -0
  205. nexaai/mlx_backend/common/__init__.py +0 -0
  206. nexaai/mlx_backend/common/utils.py +25 -0
  207. nexaai/mlx_backend/cv/__init__.py +0 -0
  208. nexaai/mlx_backend/cv/generate.py +195 -0
  209. nexaai/mlx_backend/cv/interface.py +151 -0
  210. nexaai/mlx_backend/cv/main.py +81 -0
  211. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  212. nexaai/mlx_backend/embedding/__init__.py +0 -0
  213. nexaai/mlx_backend/embedding/generate.py +333 -0
  214. nexaai/mlx_backend/embedding/interface.py +617 -0
  215. nexaai/mlx_backend/embedding/main.py +173 -0
  216. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  217. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  218. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  219. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  220. nexaai/mlx_backend/image_gen/interface.py +82 -0
  221. nexaai/mlx_backend/image_gen/main.py +281 -0
  222. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  223. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  224. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  225. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  226. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  227. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  228. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  229. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  230. nexaai/mlx_backend/llm/__init__.py +0 -0
  231. nexaai/mlx_backend/llm/generate.py +149 -0
  232. nexaai/mlx_backend/llm/interface.py +764 -0
  233. nexaai/mlx_backend/llm/main.py +68 -0
  234. nexaai/mlx_backend/ml.py +888 -0
  235. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  236. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  237. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  238. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  239. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  240. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  241. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  242. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  243. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  244. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  245. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  246. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  247. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  248. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  272. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  273. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  274. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  275. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  276. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  277. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  278. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  279. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  280. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  281. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  282. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  283. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  284. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  286. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  287. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  288. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  289. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  290. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  291. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  292. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  293. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  294. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  295. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  296. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  297. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  305. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  306. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  307. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  308. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  309. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  310. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  311. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  312. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  313. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  314. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  315. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  316. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  317. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  318. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  319. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  320. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  321. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  322. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  378. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  379. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  380. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  381. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  382. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  383. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  384. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  385. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  386. nexaai/mlx_backend/profiling.py +239 -0
  387. nexaai/mlx_backend/rerank/__init__.py +0 -0
  388. nexaai/mlx_backend/rerank/generate.py +174 -0
  389. nexaai/mlx_backend/rerank/interface.py +287 -0
  390. nexaai/mlx_backend/rerank/main.py +127 -0
  391. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  392. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  393. nexaai/mlx_backend/sd/__init__.py +1 -0
  394. nexaai/mlx_backend/sd/interface.py +362 -0
  395. nexaai/mlx_backend/sd/main.py +286 -0
  396. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  397. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  398. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  399. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  400. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  401. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  402. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  403. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  404. nexaai/mlx_backend/tts/__init__.py +12 -0
  405. nexaai/mlx_backend/tts/interface.py +276 -0
  406. nexaai/mlx_backend/vlm/__init__.py +3 -0
  407. nexaai/mlx_backend/vlm/generate.py +572 -0
  408. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +261 -0
  409. nexaai/mlx_backend/vlm/interface.py +415 -0
  410. nexaai/mlx_backend/vlm/main.py +316 -0
  411. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  412. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  413. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  414. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  415. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  416. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  417. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  418. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  419. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  420. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  421. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  422. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  423. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  424. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  425. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  426. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  427. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  429. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  430. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  431. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  432. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  433. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  434. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  435. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  436. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  437. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  438. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  439. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  440. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  441. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  442. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  443. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  444. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  445. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  446. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  447. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  448. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  449. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  450. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  451. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  452. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  453. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  454. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  455. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  456. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  457. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  458. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  459. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  460. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  461. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  462. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  463. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  464. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  465. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  466. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  467. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  468. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  469. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  470. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  471. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  475. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  476. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  477. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  478. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  479. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  480. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  481. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  482. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  483. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  484. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  485. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  486. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  487. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  488. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  489. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  490. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  492. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  493. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  494. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  496. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  497. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  498. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  499. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  500. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  501. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  502. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  503. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  504. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  505. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  506. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  507. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  508. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  509. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  510. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  511. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  512. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  513. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  514. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  515. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
  522. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  523. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  524. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  525. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  526. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  527. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  528. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  529. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  530. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  531. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  532. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  533. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  534. nexaai/rerank.py +55 -0
  535. nexaai/rerank_impl/__init__.py +0 -0
  536. nexaai/rerank_impl/mlx_rerank_impl.py +92 -0
  537. nexaai/rerank_impl/pybind_rerank_impl.py +43 -0
  538. nexaai/runtime.py +68 -0
  539. nexaai/tts.py +74 -0
  540. nexaai/tts_impl/__init__.py +0 -0
  541. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  542. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  543. nexaai/utils/avatar_fetcher.py +104 -0
  544. nexaai/utils/decode.py +18 -0
  545. nexaai/utils/manifest_utils.py +324 -0
  546. nexaai/utils/model_manager.py +1353 -0
  547. nexaai/utils/model_types.py +47 -0
  548. nexaai/utils/progress_tracker.py +385 -0
  549. nexaai/utils/quantization_utils.py +245 -0
  550. nexaai/vlm.py +128 -0
  551. nexaai/vlm_impl/__init__.py +0 -0
  552. nexaai/vlm_impl/mlx_vlm_impl.py +258 -0
  553. nexaai/vlm_impl/pybind_vlm_impl.py +230 -0
  554. nexaai-1.0.16rc13.dist-info/METADATA +32 -0
  555. nexaai-1.0.16rc13.dist-info/RECORD +557 -0
  556. nexaai-1.0.16rc13.dist-info/WHEEL +5 -0
  557. nexaai-1.0.16rc13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1353 @@
1
+ import os
2
+ import shutil
3
+ import json
4
+ from datetime import datetime
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Callable, Dict, Any, List, Union
7
+ import functools
8
+ from enum import Enum
9
+ from tqdm.auto import tqdm
10
+ from huggingface_hub import HfApi
11
+ from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError
12
+
13
+ from .progress_tracker import CustomProgressTqdm, DownloadProgressTracker
14
+ from .avatar_fetcher import get_avatar_url_for_repo
15
+ from .manifest_utils import (
16
+ load_download_metadata,
17
+ save_download_metadata,
18
+ save_manifest_with_files_metadata,
19
+ )
20
+
21
+ # Default path for model storage
22
+ DEFAULT_MODEL_SAVING_PATH = "~/.cache/nexa.ai/nexa_sdk/models/"
23
+
24
+
25
+ @dataclass
26
+ class MMProjInfo:
27
+ """Data class for mmproj file information."""
28
+ mmproj_path: Optional[str] = None
29
+ size: int = 0
30
+
31
+ @dataclass
32
+ class DownloadedModel:
33
+ """Data class representing a downloaded model with all its metadata."""
34
+ repo_id: str
35
+ files: List[str]
36
+ folder_type: str # 'owner_repo' or 'direct_repo'
37
+ local_path: str
38
+ size_bytes: int
39
+ file_count: int
40
+ full_repo_download_complete: bool = True # True if no incomplete downloads detected
41
+ pipeline_tag: Optional[str] = None # Pipeline tag from HuggingFace model info
42
+ download_time: Optional[str] = None # ISO format timestamp of download
43
+ avatar_url: Optional[str] = None # Avatar URL for the model author
44
+ mmproj_info: Optional[MMProjInfo] = None # mmproj file information
45
+
46
+ def to_dict(self) -> Dict[str, Any]:
47
+ """Convert to dictionary format for backward compatibility."""
48
+ result = {
49
+ 'repo_id': self.repo_id,
50
+ 'files': self.files,
51
+ 'folder_type': self.folder_type,
52
+ 'local_path': self.local_path,
53
+ 'size_bytes': self.size_bytes,
54
+ 'file_count': self.file_count,
55
+ 'full_repo_download_complete': self.full_repo_download_complete,
56
+ 'pipeline_tag': self.pipeline_tag,
57
+ 'download_time': self.download_time,
58
+ 'avatar_url': self.avatar_url,
59
+ 'mmproj_info': {
60
+ 'mmproj_path': self.mmproj_info.mmproj_path,
61
+ 'size': self.mmproj_info.size
62
+ } if self.mmproj_info else None
63
+ }
64
+ return result
65
+
66
+
67
+ ##########################################################################
68
+ # List downloaded models #
69
+ ##########################################################################
70
+
71
+
72
+ def _check_for_incomplete_downloads(directory_path: str) -> bool:
73
+ """
74
+ Check if there are incomplete downloads in the model directory.
75
+
76
+ This function checks for the presence of .incomplete or .lock files
77
+ in the .cache/huggingface/download directory within the model folder,
78
+ which indicates that the model download has not completed.
79
+
80
+ Args:
81
+ directory_path: Path to the model directory
82
+
83
+ Returns:
84
+ bool: True if download is complete (no incomplete files found),
85
+ False if incomplete downloads are detected
86
+ """
87
+ # Check for .cache/huggingface/download directory
88
+ cache_dir = os.path.join(directory_path, '.cache', 'huggingface', 'download')
89
+
90
+ # If the cache directory doesn't exist, assume download is complete
91
+ if not os.path.exists(cache_dir):
92
+ return True
93
+
94
+ try:
95
+ # Walk through the cache directory to find incomplete or lock files
96
+ for root, dirs, files in os.walk(cache_dir):
97
+ for filename in files:
98
+ # Check for .incomplete or .lock files
99
+ if filename.endswith('.incomplete'):
100
+ return False # Found incomplete download
101
+
102
+ # No incomplete files found
103
+ return True
104
+ except (OSError, IOError):
105
+ # If we can't access the directory, assume download is complete
106
+ return True
107
+
108
+ def _get_directory_size_and_files(directory_path: str) -> tuple[int, List[str]]:
109
+ """Get total size and list of files in a directory."""
110
+ total_size = 0
111
+ files = []
112
+
113
+ try:
114
+ for root, dirs, filenames in os.walk(directory_path):
115
+ for filename in filenames:
116
+ file_path = os.path.join(root, filename)
117
+ try:
118
+ file_size = os.path.getsize(file_path)
119
+ total_size += file_size
120
+ # Store relative path from the directory
121
+ rel_path = os.path.relpath(file_path, directory_path)
122
+ files.append(rel_path)
123
+ except (OSError, IOError):
124
+ # Skip files that can't be accessed
125
+ continue
126
+ except (OSError, IOError):
127
+ # Skip directories that can't be accessed
128
+ pass
129
+
130
+ return total_size, files
131
+
132
+
133
+ def _has_valid_metadata(directory_path: str) -> bool:
134
+ """Check if directory has either nexa.manifest or download_metadata.json (for backward compatibility)."""
135
+ manifest_path = os.path.join(directory_path, 'nexa.manifest')
136
+ old_metadata_path = os.path.join(directory_path, 'download_metadata.json')
137
+ return os.path.exists(manifest_path) or os.path.exists(old_metadata_path)
138
+
139
+
140
+ def _extract_mmproj_info(manifest: Dict[str, Any], local_path: str) -> Optional[MMProjInfo]:
141
+ """
142
+ Extract mmproj information from manifest data.
143
+
144
+ Args:
145
+ manifest: Dictionary containing manifest data
146
+ local_path: Local path to the model directory
147
+
148
+ Returns:
149
+ MMProjInfo object if mmproj file exists, None otherwise
150
+ """
151
+ # Check if manifest has MMProjFile information
152
+ mmproj_file_info = manifest.get('MMProjFile')
153
+ if not mmproj_file_info or not mmproj_file_info.get('Downloaded') or not mmproj_file_info.get('Name'):
154
+ return None
155
+
156
+ mmproj_filename = mmproj_file_info.get('Name', '')
157
+ if not mmproj_filename:
158
+ return None
159
+
160
+ # Construct full path to mmproj file
161
+ mmproj_path = os.path.join(local_path, mmproj_filename)
162
+
163
+ # Get size from manifest, but verify file exists
164
+ mmproj_size = mmproj_file_info.get('Size', 0)
165
+ if os.path.exists(mmproj_path):
166
+ try:
167
+ # Verify size matches actual file size
168
+ actual_size = os.path.getsize(mmproj_path)
169
+ mmproj_size = actual_size # Use actual size if different
170
+ except (OSError, IOError):
171
+ # If we can't get actual size, use size from manifest
172
+ pass
173
+ else:
174
+ # File doesn't exist, don't include mmproj info
175
+ return None
176
+
177
+ return MMProjInfo(mmproj_path=mmproj_path, size=mmproj_size)
178
+
179
+
180
+ def _scan_for_repo_folders(base_path: str) -> List[DownloadedModel]:
181
+ """Scan a directory for repository folders and return model information."""
182
+ models = []
183
+
184
+ try:
185
+ if not os.path.exists(base_path):
186
+ return models
187
+
188
+ for item in os.listdir(base_path):
189
+ item_path = os.path.join(base_path, item)
190
+
191
+ # Skip non-directory items
192
+ if not os.path.isdir(item_path):
193
+ continue
194
+
195
+ # Check if this might be an owner folder by looking for subdirectories
196
+ has_subdirs = False
197
+ direct_files = []
198
+
199
+ try:
200
+ for subitem in os.listdir(item_path):
201
+ subitem_path = os.path.join(item_path, subitem)
202
+ if os.path.isdir(subitem_path):
203
+ has_subdirs = True
204
+ # This looks like owner/repo structure
205
+ # Only include if nexa.manifest or download_metadata.json exists (backward compatibility)
206
+ if _has_valid_metadata(subitem_path):
207
+ size_bytes, files = _get_directory_size_and_files(subitem_path)
208
+ if files: # Only include if there are files
209
+ # Check if the download is complete
210
+ download_complete = _check_for_incomplete_downloads(subitem_path)
211
+ # Load metadata if it exists
212
+ repo_id = f"{item}/{subitem}"
213
+ metadata = load_download_metadata(subitem_path, repo_id)
214
+
215
+ # Extract mmproj information
216
+ mmproj_info = _extract_mmproj_info(metadata, subitem_path)
217
+
218
+ models.append(DownloadedModel(
219
+ repo_id=repo_id,
220
+ files=files,
221
+ folder_type='owner_repo',
222
+ local_path=subitem_path,
223
+ size_bytes=size_bytes,
224
+ file_count=len(files),
225
+ full_repo_download_complete=download_complete,
226
+ pipeline_tag=metadata.get('pipeline_tag'),
227
+ download_time=metadata.get('download_time'),
228
+ avatar_url=metadata.get('avatar_url'),
229
+ mmproj_info=mmproj_info
230
+ ))
231
+ else:
232
+ direct_files.append(subitem)
233
+ except (OSError, IOError):
234
+ # Skip directories that can't be accessed
235
+ continue
236
+
237
+ # Direct repo folder (no owner structure)
238
+ if not has_subdirs and direct_files:
239
+ # Only include if nexa.manifest or download_metadata.json exists (backward compatibility)
240
+ if _has_valid_metadata(item_path):
241
+ size_bytes, files = _get_directory_size_and_files(item_path)
242
+ if files: # Only include if there are files
243
+ # Check if the download is complete
244
+ download_complete = _check_for_incomplete_downloads(item_path)
245
+ # Load metadata if it exists
246
+ repo_id = item
247
+ metadata = load_download_metadata(item_path, repo_id)
248
+
249
+ # Extract mmproj information
250
+ mmproj_info = _extract_mmproj_info(metadata, item_path)
251
+
252
+ models.append(DownloadedModel(
253
+ repo_id=repo_id,
254
+ files=files,
255
+ folder_type='direct_repo',
256
+ local_path=item_path,
257
+ size_bytes=size_bytes,
258
+ file_count=len(files),
259
+ full_repo_download_complete=download_complete,
260
+ pipeline_tag=metadata.get('pipeline_tag'),
261
+ download_time=metadata.get('download_time'),
262
+ avatar_url=metadata.get('avatar_url'),
263
+ mmproj_info=mmproj_info
264
+ ))
265
+
266
+ except (OSError, IOError):
267
+ # Skip if base path can't be accessed
268
+ pass
269
+
270
+ return models
271
+
272
+
273
+ def list_downloaded_models(local_dir: Optional[str] = None) -> List[DownloadedModel]:
274
+ """
275
+ List all downloaded models in the specified directory.
276
+
277
+ This function scans the local directory for downloaded models and returns
278
+ information about each repository including files, size, and folder structure.
279
+
280
+ It handles different folder naming conventions:
281
+ - Owner/repo structure (e.g., "microsoft/DialoGPT-small")
282
+ - Direct repo folders (repos without owner prefix)
283
+
284
+ Args:
285
+ local_dir (str, optional): Directory to scan for downloaded models.
286
+ If None, uses DEFAULT_MODEL_SAVING_PATH.
287
+
288
+ Returns:
289
+ List[DownloadedModel]: List of DownloadedModel objects with attributes:
290
+ - repo_id: str - Repository ID (e.g., "owner/repo")
291
+ - files: List[str] - List of relative file paths in the repository
292
+ - folder_type: str - 'owner_repo' or 'direct_repo'
293
+ - local_path: str - Full path to the model directory
294
+ - size_bytes: int - Total size of all files in bytes
295
+ - file_count: int - Number of files in the repository
296
+ - full_repo_download_complete: bool - True if no incomplete downloads detected,
297
+ False if .incomplete or .lock files exist
298
+ - pipeline_tag: Optional[str] - Pipeline tag from HuggingFace model info
299
+ - download_time: Optional[str] - ISO format timestamp when the model was downloaded
300
+ - avatar_url: Optional[str] - Avatar URL for the model author
301
+ - mmproj_info: Optional[MMProjInfo] - mmproj file information with mmproj_path and size
302
+ """
303
+
304
+ # Set up local directory
305
+ if local_dir is None:
306
+ local_dir = os.path.expanduser(DEFAULT_MODEL_SAVING_PATH)
307
+
308
+ local_dir = os.path.abspath(local_dir)
309
+
310
+ if not os.path.exists(local_dir):
311
+ return []
312
+
313
+ # Scan for repository folders
314
+ models = _scan_for_repo_folders(local_dir)
315
+
316
+ # Sort by repo_id for consistent output
317
+ models.sort(key=lambda x: x.repo_id)
318
+
319
+ return models
320
+
321
+
322
+ ##########################################################################
323
+ # Remove model functions #
324
+ ##########################################################################
325
+
326
+
327
+ def _parse_model_path(model_path: str) -> tuple[str, str | None]:
328
+ """
329
+ Parse model_path to extract repo_id and optional filename.
330
+
331
+ Examples:
332
+ "microsoft/DialoGPT-small" -> ("microsoft/DialoGPT-small", None)
333
+ "microsoft/DialoGPT-small/pytorch_model.bin" -> ("microsoft/DialoGPT-small", "pytorch_model.bin")
334
+ "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf" -> ("Qwen/Qwen3-4B-GGUF", "Qwen3-4B-Q4_K_M.gguf")
335
+
336
+ Args:
337
+ model_path: The model path string
338
+
339
+ Returns:
340
+ Tuple of (repo_id, filename) where filename can be None
341
+ """
342
+ parts = model_path.strip().split('/')
343
+
344
+ if len(parts) < 2:
345
+ # Invalid format, assume it's just a repo name without owner
346
+ return model_path, None
347
+ elif len(parts) == 2:
348
+ # Format: "owner/repo"
349
+ return model_path, None
350
+ else:
351
+ # Format: "owner/repo/file" or "owner/repo/subdir/file"
352
+ repo_id = f"{parts[0]}/{parts[1]}"
353
+ filename = '/'.join(parts[2:])
354
+ return repo_id, filename
355
+
356
+
357
+ def _validate_and_parse_input(model_path: str) -> tuple[str, Optional[str]]:
358
+ """Validate input and parse model path."""
359
+ if not model_path or not isinstance(model_path, str) or not model_path.strip():
360
+ raise ValueError("model_path is required and must be a non-empty string")
361
+
362
+ model_path = model_path.strip()
363
+ return _parse_model_path(model_path)
364
+
365
+
366
+ def _find_target_model(repo_id: str, local_dir: str) -> DownloadedModel:
367
+ """Find and validate the target model exists."""
368
+ downloaded_models = list_downloaded_models(local_dir)
369
+
370
+ for model in downloaded_models:
371
+ if model.repo_id == repo_id:
372
+ return model
373
+
374
+ available_repos = [model.repo_id for model in downloaded_models]
375
+ raise FileNotFoundError(
376
+ f"Repository '{repo_id}' not found in downloaded models. "
377
+ f"Available repositories: {available_repos}"
378
+ )
379
+
380
+
381
+ def _clean_empty_owner_directory(target_model: DownloadedModel) -> None:
382
+ """Remove empty owner directory if applicable."""
383
+ if target_model.folder_type != 'owner_repo':
384
+ return
385
+
386
+ parent_dir = os.path.dirname(target_model.local_path)
387
+ try:
388
+ if os.path.exists(parent_dir) and not os.listdir(parent_dir):
389
+ os.rmdir(parent_dir)
390
+ except OSError:
391
+ pass
392
+
393
+
394
+ def _remove_specific_file(target_model: DownloadedModel, file_name: str, local_dir: str) -> DownloadedModel:
395
+ """Remove a specific file from the repository."""
396
+ # Validate file exists in model
397
+ if file_name not in target_model.files:
398
+ raise FileNotFoundError(
399
+ f"File '{file_name}' not found in repository '{target_model.repo_id}'. "
400
+ f"Available files: {target_model.files[:10]}{'...' if len(target_model.files) > 10 else ''}"
401
+ )
402
+
403
+ # Construct full file path and validate it exists on disk
404
+ file_path = os.path.join(target_model.local_path, file_name)
405
+ if not os.path.exists(file_path):
406
+ raise FileNotFoundError(f"File does not exist on disk: {file_path}")
407
+
408
+ # Get file size before removal
409
+ try:
410
+ file_size = os.path.getsize(file_path)
411
+ except OSError:
412
+ file_size = 0
413
+
414
+ # Remove the file
415
+ try:
416
+ os.remove(file_path)
417
+ except OSError as e:
418
+ raise OSError(f"Failed to remove file '{file_path}': {e}")
419
+
420
+ # Create updated model object
421
+ updated_files = [f for f in target_model.files if f != file_name]
422
+ updated_size = target_model.size_bytes - file_size
423
+ # Re-check download completeness after file removal
424
+ download_complete = _check_for_incomplete_downloads(target_model.local_path)
425
+ updated_model = DownloadedModel(
426
+ repo_id=target_model.repo_id,
427
+ files=updated_files,
428
+ folder_type=target_model.folder_type,
429
+ local_path=target_model.local_path,
430
+ size_bytes=updated_size,
431
+ file_count=len(updated_files),
432
+ full_repo_download_complete=download_complete
433
+ )
434
+
435
+ # If no files left, remove the entire directory
436
+ if len(updated_files) == 0:
437
+ try:
438
+ shutil.rmtree(target_model.local_path)
439
+ _clean_empty_owner_directory(target_model)
440
+ except OSError:
441
+ pass
442
+
443
+ return updated_model
444
+
445
+
446
+ def _remove_entire_repository(target_model: DownloadedModel, local_dir: str) -> DownloadedModel:
447
+ """Remove the entire repository and clean up."""
448
+ # Remove the directory and all its contents
449
+ try:
450
+ shutil.rmtree(target_model.local_path)
451
+ except OSError as e:
452
+ raise OSError(f"Failed to remove directory '{target_model.local_path}': {e}")
453
+
454
+ # Clean up associated resources
455
+ _clean_empty_owner_directory(target_model)
456
+
457
+ return target_model
458
+
459
+
460
+ def remove_model_or_file(
461
+ model_path: str,
462
+ local_dir: Optional[str] = None
463
+ ) -> DownloadedModel:
464
+ """
465
+ Remove a downloaded model or specific file by repository ID or file path.
466
+
467
+ This function supports two modes:
468
+ 1. Remove entire repository: "microsoft/DialoGPT-small"
469
+ 2. Remove specific file: "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf"
470
+
471
+ For entire repository removal, it removes the directory and all files. For specific file removal, it only
472
+ removes that file and updates the repository metadata.
473
+
474
+ Args:
475
+ model_path (str): Required. Either:
476
+ - Repository ID (e.g., "microsoft/DialoGPT-small") - removes entire repo
477
+ - File path (e.g., "Qwen/Qwen3-4B-GGUF/model.gguf") - removes specific file
478
+ local_dir (str, optional): Directory to search for downloaded models.
479
+ If None, uses DEFAULT_MODEL_SAVING_PATH.
480
+
481
+ Returns:
482
+ DownloadedModel: The model object representing what was removed from disk.
483
+ For file removal, returns updated model info after file removal.
484
+
485
+ Raises:
486
+ ValueError: If model_path is invalid (empty or None)
487
+ FileNotFoundError: If the repository or file is not found in downloaded models
488
+ OSError: If there's an error removing files from disk
489
+ """
490
+ # Validate input and parse path
491
+ repo_id, file_name = _validate_and_parse_input(model_path)
492
+
493
+ # Set up local directory
494
+ if local_dir is None:
495
+ local_dir = os.path.expanduser(DEFAULT_MODEL_SAVING_PATH)
496
+
497
+ local_dir = os.path.abspath(local_dir)
498
+
499
+ if not os.path.exists(local_dir):
500
+ raise FileNotFoundError(f"Local directory does not exist: {local_dir}")
501
+
502
+ # Find the target model
503
+ target_model = _find_target_model(repo_id, local_dir)
504
+
505
+ # Delegate to appropriate removal function
506
+ if file_name:
507
+ return _remove_specific_file(target_model, file_name, local_dir)
508
+ else:
509
+ return _remove_entire_repository(target_model, local_dir)
510
+
511
+
512
+ ##########################################################################
513
+ # Check model existence functions #
514
+ ##########################################################################
515
+
516
+
517
+ def check_model_existence(
518
+ model_path: str,
519
+ local_dir: Optional[str] = None
520
+ ) -> bool:
521
+ """
522
+ Check if a downloaded model or specific file exists locally.
523
+
524
+ This function supports two modes:
525
+ 1. Check entire repository: "microsoft/DialoGPT-small"
526
+ 2. Check specific file: "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf"
527
+
528
+ Args:
529
+ model_path (str): Required. Either:
530
+ - Repository ID (e.g., "microsoft/DialoGPT-small") - checks entire repo
531
+ - File path (e.g., "Qwen/Qwen3-4B-GGUF/model.gguf") - checks specific file
532
+ local_dir (str, optional): Directory to search for downloaded models.
533
+ If None, uses DEFAULT_MODEL_SAVING_PATH.
534
+
535
+ Returns:
536
+ bool: True if the requested item exists, False otherwise
537
+
538
+ Raises:
539
+ ValueError: If model_path is invalid (empty or None)
540
+ """
541
+ # Validate input and parse path
542
+ repo_id, file_name = _validate_and_parse_input(model_path)
543
+
544
+ # Set up local directory
545
+ if local_dir is None:
546
+ local_dir = os.path.expanduser(DEFAULT_MODEL_SAVING_PATH)
547
+
548
+ local_dir = os.path.abspath(local_dir)
549
+
550
+ # Return False if local directory doesn't exist
551
+ if not os.path.exists(local_dir):
552
+ return False
553
+
554
+ # Get all downloaded models
555
+ downloaded_models = list_downloaded_models(local_dir)
556
+
557
+ # Find the target model
558
+ for model in downloaded_models:
559
+ if model.repo_id == repo_id:
560
+ # If no specific file requested, repository existence is sufficient
561
+ if file_name is None:
562
+ return True
563
+ else:
564
+ # Check specific file existence
565
+ return file_name in model.files
566
+
567
+ return False
568
+
569
+
570
+ ##########################################################################
571
+ # HuggingFace Downloader Class #
572
+ ##########################################################################
573
+
574
+
575
+ class HuggingFaceDownloader:
576
+ """Class to handle downloads from HuggingFace Hub with unified API usage."""
577
+
578
+ def __init__(
579
+ self,
580
+ endpoint: Optional[str] = None,
581
+ token: Union[bool, str, None] = None,
582
+ enable_transfer: bool = True
583
+ ):
584
+ """
585
+ Initialize the downloader with HuggingFace API.
586
+
587
+ Args:
588
+ endpoint: Custom endpoint URL (e.g., "https://hf-mirror.com").
589
+ If None, uses default HuggingFace Hub.
590
+ token: Authentication token for private repositories.
591
+ enable_transfer: Whether to enable HF transfer for faster downloads.
592
+ """
593
+ # Always create an HfApi instance - either with custom endpoint or default
594
+ self.token = token if isinstance(token, str) else False # False means disable authentication
595
+ self.api = HfApi(endpoint=endpoint, token=self.token) if endpoint else HfApi(token=self.token)
596
+ self.enable_transfer = enable_transfer
597
+ self.original_hf_transfer = None
598
+ self.endpoint = endpoint # Store endpoint for avatar fetching
599
+
600
+ def _create_repo_directory(self, local_dir: str, repo_id: str) -> str:
601
+ """Create a directory structure for the repository following HF convention."""
602
+ if '/' in repo_id:
603
+ # Standard format: owner/repo
604
+ owner, repo = repo_id.split('/', 1)
605
+ repo_dir = os.path.join(local_dir, owner, repo)
606
+ else:
607
+ # Direct repo name without owner
608
+ repo_dir = os.path.join(local_dir, repo_id)
609
+
610
+ os.makedirs(repo_dir, exist_ok=True)
611
+ return repo_dir
612
+
613
+ def _created_dir_if_not_exists(self, local_dir: Optional[str]) -> str:
614
+ """Create directory if it doesn't exist and return the expanded path."""
615
+ if local_dir is None:
616
+ local_dir = DEFAULT_MODEL_SAVING_PATH
617
+
618
+ local_dir = os.path.expanduser(local_dir)
619
+ os.makedirs(local_dir, exist_ok=True)
620
+ return local_dir
621
+
622
+ def _get_repo_info_for_progress(
623
+ self,
624
+ repo_id: str,
625
+ file_name: Optional[Union[str, List[str]]] = None
626
+ ) -> tuple[int, int]:
627
+ """Get total repository size and file count for progress tracking."""
628
+ try:
629
+ info = self.api.model_info(repo_id, files_metadata=True, token=self.token)
630
+
631
+ total_size = 0
632
+ file_count = 0
633
+
634
+ if info.siblings:
635
+ for sibling in info.siblings:
636
+ # Handle different file_name types
637
+ if file_name is not None:
638
+ if isinstance(file_name, str):
639
+ # Single file - only count if it matches
640
+ if sibling.rfilename != file_name:
641
+ continue
642
+ elif isinstance(file_name, list):
643
+ # Multiple files - only count if in the list
644
+ if sibling.rfilename not in file_name:
645
+ continue
646
+
647
+ # For all matching files (or all files if file_name is None)
648
+ if hasattr(sibling, 'size') and sibling.size is not None:
649
+ total_size += sibling.size
650
+ file_count += 1
651
+ else:
652
+ # Count files without size info
653
+ file_count += 1
654
+
655
+ return total_size, file_count if file_count > 0 else 1
656
+ except Exception:
657
+ # If we can't get info, return defaults
658
+ return 0, 1
659
+
660
+ def _validate_and_setup_params(
661
+ self,
662
+ repo_id: str,
663
+ file_name: Optional[Union[str, List[str]]]
664
+ ) -> tuple[str, Optional[Union[str, List[str]]]]:
665
+ """Validate and normalize input parameters."""
666
+ if not repo_id:
667
+ raise ValueError("repo_id is required")
668
+
669
+ repo_id = repo_id.strip()
670
+
671
+ # Handle file_name parameter
672
+ if file_name is not None:
673
+ if isinstance(file_name, str):
674
+ file_name = file_name.strip()
675
+ if not file_name:
676
+ file_name = None
677
+ elif isinstance(file_name, list):
678
+ # Filter out empty strings and strip whitespace
679
+ file_name = [f.strip() for f in file_name if f and f.strip()]
680
+ if not file_name:
681
+ file_name = None
682
+ else:
683
+ raise ValueError("file_name must be a string, list of strings, or None")
684
+
685
+ return repo_id, file_name
686
+
687
+ def _setup_progress_tracker(
688
+ self,
689
+ progress_callback: Optional[Callable[[Dict[str, Any]], None]],
690
+ show_progress: bool,
691
+ repo_id: str,
692
+ file_name: Optional[Union[str, List[str]]]
693
+ ) -> Optional[DownloadProgressTracker]:
694
+ """Initialize progress tracker if callback is provided."""
695
+ if not progress_callback:
696
+ return None
697
+
698
+ progress_tracker = DownloadProgressTracker(progress_callback, show_progress)
699
+ # Get repo info for progress tracking - now handles all cases
700
+ total_size, file_count = self._get_repo_info_for_progress(repo_id, file_name)
701
+ progress_tracker.set_repo_info(total_size, file_count)
702
+ return progress_tracker
703
+
704
+ def _setup_hf_transfer_env(self) -> None:
705
+ """Set up HF transfer environment."""
706
+ self.original_hf_transfer = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER")
707
+ if self.enable_transfer:
708
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
709
+
710
+ def _cleanup_hf_transfer_env(self) -> None:
711
+ """Restore original HF transfer environment."""
712
+ if self.original_hf_transfer is not None:
713
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = self.original_hf_transfer
714
+ else:
715
+ os.environ.pop("HF_HUB_ENABLE_HF_TRANSFER", None)
716
+
717
+ def _validate_repository_and_get_info(
718
+ self,
719
+ repo_id: str,
720
+ progress_tracker: Optional[DownloadProgressTracker]
721
+ ):
722
+ """Validate repository exists and get info."""
723
+ try:
724
+ info = self.api.model_info(repo_id, token=self.token)
725
+ return info
726
+ except RepositoryNotFoundError:
727
+ error_msg = f"Repository '{repo_id}' not found. Please check the repository ID."
728
+ if progress_tracker:
729
+ progress_tracker.set_error(error_msg)
730
+ raise RepositoryNotFoundError(error_msg)
731
+ except HfHubHTTPError as e:
732
+ if e.response.status_code == 404:
733
+ error_msg = f"Repository '{repo_id}' not found. Please check the repository ID."
734
+ if progress_tracker:
735
+ progress_tracker.set_error(error_msg)
736
+ raise RepositoryNotFoundError(error_msg)
737
+ else:
738
+ error_msg = f"HTTP error while accessing repository '{repo_id}': {e}"
739
+ if progress_tracker:
740
+ progress_tracker.set_error(error_msg)
741
+ raise HfHubHTTPError(error_msg)
742
+
743
+ def _validate_file_exists_in_repo(
744
+ self,
745
+ file_name: str,
746
+ info,
747
+ repo_id: str,
748
+ progress_tracker: Optional[DownloadProgressTracker]
749
+ ) -> None:
750
+ """Validate that the file exists in the repository."""
751
+ file_exists = False
752
+ if info.siblings:
753
+ for sibling in info.siblings:
754
+ if sibling.rfilename == file_name:
755
+ file_exists = True
756
+ break
757
+
758
+ if not file_exists:
759
+ available_files = [sibling.rfilename for sibling in info.siblings] if info.siblings else []
760
+ error_msg = (
761
+ f"File '{file_name}' not found in repository '{repo_id}'. "
762
+ f"Available files: {available_files[:10]}{'...' if len(available_files) > 10 else ''}"
763
+ )
764
+ if progress_tracker:
765
+ progress_tracker.set_error(error_msg)
766
+ progress_tracker.stop_tracking()
767
+ raise ValueError(error_msg)
768
+
769
+ def _check_file_exists_and_valid(
770
+ self,
771
+ file_path: str,
772
+ expected_size: Optional[int] = None
773
+ ) -> bool:
774
+ """Check if a file exists and is valid (non-empty, correct size if known)."""
775
+ if not os.path.exists(file_path):
776
+ return False
777
+
778
+ # Check file is not empty
779
+ try:
780
+ file_size = os.path.getsize(file_path)
781
+ if file_size == 0:
782
+ return False
783
+ except (OSError, IOError):
784
+ return False
785
+
786
+ # If we have expected size, check it matches
787
+ if expected_size is not None and file_size != expected_size:
788
+ return False
789
+
790
+ # If no expected size, just check that file is not empty
791
+ return os.path.getsize(file_path) > 0
792
+
793
+ def _fetch_and_save_metadata(self, repo_id: str, local_dir: str, is_mmproj: bool = False, file_name: Optional[Union[str, List[str]]] = None) -> None:
794
+ """Fetch model info and save metadata after successful download."""
795
+ # Initialize metadata with defaults to ensure manifest is always created
796
+ old_metadata = {
797
+ 'pipeline_tag': "text-generation", # Default to text-generation pipeline-tag
798
+ 'download_time': datetime.now().isoformat(),
799
+ 'avatar_url': None
800
+ }
801
+
802
+ # Try to fetch additional metadata, but don't let failures prevent manifest creation
803
+ try:
804
+ # Fetch model info to get pipeline_tag
805
+ info = self.api.model_info(repo_id, token=self.token)
806
+ if hasattr(info, 'pipeline_tag') and info.pipeline_tag:
807
+ old_metadata['pipeline_tag'] = info.pipeline_tag
808
+ except Exception as e:
809
+ # Log the error but continue with manifest creation
810
+ print(f"Warning: Could not fetch model info for {repo_id}: {e}")
811
+
812
+ try:
813
+ # Get avatar URL
814
+ avatar_url = get_avatar_url_for_repo(repo_id, custom_endpoint=self.endpoint)
815
+ if avatar_url:
816
+ old_metadata['avatar_url'] = avatar_url
817
+ except Exception as e:
818
+ # Log the error but continue with manifest creation
819
+ print(f"Warning: Could not fetch avatar URL for {repo_id}: {e}")
820
+
821
+ # CRITICAL: Always create the manifest file, regardless of metadata fetch failures
822
+ try:
823
+ save_manifest_with_files_metadata(repo_id, local_dir, old_metadata, is_mmproj, file_name)
824
+ print(f"[OK] Successfully created nexa.manifest for {repo_id}")
825
+ except Exception as e:
826
+ # This is critical - if manifest creation fails, we should know about it
827
+ print(f"ERROR: Failed to create nexa.manifest for {repo_id}: {e}")
828
+ # Try a fallback approach - create a minimal manifest
829
+ try:
830
+ minimal_manifest = {
831
+ "Name": repo_id,
832
+ "ModelType": "other",
833
+ "PluginId": "unknown",
834
+ "ModelFile": {},
835
+ "MMProjFile": {"Name": "", "Downloaded": False, "Size": 0},
836
+ "TokenizerFile": {"Name": "", "Downloaded": False, "Size": 0},
837
+ "ExtraFiles": None,
838
+ "pipeline_tag": old_metadata.get('pipeline_tag'),
839
+ "download_time": old_metadata.get('download_time'),
840
+ "avatar_url": old_metadata.get('avatar_url')
841
+ }
842
+ save_download_metadata(local_dir, minimal_manifest)
843
+ print(f"[OK] Created minimal nexa.manifest for {repo_id} as fallback")
844
+ except Exception as fallback_error:
845
+ print(f"CRITICAL ERROR: Could not create even minimal manifest for {repo_id}: {fallback_error}")
846
+
847
+ def _download_single_file(
848
+ self,
849
+ repo_id: str,
850
+ file_name: str,
851
+ local_dir: str,
852
+ progress_tracker: Optional[DownloadProgressTracker],
853
+ force_download: bool = False
854
+ ) -> str:
855
+ """Download a single file from the repository using HuggingFace Hub API."""
856
+ # Create repo-specific directory for the single file
857
+ file_local_dir = self._create_repo_directory(local_dir, repo_id)
858
+
859
+ # Check if file already exists
860
+ local_file_path = os.path.join(file_local_dir, file_name)
861
+ if not force_download and self._check_file_exists_and_valid(local_file_path):
862
+ print(f"[SKIP] File already exists: {file_name}")
863
+ # Stop progress tracking
864
+ if progress_tracker:
865
+ progress_tracker.stop_tracking()
866
+ return local_file_path
867
+
868
+ try:
869
+ # Note: hf_hub_download doesn't support tqdm_class parameter
870
+ # Progress tracking works through the global tqdm monkey patching
871
+ downloaded_path = self.api.hf_hub_download(
872
+ repo_id=repo_id,
873
+ filename=file_name,
874
+ local_dir=file_local_dir,
875
+ local_dir_use_symlinks=False,
876
+ token=self.token,
877
+ force_download=force_download
878
+ )
879
+
880
+ # Stop progress tracking
881
+ if progress_tracker:
882
+ progress_tracker.stop_tracking()
883
+
884
+ # Save metadata after successful download
885
+ self._fetch_and_save_metadata(repo_id, file_local_dir, self._current_is_mmproj, self._current_file_name)
886
+
887
+ return downloaded_path
888
+
889
+ except HfHubHTTPError as e:
890
+ error_msg = f"Error downloading file '{file_name}': {e}"
891
+ if progress_tracker:
892
+ progress_tracker.set_error(error_msg)
893
+ progress_tracker.stop_tracking()
894
+ if e.response.status_code == 404:
895
+ raise ValueError(f"File '{file_name}' not found in repository '{repo_id}'")
896
+ else:
897
+ raise HfHubHTTPError(error_msg)
898
+
899
+ def _download_entire_repository(
900
+ self,
901
+ repo_id: str,
902
+ local_dir: str,
903
+ progress_tracker: Optional[DownloadProgressTracker],
904
+ force_download: bool = False
905
+ ) -> str:
906
+ """Download the entire repository."""
907
+ # Create a subdirectory for this specific repo
908
+ repo_local_dir = self._create_repo_directory(local_dir, repo_id)
909
+
910
+ try:
911
+ download_kwargs = {
912
+ 'repo_id': repo_id,
913
+ 'local_dir': repo_local_dir,
914
+ 'local_dir_use_symlinks': False,
915
+ 'token': self.token,
916
+ 'force_download': force_download
917
+ }
918
+
919
+ # Add tqdm_class if progress tracking is enabled
920
+ if progress_tracker:
921
+ download_kwargs['tqdm_class'] = CustomProgressTqdm
922
+
923
+ downloaded_path = self.api.snapshot_download(**download_kwargs)
924
+
925
+ # Stop progress tracking
926
+ if progress_tracker:
927
+ progress_tracker.stop_tracking()
928
+
929
+ # Save metadata after successful download
930
+ self._fetch_and_save_metadata(repo_id, repo_local_dir, self._current_is_mmproj, self._current_file_name)
931
+
932
+ return downloaded_path
933
+
934
+ except HfHubHTTPError as e:
935
+ error_msg = f"Error downloading repository '{repo_id}': {e}"
936
+ if progress_tracker:
937
+ progress_tracker.set_error(error_msg)
938
+ progress_tracker.stop_tracking()
939
+ raise HfHubHTTPError(error_msg)
940
+
941
+ def _download_multiple_files_from_hf(
942
+ self,
943
+ repo_id: str,
944
+ file_names: List[str],
945
+ local_dir: str,
946
+ progress_tracker: Optional[DownloadProgressTracker],
947
+ force_download: bool = False
948
+ ) -> str:
949
+ """Download multiple specific files from HuggingFace Hub."""
950
+ # Create repo-specific directory
951
+ repo_local_dir = self._create_repo_directory(local_dir, repo_id)
952
+
953
+ # Create overall progress bar for multiple files
954
+ overall_progress = tqdm(
955
+ total=len(file_names),
956
+ unit='file',
957
+ desc=f"Downloading {len(file_names)} files from {repo_id}",
958
+ position=0,
959
+ leave=True
960
+ )
961
+
962
+ try:
963
+ for file_name in file_names:
964
+ overall_progress.set_postfix_str(f"Current: {os.path.basename(file_name)}")
965
+
966
+ # Check if file already exists
967
+ local_file_path = os.path.join(repo_local_dir, file_name)
968
+ if not force_download and self._check_file_exists_and_valid(local_file_path):
969
+ print(f"[SKIP] File already exists: {file_name}")
970
+ overall_progress.update(1)
971
+ continue
972
+
973
+ # Download each file using hf_hub_download
974
+ self.api.hf_hub_download(
975
+ repo_id=repo_id,
976
+ filename=file_name,
977
+ local_dir=repo_local_dir,
978
+ local_dir_use_symlinks=False,
979
+ token=self.token,
980
+ force_download=force_download
981
+ )
982
+
983
+ overall_progress.update(1)
984
+
985
+ overall_progress.close()
986
+
987
+ # Stop progress tracking
988
+ if progress_tracker:
989
+ progress_tracker.stop_tracking()
990
+
991
+ # Save metadata after successful download
992
+ self._fetch_and_save_metadata(repo_id, repo_local_dir, self._current_is_mmproj, self._current_file_name)
993
+
994
+ return repo_local_dir
995
+
996
+ except HfHubHTTPError as e:
997
+ overall_progress.close()
998
+ error_msg = f"Error downloading files from '{repo_id}': {e}"
999
+ if progress_tracker:
1000
+ progress_tracker.set_error(error_msg)
1001
+ progress_tracker.stop_tracking()
1002
+ raise HfHubHTTPError(error_msg)
1003
+ except Exception as e:
1004
+ overall_progress.close()
1005
+ if progress_tracker:
1006
+ progress_tracker.set_error(str(e))
1007
+ progress_tracker.stop_tracking()
1008
+ raise
1009
+
1010
+ def download(
1011
+ self,
1012
+ repo_id: str,
1013
+ file_name: Optional[Union[str, List[str]]] = None,
1014
+ local_dir: Optional[str] = None,
1015
+ progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
1016
+ show_progress: bool = True,
1017
+ force_download: bool = False,
1018
+ is_mmproj: bool = False
1019
+ ) -> str:
1020
+ """
1021
+ Main download method that handles all download scenarios.
1022
+
1023
+ Args:
1024
+ repo_id: Repository ID to download from
1025
+ file_name: Optional file name(s) to download
1026
+ local_dir: Local directory to save files
1027
+ progress_callback: Callback for progress updates
1028
+ show_progress: Whether to show progress bar
1029
+ force_download: Force re-download even if files exist
1030
+
1031
+ Returns:
1032
+ Path to downloaded file or directory
1033
+ """
1034
+ # Validate and normalize parameters
1035
+ repo_id, file_name = self._validate_and_setup_params(repo_id, file_name)
1036
+
1037
+ # Store parameters as instance variables for use in _fetch_and_save_metadata
1038
+ self._current_is_mmproj = is_mmproj
1039
+ self._current_file_name = file_name
1040
+
1041
+ # Set up local directory
1042
+ local_dir = self._created_dir_if_not_exists(local_dir)
1043
+
1044
+ # Set up progress tracker
1045
+ file_name_for_progress = file_name if isinstance(file_name, str) else None
1046
+ progress_tracker = self._setup_progress_tracker(
1047
+ progress_callback, show_progress, repo_id, file_name_for_progress
1048
+ )
1049
+
1050
+ # Set up HF transfer environment
1051
+ self._setup_hf_transfer_env()
1052
+
1053
+ try:
1054
+ # Validate repository and get info
1055
+ info = self._validate_repository_and_get_info(repo_id, progress_tracker)
1056
+
1057
+ # Start progress tracking
1058
+ if progress_tracker:
1059
+ progress_tracker.start_tracking()
1060
+
1061
+ # Choose download strategy based on file_name
1062
+ if file_name is None:
1063
+ # Download entire repository
1064
+ return self._download_entire_repository(
1065
+ repo_id, local_dir, progress_tracker, force_download
1066
+ )
1067
+ elif isinstance(file_name, str):
1068
+ # Download specific single file
1069
+ self._validate_file_exists_in_repo(file_name, info, repo_id, progress_tracker)
1070
+ return self._download_single_file(
1071
+ repo_id, file_name, local_dir, progress_tracker, force_download
1072
+ )
1073
+ else: # file_name is a list
1074
+ # Download multiple specific files
1075
+ # Validate all files exist
1076
+ for fname in file_name:
1077
+ self._validate_file_exists_in_repo(fname, info, repo_id, progress_tracker)
1078
+
1079
+ return self._download_multiple_files_from_hf(
1080
+ repo_id, file_name, local_dir, progress_tracker, force_download
1081
+ )
1082
+
1083
+ except Exception as e:
1084
+ # Handle any unexpected errors
1085
+ if progress_tracker and progress_tracker.download_status != "error":
1086
+ progress_tracker.set_error(str(e))
1087
+ progress_tracker.stop_tracking()
1088
+ raise
1089
+
1090
+ finally:
1091
+ # Restore original HF transfer setting
1092
+ self._cleanup_hf_transfer_env()
1093
+
1094
+
1095
+ ##########################################################################
1096
+ # Public Download Function #
1097
+ ##########################################################################
1098
+
1099
+
1100
+ def download_from_huggingface(
1101
+ repo_id: str,
1102
+ file_name: Optional[Union[str, List[str]]] = None,
1103
+ local_dir: Optional[str] = None,
1104
+ enable_transfer: bool = True,
1105
+ progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
1106
+ show_progress: bool = True,
1107
+ token: Union[bool, str, None] = None,
1108
+ custom_endpoint: Optional[str] = None,
1109
+ force_download: bool = False,
1110
+ is_mmproj: Optional[bool] = None
1111
+ ) -> str:
1112
+ """
1113
+ Download models or files from HuggingFace Hub or custom mirror endpoints.
1114
+
1115
+ Args:
1116
+ repo_id (str): Required. The repository ID to download from (e.g., "microsoft/DialoGPT-medium")
1117
+ file_name (Union[str, List[str]], optional): Single filename or list of filenames to download.
1118
+ If None, downloads entire repo.
1119
+ local_dir (str, optional): Local directory to save files. If None, uses DEFAULT_MODEL_SAVING_PATH.
1120
+ enable_transfer (bool, optional): Whether to enable HF transfer for faster downloads. Default True.
1121
+ progress_callback (Callable, optional): Callback function to receive progress updates.
1122
+ Function receives a dict with progress information.
1123
+ show_progress (bool, optional): Whether to show a unified progress bar in the terminal. Default True.
1124
+ Only works when progress_callback is provided.
1125
+ token (Union[bool, str, None], optional): A token to be used for the download.
1126
+ - If True, the token is read from the HuggingFace config folder.
1127
+ - If a string, it's used as the authentication token.
1128
+ - If None, uses default behavior.
1129
+ custom_endpoint (str, optional): A custom HuggingFace-compatible endpoint URL.
1130
+ Should be ONLY the base endpoint without any paths.
1131
+ Examples:
1132
+ - "https://hf-mirror.com"
1133
+ - "https://huggingface.co" (default)
1134
+ The endpoint will be used to initialize HfApi for all downloads.
1135
+ force_download (bool, optional): If True, download files even if they already exist locally.
1136
+ Default False (skip existing files).
1137
+ is_mmproj (bool, optional): Whether the file being downloaded is an mmproj file. Only used when
1138
+ file_name is not None. If None, defaults to True if 'mmproj' is in
1139
+ the filename, False otherwise.
1140
+
1141
+ Returns:
1142
+ str: Path to the downloaded file or directory
1143
+
1144
+ Raises:
1145
+ ValueError: If repo_id is invalid or file_name doesn't exist in the repo
1146
+ RepositoryNotFoundError: If the repository doesn't exist
1147
+ HfHubHTTPError: If there's an HTTP error during download
1148
+
1149
+ Progress Callback Data Format:
1150
+ {
1151
+ 'status': str, # 'idle', 'downloading', 'completed', 'error'
1152
+ 'error_message': str, # Only present if status is 'error'
1153
+ 'progress': {
1154
+ 'total_downloaded': int, # Bytes downloaded
1155
+ 'total_size': int, # Total bytes to download
1156
+ 'percentage': float, # Progress percentage (0-100)
1157
+ 'files_active': int, # Number of files currently downloading
1158
+ 'files_total': int, # Total number of files
1159
+ 'known_total': bool # Whether total size is known
1160
+ },
1161
+ 'speed': {
1162
+ 'bytes_per_second': float, # Download speed in bytes/sec
1163
+ 'formatted': str # Human readable speed (e.g., "1.2 MB/s")
1164
+ },
1165
+ 'formatting': {
1166
+ 'downloaded': str, # Human readable downloaded size
1167
+ 'total_size': str # Human readable total size
1168
+ },
1169
+ 'timing': {
1170
+ 'elapsed_seconds': float, # Time since download started
1171
+ 'eta_seconds': float, # Estimated time remaining
1172
+ 'start_time': float # Download start timestamp
1173
+ }
1174
+ }
1175
+ """
1176
+ # Set default value for is_mmproj based on filename if not explicitly provided
1177
+ if is_mmproj is None and file_name is not None:
1178
+ # Check if any filename contains 'mmproj'
1179
+ filenames_to_check = file_name if isinstance(file_name, list) else [file_name]
1180
+ is_mmproj = any('mmproj' in filename.lower() for filename in filenames_to_check)
1181
+ elif is_mmproj is None:
1182
+ # Default to False if no file_name is provided
1183
+ is_mmproj = False
1184
+
1185
+ # Create downloader instance with custom endpoint if provided
1186
+ downloader = HuggingFaceDownloader(
1187
+ endpoint=custom_endpoint,
1188
+ token=token,
1189
+ enable_transfer=enable_transfer
1190
+ )
1191
+
1192
+ # Use the downloader to perform the download
1193
+ return downloader.download(
1194
+ repo_id=repo_id,
1195
+ file_name=file_name,
1196
+ local_dir=local_dir,
1197
+ progress_callback=progress_callback,
1198
+ show_progress=show_progress,
1199
+ force_download=force_download,
1200
+ is_mmproj=is_mmproj
1201
+ )
1202
+
1203
+
1204
+ ##########################################################################
1205
+ # Auto-download decorator #
1206
+ ##########################################################################
1207
+
1208
+
1209
+ def _download_model_if_needed(
1210
+ model_path: str,
1211
+ param_name: str,
1212
+ progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
1213
+ token: Union[bool, str, None] = None,
1214
+ is_mmproj: bool = False
1215
+ ) -> str:
1216
+ """
1217
+ Helper function to download a model from HuggingFace if it doesn't exist locally.
1218
+
1219
+ Args:
1220
+ model_path: The model path that may be local or remote
1221
+ param_name: Name of the parameter (for error messages)
1222
+ progress_callback: Callback function for download progress updates
1223
+ token: HuggingFace authentication token for private repositories
1224
+
1225
+ Returns:
1226
+ str: Local path to the model (either existing or downloaded)
1227
+
1228
+ Raises:
1229
+ RuntimeError: If download fails
1230
+ """
1231
+ # Check if model_path exists locally (file or directory)
1232
+ if os.path.exists(model_path):
1233
+ # Local path exists, return as-is
1234
+ return model_path
1235
+
1236
+ # Model path doesn't exist locally, try to download from HuggingFace
1237
+ try:
1238
+ # Parse model_path to extract repo_id and filename
1239
+ repo_id, file_name = _parse_model_path(model_path)
1240
+
1241
+ # Download the model
1242
+ downloaded_path = download_from_huggingface(
1243
+ repo_id=repo_id,
1244
+ file_name=file_name,
1245
+ local_dir=None, # Use default cache directory
1246
+ enable_transfer=True,
1247
+ progress_callback=progress_callback,
1248
+ show_progress=True,
1249
+ token=token,
1250
+ is_mmproj=is_mmproj
1251
+ )
1252
+
1253
+ return downloaded_path
1254
+
1255
+ except Exception as e:
1256
+ # Only handle download-related errors
1257
+ raise RuntimeError(f"Could not load model from '{param_name}={model_path}': {e}")
1258
+
1259
+
1260
+ def auto_download_model(func: Callable) -> Callable:
1261
+ """
1262
+ Decorator that automatically downloads models from HuggingFace if they don't exist locally.
1263
+
1264
+ This decorator should be applied to __init__ methods that take a name_or_path parameter
1265
+ and optionally an mmproj_path parameter. If these paths don't exist as local files/directories,
1266
+ it will attempt to download them from HuggingFace Hub using the download_from_huggingface function.
1267
+
1268
+ The name_or_path and mmproj_path can be in formats like:
1269
+ - "microsoft/DialoGPT-small" (downloads entire repo)
1270
+ - "microsoft/DialoGPT-small/pytorch_model.bin" (downloads specific file)
1271
+ - "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf" (downloads specific file)
1272
+
1273
+ Optional kwargs that are extracted and passed to download_from_huggingface:
1274
+ - progress_callback: Callback function for download progress updates
1275
+ - token: HuggingFace authentication token for private repositories
1276
+
1277
+ Args:
1278
+ func: The __init__ method to wrap
1279
+
1280
+ Returns:
1281
+ Wrapped function that handles automatic model downloading
1282
+ """
1283
+ @functools.wraps(func)
1284
+ def wrapper(*args, **kwargs):
1285
+ # Extract progress_callback and token from arguments
1286
+ progress_callback = None
1287
+ if 'progress_callback' in kwargs:
1288
+ progress_callback = kwargs.pop('progress_callback') # Remove from kwargs to avoid passing to original func
1289
+
1290
+ token = None
1291
+ if 'token' in kwargs:
1292
+ token = kwargs.pop('token') # Remove from kwargs to avoid passing to original func
1293
+
1294
+ # Handle name_or_path parameter
1295
+ name_or_path = None
1296
+ name_path_index = None
1297
+ is_name_positional = False
1298
+
1299
+ # Find name_or_path in arguments
1300
+ # Assuming name_or_path is the first argument after self
1301
+ if len(args) >= 2:
1302
+ name_or_path = args[1]
1303
+ args_list = list(args)
1304
+ name_path_index = 1
1305
+ is_name_positional = True
1306
+ elif 'name_or_path' in kwargs:
1307
+ name_or_path = kwargs['name_or_path']
1308
+ is_name_positional = False
1309
+
1310
+ # Handle mmproj_path parameter
1311
+ mmproj_path = None
1312
+ if 'mmproj_path' in kwargs:
1313
+ mmproj_path = kwargs['mmproj_path']
1314
+
1315
+ # If neither parameter is found, call original function
1316
+ if name_or_path is None and mmproj_path is None:
1317
+ return func(*args, **kwargs)
1318
+
1319
+ # Download name_or_path if needed
1320
+ if name_or_path is not None:
1321
+ try:
1322
+ downloaded_name_path = _download_model_if_needed(
1323
+ name_or_path, 'name_or_path', progress_callback, token
1324
+ )
1325
+
1326
+ # Replace name_or_path with downloaded path
1327
+ if is_name_positional:
1328
+ if name_path_index is not None:
1329
+ args_list[name_path_index] = downloaded_name_path
1330
+ args = tuple(args_list)
1331
+ else:
1332
+ kwargs['name_or_path'] = downloaded_name_path
1333
+
1334
+ except Exception as e:
1335
+ raise e # Re-raise the error from _download_model_if_needed
1336
+
1337
+ # Download mmproj_path if needed
1338
+ if mmproj_path is not None:
1339
+ try:
1340
+ downloaded_mmproj_path = _download_model_if_needed(
1341
+ mmproj_path, 'mmproj_path', progress_callback, token, is_mmproj=True
1342
+ )
1343
+
1344
+ # Replace mmproj_path with downloaded path
1345
+ kwargs['mmproj_path'] = downloaded_mmproj_path
1346
+
1347
+ except Exception as e:
1348
+ raise e # Re-raise the error from _download_model_if_needed
1349
+
1350
+ # Call original function with updated paths (outside try-catch to let model creation errors bubble up)
1351
+ return func(*args, **kwargs)
1352
+
1353
+ return wrapper