nexaai 1.0.21__cp311-cp311-win_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (154) hide show
  1. nexaai/__init__.py +95 -0
  2. nexaai/_stub.cp311-win_arm64.pyd +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +92 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +6 -0
  10. nexaai/binds/asr_bind.cp311-win_arm64.pyd +0 -0
  11. nexaai/binds/common_bind.cp311-win_arm64.pyd +0 -0
  12. nexaai/binds/cpu_gpu/ggml-base.dll +0 -0
  13. nexaai/binds/cpu_gpu/ggml-cpu.dll +0 -0
  14. nexaai/binds/cpu_gpu/ggml-opencl.dll +0 -0
  15. nexaai/binds/cpu_gpu/ggml.dll +0 -0
  16. nexaai/binds/cpu_gpu/libomp140.aarch64.dll +0 -0
  17. nexaai/binds/cpu_gpu/mtmd.dll +0 -0
  18. nexaai/binds/cpu_gpu/nexa_cpu_gpu.dll +0 -0
  19. nexaai/binds/cpu_gpu/nexa_plugin.dll +0 -0
  20. nexaai/binds/embedder_bind.cp311-win_arm64.pyd +0 -0
  21. nexaai/binds/libcrypto-3-arm64.dll +0 -0
  22. nexaai/binds/libssl-3-arm64.dll +0 -0
  23. nexaai/binds/llm_bind.cp311-win_arm64.pyd +0 -0
  24. nexaai/binds/nexa_bridge.dll +0 -0
  25. nexaai/binds/npu/FLAC.dll +0 -0
  26. nexaai/binds/npu/convnext-sdk.dll +0 -0
  27. nexaai/binds/npu/embed-gemma-sdk.dll +0 -0
  28. nexaai/binds/npu/fftw3.dll +0 -0
  29. nexaai/binds/npu/fftw3f.dll +0 -0
  30. nexaai/binds/npu/ggml-base.dll +0 -0
  31. nexaai/binds/npu/ggml-cpu.dll +0 -0
  32. nexaai/binds/npu/ggml-opencl.dll +0 -0
  33. nexaai/binds/npu/ggml.dll +0 -0
  34. nexaai/binds/npu/granite-nano-sdk.dll +0 -0
  35. nexaai/binds/npu/granite4-sdk.dll +0 -0
  36. nexaai/binds/npu/htp-files/Genie.dll +0 -0
  37. nexaai/binds/npu/htp-files/PlatformValidatorShared.dll +0 -0
  38. nexaai/binds/npu/htp-files/QnnChrometraceProfilingReader.dll +0 -0
  39. nexaai/binds/npu/htp-files/QnnCpu.dll +0 -0
  40. nexaai/binds/npu/htp-files/QnnCpuNetRunExtensions.dll +0 -0
  41. nexaai/binds/npu/htp-files/QnnDsp.dll +0 -0
  42. nexaai/binds/npu/htp-files/QnnDspNetRunExtensions.dll +0 -0
  43. nexaai/binds/npu/htp-files/QnnDspV66CalculatorStub.dll +0 -0
  44. nexaai/binds/npu/htp-files/QnnDspV66Stub.dll +0 -0
  45. nexaai/binds/npu/htp-files/QnnGenAiTransformer.dll +0 -0
  46. nexaai/binds/npu/htp-files/QnnGenAiTransformerCpuOpPkg.dll +0 -0
  47. nexaai/binds/npu/htp-files/QnnGenAiTransformerModel.dll +0 -0
  48. nexaai/binds/npu/htp-files/QnnGpu.dll +0 -0
  49. nexaai/binds/npu/htp-files/QnnGpuNetRunExtensions.dll +0 -0
  50. nexaai/binds/npu/htp-files/QnnGpuProfilingReader.dll +0 -0
  51. nexaai/binds/npu/htp-files/QnnHtp.dll +0 -0
  52. nexaai/binds/npu/htp-files/QnnHtpNetRunExtensions.dll +0 -0
  53. nexaai/binds/npu/htp-files/QnnHtpOptraceProfilingReader.dll +0 -0
  54. nexaai/binds/npu/htp-files/QnnHtpPrepare.dll +0 -0
  55. nexaai/binds/npu/htp-files/QnnHtpProfilingReader.dll +0 -0
  56. nexaai/binds/npu/htp-files/QnnHtpV68CalculatorStub.dll +0 -0
  57. nexaai/binds/npu/htp-files/QnnHtpV68Stub.dll +0 -0
  58. nexaai/binds/npu/htp-files/QnnHtpV73CalculatorStub.dll +0 -0
  59. nexaai/binds/npu/htp-files/QnnHtpV73Stub.dll +0 -0
  60. nexaai/binds/npu/htp-files/QnnIr.dll +0 -0
  61. nexaai/binds/npu/htp-files/QnnJsonProfilingReader.dll +0 -0
  62. nexaai/binds/npu/htp-files/QnnModelDlc.dll +0 -0
  63. nexaai/binds/npu/htp-files/QnnSaver.dll +0 -0
  64. nexaai/binds/npu/htp-files/QnnSystem.dll +0 -0
  65. nexaai/binds/npu/htp-files/SNPE.dll +0 -0
  66. nexaai/binds/npu/htp-files/SnpeDspV66Stub.dll +0 -0
  67. nexaai/binds/npu/htp-files/SnpeHtpPrepare.dll +0 -0
  68. nexaai/binds/npu/htp-files/SnpeHtpV68Stub.dll +0 -0
  69. nexaai/binds/npu/htp-files/SnpeHtpV73Stub.dll +0 -0
  70. nexaai/binds/npu/htp-files/calculator.dll +0 -0
  71. nexaai/binds/npu/htp-files/calculator_htp.dll +0 -0
  72. nexaai/binds/npu/htp-files/libCalculator_skel.so +0 -0
  73. nexaai/binds/npu/htp-files/libQnnHtpV73.so +0 -0
  74. nexaai/binds/npu/htp-files/libQnnHtpV73QemuDriver.so +0 -0
  75. nexaai/binds/npu/htp-files/libQnnHtpV73Skel.so +0 -0
  76. nexaai/binds/npu/htp-files/libQnnSaver.so +0 -0
  77. nexaai/binds/npu/htp-files/libQnnSystem.so +0 -0
  78. nexaai/binds/npu/htp-files/libSnpeHtpV73Skel.so +0 -0
  79. nexaai/binds/npu/htp-files/libqnnhtpv73.cat +0 -0
  80. nexaai/binds/npu/htp-files/libsnpehtpv73.cat +0 -0
  81. nexaai/binds/npu/jina-rerank-sdk.dll +0 -0
  82. nexaai/binds/npu/libcrypto-3-arm64.dll +0 -0
  83. nexaai/binds/npu/libmp3lame.DLL +0 -0
  84. nexaai/binds/npu/libomp140.aarch64.dll +0 -0
  85. nexaai/binds/npu/libssl-3-arm64.dll +0 -0
  86. nexaai/binds/npu/liquid-sdk.dll +0 -0
  87. nexaai/binds/npu/llama3-3b-sdk.dll +0 -0
  88. nexaai/binds/npu/mpg123.dll +0 -0
  89. nexaai/binds/npu/nexa-mm-process.dll +0 -0
  90. nexaai/binds/npu/nexa-sampling.dll +0 -0
  91. nexaai/binds/npu/nexa_plugin.dll +0 -0
  92. nexaai/binds/npu/nexaproc.dll +0 -0
  93. nexaai/binds/npu/ogg.dll +0 -0
  94. nexaai/binds/npu/omni-neural-sdk.dll +0 -0
  95. nexaai/binds/npu/openblas.dll +0 -0
  96. nexaai/binds/npu/opus.dll +0 -0
  97. nexaai/binds/npu/paddle-ocr-proc-lib.dll +0 -0
  98. nexaai/binds/npu/paddleocr-sdk.dll +0 -0
  99. nexaai/binds/npu/parakeet-sdk.dll +0 -0
  100. nexaai/binds/npu/phi3-5-sdk.dll +0 -0
  101. nexaai/binds/npu/phi4-sdk.dll +0 -0
  102. nexaai/binds/npu/pyannote-sdk.dll +0 -0
  103. nexaai/binds/npu/qwen3-4b-sdk.dll +0 -0
  104. nexaai/binds/npu/qwen3vl-sdk.dll +0 -0
  105. nexaai/binds/npu/qwen3vl-vision.dll +0 -0
  106. nexaai/binds/npu/rtaudio.dll +0 -0
  107. nexaai/binds/npu/vorbis.dll +0 -0
  108. nexaai/binds/npu/vorbisenc.dll +0 -0
  109. nexaai/binds/npu/yolov12-sdk.dll +0 -0
  110. nexaai/binds/npu/zlib1.dll +0 -0
  111. nexaai/binds/rerank_bind.cp311-win_arm64.pyd +0 -0
  112. nexaai/binds/vlm_bind.cp311-win_arm64.pyd +0 -0
  113. nexaai/common.py +105 -0
  114. nexaai/cv.py +93 -0
  115. nexaai/cv_impl/__init__.py +0 -0
  116. nexaai/cv_impl/mlx_cv_impl.py +89 -0
  117. nexaai/cv_impl/pybind_cv_impl.py +32 -0
  118. nexaai/embedder.py +73 -0
  119. nexaai/embedder_impl/__init__.py +0 -0
  120. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  121. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  122. nexaai/image_gen.py +141 -0
  123. nexaai/image_gen_impl/__init__.py +0 -0
  124. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  125. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  126. nexaai/llm.py +98 -0
  127. nexaai/llm_impl/__init__.py +0 -0
  128. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  129. nexaai/llm_impl/pybind_llm_impl.py +220 -0
  130. nexaai/log.py +92 -0
  131. nexaai/rerank.py +57 -0
  132. nexaai/rerank_impl/__init__.py +0 -0
  133. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  134. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  135. nexaai/runtime.py +68 -0
  136. nexaai/runtime_error.py +24 -0
  137. nexaai/tts.py +75 -0
  138. nexaai/tts_impl/__init__.py +0 -0
  139. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  140. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  141. nexaai/utils/decode.py +18 -0
  142. nexaai/utils/manifest_utils.py +531 -0
  143. nexaai/utils/model_manager.py +1562 -0
  144. nexaai/utils/model_types.py +49 -0
  145. nexaai/utils/progress_tracker.py +385 -0
  146. nexaai/utils/quantization_utils.py +245 -0
  147. nexaai/vlm.py +130 -0
  148. nexaai/vlm_impl/__init__.py +0 -0
  149. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  150. nexaai/vlm_impl/pybind_vlm_impl.py +256 -0
  151. nexaai-1.0.21.dist-info/METADATA +31 -0
  152. nexaai-1.0.21.dist-info/RECORD +154 -0
  153. nexaai-1.0.21.dist-info/WHEEL +5 -0
  154. nexaai-1.0.21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1562 @@
