nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1745 @@
1
+ import os
2
+ import shutil
3
+ import json
4
+ from datetime import datetime
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Callable, Dict, Any, List, Union
7
+ import functools
8
+ from enum import Enum
9
+ from tqdm.auto import tqdm
10
+ from huggingface_hub import HfApi
11
+ from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError
12
+
13
+ from .progress_tracker import CustomProgressTqdm, DownloadProgressTracker
14
+ from .manifest_utils import (
15
+ load_download_metadata,
16
+ save_download_metadata,
17
+ save_manifest_with_files_metadata,
18
+ )
19
+
20
+ # Default path for model storage
21
+ DEFAULT_MODEL_SAVING_PATH = "~/.cache/nexa.ai/nexa_sdk/models/"
22
+
23
+
24
+ @dataclass
25
+ class MMProjInfo:
26
+ """Data class for mmproj file information."""
27
+ mmproj_path: Optional[str] = None
28
+ size: int = 0
29
+
30
+ @dataclass
31
+ class DownloadedModel:
32
+ """Data class representing a downloaded model with all its metadata."""
33
+ repo_id: str
34
+ files: List[str]
35
+ folder_type: str # 'owner_repo' or 'direct_repo'
36
+ local_path: str
37
+ size_bytes: int
38
+ file_count: int
39
+ full_repo_download_complete: bool = True # True if no incomplete downloads detected
40
+ pipeline_tag: Optional[str] = None # Pipeline tag from HuggingFace model info
41
+ download_time: Optional[str] = None # ISO format timestamp of download
42
+ avatar_url: Optional[str] = None # Avatar URL for the model author
43
+ mmproj_info: Optional[MMProjInfo] = None # mmproj file information
44
+
45
+ def to_dict(self) -> Dict[str, Any]:
46
+ """Convert to dictionary format for backward compatibility."""
47
+ result = {
48
+ 'repo_id': self.repo_id,
49
+ 'files': self.files,
50
+ 'folder_type': self.folder_type,
51
+ 'local_path': self.local_path,
52
+ 'size_bytes': self.size_bytes,
53
+ 'file_count': self.file_count,
54
+ 'full_repo_download_complete': self.full_repo_download_complete,
55
+ 'pipeline_tag': self.pipeline_tag,
56
+ 'download_time': self.download_time,
57
+ 'avatar_url': self.avatar_url,
58
+ 'mmproj_info': {
59
+ 'mmproj_path': self.mmproj_info.mmproj_path,
60
+ 'size': self.mmproj_info.size
61
+ } if self.mmproj_info else None
62
+ }
63
+ return result
64
+
65
+
66
+ ##########################################################################
67
+ # List downloaded models #
68
+ ##########################################################################
69
+
70
+
71
+ def _check_for_incomplete_downloads(directory_path: str) -> bool:
72
+ """
73
+ Check if there are incomplete downloads in the model directory.
74
+
75
+ This function checks for the presence of .incomplete or .lock files
76
+ in the .cache/huggingface/download directory within the model folder,
77
+ which indicates that the model download has not completed.
78
+
79
+ Args:
80
+ directory_path: Path to the model directory
81
+
82
+ Returns:
83
+ bool: True if download is complete (no incomplete files found),
84
+ False if incomplete downloads are detected
85
+ """
86
+ # Check for .cache/huggingface/download directory
87
+ cache_dir = os.path.join(directory_path, '.cache', 'huggingface', 'download')
88
+
89
+ # If the cache directory doesn't exist, assume download is complete
90
+ if not os.path.exists(cache_dir):
91
+ return True
92
+
93
+ try:
94
+ # Walk through the cache directory to find incomplete or lock files
95
+ for root, dirs, files in os.walk(cache_dir):
96
+ for filename in files:
97
+ # Check for .incomplete or .lock files
98
+ if filename.endswith('.incomplete'):
99
+ return False # Found incomplete download
100
+
101
+ # No incomplete files found
102
+ return True
103
+ except (OSError, IOError):
104
+ # If we can't access the directory, assume download is complete
105
+ return True
106
+
107
+ def _get_directory_size_and_files(directory_path: str) -> tuple[int, List[str]]:
108
+ """Get total size and list of files in a directory."""
109
+ total_size = 0
110
+ files = []
111
+
112
+ try:
113
+ for root, dirs, filenames in os.walk(directory_path):
114
+ for filename in filenames:
115
+ file_path = os.path.join(root, filename)
116
+ try:
117
+ file_size = os.path.getsize(file_path)
118
+ total_size += file_size
119
+ # Store relative path from the directory
120
+ rel_path = os.path.relpath(file_path, directory_path)
121
+ files.append(rel_path)
122
+ except (OSError, IOError):
123
+ # Skip files that can't be accessed
124
+ continue
125
+ except (OSError, IOError):
126
+ # Skip directories that can't be accessed
127
+ pass
128
+
129
+ return total_size, files
130
+
131
+
132
+ def _has_valid_metadata(directory_path: str) -> bool:
133
+ """Check if directory has either nexa.manifest or download_metadata.json (for backward compatibility)."""
134
+ manifest_path = os.path.join(directory_path, 'nexa.manifest')
135
+ old_metadata_path = os.path.join(directory_path, 'download_metadata.json')
136
+ return os.path.exists(manifest_path) or os.path.exists(old_metadata_path)
137
+
138
+
139
+ def _extract_mmproj_info(manifest: Dict[str, Any], local_path: str) -> Optional[MMProjInfo]:
140
+ """
141
+ Extract mmproj information from manifest data.
142
+
143
+ Args:
144
+ manifest: Dictionary containing manifest data
145
+ local_path: Local path to the model directory
146
+
147
+ Returns:
148
+ MMProjInfo object if mmproj file exists, None otherwise
149
+ """
150
+ # Check if manifest has MMProjFile information
151
+ mmproj_file_info = manifest.get('MMProjFile')
152
+ if not mmproj_file_info or not mmproj_file_info.get('Downloaded') or not mmproj_file_info.get('Name'):
153
+ return None
154
+
155
+ mmproj_filename = mmproj_file_info.get('Name', '')
156
+ if not mmproj_filename:
157
+ return None
158
+
159
+ # Construct full path to mmproj file
160
+ mmproj_path = os.path.join(local_path, mmproj_filename)
161
+
162
+ # Get size from manifest, but verify file exists
163
+ mmproj_size = mmproj_file_info.get('Size', 0)
164
+ if os.path.exists(mmproj_path):
165
+ try:
166
+ # Verify size matches actual file size
167
+ actual_size = os.path.getsize(mmproj_path)
168
+ mmproj_size = actual_size # Use actual size if different
169
+ except (OSError, IOError):
170
+ # If we can't get actual size, use size from manifest
171
+ pass
172
+ else:
173
+ # File doesn't exist, don't include mmproj info
174
+ return None
175
+
176
+ return MMProjInfo(mmproj_path=mmproj_path, size=mmproj_size)
177
+
178
+
179
+ def _scan_for_repo_folders(base_path: str) -> List[DownloadedModel]:
180
+ """Scan a directory for repository folders and return model information."""
181
+ models = []
182
+
183
+ try:
184
+ if not os.path.exists(base_path):
185
+ return models
186
+
187
+ for item in os.listdir(base_path):
188
+ item_path = os.path.join(base_path, item)
189
+
190
+ # Skip non-directory items
191
+ if not os.path.isdir(item_path):
192
+ continue
193
+
194
+ # Check if this might be an owner folder by looking for subdirectories
195
+ has_subdirs = False
196
+ direct_files = []
197
+
198
+ try:
199
+ for subitem in os.listdir(item_path):
200
+ subitem_path = os.path.join(item_path, subitem)
201
+ if os.path.isdir(subitem_path):
202
+ has_subdirs = True
203
+ # This looks like owner/repo structure
204
+ # Only include if nexa.manifest or download_metadata.json exists (backward compatibility)
205
+ if _has_valid_metadata(subitem_path):
206
+ size_bytes, files = _get_directory_size_and_files(subitem_path)
207
+ if files: # Only include if there are files
208
+ # Check if the download is complete
209
+ download_complete = _check_for_incomplete_downloads(subitem_path)
210
+ # Load metadata if it exists
211
+ repo_id = f"{item}/{subitem}"
212
+ metadata = load_download_metadata(subitem_path, repo_id)
213
+
214
+ # Extract mmproj information
215
+ mmproj_info = _extract_mmproj_info(metadata, subitem_path)
216
+
217
+ models.append(DownloadedModel(
218
+ repo_id=repo_id,
219
+ files=files,
220
+ folder_type='owner_repo',
221
+ local_path=subitem_path,
222
+ size_bytes=size_bytes,
223
+ file_count=len(files),
224
+ full_repo_download_complete=download_complete,
225
+ pipeline_tag=metadata.get('pipeline_tag'),
226
+ download_time=metadata.get('download_time'),
227
+ avatar_url=metadata.get('avatar_url'),
228
+ mmproj_info=mmproj_info
229
+ ))
230
+ else:
231
+ direct_files.append(subitem)
232
+ except (OSError, IOError):
233
+ # Skip directories that can't be accessed
234
+ continue
235
+
236
+ # Direct repo folder (no owner structure)
237
+ if not has_subdirs and direct_files:
238
+ # Only include if nexa.manifest or download_metadata.json exists (backward compatibility)
239
+ if _has_valid_metadata(item_path):
240
+ size_bytes, files = _get_directory_size_and_files(item_path)
241
+ if files: # Only include if there are files
242
+ # Check if the download is complete
243
+ download_complete = _check_for_incomplete_downloads(item_path)
244
+ # Load metadata if it exists
245
+ repo_id = item
246
+ metadata = load_download_metadata(item_path, repo_id)
247
+
248
+ # Extract mmproj information
249
+ mmproj_info = _extract_mmproj_info(metadata, item_path)
250
+
251
+ models.append(DownloadedModel(
252
+ repo_id=repo_id,
253
+ files=files,
254
+ folder_type='direct_repo',
255
+ local_path=item_path,
256
+ size_bytes=size_bytes,
257
+ file_count=len(files),
258
+ full_repo_download_complete=download_complete,
259
+ pipeline_tag=metadata.get('pipeline_tag'),
260
+ download_time=metadata.get('download_time'),
261
+ avatar_url=metadata.get('avatar_url'),
262
+ mmproj_info=mmproj_info
263
+ ))
264
+
265
+ except (OSError, IOError):
266
+ # Skip if base path can't be accessed
267
+ pass
268
+
269
+ return models
270
+
271
+
272
+ def list_downloaded_models(local_dir: Optional[str] = None) -> List[DownloadedModel]:
273
+ """
274
+ List all downloaded models in the specified directory.
275
+
276
+ This function scans the local directory for downloaded models and returns
277
+ information about each repository including files, size, and folder structure.
278
+
279
+ It handles different folder naming conventions:
280
+ - Owner/repo structure (e.g., "microsoft/DialoGPT-small")
281
+ - Direct repo folders (repos without owner prefix)
282
+
283
+ Args:
284
+ local_dir (str, optional): Directory to scan for downloaded models.
285
+ If None, uses DEFAULT_MODEL_SAVING_PATH.
286
+
287
+ Returns:
288
+ List[DownloadedModel]: List of DownloadedModel objects with attributes:
289
+ - repo_id: str - Repository ID (e.g., "owner/repo")
290
+ - files: List[str] - List of relative file paths in the repository
291
+ - folder_type: str - 'owner_repo' or 'direct_repo'
292
+ - local_path: str - Full path to the model directory
293
+ - size_bytes: int - Total size of all files in bytes
294
+ - file_count: int - Number of files in the repository
295
+ - full_repo_download_complete: bool - True if no incomplete downloads detected,
296
+ False if .incomplete or .lock files exist
297
+ - pipeline_tag: Optional[str] - Pipeline tag from HuggingFace model info
298
+ - download_time: Optional[str] - ISO format timestamp when the model was downloaded
299
+ - avatar_url: Optional[str] - Avatar URL for the model author
300
+ - mmproj_info: Optional[MMProjInfo] - mmproj file information with mmproj_path and size
301
+ """
302
+
303
+ # Set up local directory
304
+ if local_dir is None:
305
+ local_dir = os.path.expanduser(DEFAULT_MODEL_SAVING_PATH)
306
+
307
+ local_dir = os.path.abspath(local_dir)
308
+
309
+ if not os.path.exists(local_dir):
310
+ return []
311
+
312
+ # Scan for repository folders
313
+ models = _scan_for_repo_folders(local_dir)
314
+
315
+ # Sort by repo_id for consistent output
316
+ models.sort(key=lambda x: x.repo_id)
317
+
318
+ return models
319
+
320
+
321
+ ##########################################################################
322
+ # Remove model functions #
323
+ ##########################################################################
324
+
325
+
326
+ def _parse_model_path(model_path: str) -> tuple[str, str | None]:
327
+ """
328
+ Parse model_path to extract repo_id and optional filename.
329
+
330
+ Examples:
331
+ "microsoft/DialoGPT-small" -> ("microsoft/DialoGPT-small", None)
332
+ "microsoft/DialoGPT-small/pytorch_model.bin" -> ("microsoft/DialoGPT-small", "pytorch_model.bin")
333
+ "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf" -> ("Qwen/Qwen3-4B-GGUF", "Qwen3-4B-Q4_K_M.gguf")
334
+
335
+ Args:
336
+ model_path: The model path string
337
+
338
+ Returns:
339
+ Tuple of (repo_id, filename) where filename can be None
340
+ """
341
+ parts = model_path.strip().split('/')
342
+
343
+ if len(parts) < 2:
344
+ # Invalid format, assume it's just a repo name without owner
345
+ return model_path, None
346
+ elif len(parts) == 2:
347
+ # Format: "owner/repo"
348
+ return model_path, None
349
+ else:
350
+ # Format: "owner/repo/file" or "owner/repo/subdir/file"
351
+ repo_id = f"{parts[0]}/{parts[1]}"
352
+ filename = '/'.join(parts[2:])
353
+ return repo_id, filename
354
+
355
+
356
+ def _validate_and_parse_input(model_path: str) -> tuple[str, Optional[str]]:
357
+ """Validate input and parse model path."""
358
+ if not model_path or not isinstance(model_path, str) or not model_path.strip():
359
+ raise ValueError("model_path is required and must be a non-empty string")
360
+
361
+ model_path = model_path.strip()
362
+ return _parse_model_path(model_path)
363
+
364
+
365
+ def _find_target_model(repo_id: str, local_dir: str) -> DownloadedModel:
366
+ """Find and validate the target model exists."""
367
+ downloaded_models = list_downloaded_models(local_dir)
368
+
369
+ for model in downloaded_models:
370
+ if model.repo_id == repo_id:
371
+ return model
372
+
373
+ available_repos = [model.repo_id for model in downloaded_models]
374
+ raise FileNotFoundError(
375
+ f"Repository '{repo_id}' not found in downloaded models. "
376
+ f"Available repositories: {available_repos}"
377
+ )
378
+
379
+
380
+ def _clean_empty_owner_directory(target_model: DownloadedModel) -> None:
381
+ """Remove empty owner directory if applicable."""
382
+ if target_model.folder_type != 'owner_repo':
383
+ return
384
+
385
+ parent_dir = os.path.dirname(target_model.local_path)
386
+ try:
387
+ if os.path.exists(parent_dir) and not os.listdir(parent_dir):
388
+ os.rmdir(parent_dir)
389
+ except OSError:
390
+ pass
391
+
392
+
393
+ def _remove_specific_file(target_model: DownloadedModel, file_name: str, local_dir: str) -> DownloadedModel:
394
+ """Remove a specific file from the repository."""
395
+ # Validate file exists in model
396
+ if file_name not in target_model.files:
397
+ raise FileNotFoundError(
398
+ f"File '{file_name}' not found in repository '{target_model.repo_id}'. "
399
+ f"Available files: {target_model.files[:10]}{'...' if len(target_model.files) > 10 else ''}"
400
+ )
401
+
402
+ # Construct full file path and validate it exists on disk
403
+ file_path = os.path.join(target_model.local_path, file_name)
404
+ if not os.path.exists(file_path):
405
+ raise FileNotFoundError(f"File does not exist on disk: {file_path}")
406
+
407
+ # Get file size before removal
408
+ try:
409
+ file_size = os.path.getsize(file_path)
410
+ except OSError:
411
+ file_size = 0
412
+
413
+ # Check if we should remove entire folder instead (for .gguf files)
414
+ # If removing a .gguf file and no other non-mmproj .gguf files remain, remove entire folder
415
+ if file_name.endswith('.gguf'):
416
+ updated_files = [f for f in target_model.files if f != file_name]
417
+ # Find remaining .gguf files that don't contain "mmproj" in filename
418
+ remaining_non_mmproj_gguf = [
419
+ f for f in updated_files
420
+ if f.endswith('.gguf') and 'mmproj' not in f.lower()
421
+ ]
422
+
423
+ # If no non-mmproj .gguf files remain, remove entire repository
424
+ if len(remaining_non_mmproj_gguf) == 0:
425
+ return _remove_entire_repository(target_model, local_dir)
426
+
427
+ # Remove the file
428
+ try:
429
+ os.remove(file_path)
430
+ except OSError as e:
431
+ raise OSError(f"Failed to remove file '{file_path}': {e}")
432
+
433
+ # Create updated model object
434
+ updated_files = [f for f in target_model.files if f != file_name]
435
+ updated_size = target_model.size_bytes - file_size
436
+ # Re-check download completeness after file removal
437
+ download_complete = _check_for_incomplete_downloads(target_model.local_path)
438
+ updated_model = DownloadedModel(
439
+ repo_id=target_model.repo_id,
440
+ files=updated_files,
441
+ folder_type=target_model.folder_type,
442
+ local_path=target_model.local_path,
443
+ size_bytes=updated_size,
444
+ file_count=len(updated_files),
445
+ full_repo_download_complete=download_complete
446
+ )
447
+
448
+ # If no files left, remove the entire directory
449
+ if len(updated_files) == 0:
450
+ try:
451
+ shutil.rmtree(target_model.local_path)
452
+ _clean_empty_owner_directory(target_model)
453
+ except OSError:
454
+ pass
455
+
456
+ return updated_model
457
+
458
+
459
+ def _remove_entire_repository(target_model: DownloadedModel, local_dir: str) -> DownloadedModel:
460
+ """Remove the entire repository and clean up."""
461
+ # Remove the directory and all its contents
462
+ try:
463
+ shutil.rmtree(target_model.local_path)
464
+ except OSError as e:
465
+ raise OSError(f"Failed to remove directory '{target_model.local_path}': {e}")
466
+
467
+ # Clean up associated resources
468
+ _clean_empty_owner_directory(target_model)
469
+
470
+ return target_model
471
+
472
+
473
+ def remove_model_or_file(
474
+ model_path: str,
475
+ local_dir: Optional[str] = None
476
+ ) -> DownloadedModel:
477
+ """
478
+ Remove a downloaded model or specific file by repository ID or file path.
479
+
480
+ This function supports two modes:
481
+ 1. Remove entire repository: "microsoft/DialoGPT-small"
482
+ 2. Remove specific file: "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf"
483
+
484
+ For entire repository removal, it removes the directory and all files. For specific file removal, it only
485
+ removes that file and updates the repository metadata.
486
+
487
+ Args:
488
+ model_path (str): Required. Either:
489
+ - Repository ID (e.g., "microsoft/DialoGPT-small") - removes entire repo
490
+ - File path (e.g., "Qwen/Qwen3-4B-GGUF/model.gguf") - removes specific file
491
+ local_dir (str, optional): Directory to search for downloaded models.
492
+ If None, uses DEFAULT_MODEL_SAVING_PATH.
493
+
494
+ Returns:
495
+ DownloadedModel: The model object representing what was removed from disk.
496
+ For file removal, returns updated model info after file removal.
497
+
498
+ Raises:
499
+ ValueError: If model_path is invalid (empty or None)
500
+ FileNotFoundError: If the repository or file is not found in downloaded models
501
+ OSError: If there's an error removing files from disk
502
+ """
503
+ # Validate input and parse path
504
+ repo_id, file_name = _validate_and_parse_input(model_path)
505
+
506
+ # Set up local directory
507
+ if local_dir is None:
508
+ local_dir = os.path.expanduser(DEFAULT_MODEL_SAVING_PATH)
509
+
510
+ local_dir = os.path.abspath(local_dir)
511
+
512
+ if not os.path.exists(local_dir):
513
+ raise FileNotFoundError(f"Local directory does not exist: {local_dir}")
514
+
515
+ # Find the target model
516
+ target_model = _find_target_model(repo_id, local_dir)
517
+
518
+ # Delegate to appropriate removal function
519
+ if file_name:
520
+ return _remove_specific_file(target_model, file_name, local_dir)
521
+ else:
522
+ return _remove_entire_repository(target_model, local_dir)
523
+
524
+
525
+ ##########################################################################
526
+ # Check model existence functions #
527
+ ##########################################################################
528
+
529
+
530
+ def check_model_existence(
531
+ model_path: str,
532
+ local_dir: Optional[str] = None
533
+ ) -> bool:
534
+ """
535
+ Check if a downloaded model or specific file exists locally.
536
+
537
+ This function supports two modes:
538
+ 1. Check entire repository: "microsoft/DialoGPT-small"
539
+ 2. Check specific file: "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf"
540
+
541
+ Args:
542
+ model_path (str): Required. Either:
543
+ - Repository ID (e.g., "microsoft/DialoGPT-small") - checks entire repo
544
+ - File path (e.g., "Qwen/Qwen3-4B-GGUF/model.gguf") - checks specific file
545
+ local_dir (str, optional): Directory to search for downloaded models.
546
+ If None, uses DEFAULT_MODEL_SAVING_PATH.
547
+
548
+ Returns:
549
+ bool: True if the requested item exists, False otherwise
550
+
551
+ Raises:
552
+ ValueError: If model_path is invalid (empty or None)
553
+ """
554
+ # Validate input and parse path
555
+ repo_id, file_name = _validate_and_parse_input(model_path)
556
+
557
+ # Set up local directory
558
+ if local_dir is None:
559
+ local_dir = os.path.expanduser(DEFAULT_MODEL_SAVING_PATH)
560
+
561
+ local_dir = os.path.abspath(local_dir)
562
+
563
+ # Return False if local directory doesn't exist
564
+ if not os.path.exists(local_dir):
565
+ return False
566
+
567
+ # Get all downloaded models
568
+ downloaded_models = list_downloaded_models(local_dir)
569
+
570
+ # Find the target model
571
+ for model in downloaded_models:
572
+ if model.repo_id == repo_id:
573
+ # If no specific file requested, repository existence is sufficient
574
+ if file_name is None:
575
+ return True
576
+ else:
577
+ # Check specific file existence
578
+ return file_name in model.files
579
+
580
+ return False
581
+
582
+
583
+ ##########################################################################
584
+ # HuggingFace Downloader Class #
585
+ ##########################################################################
586
+
587
+
588
+ class HuggingFaceDownloader:
589
+ """Class to handle downloads from HuggingFace Hub with unified API usage."""
590
+
591
+ def __init__(
592
+ self,
593
+ endpoint: Optional[str] = None,
594
+ token: Union[bool, str, None] = None,
595
+ enable_transfer: bool = True
596
+ ):
597
+ """
598
+ Initialize the downloader with HuggingFace API.
599
+
600
+ Args:
601
+ endpoint: Custom endpoint URL (e.g., "https://hf-mirror.com").
602
+ If None, uses default HuggingFace Hub.
603
+ token: Authentication token for private repositories.
604
+ enable_transfer: Whether to enable HF transfer for faster downloads.
605
+ """
606
+ # Always create an HfApi instance - either with custom endpoint or default
607
+ self.token = token if isinstance(token, str) else False # False means disable authentication
608
+ self.api = HfApi(endpoint=endpoint, token=self.token) if endpoint else HfApi(token=self.token)
609
+ self.enable_transfer = enable_transfer
610
+ self.original_hf_transfer = None
611
+ self.endpoint = endpoint # Store endpoint for avatar fetching
612
+ self._model_info_cache: Dict[str, Any] = {} # Cache for model_info results
613
+
614
+ def _create_repo_directory(self, local_dir: str, repo_id: str) -> str:
615
+ """Create a directory structure for the repository following HF convention."""
616
+ if '/' in repo_id:
617
+ # Standard format: owner/repo
618
+ owner, repo = repo_id.split('/', 1)
619
+ repo_dir = os.path.join(local_dir, owner, repo)
620
+ else:
621
+ # Direct repo name without owner
622
+ repo_dir = os.path.join(local_dir, repo_id)
623
+
624
+ os.makedirs(repo_dir, exist_ok=True)
625
+ return repo_dir
626
+
627
+ def _created_dir_if_not_exists(self, local_dir: Optional[str]) -> str:
628
+ """Create directory if it doesn't exist and return the expanded path."""
629
+ if local_dir is None:
630
+ local_dir = DEFAULT_MODEL_SAVING_PATH
631
+
632
+ local_dir = os.path.expanduser(local_dir)
633
+ os.makedirs(local_dir, exist_ok=True)
634
+ return local_dir
635
+
636
+ def _get_model_info_cached(self, repo_id: str, files_metadata: bool = False):
637
+ """Get model info with caching to avoid rate limiting.
638
+
639
+ Args:
640
+ repo_id: Repository ID
641
+ files_metadata: Whether to include files metadata
642
+
643
+ Returns:
644
+ Model info object from HuggingFace API
645
+ """
646
+ # Create cache key based on repo_id and files_metadata flag
647
+ cache_key = f"{repo_id}:files={files_metadata}"
648
+
649
+ # Return cached result if available
650
+ if cache_key in self._model_info_cache:
651
+ return self._model_info_cache[cache_key]
652
+
653
+ # Fetch from API and cache the result
654
+ try:
655
+ info = self.api.model_info(repo_id, files_metadata=files_metadata, token=self.token)
656
+ self._model_info_cache[cache_key] = info
657
+ return info
658
+ except Exception:
659
+ # Don't cache errors, re-raise
660
+ raise
661
+
662
+ def _get_repo_info_for_progress(
663
+ self,
664
+ repo_id: str,
665
+ file_name: Optional[Union[str, List[str]]] = None
666
+ ) -> tuple[int, int]:
667
+ """Get total repository size and file count for progress tracking."""
668
+ try:
669
+ info = self._get_model_info_cached(repo_id, files_metadata=True)
670
+
671
+ total_size = 0
672
+ file_count = 0
673
+
674
+ if info.siblings:
675
+ for sibling in info.siblings:
676
+ # Handle different file_name types
677
+ if file_name is not None:
678
+ if isinstance(file_name, str):
679
+ # Single file - only count if it matches
680
+ if sibling.rfilename != file_name:
681
+ continue
682
+ elif isinstance(file_name, list):
683
+ # Multiple files - only count if in the list
684
+ if sibling.rfilename not in file_name:
685
+ continue
686
+
687
+ # For all matching files (or all files if file_name is None)
688
+ if hasattr(sibling, 'size') and sibling.size is not None:
689
+ total_size += sibling.size
690
+ file_count += 1
691
+ else:
692
+ # Count files without size info
693
+ file_count += 1
694
+
695
+ return total_size, file_count if file_count > 0 else 1
696
+ except Exception:
697
+ # If we can't get info, return defaults
698
+ return 0, 1
699
+
700
+ def _validate_and_setup_params(
701
+ self,
702
+ repo_id: str,
703
+ file_name: Optional[Union[str, List[str]]]
704
+ ) -> tuple[str, Optional[Union[str, List[str]]]]:
705
+ """Validate and normalize input parameters."""
706
+ if not repo_id:
707
+ raise ValueError("repo_id is required")
708
+
709
+ repo_id = repo_id.strip()
710
+
711
+ # Handle file_name parameter
712
+ if file_name is not None:
713
+ if isinstance(file_name, str):
714
+ file_name = file_name.strip()
715
+ if not file_name:
716
+ file_name = None
717
+ elif isinstance(file_name, list):
718
+ # Filter out empty strings and strip whitespace
719
+ file_name = [f.strip() for f in file_name if f and f.strip()]
720
+ if not file_name:
721
+ file_name = None
722
+ else:
723
+ raise ValueError("file_name must be a string, list of strings, or None")
724
+
725
+ return repo_id, file_name
726
+
727
+ def _setup_progress_tracker(
728
+ self,
729
+ progress_callback: Optional[Callable[[Dict[str, Any]], None]],
730
+ show_progress: bool,
731
+ repo_id: str,
732
+ file_name: Optional[Union[str, List[str]]]
733
+ ) -> Optional[DownloadProgressTracker]:
734
+ """Initialize progress tracker if callback is provided."""
735
+ if not progress_callback:
736
+ return None
737
+
738
+ progress_tracker = DownloadProgressTracker(progress_callback, show_progress)
739
+ # Get repo info for progress tracking - now handles all cases
740
+ total_size, file_count = self._get_repo_info_for_progress(repo_id, file_name)
741
+ progress_tracker.set_repo_info(total_size, file_count)
742
+ return progress_tracker
743
+
744
+ def _setup_hf_transfer_env(self) -> None:
745
+ """Set up HF transfer environment."""
746
+ self.original_hf_transfer = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER")
747
+ if self.enable_transfer:
748
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
749
+
750
+ def _cleanup_hf_transfer_env(self) -> None:
751
+ """Restore original HF transfer environment."""
752
+ if self.original_hf_transfer is not None:
753
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = self.original_hf_transfer
754
+ else:
755
+ os.environ.pop("HF_HUB_ENABLE_HF_TRANSFER", None)
756
+
757
+ def _validate_repository_and_get_info(
758
+ self,
759
+ repo_id: str,
760
+ progress_tracker: Optional[DownloadProgressTracker]
761
+ ):
762
+ """Validate repository exists and get info."""
763
+ try:
764
+ info = self._get_model_info_cached(repo_id, files_metadata=False)
765
+ return info
766
+ except RepositoryNotFoundError:
767
+ error_msg = f"Repository '{repo_id}' not found. Please check the repository ID."
768
+ if progress_tracker:
769
+ progress_tracker.set_error(error_msg)
770
+ raise RepositoryNotFoundError(error_msg)
771
+ except HfHubHTTPError as e:
772
+ if e.response.status_code == 404:
773
+ error_msg = f"Repository '{repo_id}' not found. Please check the repository ID."
774
+ if progress_tracker:
775
+ progress_tracker.set_error(error_msg)
776
+ raise RepositoryNotFoundError(error_msg)
777
+ else:
778
+ error_msg = f"HTTP error while accessing repository '{repo_id}': {e}"
779
+ if progress_tracker:
780
+ progress_tracker.set_error(error_msg)
781
+ raise HfHubHTTPError(error_msg)
782
+
783
+ def _validate_file_exists_in_repo(
784
+ self,
785
+ file_name: str,
786
+ info,
787
+ repo_id: str,
788
+ progress_tracker: Optional[DownloadProgressTracker]
789
+ ) -> None:
790
+ """Validate that the file exists in the repository."""
791
+ file_exists = False
792
+ if info.siblings:
793
+ for sibling in info.siblings:
794
+ if sibling.rfilename == file_name:
795
+ file_exists = True
796
+ break
797
+
798
+ if not file_exists:
799
+ available_files = [sibling.rfilename for sibling in info.siblings] if info.siblings else []
800
+ error_msg = (
801
+ f"File '{file_name}' not found in repository '{repo_id}'. "
802
+ f"Available files: {available_files[:10]}{'...' if len(available_files) > 10 else ''}"
803
+ )
804
+ if progress_tracker:
805
+ progress_tracker.set_error(error_msg)
806
+ progress_tracker.stop_tracking()
807
+ raise ValueError(error_msg)
808
+
809
+ def _check_file_exists_and_valid(
810
+ self,
811
+ file_path: str,
812
+ expected_size: Optional[int] = None
813
+ ) -> bool:
814
+ """Check if a file exists and is valid (non-empty, correct size if known)."""
815
+ if not os.path.exists(file_path):
816
+ return False
817
+
818
+ # Check file is not empty
819
+ try:
820
+ file_size = os.path.getsize(file_path)
821
+ if file_size == 0:
822
+ return False
823
+ except (OSError, IOError):
824
+ return False
825
+
826
+ # If we have expected size, check it matches
827
+ if expected_size is not None and file_size != expected_size:
828
+ return False
829
+
830
+ # If no expected size, just check that file is not empty
831
+ return os.path.getsize(file_path) > 0
832
+
833
+ def _extract_model_file_type_from_tags(self, repo_id: str) -> Optional[str]:
834
+ """Extract model file type from repo tags with priority: NPU > MLX > GGUF."""
835
+ try:
836
+ info = self._get_model_info_cached(repo_id, files_metadata=False)
837
+ if hasattr(info, 'tags') and info.tags:
838
+ # Convert tags to lowercase for case-insensitive matching
839
+ tags_lower = [tag.lower() for tag in info.tags]
840
+
841
+ # Check with priority: NPU > MLX > GGUF
842
+ if 'npu' in tags_lower:
843
+ return 'npu'
844
+ elif 'mlx' in tags_lower:
845
+ return 'mlx'
846
+ elif 'gguf' in tags_lower:
847
+ return 'gguf'
848
+ except Exception:
849
+ pass
850
+ return None
851
+
852
+ def _load_downloaded_manifest(self, local_dir: str) -> Dict[str, Any]:
853
+ """Load nexa.manifest from the downloaded repository if it exists."""
854
+ manifest_path = os.path.join(local_dir, 'nexa.manifest')
855
+ if os.path.exists(manifest_path):
856
+ try:
857
+ with open(manifest_path, 'r', encoding='utf-8') as f:
858
+ return json.load(f)
859
+ except (json.JSONDecodeError, IOError):
860
+ pass
861
+ return {}
862
+
863
+ def _download_manifest_if_needed(self, repo_id: str, local_dir: str) -> bool:
864
+ """
865
+ Download nexa.manifest from the repository if it doesn't exist locally.
866
+
867
+ Args:
868
+ repo_id: Repository ID
869
+ local_dir: Local directory where the manifest should be saved
870
+
871
+ Returns:
872
+ bool: True if manifest was downloaded or already exists, False if not found in repo
873
+ """
874
+ manifest_path = os.path.join(local_dir, 'nexa.manifest')
875
+
876
+ # Check if manifest already exists locally
877
+ if os.path.exists(manifest_path):
878
+ return True
879
+
880
+ # Try to download nexa.manifest from the repository
881
+ try:
882
+ print(f"[INFO] Attempting to download nexa.manifest from {repo_id}...")
883
+ self.api.hf_hub_download(
884
+ repo_id=repo_id,
885
+ filename='nexa.manifest',
886
+ local_dir=local_dir,
887
+ local_dir_use_symlinks=False,
888
+ token=self.token,
889
+ force_download=False
890
+ )
891
+ print(f"[OK] Successfully downloaded nexa.manifest from {repo_id}")
892
+ return True
893
+ except Exception as e:
894
+ # Manifest doesn't exist in repo or other error - this is fine, we'll create it
895
+ print(f"[INFO] nexa.manifest not found in {repo_id}, will create locally")
896
+ return False
897
+
898
+ def _fetch_and_save_metadata(self, repo_id: str, local_dir: str, is_mmproj: bool = False, file_name: Optional[Union[str, List[str]]] = None, **kwargs) -> None:
899
+ """Fetch model info and save metadata after successful download."""
900
+ # Initialize metadata with defaults to ensure manifest is always created
901
+ old_metadata = {
902
+ 'pipeline_tag': "text-generation", # Default to text-generation pipeline-tag
903
+ 'download_time': datetime.now().isoformat(),
904
+ 'avatar_url': None
905
+ }
906
+
907
+ # Try to fetch additional metadata, but don't let failures prevent manifest creation
908
+ try:
909
+ # Fetch model info to get pipeline_tag (using cache)
910
+ info = self._get_model_info_cached(repo_id, files_metadata=False)
911
+ if hasattr(info, 'pipeline_tag') and info.pipeline_tag:
912
+ old_metadata['pipeline_tag'] = info.pipeline_tag
913
+ except Exception as e:
914
+ # Log the error but continue with manifest creation
915
+ print(f"Warning: Could not fetch model info for {repo_id}: {e}")
916
+
917
+ # Use input avater url if provided
918
+ old_metadata['avatar_url'] = kwargs.get('avatar_url')
919
+
920
+ # Extract model file type from tags
921
+ model_file_type = self._extract_model_file_type_from_tags(repo_id)
922
+ if model_file_type:
923
+ old_metadata['model_file_type'] = model_file_type
924
+
925
+ # Load existing nexa.manifest from downloaded repo (if exists)
926
+ downloaded_manifest = self._load_downloaded_manifest(local_dir)
927
+ if downloaded_manifest:
928
+ old_metadata['downloaded_manifest'] = downloaded_manifest
929
+
930
+
931
+ # CRITICAL: Always create the manifest file, regardless of metadata fetch failures
932
+ try:
933
+ save_manifest_with_files_metadata(repo_id, local_dir, old_metadata, is_mmproj, file_name, **kwargs)
934
+ print(f"[OK] Successfully created nexa.manifest for {repo_id}")
935
+ except Exception as e:
936
+ # This is critical - if manifest creation fails, we should know about it
937
+ print(f"ERROR: Failed to create nexa.manifest for {repo_id}: {e}")
938
+ # Try a fallback approach - create a minimal manifest
939
+ try:
940
+ minimal_manifest = {
941
+ "Name": repo_id,
942
+ "ModelName": kwargs.get('model_name', ''),
943
+ "ModelType": kwargs.get('model_type', 'other'),
944
+ "PluginId": kwargs.get('plugin_id', 'unknown'),
945
+ "DeviceId": kwargs.get('device_id', ''),
946
+ "MinSDKVersion": kwargs.get('min_sdk_version', ''),
947
+ "ModelFile": {},
948
+ "MMProjFile": {"Name": "", "Downloaded": False, "Size": 0},
949
+ "TokenizerFile": {"Name": "", "Downloaded": False, "Size": 0},
950
+ "ExtraFiles": None,
951
+ "pipeline_tag": old_metadata.get('pipeline_tag'),
952
+ "download_time": old_metadata.get('download_time'),
953
+ "avatar_url": old_metadata.get('avatar_url')
954
+ }
955
+ save_download_metadata(local_dir, minimal_manifest)
956
+ print(f"[OK] Created minimal nexa.manifest for {repo_id} as fallback")
957
+ except Exception as fallback_error:
958
+ print(f"CRITICAL ERROR: Could not create even minimal manifest for {repo_id}: {fallback_error}")
959
+
960
+ def _download_single_file(
961
+ self,
962
+ repo_id: str,
963
+ file_name: str,
964
+ local_dir: str,
965
+ progress_tracker: Optional[DownloadProgressTracker],
966
+ force_download: bool = False,
967
+ **kwargs
968
+ ) -> str:
969
+ """Download a single file from the repository using HuggingFace Hub API."""
970
+ # Create repo-specific directory for the single file
971
+ file_local_dir = self._create_repo_directory(local_dir, repo_id)
972
+
973
+ # Check if file already exists
974
+ local_file_path = os.path.join(file_local_dir, file_name)
975
+ if not force_download and self._check_file_exists_and_valid(local_file_path):
976
+ print(f"[SKIP] File already exists: {file_name}")
977
+ # Stop progress tracking
978
+ if progress_tracker:
979
+ progress_tracker.stop_tracking()
980
+ return local_file_path
981
+
982
+ try:
983
+ # Note: hf_hub_download doesn't support tqdm_class parameter
984
+ # Progress tracking works through the global tqdm monkey patching
985
+ downloaded_path = self.api.hf_hub_download(
986
+ repo_id=repo_id,
987
+ filename=file_name,
988
+ local_dir=file_local_dir,
989
+ local_dir_use_symlinks=False,
990
+ token=self.token,
991
+ force_download=force_download
992
+ )
993
+
994
+ # Stop progress tracking
995
+ if progress_tracker:
996
+ progress_tracker.stop_tracking()
997
+
998
+ # Download nexa.manifest from repo if it doesn't exist locally
999
+ self._download_manifest_if_needed(repo_id, file_local_dir)
1000
+
1001
+ # Save metadata after successful download
1002
+ self._fetch_and_save_metadata(repo_id, file_local_dir, self._current_is_mmproj, self._current_file_name, **kwargs)
1003
+
1004
+ return downloaded_path
1005
+
1006
+ except HfHubHTTPError as e:
1007
+ error_msg = f"Error downloading file '{file_name}': {e}"
1008
+ if progress_tracker:
1009
+ progress_tracker.set_error(error_msg)
1010
+ progress_tracker.stop_tracking()
1011
+ if e.response.status_code == 404:
1012
+ raise ValueError(f"File '{file_name}' not found in repository '{repo_id}'")
1013
+ else:
1014
+ raise HfHubHTTPError(error_msg)
1015
+
1016
+ def _download_entire_repository(
1017
+ self,
1018
+ repo_id: str,
1019
+ local_dir: str,
1020
+ progress_tracker: Optional[DownloadProgressTracker],
1021
+ force_download: bool = False,
1022
+ **kwargs
1023
+ ) -> str:
1024
+ """Download the entire repository."""
1025
+ # Create a subdirectory for this specific repo
1026
+ repo_local_dir = self._create_repo_directory(local_dir, repo_id)
1027
+
1028
+ try:
1029
+ download_kwargs = {
1030
+ 'repo_id': repo_id,
1031
+ 'local_dir': repo_local_dir,
1032
+ 'local_dir_use_symlinks': False,
1033
+ 'token': self.token,
1034
+ 'force_download': force_download
1035
+ }
1036
+
1037
+ # Add tqdm_class if progress tracking is enabled
1038
+ if progress_tracker:
1039
+ download_kwargs['tqdm_class'] = CustomProgressTqdm
1040
+
1041
+ downloaded_path = self.api.snapshot_download(**download_kwargs)
1042
+
1043
+ # Stop progress tracking
1044
+ if progress_tracker:
1045
+ progress_tracker.stop_tracking()
1046
+
1047
+ # Save metadata after successful download
1048
+ self._fetch_and_save_metadata(repo_id, repo_local_dir, self._current_is_mmproj, self._current_file_name, **kwargs)
1049
+
1050
+ return downloaded_path
1051
+
1052
+ except HfHubHTTPError as e:
1053
+ error_msg = f"Error downloading repository '{repo_id}': {e}"
1054
+ if progress_tracker:
1055
+ progress_tracker.set_error(error_msg)
1056
+ progress_tracker.stop_tracking()
1057
+ raise HfHubHTTPError(error_msg)
1058
+
1059
+ def _download_multiple_files_from_hf(
1060
+ self,
1061
+ repo_id: str,
1062
+ file_names: List[str],
1063
+ local_dir: str,
1064
+ progress_tracker: Optional[DownloadProgressTracker],
1065
+ force_download: bool = False,
1066
+ **kwargs
1067
+ ) -> str:
1068
+ """Download multiple specific files from HuggingFace Hub."""
1069
+ # Create repo-specific directory
1070
+ repo_local_dir = self._create_repo_directory(local_dir, repo_id)
1071
+
1072
+ # Create overall progress bar for multiple files
1073
+ overall_progress = tqdm(
1074
+ total=len(file_names),
1075
+ unit='file',
1076
+ desc=f"Downloading {len(file_names)} files from {repo_id}",
1077
+ position=0,
1078
+ leave=True
1079
+ )
1080
+
1081
+ try:
1082
+ for file_name in file_names:
1083
+ overall_progress.set_postfix_str(f"Current: {os.path.basename(file_name)}")
1084
+
1085
+ # Check if file already exists
1086
+ local_file_path = os.path.join(repo_local_dir, file_name)
1087
+ if not force_download and self._check_file_exists_and_valid(local_file_path):
1088
+ print(f"[SKIP] File already exists: {file_name}")
1089
+ overall_progress.update(1)
1090
+ continue
1091
+
1092
+ # Download each file using hf_hub_download
1093
+ self.api.hf_hub_download(
1094
+ repo_id=repo_id,
1095
+ filename=file_name,
1096
+ local_dir=repo_local_dir,
1097
+ local_dir_use_symlinks=False,
1098
+ token=self.token,
1099
+ force_download=force_download
1100
+ )
1101
+
1102
+ overall_progress.update(1)
1103
+
1104
+ overall_progress.close()
1105
+
1106
+ # Stop progress tracking
1107
+ if progress_tracker:
1108
+ progress_tracker.stop_tracking()
1109
+
1110
+ # Download nexa.manifest from repo if it doesn't exist locally
1111
+ self._download_manifest_if_needed(repo_id, repo_local_dir)
1112
+
1113
+ # Save metadata after successful download
1114
+ self._fetch_and_save_metadata(repo_id, repo_local_dir, self._current_is_mmproj, self._current_file_name, **kwargs)
1115
+
1116
+ return repo_local_dir
1117
+
1118
+ except HfHubHTTPError as e:
1119
+ overall_progress.close()
1120
+ error_msg = f"Error downloading files from '{repo_id}': {e}"
1121
+ if progress_tracker:
1122
+ progress_tracker.set_error(error_msg)
1123
+ progress_tracker.stop_tracking()
1124
+ raise HfHubHTTPError(error_msg)
1125
+ except Exception as e:
1126
+ overall_progress.close()
1127
+ if progress_tracker:
1128
+ progress_tracker.set_error(str(e))
1129
+ progress_tracker.stop_tracking()
1130
+ raise
1131
+
1132
+ def download(
1133
+ self,
1134
+ repo_id: str,
1135
+ file_name: Optional[Union[str, List[str]]] = None,
1136
+ local_dir: Optional[str] = None,
1137
+ progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
1138
+ show_progress: bool = True,
1139
+ force_download: bool = False,
1140
+ is_mmproj: bool = False,
1141
+ **kwargs
1142
+ ) -> str:
1143
+ """
1144
+ Main download method that handles all download scenarios.
1145
+
1146
+ Args:
1147
+ repo_id: Repository ID to download from
1148
+ file_name: Optional file name(s) to download
1149
+ local_dir: Local directory to save files
1150
+ progress_callback: Callback for progress updates
1151
+ show_progress: Whether to show progress bar
1152
+ force_download: Force re-download even if files exist
1153
+
1154
+ Returns:
1155
+ Path to downloaded file or directory
1156
+ """
1157
+ # Validate and normalize parameters
1158
+ repo_id, file_name = self._validate_and_setup_params(repo_id, file_name)
1159
+
1160
+ # Store parameters as instance variables for use in _fetch_and_save_metadata
1161
+ self._current_is_mmproj = is_mmproj
1162
+ self._current_file_name = file_name
1163
+
1164
+ # Set up local directory
1165
+ local_dir = self._created_dir_if_not_exists(local_dir)
1166
+
1167
+ # Set up progress tracker
1168
+ file_name_for_progress = file_name if isinstance(file_name, str) else None
1169
+ progress_tracker = self._setup_progress_tracker(
1170
+ progress_callback, show_progress, repo_id, file_name_for_progress
1171
+ )
1172
+
1173
+ # Set up HF transfer environment
1174
+ self._setup_hf_transfer_env()
1175
+
1176
+ try:
1177
+ # Validate repository and get info
1178
+ info = self._validate_repository_and_get_info(repo_id, progress_tracker)
1179
+
1180
+ # Start progress tracking
1181
+ if progress_tracker:
1182
+ progress_tracker.start_tracking()
1183
+
1184
+ # Choose download strategy based on file_name
1185
+ if file_name is None:
1186
+ # Download entire repository
1187
+ return self._download_entire_repository(
1188
+ repo_id, local_dir, progress_tracker, force_download, **kwargs
1189
+ )
1190
+ elif isinstance(file_name, str):
1191
+ # Download specific single file
1192
+ self._validate_file_exists_in_repo(file_name, info, repo_id, progress_tracker)
1193
+ return self._download_single_file(
1194
+ repo_id, file_name, local_dir, progress_tracker, force_download, **kwargs
1195
+ )
1196
+ else: # file_name is a list
1197
+ # Download multiple specific files
1198
+ # Validate all files exist
1199
+ for fname in file_name:
1200
+ self._validate_file_exists_in_repo(fname, info, repo_id, progress_tracker)
1201
+
1202
+ return self._download_multiple_files_from_hf(
1203
+ repo_id, file_name, local_dir, progress_tracker, force_download, **kwargs
1204
+ )
1205
+
1206
+ except Exception as e:
1207
+ # Handle any unexpected errors
1208
+ if progress_tracker and progress_tracker.download_status != "error":
1209
+ progress_tracker.set_error(str(e))
1210
+ progress_tracker.stop_tracking()
1211
+ raise
1212
+
1213
+ finally:
1214
+ # Restore original HF transfer setting
1215
+ self._cleanup_hf_transfer_env()
1216
+
1217
+
1218
+ ##########################################################################
1219
+ # Public Download Function #
1220
+ ##########################################################################
1221
+
1222
+
1223
+ def download_from_huggingface(
1224
+ repo_id: str,
1225
+ file_name: Optional[Union[str, List[str]]] = None,
1226
+ local_dir: Optional[str] = None,
1227
+ enable_transfer: bool = True,
1228
+ progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
1229
+ show_progress: bool = True,
1230
+ token: Union[bool, str, None] = None,
1231
+ custom_endpoint: Optional[str] = None,
1232
+ force_download: bool = False,
1233
+ is_mmproj: Optional[bool] = None,
1234
+ **kwargs
1235
+ ) -> str:
1236
+ """
1237
+ Download models or files from HuggingFace Hub or custom mirror endpoints.
1238
+
1239
+ Args:
1240
+ repo_id (str): Required. The repository ID to download from (e.g., "microsoft/DialoGPT-medium")
1241
+ file_name (Union[str, List[str]], optional): Single filename or list of filenames to download.
1242
+ If None, downloads entire repo.
1243
+ local_dir (str, optional): Local directory to save files. If None, uses DEFAULT_MODEL_SAVING_PATH.
1244
+ enable_transfer (bool, optional): Whether to enable HF transfer for faster downloads. Default True.
1245
+ progress_callback (Callable, optional): Callback function to receive progress updates.
1246
+ Function receives a dict with progress information.
1247
+ show_progress (bool, optional): Whether to show a unified progress bar in the terminal. Default True.
1248
+ Only works when progress_callback is provided.
1249
+ token (Union[bool, str, None], optional): A token to be used for the download.
1250
+ - If True, the token is read from the HuggingFace config folder.
1251
+ - If a string, it's used as the authentication token.
1252
+ - If None, uses default behavior.
1253
+ custom_endpoint (str, optional): A custom HuggingFace-compatible endpoint URL.
1254
+ Should be ONLY the base endpoint without any paths.
1255
+ Examples:
1256
+ - "https://hf-mirror.com"
1257
+ - "https://huggingface.co" (default)
1258
+ The endpoint will be used to initialize HfApi for all downloads.
1259
+ force_download (bool, optional): If True, download files even if they already exist locally.
1260
+ Default False (skip existing files).
1261
+ is_mmproj (bool, optional): Whether the file being downloaded is an mmproj file. Only used when
1262
+ file_name is not None. If None, defaults to True if 'mmproj' is in
1263
+ the filename, False otherwise.
1264
+ **kwargs: Additional parameters including:
1265
+ - plugin_id (str): Override PluginId in nexa.manifest (highest priority)
1266
+ - model_name (str): Override ModelName in nexa.manifest (highest priority)
1267
+ - model_type (str): Override ModelType in nexa.manifest (highest priority)
1268
+ - device_id (str): Set DeviceId in nexa.manifest (highest priority)
1269
+ - min_sdk_version (str): Set MinSDKVersion in nexa.manifest (highest priority)
1270
+
1271
+ Returns:
1272
+ str: Path to the downloaded file or directory
1273
+
1274
+ Raises:
1275
+ ValueError: If repo_id is invalid or file_name doesn't exist in the repo
1276
+ RepositoryNotFoundError: If the repository doesn't exist
1277
+ HfHubHTTPError: If there's an HTTP error during download
1278
+
1279
+ Progress Callback Data Format:
1280
+ {
1281
+ 'status': str, # 'idle', 'downloading', 'completed', 'error'
1282
+ 'error_message': str, # Only present if status is 'error'
1283
+ 'progress': {
1284
+ 'total_downloaded': int, # Bytes downloaded
1285
+ 'total_size': int, # Total bytes to download
1286
+ 'percentage': float, # Progress percentage (0-100)
1287
+ 'files_active': int, # Number of files currently downloading
1288
+ 'files_total': int, # Total number of files
1289
+ 'known_total': bool # Whether total size is known
1290
+ },
1291
+ 'speed': {
1292
+ 'bytes_per_second': float, # Download speed in bytes/sec
1293
+ 'formatted': str # Human readable speed (e.g., "1.2 MB/s")
1294
+ },
1295
+ 'formatting': {
1296
+ 'downloaded': str, # Human readable downloaded size
1297
+ 'total_size': str # Human readable total size
1298
+ },
1299
+ 'timing': {
1300
+ 'elapsed_seconds': float, # Time since download started
1301
+ 'eta_seconds': float, # Estimated time remaining
1302
+ 'start_time': float # Download start timestamp
1303
+ }
1304
+ }
1305
+ """
1306
+ # Set default value for is_mmproj based on filename if not explicitly provided
1307
+ if is_mmproj is None and file_name is not None:
1308
+ # Check if any filename contains 'mmproj'
1309
+ filenames_to_check = file_name if isinstance(file_name, list) else [file_name]
1310
+ is_mmproj = any('mmproj' in filename.lower() for filename in filenames_to_check)
1311
+ elif is_mmproj is None:
1312
+ # Default to False if no file_name is provided
1313
+ is_mmproj = False
1314
+
1315
+ # Create downloader instance with custom endpoint if provided
1316
+ downloader = HuggingFaceDownloader(
1317
+ endpoint=custom_endpoint,
1318
+ token=token,
1319
+ enable_transfer=enable_transfer
1320
+ )
1321
+
1322
+ # Use the downloader to perform the download
1323
+ return downloader.download(
1324
+ repo_id=repo_id,
1325
+ file_name=file_name,
1326
+ local_dir=local_dir,
1327
+ progress_callback=progress_callback,
1328
+ show_progress=show_progress,
1329
+ force_download=force_download,
1330
+ is_mmproj=is_mmproj,
1331
+ **kwargs
1332
+ )
1333
+
1334
+
1335
+ ##########################################################################
1336
+ # Auto-download decorator #
1337
+ ##########################################################################
1338
+
1339
+
1340
+ def _find_q4_0_file_in_repo(repo_id: str, token: Union[bool, str, None] = None) -> Optional[str]:
1341
+ """
1342
+ Find a GGUF file with 'q4_0' in its name (case-insensitive) in a HuggingFace repository.
1343
+
1344
+ Args:
1345
+ repo_id: The repository ID (e.g., "owner/repo")
1346
+ token: HuggingFace authentication token
1347
+
1348
+ Returns:
1349
+ The filename with 'q4_0' if found, None otherwise.
1350
+ Prioritizes files with exact 'Q4_0' match in name.
1351
+ """
1352
+ try:
1353
+ # Create HfApi instance to list files
1354
+ api = HfApi(token=token if isinstance(token, str) else None)
1355
+ info = api.model_info(repo_id, files_metadata=True, token=token if isinstance(token, str) else None)
1356
+
1357
+ if not info.siblings:
1358
+ return None
1359
+
1360
+ # Look for GGUF files with 'q4_0' in the name
1361
+ q4_0_files = []
1362
+ for sibling in info.siblings:
1363
+ filename = sibling.rfilename
1364
+ filename_lower = filename.lower()
1365
+
1366
+ # Check if it's a GGUF file with 'q4_0' in the name
1367
+ if filename_lower.endswith('.gguf') and 'q4_0' in filename_lower:
1368
+ q4_0_files.append(filename)
1369
+
1370
+ if not q4_0_files:
1371
+ return None
1372
+
1373
+ # If multiple files found, prioritize exact Q4_0 match, then return first
1374
+ for filename in q4_0_files:
1375
+ if 'Q4_0' in filename: # Exact case match
1376
+ return filename
1377
+
1378
+ # Return the first q4_0 file found
1379
+ return q4_0_files[0]
1380
+
1381
+ except Exception:
1382
+ # If we can't list files (e.g., network error, auth error), return None
1383
+ return None
1384
+
1385
+
1386
+ def _find_best_mmproj_file_in_repo(repo_id: str, token: Union[bool, str, None] = None) -> Optional[str]:
1387
+ """
1388
+ Find the best mmproj GGUF file in a HuggingFace repository.
1389
+
1390
+ Selection criteria:
1391
+ 1. Must end with .gguf
1392
+ 2. Must contain "mmproj" in filename (case-insensitive)
1393
+ 3. Priority order: f16 > bf16 > f32
1394
+
1395
+ Args:
1396
+ repo_id: The repository ID (e.g., "owner/repo")
1397
+ token: HuggingFace authentication token
1398
+
1399
+ Returns:
1400
+ The best mmproj filename if found, None otherwise.
1401
+ """
1402
+ try:
1403
+ # Create HfApi instance to list files
1404
+ api = HfApi(token=token if isinstance(token, str) else None)
1405
+ info = api.model_info(repo_id, files_metadata=True, token=token if isinstance(token, str) else None)
1406
+
1407
+ if not info.siblings:
1408
+ return None
1409
+
1410
+ # Look for mmproj GGUF files
1411
+ mmproj_files = []
1412
+ for sibling in info.siblings:
1413
+ filename = sibling.rfilename
1414
+ filename_lower = filename.lower()
1415
+
1416
+ # Check if it's a GGUF file with 'mmproj' in the name
1417
+ if filename_lower.endswith('.gguf') and 'mmproj' in filename_lower:
1418
+ mmproj_files.append(filename)
1419
+
1420
+ if not mmproj_files:
1421
+ return None
1422
+
1423
+ # Prioritize by precision: f16 > bf16 > f32
1424
+ # First try to find f16
1425
+ for filename in mmproj_files:
1426
+ filename_lower = filename.lower()
1427
+ if 'f16' in filename_lower and 'bf16' not in filename_lower:
1428
+ return filename
1429
+
1430
+ # Then try bf16
1431
+ for filename in mmproj_files:
1432
+ if 'bf16' in filename.lower():
1433
+ return filename
1434
+
1435
+ # Then try f32
1436
+ for filename in mmproj_files:
1437
+ if 'f32' in filename.lower():
1438
+ return filename
1439
+
1440
+ # If no specific precision found, return the first mmproj file
1441
+ return mmproj_files[0]
1442
+
1443
+ except Exception:
1444
+ # If we can't list files (e.g., network error, auth error), return None
1445
+ return None
1446
+
1447
+
1448
+ def _check_if_model_needs_mmproj(repo_id: str, token: Union[bool, str, None] = None) -> bool:
1449
+ """
1450
+ Check if a model has 'image-text-to-text' tag, indicating it needs an mmproj file.
1451
+
1452
+ Args:
1453
+ repo_id: The repository ID (e.g., "owner/repo")
1454
+ token: HuggingFace authentication token
1455
+
1456
+ Returns:
1457
+ True if the model has 'image-text-to-text' tag, False otherwise.
1458
+ """
1459
+ try:
1460
+ # Create HfApi instance
1461
+ api = HfApi(token=token if isinstance(token, str) else None)
1462
+ info = api.model_info(repo_id, files_metadata=False, token=token if isinstance(token, str) else None)
1463
+
1464
+ # Check if 'image-text-to-text' is in tags
1465
+ if hasattr(info, 'tags') and info.tags:
1466
+ return 'image-text-to-text' in info.tags
1467
+
1468
+ return False
1469
+
1470
+ except Exception:
1471
+ # If we can't fetch model info, return False
1472
+ return False
1473
+
1474
+
1475
+ def _download_model_if_needed(
1476
+ model_path: str,
1477
+ param_name: str,
1478
+ progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
1479
+ token: Union[bool, str, None] = None,
1480
+ is_mmproj: bool = False,
1481
+ **kwargs
1482
+ ) -> tuple[str, Optional[str], Optional[str]]:
1483
+ """
1484
+ Helper function to download a model from HuggingFace if it doesn't exist locally.
1485
+
1486
+ Args:
1487
+ model_path: The model path that may be local or remote
1488
+ param_name: Name of the parameter (for error messages)
1489
+ progress_callback: Callback function for download progress updates
1490
+ token: HuggingFace authentication token for private repositories
1491
+
1492
+ Returns:
1493
+ tuple[str, Optional[str], Optional[str]]: Tuple of (local_path, model_name, plugin_id)
1494
+ - local_path: Local path to the model (either existing or downloaded)
1495
+ - model_name: ModelName from nexa.manifest if available, None otherwise
1496
+ - plugin_id: PluginId from nexa.manifest if available, None otherwise
1497
+
1498
+ Raises:
1499
+ RuntimeError: If download fails
1500
+ """
1501
+ # Helper function to extract model info from manifest
1502
+ def _extract_info_from_manifest(path: str) -> tuple[Optional[str], Optional[str], Optional[dict]]:
1503
+ """Extract ModelName, PluginId, and full manifest from nexa.manifest if it exists."""
1504
+ # If path is a file, check its parent directory for manifest
1505
+ if os.path.isfile(path):
1506
+ manifest_dir = os.path.dirname(path)
1507
+ else:
1508
+ manifest_dir = path
1509
+
1510
+ manifest_path = os.path.join(manifest_dir, 'nexa.manifest')
1511
+ if not os.path.exists(manifest_path):
1512
+ return None, None, None
1513
+
1514
+ try:
1515
+ with open(manifest_path, 'r', encoding='utf-8') as f:
1516
+ manifest = json.load(f)
1517
+ return manifest.get('ModelName'), manifest.get('PluginId'), manifest
1518
+ except (json.JSONDecodeError, IOError):
1519
+ return None, None, None
1520
+
1521
+ # Helper function to get a model file path from manifest
1522
+ # Note: Tnis is for NPU only, because when downloading, it is a directory; when passing local path to inference, it needs to be a file.
1523
+ def _get_model_file_from_manifest(manifest: dict, base_dir: str) -> Optional[str]:
1524
+ """Extract a model file path from manifest's ModelFile section."""
1525
+ if not manifest or 'ModelFile' not in manifest:
1526
+ return None
1527
+
1528
+ model_files = manifest['ModelFile']
1529
+ # Find the first valid model file (skip N/A entries and metadata files)
1530
+ for key, file_info in model_files.items():
1531
+ if key == 'N/A':
1532
+ continue
1533
+ if isinstance(file_info, dict) and 'Name' in file_info:
1534
+ file_name = file_info['Name']
1535
+ # Skip common non-model files
1536
+ if file_name and not file_name.startswith('.') and file_name.endswith('.nexa'):
1537
+ file_path = os.path.join(base_dir, file_name)
1538
+ if os.path.exists(file_path):
1539
+ return file_path
1540
+
1541
+ # If no .nexa files found, try ExtraFiles for .nexa files
1542
+ if 'ExtraFiles' in manifest:
1543
+ for file_info in manifest['ExtraFiles']:
1544
+ if isinstance(file_info, dict) and 'Name' in file_info:
1545
+ file_name = file_info['Name']
1546
+ if file_name and file_name.endswith('.nexa') and not file_name.startswith('.cache'):
1547
+ file_path = os.path.join(base_dir, file_name)
1548
+ if os.path.exists(file_path):
1549
+ return file_path
1550
+
1551
+ return None
1552
+
1553
+ # Check if model_path exists locally (file or directory)
1554
+ if os.path.exists(model_path):
1555
+ # Local path exists, try to extract model info
1556
+ model_name, plugin_id, manifest = _extract_info_from_manifest(model_path)
1557
+
1558
+ # If PluginId is "npu" and path is a directory, convert to file path
1559
+ if plugin_id == "npu" and os.path.isdir(model_path):
1560
+ model_file_path = _get_model_file_from_manifest(manifest, model_path)
1561
+ if model_file_path:
1562
+ model_path = model_file_path
1563
+
1564
+ return model_path, model_name, plugin_id
1565
+
1566
+ # Model path doesn't exist locally, try to download from HuggingFace
1567
+ try:
1568
+ # Parse model_path to extract repo_id and filename
1569
+ repo_id, file_name = _parse_model_path(model_path)
1570
+
1571
+ # Smart file selection for llama_cpp/cpu_gpu models
1572
+ # If no specific file is requested (owner/repo format) and not an mmproj file,
1573
+ # try to find and download a Q4_0 GGUF file instead of the entire repo
1574
+ if file_name is None and not is_mmproj:
1575
+ # Count slashes to determine if it's owner/repo format (1 slash)
1576
+ slash_count = model_path.count('/')
1577
+ if slash_count == 1:
1578
+ # Try to find a Q4_0 file in the repo
1579
+ q4_0_file = _find_q4_0_file_in_repo(repo_id, token)
1580
+ if q4_0_file:
1581
+ # Found a Q4_0 file, download it instead of the whole repo
1582
+ file_name = q4_0_file
1583
+ print(f"Found Q4_0 model file: {q4_0_file}, downloading it instead of the full repository.")
1584
+ # If no Q4_0 file found, file_name remains None and will download entire repo
1585
+
1586
+ # Download the model
1587
+ downloaded_path = download_from_huggingface(
1588
+ repo_id=repo_id,
1589
+ file_name=file_name,
1590
+ local_dir=None, # Use default cache directory
1591
+ enable_transfer=True,
1592
+ progress_callback=progress_callback,
1593
+ show_progress=True,
1594
+ token=token,
1595
+ is_mmproj=is_mmproj,
1596
+ **kwargs
1597
+ )
1598
+
1599
+ # Extract model info from the downloaded manifest
1600
+ model_name, plugin_id, manifest = _extract_info_from_manifest(downloaded_path)
1601
+
1602
+ # If PluginId is "npu" and path is a directory, convert to file path
1603
+ if plugin_id == "npu" and os.path.isdir(downloaded_path):
1604
+ model_file_path = _get_model_file_from_manifest(manifest, downloaded_path)
1605
+ if model_file_path:
1606
+ downloaded_path = model_file_path
1607
+
1608
+ return downloaded_path, model_name, plugin_id
1609
+
1610
+ except Exception as e:
1611
+ # Only handle download-related errors
1612
+ raise RuntimeError(f"Could not load model from '{param_name}={model_path}': {e}")
1613
+
1614
+
1615
+ def auto_download_model(func: Callable) -> Callable:
1616
+ """
1617
+ Decorator that automatically downloads models from HuggingFace if they don't exist locally.
1618
+
1619
+ This decorator should be applied to __init__ methods that take a name_or_path parameter
1620
+ and optionally an mmproj_path parameter. If these paths don't exist as local files/directories,
1621
+ it will attempt to download them from HuggingFace Hub using the download_from_huggingface function.
1622
+
1623
+ The name_or_path and mmproj_path can be in formats like:
1624
+ - "microsoft/DialoGPT-small" (downloads entire repo)
1625
+ - "microsoft/DialoGPT-small/pytorch_model.bin" (downloads specific file)
1626
+ - "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf" (downloads specific file)
1627
+
1628
+ Optional kwargs that are extracted and passed to download_from_huggingface:
1629
+ - progress_callback: Callback function for download progress updates
1630
+ - token: HuggingFace authentication token for private repositories
1631
+
1632
+ Args:
1633
+ func: The __init__ method to wrap
1634
+
1635
+ Returns:
1636
+ Wrapped function that handles automatic model downloading
1637
+ """
1638
+ @functools.wraps(func)
1639
+ def wrapper(*args, **kwargs):
1640
+ # Extract progress_callback and token from arguments
1641
+ progress_callback = None
1642
+ if 'progress_callback' in kwargs:
1643
+ progress_callback = kwargs.pop('progress_callback') # Remove from kwargs to avoid passing to original func
1644
+
1645
+ token = None
1646
+ if 'token' in kwargs:
1647
+ token = kwargs.pop('token') # Remove from kwargs to avoid passing to original func
1648
+ if token is None or token == '':
1649
+ token = os.getenv('HF_TOKEN') or os.getenv('NEXA_HFTOKEN')
1650
+
1651
+ # Handle name_or_path parameter
1652
+ name_or_path = None
1653
+ name_path_index = None
1654
+ is_name_positional = False
1655
+
1656
+ # Find name_or_path in arguments
1657
+ # Assuming name_or_path is the first argument after self
1658
+ if len(args) >= 2:
1659
+ name_or_path = args[1]
1660
+ args_list = list(args)
1661
+ name_path_index = 1
1662
+ is_name_positional = True
1663
+ elif 'name_or_path' in kwargs:
1664
+ name_or_path = kwargs['name_or_path']
1665
+ is_name_positional = False
1666
+
1667
+ # Handle mmproj_path parameter
1668
+ mmproj_path = None
1669
+ if 'mmproj_path' in kwargs:
1670
+ mmproj_path = kwargs['mmproj_path']
1671
+
1672
+ # If neither parameter is found, call original function
1673
+ if name_or_path is None and mmproj_path is None:
1674
+ return func(*args, **kwargs)
1675
+
1676
+ # Download name_or_path if needed
1677
+ if name_or_path is not None:
1678
+ try:
1679
+ downloaded_name_path, model_name, plugin_id = _download_model_if_needed(
1680
+ name_or_path, 'name_or_path', progress_callback, token, **kwargs
1681
+ )
1682
+
1683
+ # Replace name_or_path with downloaded path
1684
+ if is_name_positional:
1685
+ if name_path_index is not None:
1686
+ args_list[name_path_index] = downloaded_name_path
1687
+ args = tuple(args_list)
1688
+ else:
1689
+ kwargs['name_or_path'] = downloaded_name_path
1690
+
1691
+ # Add model_name to kwargs if it exists and not already set
1692
+ if model_name is not None and 'model_name' not in kwargs:
1693
+ kwargs['model_name'] = model_name
1694
+
1695
+ # Auto-download mmproj if needed for image-text-to-text models
1696
+ # Only when mmproj_path was not provided and plugin is llama_cpp/cpu_gpu
1697
+ if mmproj_path is None and plugin_id in ['llama_cpp', 'cpu_gpu']:
1698
+ # Parse the original name_or_path to get repo_id
1699
+ repo_id, _ = _parse_model_path(name_or_path)
1700
+
1701
+ # Check if the model needs mmproj (has image-text-to-text tag)
1702
+ if _check_if_model_needs_mmproj(repo_id, token):
1703
+ # Find the best mmproj file in the repo
1704
+ best_mmproj = _find_best_mmproj_file_in_repo(repo_id, token)
1705
+
1706
+ if best_mmproj:
1707
+ print(f"Detected image-text-to-text model. Auto-downloading mmproj file: {best_mmproj}")
1708
+ try:
1709
+ # Download the mmproj file
1710
+ downloaded_mmproj_path, _, _ = _download_model_if_needed(
1711
+ f"{repo_id}/{best_mmproj}",
1712
+ 'mmproj_path',
1713
+ progress_callback,
1714
+ token,
1715
+ is_mmproj=True,
1716
+ **kwargs
1717
+ )
1718
+ # Set the mmproj_path in kwargs
1719
+ kwargs['mmproj_path'] = downloaded_mmproj_path
1720
+ # Update mmproj_path variable to prevent processing again below
1721
+ mmproj_path = downloaded_mmproj_path
1722
+ except Exception as e:
1723
+ print(f"Warning: Failed to auto-download mmproj file: {e}")
1724
+ # Continue without mmproj - let the model initialization handle the error if needed
1725
+
1726
+ except Exception as e:
1727
+ raise e # Re-raise the error from _download_model_if_needed
1728
+
1729
+ # Download mmproj_path if needed (user explicitly provided it)
1730
+ if mmproj_path is not None:
1731
+ try:
1732
+ downloaded_mmproj_path, _, _ = _download_model_if_needed(
1733
+ mmproj_path, 'mmproj_path', progress_callback, token, is_mmproj=True, **kwargs
1734
+ )
1735
+
1736
+ # Replace mmproj_path with downloaded path
1737
+ kwargs['mmproj_path'] = downloaded_mmproj_path
1738
+
1739
+ except Exception as e:
1740
+ raise e # Re-raise the error from _download_model_if_needed
1741
+
1742
+ # Call original function with updated paths (outside try-catch to let model creation errors bubble up)
1743
+ return func(*args, **kwargs)
1744
+
1745
+ return wrapper