1
+ import os
2
+ import shutil
3
+ import json
4
+ from datetime import datetime
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Callable, Dict, Any, List, Union
7
+ import functools
8
+ from enum import Enum
9
+ from tqdm.auto import tqdm
10
+ from huggingface_hub import HfApi
11
+ from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError
12
+
13
+ from .progress_tracker import CustomProgressTqdm, DownloadProgressTracker
14
+ from .manifest_utils import (
15
+ load_download_metadata,
16
+ save_download_metadata,
17
+ save_manifest_with_files_metadata,
18
+ )
19
+
20
+ # Default path for model storage
21
+ DEFAULT_MODEL_SAVING_PATH = "~/.cache/nexa.ai/nexa_sdk/models/"
22
+
23
+
24
+ @dataclass
25
+ class MMProjInfo:
26
+ """Data class for mmproj file information."""
27
+ mmproj_path: Optional[str] = None
28
+ size: int = 0
29
+
30
+ @dataclass
31
+ class DownloadedModel:
32
+ """Data class representing a downloaded model with all its metadata."""
33
+ repo_id: str
34
+ files: List[str]
35
+ folder_type: str # 'owner_repo' or 'direct_repo'
36
+ local_path: str
37
+ size_bytes: int
38
+ file_count: int
39
+ full_repo_download_complete: bool = True # True if no incomplete downloads detected
40
+ pipeline_tag: Optional[str] = None # Pipeline tag from HuggingFace model info
41
+ download_time: Optional[str] = None # ISO format timestamp of download
42
+ avatar_url: Optional[str] = None # Avatar URL for the model author
43
+ mmproj_info: Optional[MMProjInfo] = None # mmproj file information
44
+
45
+ def to_dict(self) -> Dict[str, Any]:
46
+ """Convert to dictionary format for backward compatibility."""
47
+ result = {
48
+ 'repo_id': self.repo_id,
49
+ 'files': self.files,
50
+ 'folder_type': self.folder_type,
51
+ 'local_path': self.local_path,
52
+ 'size_bytes': self.size_bytes,
53
+ 'file_count': self.file_count,
54
+ 'full_repo_download_complete': self.full_repo_download_complete,
55
+ 'pipeline_tag': self.pipeline_tag,
56
+ 'download_time': self.download_time,
57
+ 'avatar_url': self.avatar_url,
58
+ 'mmproj_info': {
59
+ 'mmproj_path': self.mmproj_info.mmproj_path,
60
+ 'size': self.mmproj_info.size
61
+ } if self.mmproj_info else None
62
+ }
63
+ return result
64
+
65
+
66
+ ##########################################################################
67
+ # List downloaded models #
68
+ ##########################################################################
69
+
70
+
71
+ def _check_for_incomplete_downloads(directory_path: str) -> bool:
72
+ """
73
+ Check if there are incomplete downloads in the model directory.
74
+
75
+ This function checks for the presence of .incomplete or .lock files
76
+ in the .cache/huggingface/download directory within the model folder,
77
+ which indicates that the model download has not completed.
78
+
79
+ Args:
80
+ directory_path: Path to the model directory
81
+
82
+ Returns:
83
+ bool: True if download is complete (no incomplete files found),
84
+ False if incomplete downloads are detected
85
+ """
86
+ # Check for .cache/huggingface/download directory
87
+ cache_dir = os.path.join(directory_path, '.cache', 'huggingface', 'download')
88
+
89
+ # If the cache directory doesn't exist, assume download is complete
90
+ if not os.path.exists(cache_dir):
91
+ return True
92
+
93
+ try:
94
+ # Walk through the cache directory to find incomplete or lock files
95
+ for root, dirs, files in os.walk(cache_dir):
96
+ for filename in files:
97
+ # Check for .incomplete or .lock files
98
+ if filename.endswith('.incomplete'):
99
+ return False # Found incomplete download
100
+
101
+ # No incomplete files found
102
+ return True
103
+ except (OSError, IOError):
104
+ # If we can't access the directory, assume download is complete
105
+ return True
106
+
107
+ def _get_directory_size_and_files(directory_path: str) -> tuple[int, List[str]]:
108
+ """Get total size and list of files in a directory."""
109
+ total_size = 0
110
+ files = []
111
+
112
+ try:
113
+ for root, dirs, filenames in os.walk(directory_path):
114
+ for filename in filenames:
115
+ file_path = os.path.join(root, filename)
116
+ try:
117
+ file_size = os.path.getsize(file_path)
118
+ total_size += file_size
119
+ # Store relative path from the directory
120
+ rel_path = os.path.relpath(file_path, directory_path)
121
+ files.append(rel_path)
122
+ except (OSError, IOError):
123
+ # Skip files that can't be accessed
124
+ continue
125
+ except (OSError, IOError):
126
+ # Skip directories that can't be accessed
127
+ pass
128
+
129
+ return total_size, files
130
+
131
+
132
+ def _has_valid_metadata(directory_path: str) -> bool:
133
+ """Check if directory has either nexa.manifest or download_metadata.json (for backward compatibility)."""
134
+ manifest_path = os.path.join(directory_path, 'nexa.manifest')
135
+ old_metadata_path = os.path.join(directory_path, 'download_metadata.json')
136
+ return os.path.exists(manifest_path) or os.path.exists(old_metadata_path)
137
+
138
+
139
+ def _extract_mmproj_info(manifest: Dict[str, Any], local_path: str) -> Optional[MMProjInfo]:
140
+ """
141
+ Extract mmproj information from manifest data.
142
+
143
+ Args:
144
+ manifest: Dictionary containing manifest data
145
+ local_path: Local path to the model directory
146
+
147
+ Returns:
148
+ MMProjInfo object if mmproj file exists, None otherwise
149
+ """
150
+ # Check if manifest has MMProjFile information
151
+ mmproj_file_info = manifest.get('MMProjFile')
152
+ if not mmproj_file_info or not mmproj_file_info.get('Downloaded') or not mmproj_file_info.get('Name'):
153
+ return None
154
+
155
+ mmproj_filename = mmproj_file_info.get('Name', '')
156
+ if not mmproj_filename:
157
+ return None
158
+
159
+ # Construct full path to mmproj file
160
+ mmproj_path = os.path.join(local_path, mmproj_filename)
161
+
162
+ # Get size from manifest, but verify file exists
163
+ mmproj_size = mmproj_file_info.get('Size', 0)
164
+ if os.path.exists(mmproj_path):
165
+ try:
166
+ # Verify size matches actual file size
167
+ actual_size = os.path.getsize(mmproj_path)
168
+ mmproj_size = actual_size # Use actual size if different
169
+ except (OSError, IOError):
170
+ # If we can't get actual size, use size from manifest
171
+ pass
172
+ else:
173
+ # File doesn't exist, don't include mmproj info
174
+ return None
175
+
176
+ return MMProjInfo(mmproj_path=mmproj_path, size=mmproj_size)
177
+
178
+
179
+ def _scan_for_repo_folders(base_path: str) -> List[DownloadedModel]:
180
+ """Scan a directory for repository folders and return model information."""
181
+ models = []
182
+
183
+ try:
184
+ if not os.path.exists(base_path):
185
+ return models
186
+
187
+ for item in os.listdir(base_path):
188
+ item_path = os.path.join(base_path, item)
189
+
190
+ # Skip non-directory items
191
+ if not os.path.isdir(item_path):
192
+ continue
193
+
194
+ # Check if this might be an owner folder by looking for subdirectories
195
+ has_subdirs = False
196
+ direct_files = []
197
+
198
+ try:
199
+ for subitem in os.listdir(item_path):
200
+ subitem_path = os.path.join(item_path, subitem)
201
+ if os.path.isdir(subitem_path):
202
+ has_subdirs = True
203
+ # This looks like owner/repo structure
204
+ # Only include if nexa.manifest or download_metadata.json exists (backward compatibility)
205
+ if _has_valid_metadata(subitem_path):
206
+ size_bytes, files = _get_directory_size_and_files(subitem_path)
207
+ if files: # Only include if there are files
208
+ # Check if the download is complete
209
+ download_complete = _check_for_incomplete_downloads(subitem_path)
210
+ # Load metadata if it exists
211
+ repo_id = f"{item}/{subitem}"
212
+ metadata = load_download_metadata(subitem_path, repo_id)
213
+
214
+ # Extract mmproj information
215
+ mmproj_info = _extract_mmproj_info(metadata, subitem_path)
216
+
217
+ models.append(DownloadedModel(
218
+ repo_id=repo_id,
219
+ files=files,
220
+ folder_type='owner_repo',
221
+ local_path=subitem_path,
222
+ size_bytes=size_bytes,
223
+ file_count=len(files),
224
+ full_repo_download_complete=download_complete,
225
+ pipeline_tag=metadata.get('pipeline_tag'),
226
+ download_time=metadata.get('download_time'),
227
+ avatar_url=metadata.get('avatar_url'),
228
+ mmproj_info=mmproj_info
229
+ ))
230
+ else:
231
+ direct_files.append(subitem)
232
+ except (OSError, IOError):
233
+ # Skip directories that can't be accessed
234
+ continue
235
+
236
+ # Direct repo folder (no owner structure)
237
+ if not has_subdirs and direct_files:
238
+ # Only include if nexa.manifest or download_metadata.json exists (backward compatibility)
239
+ if _has_valid_metadata(item_path):
240
+ size_bytes, files = _get_directory_size_and_files(item_path)
241
+ if files: # Only include if there are files
242
+ # Check if the download is complete
243
+ download_complete = _check_for_incomplete_downloads(item_path)
244
+ # Load metadata if it exists
245
+ repo_id = item
246
+ metadata = load_download_metadata(item_path, repo_id)
247
+
248
+ # Extract mmproj information
249
+ mmproj_info = _extract_mmproj_info(metadata, item_path)
250
+
251
+ models.append(DownloadedModel(
252
+ repo_id=repo_id,
253
+ files=files,
254
+ folder_type='direct_repo',
255
+ local_path=item_path,
256
+ size_bytes=size_bytes,
257
+ file_count=len(files),
258
+ full_repo_download_complete=download_complete,
259
+ pipeline_tag=metadata.get('pipeline_tag'),
260
+ download_time=metadata.get('download_time'),
261
+ avatar_url=metadata.get('avatar_url'),
262
+ mmproj_info=mmproj_info
263
+ ))
264
+
265
+ except (OSError, IOError):
266
+ # Skip if base path can't be accessed
267
+ pass
268
+
269
+ return models
270
+
271
+
272
+ def list_downloaded_models(local_dir: Optional[str] = None) -> List[DownloadedModel]:
273
+ """
274
+ List all downloaded models in the specified directory.
275
+
276
+ This function scans the local directory for downloaded models and returns
277
+ information about each repository including files, size, and folder structure.
278
+
279
+ It handles different folder naming conventions:
280
+ - Owner/repo structure (e.g., "microsoft/DialoGPT-small")
281
+ - Direct repo folders (repos without owner prefix)
282
+
283
+ Args:
284
+ local_dir (str, optional): Directory to scan for downloaded models.
285
+ If None, uses DEFAULT_MODEL_SAVING_PATH.
286
+
287
+ Returns:
288
+ List[DownloadedModel]: List of DownloadedModel objects with attributes:
289
+ - repo_id: str - Repository ID (e.g., "owner/repo")
290
+ - files: List[str] - List of relative file paths in the repository
291
+ - folder_type: str - 'owner_repo' or 'direct_repo'
292
+ - local_path: str - Full path to the model directory
293
+ - size_bytes: int - Total size of all files in bytes
294
+ - file_count: int - Number of files in the repository
295
+ - full_repo_download_complete: bool - True if no incomplete downloads detected,
296
+ False if .incomplete or .lock files exist
297
+ - pipeline_tag: Optional[str] - Pipeline tag from HuggingFace model info
298
+ - download_time: Optional[str] - ISO format timestamp when the model was downloaded
299
+ - avatar_url: Optional[str] - Avatar URL for the model author
300
+ - mmproj_info: Optional[MMProjInfo] - mmproj file information with mmproj_path and size
301
+ """
302
+
303
+ # Set up local directory
304
+ if local_dir is None:
305
+ local_dir = os.path.expanduser(DEFAULT_MODEL_SAVING_PATH)
306
+
307
+ local_dir = os.path.abspath(local_dir)
308
+
309
+ if not os.path.exists(local_dir):
310
+ return []
311
+
312
+ # Scan for repository folders
313
+ models = _scan_for_repo_folders(local_dir)
314
+
315
+ # Sort by repo_id for consistent output
316
+ models.sort(key=lambda x: x.repo_id)
317
+
318
+ return models
319
+
320
+
321
+ ##########################################################################
322
+ # Remove model functions #
323
+ ##########################################################################
324
+
325
+
326
+ def _parse_model_path(model_path: str) -> tuple[str, str | None]:
327
+ """
328
+ Parse model_path to extract repo_id and optional filename.
329
+
330
+ Examples:
331
+ "microsoft/DialoGPT-small" -> ("microsoft/DialoGPT-small", None)
332
+ "microsoft/DialoGPT-small/pytorch_model.bin" -> ("microsoft/DialoGPT-small", "pytorch_model.bin")
333
+ "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf" -> ("Qwen/Qwen3-4B-GGUF", "Qwen3-4B-Q4_K_M.gguf")
334
+
335
+ Args:
336
+ model_path: The model path string
337
+
338
+ Returns:
339
+ Tuple of (repo_id, filename) where filename can be None
340
+ """
341
+ parts = model_path.strip().split('/')
342
+
343
+ if len(parts) < 2:
344
+ # Invalid format, assume it's just a repo name without owner
345
+ return model_path, None
346
+ elif len(parts) == 2:
347
+ # Format: "owner/repo"
348
+ return model_path, None
349
+ else:
350
+ # Format: "owner/repo/file" or "owner/repo/subdir/file"
351
+ repo_id = f"{parts[0]}/{parts[1]}"
352
+ filename = '/'.join(parts[2:])
353
+ return repo_id, filename
354
+
355
+
356
+ def _validate_and_parse_input(model_path: str) -> tuple[str, Optional[str]]:
357
+ """Validate input and parse model path."""
358
+ if not model_path or not isinstance(model_path, str) or not model_path.strip():
359
+ raise ValueError("model_path is required and must be a non-empty string")
360
+
361
+ model_path = model_path.strip()
362
+ return _parse_model_path(model_path)
363
+
364
+
365
+ def _find_target_model(repo_id: str, local_dir: str) -> DownloadedModel:
366
+ """Find and validate the target model exists."""
367
+ downloaded_models = list_downloaded_models(local_dir)
368
+
369
+ for model in downloaded_models:
370
+ if model.repo_id == repo_id:
371
+ return model
372
+
373
+ available_repos = [model.repo_id for model in downloaded_models]
374
+ raise FileNotFoundError(
375
+ f"Repository '{repo_id}' not found in downloaded models. "
376
+ f"Available repositories: {available_repos}"
377
+ )
378
+
379
+
380
+ def _clean_empty_owner_directory(target_model: DownloadedModel) -> None:
381
+ """Remove empty owner directory if applicable."""
382
+ if target_model.folder_type != 'owner_repo':
383
+ return
384
+
385
+ parent_dir = os.path.dirname(target_model.local_path)
386
+ try:
387
+ if os.path.exists(parent_dir) and not os.listdir(parent_dir):
388
+ os.rmdir(parent_dir)
389
+ except OSError:
390
+ pass
391
+
392
+
393
+ def _remove_specific_file(target_model: DownloadedModel, file_name: str, local_dir: str) -> DownloadedModel:
394
+ """Remove a specific file from the repository."""
395
+ # Validate file exists in model
396
+ if file_name not in target_model.files:
397
+ raise FileNotFoundError(
398
+ f"File '{file_name}' not found in repository '{target_model.repo_id}'. "
399
+ f"Available files: {target_model.files[:10]}{'...' if len(target_model.files) > 10 else ''}"
400
+ )
401
+
402
+ # Construct full file path and validate it exists on disk
403
+ file_path = os.path.join(target_model.local_path, file_name)
404
+ if not os.path.exists(file_path):
405
+ raise FileNotFoundError(f"File does not exist on disk: {file_path}")
406
+
407
+ # Get file size before removal
408
+ try:
409
+ file_size = os.path.getsize(file_path)
410
+ except OSError:
411
+ file_size = 0
412
+
413
+ # Check if we should remove entire folder instead (for .gguf files)
414
+ # If removing a .gguf file and no other non-mmproj .gguf files remain, remove entire folder
415
+ if file_name.endswith('.gguf'):
416
+ updated_files = [f for f in target_model.files if f != file_name]
417
+ # Find remaining .gguf files that don't contain "mmproj" in filename
418
+ remaining_non_mmproj_gguf = [
419
+ f for f in updated_files
420
+ if f.endswith('.gguf') and 'mmproj' not in f.lower()
421
+ ]
422
+
423
+ # If no non-mmproj .gguf files remain, remove entire repository
424
+ if len(remaining_non_mmproj_gguf) == 0:
425
+ return _remove_entire_repository(target_model, local_dir)
426
+
427
+ # Remove the file
428
+ try:
429
+ os.remove(file_path)
430
+ except OSError as e:
431
+ raise OSError(f"Failed to remove file '{file_path}': {e}")
432
+
433
+ # Create updated model object
434
+ updated_files = [f for f in target_model.files if f != file_name]
435
+ updated_size = target_model.size_bytes - file_size
436
+ # Re-check download completeness after file removal
437
+ download_complete = _check_for_incomplete_downloads(target_model.local_path)
438
+ updated_model = DownloadedModel(
439
+ repo_id=target_model.repo_id,
440
+ files=updated_files,
441
+ folder_type=target_model.folder_type,
442
+ local_path=target_model.local_path,
443
+ size_bytes=updated_size,
444
+ file_count=len(updated_files),
445
+ full_repo_download_complete=download_complete
446
+ )
447
+
448
+ # If no files left, remove the entire directory
449
+ if len(updated_files) == 0:
450
+ try:
451
+ shutil.rmtree(target_model.local_path)
452
+ _clean_empty_owner_directory(target_model)
453
+ except OSError:
454
+ pass
455
+
456
+ return updated_model
457
+
458
+
459
+ def _remove_entire_repository(target_model: DownloadedModel, local_dir: str) -> DownloadedModel:
460
+ """Remove the entire repository and clean up."""
461
+ # Remove the directory and all its contents
462
+ try:
463
+ shutil.rmtree(target_model.local_path)
464
+ except OSError as e:
465
+ raise OSError(f"Failed to remove directory '{target_model.local_path}': {e}")
466
+
467
+ # Clean up associated resources
468
+ _clean_empty_owner_directory(target_model)
469
+
470
+ return target_model
471
+
472
+
473
+ def remove_model_or_file(
474
+ model_path: str,
475
+ local_dir: Optional[str] = None
476
+ ) -> DownloadedModel:
477
+ """
478
+ Remove a downloaded model or specific file by repository ID or file path.
479
+
480
+ This function supports two modes:
481
+ 1. Remove entire repository: "microsoft/DialoGPT-small"
482
+ 2. Remove specific file: "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf"
483
+
484
+ For entire repository removal, it removes the directory and all files. For specific file removal, it only
485
+ removes that file and updates the repository metadata.
486
+
487
+ Args:
488
+ model_path (str): Required. Either:
489
+ - Repository ID (e.g., "microsoft/DialoGPT-small") - removes entire repo
490
+ - File path (e.g., "Qwen/Qwen3-4B-GGUF/model.gguf") - removes specific file
491
+ local_dir (str, optional): Directory to search for downloaded models.
492
+ If None, uses DEFAULT_MODEL_SAVING_PATH.
493
+
494
+ Returns:
495
+ DownloadedModel: The model object representing what was removed from disk.
496
+ For file removal, returns updated model info after file removal.
497
+
498
+ Raises:
499
+ ValueError: If model_path is invalid (empty or None)
500
+ FileNotFoundError: If the repository or file is not found in downloaded models
501
+ OSError: If there's an error removing files from disk
502
+ """
503
+ # Validate input and parse path
504
+ repo_id, file_name = _validate_and_parse_input(model_path)
505
+
506
+ # Set up local directory
507
+ if local_dir is None:
508
+ local_dir = os.path.expanduser(DEFAULT_MODEL_SAVING_PATH)
509
+
510
+ local_dir = os.path.abspath(local_dir)
511
+
512
+ if not os.path.exists(local_dir):
513
+ raise FileNotFoundError(f"Local directory does not exist: {local_dir}")
514
+
515
+ # Find the target model
516
+ target_model = _find_target_model(repo_id, local_dir)
517
+
518
+ # Delegate to appropriate removal function
519
+ if file_name:
520
+ return _remove_specific_file(target_model, file_name, local_dir)
521
+ else:
522
+ return _remove_entire_repository(target_model, local_dir)
523
+
524
+
525
+ ##########################################################################
526
+ # Check model existence functions #
527
+ ##########################################################################
528
+
529
+
530
+ def check_model_existence(
531
+ model_path: str,
532
+ local_dir: Optional[str] = None
533
+ ) -> bool:
534
+ """
535
+ Check if a downloaded model or specific file exists locally.
536
+
537
+ This function supports two modes:
538
+ 1. Check entire repository: "microsoft/DialoGPT-small"
539
+ 2. Check specific file: "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf"
540
+
541
+ Args:
542
+ model_path (str): Required. Either:
543
+ - Repository ID (e.g., "microsoft/DialoGPT-small") - checks entire repo
544
+ - File path (e.g., "Qwen/Qwen3-4B-GGUF/model.gguf") - checks specific file
545
+ local_dir (str, optional): Directory to search for downloaded models.
546
+ If None, uses DEFAULT_MODEL_SAVING_PATH.
547
+
548
+ Returns:
549
+ bool: True if the requested item exists, False otherwise
550
+
551
+ Raises:
552
+ ValueError: If model_path is invalid (empty or None)
553
+ """
554
+ # Validate input and parse path
555
+ repo_id, file_name = _validate_and_parse_input(model_path)
556
+
557
+ # Set up local directory
558
+ if local_dir is None:
559
+ local_dir = os.path.expanduser(DEFAULT_MODEL_SAVING_PATH)
560
+
561
+ local_dir = os.path.abspath(local_dir)
562
+
563
+ # Return False if local directory doesn't exist
564
+ if not os.path.exists(local_dir):
565
+ return False
566
+
567
+ # Get all downloaded models
568
+ downloaded_models = list_downloaded_models(local_dir)
569
+
570
+ # Find the target model
571
+ for model in downloaded_models:
572
+ if model.repo_id == repo_id:
573
+ # If no specific file requested, repository existence is sufficient
574
+ if file_name is None:
575
+ return True
576
+ else:
577
+ # Check specific file existence
578
+ return file_name in model.files
579
+
580
+ return False
581
+
582
+
583
+ ##########################################################################
584
+ # HuggingFace Downloader Class #
585
+ ##########################################################################
586
+
587
+
588
+ class HuggingFaceDownloader:
589
+ """Class to handle downloads from HuggingFace Hub with unified API usage."""
590
+
591
+ def __init__(
592
+ self,
593
+ endpoint: Optional[str] = None,
594
+ token: Union[bool, str, None] = None,
595
+ enable_transfer: bool = True
596
+ ):
597
+ """
598
+ Initialize the downloader with HuggingFace API.
599
+
600
+ Args:
601
+ endpoint: Custom endpoint URL (e.g., "https://hf-mirror.com").
602
+ If None, uses default HuggingFace Hub.
603
+ token: Authentication token for private repositories.
604
+ enable_transfer: Whether to enable HF transfer for faster downloads.
605
+ """
606
+ # Always create an HfApi instance - either with custom endpoint or default
607
+ self.token = token if isinstance(token, str) else False # False means disable authentication
608
+ self.api = HfApi(endpoint=endpoint, token=self.token) if endpoint else HfApi(token=self.token)
609
+ self.enable_transfer = enable_transfer
610
+ self.original_hf_transfer = None
611
+ self.endpoint = endpoint # Store endpoint for avatar fetching
612
+ self._model_info_cache: Dict[str, Any] = {} # Cache for model_info results
613
+
614
+ def _create_repo_directory(self, local_dir: str, repo_id: str) -> str:
615
+ """Create a directory structure for the repository following HF convention."""
616
+ if '/' in repo_id:
617
+ # Standard format: owner/repo
618
+ owner, repo = repo_id.split('/', 1)
619
+ repo_dir = os.path.join(local_dir, owner, repo)
620
+ else:
621
+ # Direct repo name without owner
622
+ repo_dir = os.path.join(local_dir, repo_id)
623
+
624
+ os.makedirs(repo_dir, exist_ok=True)
625
+ return repo_dir
626
+
627
+ def _created_dir_if_not_exists(self, local_dir: Optional[str]) -> str:
628
+ """Create directory if it doesn't exist and return the expanded path."""
629
+ if local_dir is None:
630
+ local_dir = DEFAULT_MODEL_SAVING_PATH
631
+
632
+ local_dir = os.path.expanduser(local_dir)
633
+ os.makedirs(local_dir, exist_ok=True)
634
+ return local_dir
635
+
636
+ def _get_model_info_cached(self, repo_id: str, files_metadata: bool = False):
637
+ """Get model info with caching to avoid rate limiting.
638
+
639
+ Args:
640
+ repo_id: Repository ID
641
+ files_metadata: Whether to include files metadata
642
+
643
+ Returns:
644
+ Model info object from HuggingFace API
645
+ """
646
+ # Create cache key based on repo_id and files_metadata flag
647
+ cache_key = f"{repo_id}:files={files_metadata}"
648
+
649
+ # Return cached result if available
650
+ if cache_key in self._model_info_cache:
651
+ return self._model_info_cache[cache_key]
652
+
653
+ # Fetch from API and cache the result
654
+ try:
655
+ info = self.api.model_info(repo_id, files_metadata=files_metadata, token=self.token)
656
+ self._model_info_cache[cache_key] = info
657
+ return info
658
+ except Exception:
659
+ # Don't cache errors, re-raise
660
+ raise
661
+
662
+ def _get_repo_info_for_progress(
663
+ self,
664
+ repo_id: str,
665
+ file_name: Optional[Union[str, List[str]]] = None
666
+ ) -> tuple[int, int]:
667
+ """Get total repository size and file count for progress tracking."""
668
+ try:
669
+ info = self._get_model_info_cached(repo_id, files_metadata=True)
670
+
671
+ total_size = 0
672
+ file_count = 0
673
+
674
+ if info.siblings:
675
+ for sibling in info.siblings:
676
+ # Handle different file_name types
677
+ if file_name is not None:
678
+ if isinstance(file_name, str):
679
+ # Single file - only count if it matches
680
+ if sibling.rfilename != file_name:
681
+ continue
682
+ elif isinstance(file_name, list):
683
+ # Multiple files - only count if in the list
684
+ if sibling.rfilename not in file_name:
685
+ continue
686
+
687
+ # For all matching files (or all files if file_name is None)
688
+ if hasattr(sibling, 'size') and sibling.size is not None:
689
+ total_size += sibling.size
690
+ file_count += 1
691
+ else:
692
+ # Count files without size info
693
+ file_count += 1
694
+
695
+ return total_size, file_count if file_count > 0 else 1
696
+ except Exception:
697
+ # If we can't get info, return defaults
698
+ return 0, 1
699
+
700
+ def _validate_and_setup_params(
701
+ self,
702
+ repo_id: str,
703
+ file_name: Optional[Union[str, List[str]]]
704
+ ) -> tuple[str, Optional[Union[str, List[str]]]]:
705
+ """Validate and normalize input parameters."""
706
+ if not repo_id:
707
+ raise ValueError("repo_id is required")
708
+
709
+ repo_id = repo_id.strip()
710
+
711
+ # Handle file_name parameter
712
+ if file_name is not None:
713
+ if isinstance(file_name, str):
714
+ file_name = file_name.strip()
715
+ if not file_name:
716
+ file_name = None
717
+ elif isinstance(file_name, list):
718
+ # Filter out empty strings and strip whitespace
719
+ file_name = [f.strip() for f in file_name if f and f.strip()]
720
+ if not file_name:
721
+ file_name = None
722
+ else:
723
+ raise ValueError("file_name must be a string, list of strings, or None")
724
+
725
+ return repo_id, file_name
726
+
727
+ def _setup_progress_tracker(
728
+ self,
729
+ progress_callback: Optional[Callable[[Dict[str, Any]], None]],
730
+ show_progress: bool,
731
+ repo_id: str,
732
+ file_name: Optional[Union[str, List[str]]]
733
+ ) -> Optional[DownloadProgressTracker]:
734
+ """Initialize progress tracker if callback is provided."""
735
+ if not progress_callback:
736
+ return None
737
+
738
+ progress_tracker = DownloadProgressTracker(progress_callback, show_progress)
739
+ # Get repo info for progress tracking - now handles all cases
740
+ total_size, file_count = self._get_repo_info_for_progress(repo_id, file_name)
741
+ progress_tracker.set_repo_info(total_size, file_count)
742
+ return progress_tracker
743
+
744
+ def _setup_hf_transfer_env(self) -> None:
745
+ """Set up HF transfer environment."""
746
+ self.original_hf_transfer = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER")
747
+ if self.enable_transfer:
748
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
749
+
750
+ def _cleanup_hf_transfer_env(self) -> None:
751
+ """Restore original HF transfer environment."""
752
+ if self.original_hf_transfer is not None:
753
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = self.original_hf_transfer
754
+ else:
755
+ os.environ.pop("HF_HUB_ENABLE_HF_TRANSFER", None)
756
+
757
+ def _validate_repository_and_get_info(
758
+ self,
759
+ repo_id: str,
760
+ progress_tracker: Optional[DownloadProgressTracker]
761
+ ):
762
+ """Validate repository exists and get info."""
763
+ try:
764
+ info = self._get_model_info_cached(repo_id, files_metadata=False)
765
+ return info
766
+ except RepositoryNotFoundError:
767
+ error_msg = f"Repository '{repo_id}' not found. Please check the repository ID."
768
+ if progress_tracker:
769
+ progress_tracker.set_error(error_msg)
770
+ raise RepositoryNotFoundError(error_msg)
771
+ except HfHubHTTPError as e:
772
+ if e.response.status_code == 404:
773
+ error_msg = f"Repository '{repo_id}' not found. Please check the repository ID."
774
+ if progress_tracker:
775
+ progress_tracker.set_error(error_msg)
776
+ raise RepositoryNotFoundError(error_msg)
777
+ else:
778
+ error_msg = f"HTTP error while accessing repository '{repo_id}': {e}"
779
+ if progress_tracker:
780
+ progress_tracker.set_error(error_msg)
781
+ raise HfHubHTTPError(error_msg)
782
+
783
+ def _validate_file_exists_in_repo(
784
+ self,
785
+ file_name: str,
786
+ info,
787
+ repo_id: str,
788
+ progress_tracker: Optional[DownloadProgressTracker]
789
+ ) -> None:
790
+ """Validate that the file exists in the repository."""
791
+ file_exists = False
792
+ if info.siblings:
793
+ for sibling in info.siblings:
794
+ if sibling.rfilename == file_name:
795
+ file_exists = True
796
+ break
797
+
798
+ if not file_exists:
799
+ available_files = [sibling.rfilename for sibling in info.siblings] if info.siblings else []
800
+ error_msg = (
801
+ f"File '{file_name}' not found in repository '{repo_id}'. "
802
+ f"Available files: {available_files[:10]}{'...' if len(available_files) > 10 else ''}"
803
+ )
804
+ if progress_tracker:
805
+ progress_tracker.set_error(error_msg)
806
+ progress_tracker.stop_tracking()
807
+ raise ValueError(error_msg)
808
+
809
+ def _check_file_exists_and_valid(
810
+ self,
811
+ file_path: str,
812
+ expected_size: Optional[int] = None
813
+ ) -> bool:
814
+ """Check if a file exists and is valid (non-empty, correct size if known)."""
815
+ if not os.path.exists(file_path):
816
+ return False
817
+
818
+ # Check file is not empty
819
+ try:
820
+ file_size = os.path.getsize(file_path)
821
+ if file_size == 0:
822
+ return False
823
+ except (OSError, IOError):
824
+ return False
825
+
826
+ # If we have expected size, check it matches
827
+ if expected_size is not None and file_size != expected_size:
828
+ return False
829
+
830
+ # If no expected size, just check that file is not empty
831
+ return os.path.getsize(file_path) > 0
832
+
833
+ def _extract_model_file_type_from_tags(self, repo_id: str) -> Optional[str]:
834
+ """Extract model file type from repo tags with priority: NPU > MLX > GGUF."""
835
+ try:
836
+ info = self._get_model_info_cached(repo_id, files_metadata=False)
837
+ if hasattr(info, 'tags') and info.tags:
838
+ # Convert tags to lowercase for case-insensitive matching
839
+ tags_lower = [tag.lower() for tag in info.tags]
840
+
841
+ # Check with priority: NPU > MLX > GGUF
842
+ if 'npu' in tags_lower:
843
+ return 'npu'
844
+ elif 'mlx' in tags_lower:
845
+ return 'mlx'
846
+ elif 'gguf' in tags_lower:
847
+ return 'gguf'
848
+ except Exception:
849
+ pass
850
+ return None
851
+
852
+ def _load_downloaded_manifest(self, local_dir: str) -> Dict[str, Any]:
853
+ """Load nexa.manifest from the downloaded repository if it exists."""
854
+ manifest_path = os.path.join(local_dir, 'nexa.manifest')
855
+ if os.path.exists(manifest_path):
856
+ try:
857
+ with open(manifest_path, 'r', encoding='utf-8') as f:
858
+ return json.load(f)
859
+ except (json.JSONDecodeError, IOError):
860
+ pass
861
+ return {}
862
+
863
+ def _download_manifest_if_needed(self, repo_id: str, local_dir: str) -> bool:
864
+ """
865
+ Download nexa.manifest from the repository if it doesn't exist locally.
866
+
867
+ Args:
868
+ repo_id: Repository ID
869
+ local_dir: Local directory where the manifest should be saved
870
+
871
+ Returns:
872
+ bool: True if manifest was downloaded or already exists, False if not found in repo
873
+ """
874
+ manifest_path = os.path.join(local_dir, 'nexa.manifest')
875
+
876
+ # Check if manifest already exists locally
877
+ if os.path.exists(manifest_path):
878
+ return True
879
+
880
+ # Try to download nexa.manifest from the repository
881
+ try:
882
+ print(f"[INFO] Attempting to download nexa.manifest from {repo_id}...")
883
+ self.api.hf_hub_download(
884
+ repo_id=repo_id,
885
+ filename='nexa.manifest',
886
+ local_dir=local_dir,
887
+ local_dir_use_symlinks=False,
888
+ token=self.token,
889
+ force_download=False
890
+ )
891
+ print(f"[OK] Successfully downloaded nexa.manifest from {repo_id}")
892
+ return True
893
+ except Exception as e:
894
+ # Manifest doesn't exist in repo or other error - this is fine, we'll create it
895
+ print(f"[INFO] nexa.manifest not found in {repo_id}, will create locally")
896
+ return False
897
+
898
+ def _fetch_and_save_metadata(self, repo_id: str, local_dir: str, is_mmproj: bool = False, file_name: Optional[Union[str, List[str]]] = None, **kwargs) -> None:
899
+ """Fetch model info and save metadata after successful download."""
900
+ # Initialize metadata with defaults to ensure manifest is always created
901
+ old_metadata = {
902
+ 'pipeline_tag': "text-generation", # Default to text-generation pipeline-tag
903
+ 'download_time': datetime.now().isoformat(),
904
+ 'avatar_url': None
905
+ }
906
+
907
+ # Try to fetch additional metadata, but don't let failures prevent manifest creation
908
+ try:
909
+ # Fetch model info to get pipeline_tag (using cache)
910
+ info = self._get_model_info_cached(repo_id, files_metadata=False)
911
+ if hasattr(info, 'pipeline_tag') and info.pipeline_tag:
912
+ old_metadata['pipeline_tag'] = info.pipeline_tag
913
+ except Exception as e:
914
+ # Log the error but continue with manifest creation
915
+ print(f"Warning: Could not fetch model info for {repo_id}: {e}")
916
+
917
+ # Use input avater url if provided
918
+ old_metadata['avatar_url'] = kwargs.get('avatar_url')
919
+
920
+ # Extract model file type from tags
921
+ model_file_type = self._extract_model_file_type_from_tags(repo_id)
922
+ if model_file_type:
923
+ old_metadata['model_file_type'] = model_file_type
924
+
925
+ # Load existing nexa.manifest from downloaded repo (if exists)
926
+ downloaded_manifest = self._load_downloaded_manifest(local_dir)
927
+ if downloaded_manifest:
928
+ old_metadata['downloaded_manifest'] = downloaded_manifest
929
+
930
+
931
+ # CRITICAL: Always create the manifest file, regardless of metadata fetch failures
932
+ try:
933
+ save_manifest_with_files_metadata(repo_id, local_dir, old_metadata, is_mmproj, file_name, **kwargs)
934
+ print(f"[OK] Successfully created nexa.manifest for {repo_id}")
935
+ except Exception as e:
936
+ # This is critical - if manifest creation fails, we should know about it
937
+ print(f"ERROR: Failed to create nexa.manifest for {repo_id}: {e}")
938
+ # Try a fallback approach - create a minimal manifest
939
+ try:
940
+ minimal_manifest = {
941
+ "Name": repo_id,
942
+ "ModelName": kwargs.get('model_name', ''),
943
+ "ModelType": kwargs.get('model_type', 'other'),
944
+ "PluginId": kwargs.get('plugin_id', 'unknown'),
945
+ "DeviceId": kwargs.get('device_id', ''),
946
+ "MinSDKVersion": kwargs.get('min_sdk_version', ''),
947
+ "ModelFile": {},
948
+ "MMProjFile": {"Name": "", "Downloaded": False, "Size": 0},
949
+ "TokenizerFile": {"Name": "", "Downloaded": False, "Size": 0},
950
+ "ExtraFiles": None,
951
+ "pipeline_tag": old_metadata.get('pipeline_tag'),
952
+ "download_time": old_metadata.get('download_time'),
953
+ "avatar_url": old_metadata.get('avatar_url')
954
+ }
955
+ save_download_metadata(local_dir, minimal_manifest)
956
+ print(f"[OK] Created minimal nexa.manifest for {repo_id} as fallback")
957
+ except Exception as fallback_error:
958
+ print(f"CRITICAL ERROR: Could not create even minimal manifest for {repo_id}: {fallback_error}")
959
+
960
+ def _download_single_file(
961
+ self,
962
+ repo_id: str,
963
+ file_name: str,
964
+ local_dir: str,
965
+ progress_tracker: Optional[DownloadProgressTracker],
966
+ force_download: bool = False,
967
+ **kwargs
968
+ ) -> str:
969
+ """Download a single file from the repository using HuggingFace Hub API."""
970
+ # Create repo-specific directory for the single file
971
+ file_local_dir = self._create_repo_directory(local_dir, repo_id)
972
+
973
+ # Check if file already exists
974
+ local_file_path = os.path.join(file_local_dir, file_name)
975
+ if not force_download and self._check_file_exists_and_valid(local_file_path):
976
+ print(f"[SKIP] File already exists: {file_name}")
977
+ # Stop progress tracking
978
+ if progress_tracker:
979
+ progress_tracker.stop_tracking()
980
+ return local_file_path
981
+
982
+ try:
983
+ # Note: hf_hub_download doesn't support tqdm_class parameter
984
+ # Progress tracking works through the global tqdm monkey patching
985
+ downloaded_path = self.api.hf_hub_download(
986
+ repo_id=repo_id,
987
+ filename=file_name,
988
+ local_dir=file_local_dir,
989
+ local_dir_use_symlinks=False,
990
+ token=self.token,
991
+ force_download=force_download
992
+ )
993
+
994
+ # Stop progress tracking
995
+ if progress_tracker:
996
+ progress_tracker.stop_tracking()
997
+
998
+ # Download nexa.manifest from repo if it doesn't exist locally
999
+ self._download_manifest_if_needed(repo_id, file_local_dir)
1000
+
1001
+ # Save metadata after successful download
1002
+ self._fetch_and_save_metadata(repo_id, file_local_dir, self._current_is_mmproj, self._current_file_name, **kwargs)
1003
+
1004
+ return downloaded_path
1005
+
1006
+ except HfHubHTTPError as e:
1007
+ error_msg = f"Error downloading file '{file_name}': {e}"
1008
+ if progress_tracker:
1009
+ progress_tracker.set_error(error_msg)
1010
+ progress_tracker.stop_tracking()
1011
+ if e.response.status_code == 404:
1012
+ raise ValueError(f"File '{file_name}' not found in repository '{repo_id}'")
1013
+ else:
1014
+ raise HfHubHTTPError(error_msg)
1015
+
1016
+ def _download_entire_repository(
1017
+ self,
1018
+ repo_id: str,
1019
+ local_dir: str,
1020
+ progress_tracker: Optional[DownloadProgressTracker],
1021
+ force_download: bool = False,
1022
+ **kwargs
1023
+ ) -> str:
1024
+ """Download the entire repository."""
1025
+ # Create a subdirectory for this specific repo
1026
+ repo_local_dir = self._create_repo_directory(local_dir, repo_id)
1027
+
1028
+ try:
1029
+ download_kwargs = {
1030
+ 'repo_id': repo_id,
1031
+ 'local_dir': repo_local_dir,
1032
+ 'local_dir_use_symlinks': False,
1033
+ 'token': self.token,
1034
+ 'force_download': force_download
1035
+ }
1036
+
1037
+ # Add tqdm_class if progress tracking is enabled
1038
+ if progress_tracker:
1039
+ download_kwargs['tqdm_class'] = CustomProgressTqdm
1040
+
1041
+ downloaded_path = self.api.snapshot_download(**download_kwargs)
1042
+
1043
+ # Stop progress tracking
1044
+ if progress_tracker:
1045
+ progress_tracker.stop_tracking()
1046
+
1047
+ # Save metadata after successful download
1048
+ self._fetch_and_save_metadata(repo_id, repo_local_dir, self._current_is_mmproj, self._current_file_name, **kwargs)
1049
+
1050
+ return downloaded_path
1051
+
1052
+ except HfHubHTTPError as e:
1053
+ error_msg = f"Error downloading repository '{repo_id}': {e}"
1054
+ if progress_tracker:
1055
+ progress_tracker.set_error(error_msg)
1056
+ progress_tracker.stop_tracking()
1057
+ raise HfHubHTTPError(error_msg)
1058
+
1059
+ def _download_multiple_files_from_hf(
1060
+ self,
1061
+ repo_id: str,
1062
+ file_names: List[str],
1063
+ local_dir: str,
1064
+ progress_tracker: Optional[DownloadProgressTracker],
1065
+ force_download: bool = False,
1066
+ **kwargs
1067
+ ) -> str:
1068
+ """Download multiple specific files from HuggingFace Hub."""
1069
+ # Create repo-specific directory
1070
+ repo_local_dir = self._create_repo_directory(local_dir, repo_id)
1071
+
1072
+ # Create overall progress bar for multiple files
1073
+ overall_progress = tqdm(
1074
+ total=len(file_names),
1075
+ unit='file',
1076
+ desc=f"Downloading {len(file_names)} files from {repo_id}",
1077
+ position=0,
1078
+ leave=True
1079
+ )
1080
+
1081
+ try:
1082
+ for file_name in file_names:
1083
+ overall_progress.set_postfix_str(f"Current: {os.path.basename(file_name)}")
1084
+
1085
+ # Check if file already exists
1086
+ local_file_path = os.path.join(repo_local_dir, file_name)
1087
+ if not force_download and self._check_file_exists_and_valid(local_file_path):
1088
+ print(f"[SKIP] File already exists: {file_name}")
1089
+ overall_progress.update(1)
1090
+ continue
1091
+
1092
+ # Download each file using hf_hub_download
1093
+ self.api.hf_hub_download(
1094
+ repo_id=repo_id,
1095
+ filename=file_name,
1096
+ local_dir=repo_local_dir,
1097
+ local_dir_use_symlinks=False,
1098
+ token=self.token,
1099
+ force_download=force_download
1100
+ )
1101
+
1102
+ overall_progress.update(1)
1103
+
1104
+ overall_progress.close()
1105
+
1106
+ # Stop progress tracking
1107
+ if progress_tracker:
1108
+ progress_tracker.stop_tracking()
1109
+
1110
+ # Download nexa.manifest from repo if it doesn't exist locally
1111
+ self._download_manifest_if_needed(repo_id, repo_local_dir)
1112
+
1113
+ # Save metadata after successful download
1114
+ self._fetch_and_save_metadata(repo_id, repo_local_dir, self._current_is_mmproj, self._current_file_name, **kwargs)
1115
+
1116
+ return repo_local_dir
1117
+
1118
+ except HfHubHTTPError as e:
1119
+ overall_progress.close()
1120
+ error_msg = f"Error downloading files from '{repo_id}': {e}"
1121
+ if progress_tracker:
1122
+ progress_tracker.set_error(error_msg)
1123
+ progress_tracker.stop_tracking()
1124
+ raise HfHubHTTPError(error_msg)
1125
+ except Exception as e:
1126
+ overall_progress.close()
1127
+ if progress_tracker:
1128
+ progress_tracker.set_error(str(e))
1129
+ progress_tracker.stop_tracking()
1130
+ raise
1131
+
1132
+ def download(
1133
+ self,
1134
+ repo_id: str,
1135
+ file_name: Optional[Union[str, List[str]]] = None,
1136
+ local_dir: Optional[str] = None,
1137
+ progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
1138
+ show_progress: bool = True,
1139
+ force_download: bool = False,
1140
+ is_mmproj: bool = False,
1141
+ **kwargs
1142
+ ) -> str:
1143
+ """
1144
+ Main download method that handles all download scenarios.
1145
+
1146
+ Args:
1147
+ repo_id: Repository ID to download from
1148
+ file_name: Optional file name(s) to download
1149
+ local_dir: Local directory to save files
1150
+ progress_callback: Callback for progress updates
1151
+ show_progress: Whether to show progress bar
1152
+ force_download: Force re-download even if files exist
1153
+
1154
+ Returns:
1155
+ Path to downloaded file or directory
1156
+ """
1157
+ # Validate and normalize parameters
1158
+ repo_id, file_name = self._validate_and_setup_params(repo_id, file_name)
1159
+
1160
+ # Store parameters as instance variables for use in _fetch_and_save_metadata
1161
+ self._current_is_mmproj = is_mmproj
1162
+ self._current_file_name = file_name
1163
+
1164
+ # Set up local directory
1165
+ local_dir = self._created_dir_if_not_exists(local_dir)
1166
+
1167
+ # Set up progress tracker
1168
+ file_name_for_progress = file_name if isinstance(file_name, str) else None
1169
+ progress_tracker = self._setup_progress_tracker(
1170
+ progress_callback, show_progress, repo_id, file_name_for_progress
1171
+ )
1172
+
1173
+ # Set up HF transfer environment
1174
+ self._setup_hf_transfer_env()
1175
+
1176
+ try:
1177
+ # Validate repository and get info
1178
+ info = self._validate_repository_and_get_info(repo_id, progress_tracker)
1179
+
1180
+ # Start progress tracking
1181
+ if progress_tracker:
1182
+ progress_tracker.start_tracking()
1183
+
1184
+ # Choose download strategy based on file_name
1185
+ if file_name is None:
1186
+ # Download entire repository
1187
+ return self._download_entire_repository(
1188
+ repo_id, local_dir, progress_tracker, force_download, **kwargs
1189
+ )
1190
+ elif isinstance(file_name, str):
1191
+ # Download specific single file
1192
+ self._validate_file_exists_in_repo(file_name, info, repo_id, progress_tracker)
1193
+ return self._download_single_file(
1194
+ repo_id, file_name, local_dir, progress_tracker, force_download, **kwargs
1195
+ )
1196
+ else: # file_name is a list
1197
+ # Download multiple specific files
1198
+ # Validate all files exist
1199
+ for fname in file_name:
1200
+ self._validate_file_exists_in_repo(fname, info, repo_id, progress_tracker)
1201
+
1202
+ return self._download_multiple_files_from_hf(
1203
+ repo_id, file_name, local_dir, progress_tracker, force_download, **kwargs
1204
+ )
1205
+
1206
+ except Exception as e:
1207
+ # Handle any unexpected errors
1208
+ if progress_tracker and progress_tracker.download_status != "error":
1209
+ progress_tracker.set_error(str(e))
1210
+ progress_tracker.stop_tracking()
1211
+ raise
1212
+
1213
+ finally:
1214
+ # Restore original HF transfer setting
1215
+ self._cleanup_hf_transfer_env()
1216
+
1217
+
1218
+ ##########################################################################
1219
+ # Public Download Function #
1220
+ ##########################################################################
1221
+
1222
+
1223
+ def download_from_huggingface(
1224
+ repo_id: str,
1225
+ file_name: Optional[Union[str, List[str]]] = None,
1226
+ local_dir: Optional[str] = None,
1227
+ enable_transfer: bool = True,
1228
+ progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
1229
+ show_progress: bool = True,
1230
+ token: Union[bool, str, None] = None,
1231
+ custom_endpoint: Optional[str] = None,
1232
+ force_download: bool = False,
1233
+ is_mmproj: Optional[bool] = None,
1234
+ **kwargs
1235
+ ) -> str:
1236
+ """
1237
+ Download models or files from HuggingFace Hub or custom mirror endpoints.
1238
+
1239
+ Args:
1240
+ repo_id (str): Required. The repository ID to download from (e.g., "microsoft/DialoGPT-medium")
1241
+ file_name (Union[str, List[str]], optional): Single filename or list of filenames to download.
1242
+ If None, downloads entire repo.
1243
+ local_dir (str, optional): Local directory to save files. If None, uses DEFAULT_MODEL_SAVING_PATH.
1244
+ enable_transfer (bool, optional): Whether to enable HF transfer for faster downloads. Default True.
1245
+ progress_callback (Callable, optional): Callback function to receive progress updates.
1246
+ Function receives a dict with progress information.
1247
+ show_progress (bool, optional): Whether to show a unified progress bar in the terminal. Default True.
1248
+ Only works when progress_callback is provided.
1249
+ token (Union[bool, str, None], optional): A token to be used for the download.
1250
+ - If True, the token is read from the HuggingFace config folder.
1251
+ - If a string, it's used as the authentication token.
1252
+ - If None, uses default behavior.
1253
+ custom_endpoint (str, optional): A custom HuggingFace-compatible endpoint URL.
1254
+ Should be ONLY the base endpoint without any paths.
1255
+ Examples:
1256
+ - "https://hf-mirror.com"
1257
+ - "https://huggingface.co" (default)
1258
+ The endpoint will be used to initialize HfApi for all downloads.
1259
+ force_download (bool, optional): If True, download files even if they already exist locally.
1260
+ Default False (skip existing files).
1261
+ is_mmproj (bool, optional): Whether the file being downloaded is an mmproj file. Only used when
1262
+ file_name is not None. If None, defaults to True if 'mmproj' is in
1263
+ the filename, False otherwise.
1264
+ **kwargs: Additional parameters including:
1265
+ - plugin_id (str): Override PluginId in nexa.manifest (highest priority)
1266
+ - model_name (str): Override ModelName in nexa.manifest (highest priority)
1267
+ - model_type (str): Override ModelType in nexa.manifest (highest priority)
1268
+ - device_id (str): Set DeviceId in nexa.manifest (highest priority)
1269
+ - min_sdk_version (str): Set MinSDKVersion in nexa.manifest (highest priority)
1270
+
1271
+ Returns:
1272
+ str: Path to the downloaded file or directory
1273
+
1274
+ Raises:
1275
+ ValueError: If repo_id is invalid or file_name doesn't exist in the repo
1276
+ RepositoryNotFoundError: If the repository doesn't exist
1277
+ HfHubHTTPError: If there's an HTTP error during download
1278
+
1279
+ Progress Callback Data Format:
1280
+ {
1281
+ 'status': str, # 'idle', 'downloading', 'completed', 'error'
1282
+ 'error_message': str, # Only present if status is 'error'
1283
+ 'progress': {
1284
+ 'total_downloaded': int, # Bytes downloaded
1285
+ 'total_size': int, # Total bytes to download
1286
+ 'percentage': float, # Progress percentage (0-100)
1287
+ 'files_active': int, # Number of files currently downloading
1288
+ 'files_total': int, # Total number of files
1289
+ 'known_total': bool # Whether total size is known
1290
+ },
1291
+ 'speed': {
1292
+ 'bytes_per_second': float, # Download speed in bytes/sec
1293
+ 'formatted': str # Human readable speed (e.g., "1.2 MB/s")
1294
+ },
1295
+ 'formatting': {
1296
+ 'downloaded': str, # Human readable downloaded size
1297
+ 'total_size': str # Human readable total size
1298
+ },
1299
+ 'timing': {
1300
+ 'elapsed_seconds': float, # Time since download started
1301
+ 'eta_seconds': float, # Estimated time remaining
1302
+ 'start_time': float # Download start timestamp
1303
+ }
1304
+ }
1305
+ """
1306
+ # Set default value for is_mmproj based on filename if not explicitly provided
1307
+ if is_mmproj is None and file_name is not None:
1308
+ # Check if any filename contains 'mmproj'
1309
+ filenames_to_check = file_name if isinstance(file_name, list) else [file_name]
1310
+ is_mmproj = any('mmproj' in filename.lower() for filename in filenames_to_check)
1311
+ elif is_mmproj is None:
1312
+ # Default to False if no file_name is provided
1313
+ is_mmproj = False
1314
+
1315
+ # Create downloader instance with custom endpoint if provided
1316
+ downloader = HuggingFaceDownloader(
1317
+ endpoint=custom_endpoint,
1318
+ token=token,
1319
+ enable_transfer=enable_transfer
1320
+ )
1321
+
1322
+ # Use the downloader to perform the download
1323
+ return downloader.download(
1324
+ repo_id=repo_id,
1325
+ file_name=file_name,
1326
+ local_dir=local_dir,
1327
+ progress_callback=progress_callback,
1328
+ show_progress=show_progress,
1329
+ force_download=force_download,
1330
+ is_mmproj=is_mmproj,
1331
+ **kwargs
1332
+ )
1333
+
1334
+
1335
+ ##########################################################################
1336
+ # Auto-download decorator #
1337
+ ##########################################################################
1338
+
1339
+
1340
+ def _download_model_if_needed(
1341
+ model_path: str,
1342
+ param_name: str,
1343
+ progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
1344
+ token: Union[bool, str, None] = None,
1345
+ is_mmproj: bool = False,
1346
+ **kwargs
1347
+ ) -> tuple[str, Optional[str], Optional[str]]:
1348
+ """
1349
+ Helper function to download a model from HuggingFace if it doesn't exist locally.
1350
+
1351
+ Args:
1352
+ model_path: The model path that may be local or remote
1353
+ param_name: Name of the parameter (for error messages)
1354
+ progress_callback: Callback function for download progress updates
1355
+ token: HuggingFace authentication token for private repositories
1356
+
1357
+ Returns:
1358
+ tuple[str, Optional[str], Optional[str]]: Tuple of (local_path, model_name, plugin_id)
1359
+ - local_path: Local path to the model (either existing or downloaded)
1360
+ - model_name: ModelName from nexa.manifest if available, None otherwise
1361
+ - plugin_id: PluginId from nexa.manifest if available, None otherwise
1362
+
1363
+ Raises:
1364
+ RuntimeError: If download fails
1365
+ """
1366
+ # Helper function to extract model info from manifest
1367
+ def _extract_info_from_manifest(path: str) -> tuple[Optional[str], Optional[str], Optional[dict]]:
1368
+ """Extract ModelName, PluginId, and full manifest from nexa.manifest if it exists."""
1369
+ # If path is a file, check its parent directory for manifest
1370
+ if os.path.isfile(path):
1371
+ manifest_dir = os.path.dirname(path)
1372
+ else:
1373
+ manifest_dir = path
1374
+
1375
+ manifest_path = os.path.join(manifest_dir, 'nexa.manifest')
1376
+ if not os.path.exists(manifest_path):
1377
+ return None, None, None
1378
+
1379
+ try:
1380
+ with open(manifest_path, 'r', encoding='utf-8') as f:
1381
+ manifest = json.load(f)
1382
+ return manifest.get('ModelName'), manifest.get('PluginId'), manifest
1383
+ except (json.JSONDecodeError, IOError):
1384
+ return None, None, None
1385
+
1386
+ # Helper function to get a model file path from manifest
1387
+ # Note: Tnis is for NPU only, because when downloading, it is a directory; when passing local path to inference, it needs to be a file.
1388
+ def _get_model_file_from_manifest(manifest: dict, base_dir: str) -> Optional[str]:
1389
+ """Extract a model file path from manifest's ModelFile section."""
1390
+ if not manifest or 'ModelFile' not in manifest:
1391
+ return None
1392
+
1393
+ model_files = manifest['ModelFile']
1394
+ # Find the first valid model file (skip N/A entries and metadata files)
1395
+ for key, file_info in model_files.items():
1396
+ if key == 'N/A':
1397
+ continue
1398
+ if isinstance(file_info, dict) and 'Name' in file_info:
1399
+ file_name = file_info['Name']
1400
+ # Skip common non-model files
1401
+ if file_name and not file_name.startswith('.') and file_name.endswith('.nexa'):
1402
+ file_path = os.path.join(base_dir, file_name)
1403
+ if os.path.exists(file_path):
1404
+ return file_path
1405
+
1406
+ # If no .nexa files found, try ExtraFiles for .nexa files
1407
+ if 'ExtraFiles' in manifest:
1408
+ for file_info in manifest['ExtraFiles']:
1409
+ if isinstance(file_info, dict) and 'Name' in file_info:
1410
+ file_name = file_info['Name']
1411
+ if file_name and file_name.endswith('.nexa') and not file_name.startswith('.cache'):
1412
+ file_path = os.path.join(base_dir, file_name)
1413
+ if os.path.exists(file_path):
1414
+ return file_path
1415
+
1416
+ return None
1417
+
1418
+ # Check if model_path exists locally (file or directory)
1419
+ if os.path.exists(model_path):
1420
+ # Local path exists, try to extract model info
1421
+ model_name, plugin_id, manifest = _extract_info_from_manifest(model_path)
1422
+
1423
+ # If PluginId is "npu" and path is a directory, convert to file path
1424
+ if plugin_id == "npu" and os.path.isdir(model_path):
1425
+ model_file_path = _get_model_file_from_manifest(manifest, model_path)
1426
+ if model_file_path:
1427
+ model_path = model_file_path
1428
+
1429
+ return model_path, model_name, plugin_id
1430
+
1431
+ # Model path doesn't exist locally, try to download from HuggingFace
1432
+ try:
1433
+ # Parse model_path to extract repo_id and filename
1434
+ repo_id, file_name = _parse_model_path(model_path)
1435
+
1436
+ # Download the model
1437
+ downloaded_path = download_from_huggingface(
1438
+ repo_id=repo_id,
1439
+ file_name=file_name,
1440
+ local_dir=None, # Use default cache directory
1441
+ enable_transfer=True,
1442
+ progress_callback=progress_callback,
1443
+ show_progress=True,
1444
+ token=token,
1445
+ is_mmproj=is_mmproj,
1446
+ **kwargs
1447
+ )
1448
+
1449
+ # Extract model info from the downloaded manifest
1450
+ model_name, plugin_id, manifest = _extract_info_from_manifest(downloaded_path)
1451
+
1452
+ # If PluginId is "npu" and path is a directory, convert to file path
1453
+ if plugin_id == "npu" and os.path.isdir(downloaded_path):
1454
+ model_file_path = _get_model_file_from_manifest(manifest, downloaded_path)
1455
+ if model_file_path:
1456
+ downloaded_path = model_file_path
1457
+
1458
+ return downloaded_path, model_name, plugin_id
1459
+
1460
+ except Exception as e:
1461
+ # Only handle download-related errors
1462
+ raise RuntimeError(f"Could not load model from '{param_name}={model_path}': {e}")
1463
+
1464
+
1465
+ def auto_download_model(func: Callable) -> Callable:
1466
+ """
1467
+ Decorator that automatically downloads models from HuggingFace if they don't exist locally.
1468
+
1469
+ This decorator should be applied to __init__ methods that take a name_or_path parameter
1470
+ and optionally an mmproj_path parameter. If these paths don't exist as local files/directories,
1471
+ it will attempt to download them from HuggingFace Hub using the download_from_huggingface function.
1472
+
1473
+ The name_or_path and mmproj_path can be in formats like:
1474
+ - "microsoft/DialoGPT-small" (downloads entire repo)
1475
+ - "microsoft/DialoGPT-small/pytorch_model.bin" (downloads specific file)
1476
+ - "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf" (downloads specific file)
1477
+
1478
+ Optional kwargs that are extracted and passed to download_from_huggingface:
1479
+ - progress_callback: Callback function for download progress updates
1480
+ - token: HuggingFace authentication token for private repositories
1481
+
1482
+ Args:
1483
+ func: The __init__ method to wrap
1484
+
1485
+ Returns:
1486
+ Wrapped function that handles automatic model downloading
1487
+ """
1488
+ @functools.wraps(func)
1489
+ def wrapper(*args, **kwargs):
1490
+ # Extract progress_callback and token from arguments
1491
+ progress_callback = None
1492
+ if 'progress_callback' in kwargs:
1493
+ progress_callback = kwargs.pop('progress_callback') # Remove from kwargs to avoid passing to original func
1494
+
1495
+ token = None
1496
+ if 'token' in kwargs:
1497
+ token = kwargs.pop('token') # Remove from kwargs to avoid passing to original func
1498
+
1499
+ # Handle name_or_path parameter
1500
+ name_or_path = None
1501
+ name_path_index = None
1502
+ is_name_positional = False
1503
+
1504
+ # Find name_or_path in arguments
1505
+ # Assuming name_or_path is the first argument after self
1506
+ if len(args) >= 2:
1507
+ name_or_path = args[1]
1508
+ args_list = list(args)
1509
+ name_path_index = 1
1510
+ is_name_positional = True
1511
+ elif 'name_or_path' in kwargs:
1512
+ name_or_path = kwargs['name_or_path']
1513
+ is_name_positional = False
1514
+
1515
+ # Handle mmproj_path parameter
1516
+ mmproj_path = None
1517
+ if 'mmproj_path' in kwargs:
1518
+ mmproj_path = kwargs['mmproj_path']
1519
+
1520
+ # If neither parameter is found, call original function
1521
+ if name_or_path is None and mmproj_path is None:
1522
+ return func(*args, **kwargs)
1523
+
1524
+ # Download name_or_path if needed
1525
+ if name_or_path is not None:
1526
+ try:
1527
+ downloaded_name_path, model_name, plugin_id = _download_model_if_needed(
1528
+ name_or_path, 'name_or_path', progress_callback, token, **kwargs
1529
+ )
1530
+
1531
+ # Replace name_or_path with downloaded path
1532
+ if is_name_positional:
1533
+ if name_path_index is not None:
1534
+ args_list[name_path_index] = downloaded_name_path
1535
+ args = tuple(args_list)
1536
+ else:
1537
+ kwargs['name_or_path'] = downloaded_name_path
1538
+
1539
+ # Add model_name to kwargs if it exists and not already set
1540
+ if model_name is not None and 'model_name' not in kwargs:
1541
+ kwargs['model_name'] = model_name
1542
+
1543
+ except Exception as e:
1544
+ raise e # Re-raise the error from _download_model_if_needed
1545
+
1546
+ # Download mmproj_path if needed
1547
+ if mmproj_path is not None:
1548
+ try:
1549
+ downloaded_mmproj_path, _, _ = _download_model_if_needed(
1550
+ mmproj_path, 'mmproj_path', progress_callback, token, is_mmproj=True, **kwargs
1551
+ )
1552
+
1553
+ # Replace mmproj_path with downloaded path
1554
+ kwargs['mmproj_path'] = downloaded_mmproj_path
1555
+
1556
+ except Exception as e:
1557
+ raise e # Re-raise the error from _download_model_if_needed
1558
+
1559
+ # Call original function with updated paths (outside try-catch to let model creation errors bubble up)
1560
+ return func(*args, **kwargs)
1561
+
1562
+ return wrapper