nexaai 1.0.4rc13__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/__init__.py +71 -0
- nexaai/_stub.cp310-win_amd64.pyd +0 -0
- nexaai/_version.py +4 -0
- nexaai/asr.py +60 -0
- nexaai/asr_impl/__init__.py +0 -0
- nexaai/asr_impl/mlx_asr_impl.py +91 -0
- nexaai/asr_impl/pybind_asr_impl.py +43 -0
- nexaai/base.py +39 -0
- nexaai/binds/__init__.py +3 -0
- nexaai/binds/common_bind.cp310-win_amd64.pyd +0 -0
- nexaai/binds/embedder_bind.cp310-win_amd64.pyd +0 -0
- nexaai/binds/llm_bind.cp310-win_amd64.pyd +0 -0
- nexaai/binds/nexa_bridge.dll +0 -0
- nexaai/binds/nexa_llama_cpp/ggml-base.dll +0 -0
- nexaai/binds/nexa_llama_cpp/ggml-cpu.dll +0 -0
- nexaai/binds/nexa_llama_cpp/ggml-cuda.dll +0 -0
- nexaai/binds/nexa_llama_cpp/ggml-vulkan.dll +0 -0
- nexaai/binds/nexa_llama_cpp/ggml.dll +0 -0
- nexaai/binds/nexa_llama_cpp/llama.dll +0 -0
- nexaai/binds/nexa_llama_cpp/mtmd.dll +0 -0
- nexaai/binds/nexa_llama_cpp/nexa_plugin.dll +0 -0
- nexaai/common.py +61 -0
- nexaai/cv.py +87 -0
- nexaai/cv_impl/__init__.py +0 -0
- nexaai/cv_impl/mlx_cv_impl.py +88 -0
- nexaai/cv_impl/pybind_cv_impl.py +31 -0
- nexaai/embedder.py +68 -0
- nexaai/embedder_impl/__init__.py +0 -0
- nexaai/embedder_impl/mlx_embedder_impl.py +114 -0
- nexaai/embedder_impl/pybind_embedder_impl.py +91 -0
- nexaai/image_gen.py +136 -0
- nexaai/image_gen_impl/__init__.py +0 -0
- nexaai/image_gen_impl/mlx_image_gen_impl.py +291 -0
- nexaai/image_gen_impl/pybind_image_gen_impl.py +84 -0
- nexaai/llm.py +89 -0
- nexaai/llm_impl/__init__.py +0 -0
- nexaai/llm_impl/mlx_llm_impl.py +249 -0
- nexaai/llm_impl/pybind_llm_impl.py +207 -0
- nexaai/rerank.py +51 -0
- nexaai/rerank_impl/__init__.py +0 -0
- nexaai/rerank_impl/mlx_rerank_impl.py +91 -0
- nexaai/rerank_impl/pybind_rerank_impl.py +42 -0
- nexaai/runtime.py +64 -0
- nexaai/tts.py +70 -0
- nexaai/tts_impl/__init__.py +0 -0
- nexaai/tts_impl/mlx_tts_impl.py +93 -0
- nexaai/tts_impl/pybind_tts_impl.py +42 -0
- nexaai/utils/avatar_fetcher.py +104 -0
- nexaai/utils/decode.py +18 -0
- nexaai/utils/model_manager.py +1195 -0
- nexaai/utils/progress_tracker.py +372 -0
- nexaai/vlm.py +120 -0
- nexaai/vlm_impl/__init__.py +0 -0
- nexaai/vlm_impl/mlx_vlm_impl.py +205 -0
- nexaai/vlm_impl/pybind_vlm_impl.py +228 -0
- nexaai-1.0.4rc13.dist-info/METADATA +26 -0
- nexaai-1.0.4rc13.dist-info/RECORD +59 -0
- nexaai-1.0.4rc13.dist-info/WHEEL +5 -0
- nexaai-1.0.4rc13.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1195 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
import json
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Optional, Callable, Dict, Any, List, Union
|
|
7
|
+
import functools
|
|
8
|
+
from tqdm.auto import tqdm
|
|
9
|
+
from huggingface_hub import HfApi
|
|
10
|
+
from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError
|
|
11
|
+
|
|
12
|
+
from .progress_tracker import CustomProgressTqdm, DownloadProgressTracker
|
|
13
|
+
from .avatar_fetcher import get_avatar_url_for_repo
|
|
14
|
+
|
|
15
|
+
# Default path for model storage
|
|
16
|
+
DEFAULT_MODEL_SAVING_PATH = "~/.cache/nexa.ai/nexa_sdk/models/"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class DownloadedModel:
|
|
21
|
+
"""Data class representing a downloaded model with all its metadata."""
|
|
22
|
+
repo_id: str
|
|
23
|
+
files: List[str]
|
|
24
|
+
folder_type: str # 'owner_repo' or 'direct_repo'
|
|
25
|
+
local_path: str
|
|
26
|
+
size_bytes: int
|
|
27
|
+
file_count: int
|
|
28
|
+
full_repo_download_complete: bool = True # True if no incomplete downloads detected
|
|
29
|
+
pipeline_tag: Optional[str] = None # Pipeline tag from HuggingFace model info
|
|
30
|
+
download_time: Optional[str] = None # ISO format timestamp of download
|
|
31
|
+
avatar_url: Optional[str] = None # Avatar URL for the model author
|
|
32
|
+
|
|
33
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
34
|
+
"""Convert to dictionary format for backward compatibility."""
|
|
35
|
+
result = {
|
|
36
|
+
'repo_id': self.repo_id,
|
|
37
|
+
'files': self.files,
|
|
38
|
+
'folder_type': self.folder_type,
|
|
39
|
+
'local_path': self.local_path,
|
|
40
|
+
'size_bytes': self.size_bytes,
|
|
41
|
+
'file_count': self.file_count,
|
|
42
|
+
'full_repo_download_complete': self.full_repo_download_complete,
|
|
43
|
+
'pipeline_tag': self.pipeline_tag,
|
|
44
|
+
'download_time': self.download_time,
|
|
45
|
+
'avatar_url': self.avatar_url
|
|
46
|
+
}
|
|
47
|
+
return result
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
##########################################################################
|
|
51
|
+
# List downloaded models #
|
|
52
|
+
##########################################################################
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _check_for_incomplete_downloads(directory_path: str) -> bool:
|
|
56
|
+
"""
|
|
57
|
+
Check if there are incomplete downloads in the model directory.
|
|
58
|
+
|
|
59
|
+
This function checks for the presence of .incomplete or .lock files
|
|
60
|
+
in the .cache/huggingface/download directory within the model folder,
|
|
61
|
+
which indicates that the model download has not completed.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
directory_path: Path to the model directory
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
bool: True if download is complete (no incomplete files found),
|
|
68
|
+
False if incomplete downloads are detected
|
|
69
|
+
"""
|
|
70
|
+
# Check for .cache/huggingface/download directory
|
|
71
|
+
cache_dir = os.path.join(directory_path, '.cache', 'huggingface', 'download')
|
|
72
|
+
|
|
73
|
+
# If the cache directory doesn't exist, assume download is complete
|
|
74
|
+
if not os.path.exists(cache_dir):
|
|
75
|
+
return True
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
# Walk through the cache directory to find incomplete or lock files
|
|
79
|
+
for root, dirs, files in os.walk(cache_dir):
|
|
80
|
+
for filename in files:
|
|
81
|
+
# Check for .incomplete or .lock files
|
|
82
|
+
if filename.endswith('.incomplete'):
|
|
83
|
+
return False # Found incomplete download
|
|
84
|
+
|
|
85
|
+
# No incomplete files found
|
|
86
|
+
return True
|
|
87
|
+
except (OSError, IOError):
|
|
88
|
+
# If we can't access the directory, assume download is complete
|
|
89
|
+
return True
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _load_download_metadata(directory_path: str) -> Dict[str, Any]:
|
|
93
|
+
"""Load download metadata from download_metadata.json if it exists."""
|
|
94
|
+
metadata_path = os.path.join(directory_path, 'download_metadata.json')
|
|
95
|
+
if os.path.exists(metadata_path):
|
|
96
|
+
try:
|
|
97
|
+
with open(metadata_path, 'r', encoding='utf-8') as f:
|
|
98
|
+
return json.load(f)
|
|
99
|
+
except (json.JSONDecodeError, IOError):
|
|
100
|
+
pass
|
|
101
|
+
return {}
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _save_download_metadata(directory_path: str, metadata: Dict[str, Any]) -> None:
|
|
105
|
+
"""Save download metadata to download_metadata.json."""
|
|
106
|
+
metadata_path = os.path.join(directory_path, 'download_metadata.json')
|
|
107
|
+
try:
|
|
108
|
+
with open(metadata_path, 'w', encoding='utf-8') as f:
|
|
109
|
+
json.dump(metadata, f, indent=2)
|
|
110
|
+
except IOError:
|
|
111
|
+
# If we can't save metadata, don't fail the download
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _get_directory_size_and_files(directory_path: str) -> tuple[int, List[str]]:
|
|
116
|
+
"""Get total size and list of files in a directory."""
|
|
117
|
+
total_size = 0
|
|
118
|
+
files = []
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
for root, dirs, filenames in os.walk(directory_path):
|
|
122
|
+
for filename in filenames:
|
|
123
|
+
file_path = os.path.join(root, filename)
|
|
124
|
+
try:
|
|
125
|
+
file_size = os.path.getsize(file_path)
|
|
126
|
+
total_size += file_size
|
|
127
|
+
# Store relative path from the directory
|
|
128
|
+
rel_path = os.path.relpath(file_path, directory_path)
|
|
129
|
+
files.append(rel_path)
|
|
130
|
+
except (OSError, IOError):
|
|
131
|
+
# Skip files that can't be accessed
|
|
132
|
+
continue
|
|
133
|
+
except (OSError, IOError):
|
|
134
|
+
# Skip directories that can't be accessed
|
|
135
|
+
pass
|
|
136
|
+
|
|
137
|
+
return total_size, files
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _scan_for_repo_folders(base_path: str) -> List[DownloadedModel]:
|
|
141
|
+
"""Scan a directory for repository folders and return model information."""
|
|
142
|
+
models = []
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
if not os.path.exists(base_path):
|
|
146
|
+
return models
|
|
147
|
+
|
|
148
|
+
for item in os.listdir(base_path):
|
|
149
|
+
item_path = os.path.join(base_path, item)
|
|
150
|
+
|
|
151
|
+
# Skip non-directory items
|
|
152
|
+
if not os.path.isdir(item_path):
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
# Check if this might be an owner folder by looking for subdirectories
|
|
156
|
+
has_subdirs = False
|
|
157
|
+
direct_files = []
|
|
158
|
+
|
|
159
|
+
try:
|
|
160
|
+
for subitem in os.listdir(item_path):
|
|
161
|
+
subitem_path = os.path.join(item_path, subitem)
|
|
162
|
+
if os.path.isdir(subitem_path):
|
|
163
|
+
has_subdirs = True
|
|
164
|
+
# This looks like owner/repo structure
|
|
165
|
+
size_bytes, files = _get_directory_size_and_files(subitem_path)
|
|
166
|
+
if files: # Only include if there are files
|
|
167
|
+
# Check if the download is complete
|
|
168
|
+
download_complete = _check_for_incomplete_downloads(subitem_path)
|
|
169
|
+
# Load metadata if it exists
|
|
170
|
+
metadata = _load_download_metadata(subitem_path)
|
|
171
|
+
models.append(DownloadedModel(
|
|
172
|
+
repo_id=f"{item}/{subitem}",
|
|
173
|
+
files=files,
|
|
174
|
+
folder_type='owner_repo',
|
|
175
|
+
local_path=subitem_path,
|
|
176
|
+
size_bytes=size_bytes,
|
|
177
|
+
file_count=len(files),
|
|
178
|
+
full_repo_download_complete=download_complete,
|
|
179
|
+
pipeline_tag=metadata.get('pipeline_tag'),
|
|
180
|
+
download_time=metadata.get('download_time'),
|
|
181
|
+
avatar_url=metadata.get('avatar_url')
|
|
182
|
+
))
|
|
183
|
+
else:
|
|
184
|
+
direct_files.append(subitem)
|
|
185
|
+
except (OSError, IOError):
|
|
186
|
+
# Skip directories that can't be accessed
|
|
187
|
+
continue
|
|
188
|
+
|
|
189
|
+
# Direct repo folder (no owner structure)
|
|
190
|
+
if not has_subdirs and direct_files:
|
|
191
|
+
size_bytes, files = _get_directory_size_and_files(item_path)
|
|
192
|
+
if files: # Only include if there are files
|
|
193
|
+
# Check if the download is complete
|
|
194
|
+
download_complete = _check_for_incomplete_downloads(item_path)
|
|
195
|
+
# Load metadata if it exists
|
|
196
|
+
metadata = _load_download_metadata(item_path)
|
|
197
|
+
models.append(DownloadedModel(
|
|
198
|
+
repo_id=item,
|
|
199
|
+
files=files,
|
|
200
|
+
folder_type='direct_repo',
|
|
201
|
+
local_path=item_path,
|
|
202
|
+
size_bytes=size_bytes,
|
|
203
|
+
file_count=len(files),
|
|
204
|
+
full_repo_download_complete=download_complete,
|
|
205
|
+
pipeline_tag=metadata.get('pipeline_tag'),
|
|
206
|
+
download_time=metadata.get('download_time'),
|
|
207
|
+
avatar_url=metadata.get('avatar_url')
|
|
208
|
+
))
|
|
209
|
+
|
|
210
|
+
except (OSError, IOError):
|
|
211
|
+
# Skip if base path can't be accessed
|
|
212
|
+
pass
|
|
213
|
+
|
|
214
|
+
return models
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def list_downloaded_models(local_dir: Optional[str] = None) -> List[DownloadedModel]:
|
|
218
|
+
"""
|
|
219
|
+
List all downloaded models in the specified directory.
|
|
220
|
+
|
|
221
|
+
This function scans the local directory for downloaded models and returns
|
|
222
|
+
information about each repository including files, size, and folder structure.
|
|
223
|
+
|
|
224
|
+
It handles different folder naming conventions:
|
|
225
|
+
- Owner/repo structure (e.g., "microsoft/DialoGPT-small")
|
|
226
|
+
- Direct repo folders (repos without owner prefix)
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
local_dir (str, optional): Directory to scan for downloaded models.
|
|
230
|
+
If None, uses DEFAULT_MODEL_SAVING_PATH.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
List[DownloadedModel]: List of DownloadedModel objects with attributes:
|
|
234
|
+
- repo_id: str - Repository ID (e.g., "owner/repo")
|
|
235
|
+
- files: List[str] - List of relative file paths in the repository
|
|
236
|
+
- folder_type: str - 'owner_repo' or 'direct_repo'
|
|
237
|
+
- local_path: str - Full path to the model directory
|
|
238
|
+
- size_bytes: int - Total size of all files in bytes
|
|
239
|
+
- file_count: int - Number of files in the repository
|
|
240
|
+
- full_repo_download_complete: bool - True if no incomplete downloads detected,
|
|
241
|
+
False if .incomplete or .lock files exist
|
|
242
|
+
- pipeline_tag: Optional[str] - Pipeline tag from HuggingFace model info
|
|
243
|
+
- download_time: Optional[str] - ISO format timestamp when the model was downloaded
|
|
244
|
+
- avatar_url: Optional[str] - Avatar URL for the model author
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
# Set up local directory
|
|
248
|
+
if local_dir is None:
|
|
249
|
+
local_dir = os.path.expanduser(DEFAULT_MODEL_SAVING_PATH)
|
|
250
|
+
|
|
251
|
+
local_dir = os.path.abspath(local_dir)
|
|
252
|
+
|
|
253
|
+
if not os.path.exists(local_dir):
|
|
254
|
+
return []
|
|
255
|
+
|
|
256
|
+
# Scan for repository folders
|
|
257
|
+
models = _scan_for_repo_folders(local_dir)
|
|
258
|
+
|
|
259
|
+
# Sort by repo_id for consistent output
|
|
260
|
+
models.sort(key=lambda x: x.repo_id)
|
|
261
|
+
|
|
262
|
+
return models
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
##########################################################################
|
|
266
|
+
# Remove model functions #
|
|
267
|
+
##########################################################################
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def _parse_model_path(model_path: str) -> tuple[str, str | None]:
|
|
271
|
+
"""
|
|
272
|
+
Parse model_path to extract repo_id and optional filename.
|
|
273
|
+
|
|
274
|
+
Examples:
|
|
275
|
+
"microsoft/DialoGPT-small" -> ("microsoft/DialoGPT-small", None)
|
|
276
|
+
"microsoft/DialoGPT-small/pytorch_model.bin" -> ("microsoft/DialoGPT-small", "pytorch_model.bin")
|
|
277
|
+
"Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf" -> ("Qwen/Qwen3-4B-GGUF", "Qwen3-4B-Q4_K_M.gguf")
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
model_path: The model path string
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
Tuple of (repo_id, filename) where filename can be None
|
|
284
|
+
"""
|
|
285
|
+
parts = model_path.strip().split('/')
|
|
286
|
+
|
|
287
|
+
if len(parts) < 2:
|
|
288
|
+
# Invalid format, assume it's just a repo name without owner
|
|
289
|
+
return model_path, None
|
|
290
|
+
elif len(parts) == 2:
|
|
291
|
+
# Format: "owner/repo"
|
|
292
|
+
return model_path, None
|
|
293
|
+
else:
|
|
294
|
+
# Format: "owner/repo/file" or "owner/repo/subdir/file"
|
|
295
|
+
repo_id = f"{parts[0]}/{parts[1]}"
|
|
296
|
+
filename = '/'.join(parts[2:])
|
|
297
|
+
return repo_id, filename
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _validate_and_parse_input(model_path: str) -> tuple[str, Optional[str]]:
|
|
301
|
+
"""Validate input and parse model path."""
|
|
302
|
+
if not model_path or not isinstance(model_path, str) or not model_path.strip():
|
|
303
|
+
raise ValueError("model_path is required and must be a non-empty string")
|
|
304
|
+
|
|
305
|
+
model_path = model_path.strip()
|
|
306
|
+
return _parse_model_path(model_path)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def _find_target_model(repo_id: str, local_dir: str) -> DownloadedModel:
|
|
310
|
+
"""Find and validate the target model exists."""
|
|
311
|
+
downloaded_models = list_downloaded_models(local_dir)
|
|
312
|
+
|
|
313
|
+
for model in downloaded_models:
|
|
314
|
+
if model.repo_id == repo_id:
|
|
315
|
+
return model
|
|
316
|
+
|
|
317
|
+
available_repos = [model.repo_id for model in downloaded_models]
|
|
318
|
+
raise FileNotFoundError(
|
|
319
|
+
f"Repository '{repo_id}' not found in downloaded models. "
|
|
320
|
+
f"Available repositories: {available_repos}"
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def _clean_empty_owner_directory(target_model: DownloadedModel) -> None:
|
|
325
|
+
"""Remove empty owner directory if applicable."""
|
|
326
|
+
if target_model.folder_type != 'owner_repo':
|
|
327
|
+
return
|
|
328
|
+
|
|
329
|
+
parent_dir = os.path.dirname(target_model.local_path)
|
|
330
|
+
try:
|
|
331
|
+
if os.path.exists(parent_dir) and not os.listdir(parent_dir):
|
|
332
|
+
os.rmdir(parent_dir)
|
|
333
|
+
except OSError:
|
|
334
|
+
pass
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def _remove_specific_file(target_model: DownloadedModel, file_name: str, local_dir: str) -> DownloadedModel:
|
|
338
|
+
"""Remove a specific file from the repository."""
|
|
339
|
+
# Validate file exists in model
|
|
340
|
+
if file_name not in target_model.files:
|
|
341
|
+
raise FileNotFoundError(
|
|
342
|
+
f"File '{file_name}' not found in repository '{target_model.repo_id}'. "
|
|
343
|
+
f"Available files: {target_model.files[:10]}{'...' if len(target_model.files) > 10 else ''}"
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Construct full file path and validate it exists on disk
|
|
347
|
+
file_path = os.path.join(target_model.local_path, file_name)
|
|
348
|
+
if not os.path.exists(file_path):
|
|
349
|
+
raise FileNotFoundError(f"File does not exist on disk: {file_path}")
|
|
350
|
+
|
|
351
|
+
# Get file size before removal
|
|
352
|
+
try:
|
|
353
|
+
file_size = os.path.getsize(file_path)
|
|
354
|
+
except OSError:
|
|
355
|
+
file_size = 0
|
|
356
|
+
|
|
357
|
+
# Remove the file
|
|
358
|
+
try:
|
|
359
|
+
os.remove(file_path)
|
|
360
|
+
except OSError as e:
|
|
361
|
+
raise OSError(f"Failed to remove file '{file_path}': {e}")
|
|
362
|
+
|
|
363
|
+
# Create updated model object
|
|
364
|
+
updated_files = [f for f in target_model.files if f != file_name]
|
|
365
|
+
updated_size = target_model.size_bytes - file_size
|
|
366
|
+
# Re-check download completeness after file removal
|
|
367
|
+
download_complete = _check_for_incomplete_downloads(target_model.local_path)
|
|
368
|
+
updated_model = DownloadedModel(
|
|
369
|
+
repo_id=target_model.repo_id,
|
|
370
|
+
files=updated_files,
|
|
371
|
+
folder_type=target_model.folder_type,
|
|
372
|
+
local_path=target_model.local_path,
|
|
373
|
+
size_bytes=updated_size,
|
|
374
|
+
file_count=len(updated_files),
|
|
375
|
+
full_repo_download_complete=download_complete
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# If no files left, remove the entire directory
|
|
379
|
+
if len(updated_files) == 0:
|
|
380
|
+
try:
|
|
381
|
+
shutil.rmtree(target_model.local_path)
|
|
382
|
+
_clean_empty_owner_directory(target_model)
|
|
383
|
+
except OSError:
|
|
384
|
+
pass
|
|
385
|
+
|
|
386
|
+
return updated_model
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def _remove_entire_repository(target_model: DownloadedModel, local_dir: str) -> DownloadedModel:
|
|
390
|
+
"""Remove the entire repository and clean up."""
|
|
391
|
+
# Remove the directory and all its contents
|
|
392
|
+
try:
|
|
393
|
+
shutil.rmtree(target_model.local_path)
|
|
394
|
+
except OSError as e:
|
|
395
|
+
raise OSError(f"Failed to remove directory '{target_model.local_path}': {e}")
|
|
396
|
+
|
|
397
|
+
# Clean up associated resources
|
|
398
|
+
_clean_empty_owner_directory(target_model)
|
|
399
|
+
|
|
400
|
+
return target_model
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def remove_model_or_file(
|
|
404
|
+
model_path: str,
|
|
405
|
+
local_dir: Optional[str] = None
|
|
406
|
+
) -> DownloadedModel:
|
|
407
|
+
"""
|
|
408
|
+
Remove a downloaded model or specific file by repository ID or file path.
|
|
409
|
+
|
|
410
|
+
This function supports two modes:
|
|
411
|
+
1. Remove entire repository: "microsoft/DialoGPT-small"
|
|
412
|
+
2. Remove specific file: "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf"
|
|
413
|
+
|
|
414
|
+
For entire repository removal, it removes the directory and all files. For specific file removal, it only
|
|
415
|
+
removes that file and updates the repository metadata.
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
model_path (str): Required. Either:
|
|
419
|
+
- Repository ID (e.g., "microsoft/DialoGPT-small") - removes entire repo
|
|
420
|
+
- File path (e.g., "Qwen/Qwen3-4B-GGUF/model.gguf") - removes specific file
|
|
421
|
+
local_dir (str, optional): Directory to search for downloaded models.
|
|
422
|
+
If None, uses DEFAULT_MODEL_SAVING_PATH.
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
DownloadedModel: The model object representing what was removed from disk.
|
|
426
|
+
For file removal, returns updated model info after file removal.
|
|
427
|
+
|
|
428
|
+
Raises:
|
|
429
|
+
ValueError: If model_path is invalid (empty or None)
|
|
430
|
+
FileNotFoundError: If the repository or file is not found in downloaded models
|
|
431
|
+
OSError: If there's an error removing files from disk
|
|
432
|
+
"""
|
|
433
|
+
# Validate input and parse path
|
|
434
|
+
repo_id, file_name = _validate_and_parse_input(model_path)
|
|
435
|
+
|
|
436
|
+
# Set up local directory
|
|
437
|
+
if local_dir is None:
|
|
438
|
+
local_dir = os.path.expanduser(DEFAULT_MODEL_SAVING_PATH)
|
|
439
|
+
|
|
440
|
+
local_dir = os.path.abspath(local_dir)
|
|
441
|
+
|
|
442
|
+
if not os.path.exists(local_dir):
|
|
443
|
+
raise FileNotFoundError(f"Local directory does not exist: {local_dir}")
|
|
444
|
+
|
|
445
|
+
# Find the target model
|
|
446
|
+
target_model = _find_target_model(repo_id, local_dir)
|
|
447
|
+
|
|
448
|
+
# Delegate to appropriate removal function
|
|
449
|
+
if file_name:
|
|
450
|
+
return _remove_specific_file(target_model, file_name, local_dir)
|
|
451
|
+
else:
|
|
452
|
+
return _remove_entire_repository(target_model, local_dir)
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
##########################################################################
|
|
456
|
+
# Check model existence functions #
|
|
457
|
+
##########################################################################
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def check_model_existence(
|
|
461
|
+
model_path: str,
|
|
462
|
+
local_dir: Optional[str] = None
|
|
463
|
+
) -> bool:
|
|
464
|
+
"""
|
|
465
|
+
Check if a downloaded model or specific file exists locally.
|
|
466
|
+
|
|
467
|
+
This function supports two modes:
|
|
468
|
+
1. Check entire repository: "microsoft/DialoGPT-small"
|
|
469
|
+
2. Check specific file: "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf"
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
model_path (str): Required. Either:
|
|
473
|
+
- Repository ID (e.g., "microsoft/DialoGPT-small") - checks entire repo
|
|
474
|
+
- File path (e.g., "Qwen/Qwen3-4B-GGUF/model.gguf") - checks specific file
|
|
475
|
+
local_dir (str, optional): Directory to search for downloaded models.
|
|
476
|
+
If None, uses DEFAULT_MODEL_SAVING_PATH.
|
|
477
|
+
|
|
478
|
+
Returns:
|
|
479
|
+
bool: True if the requested item exists, False otherwise
|
|
480
|
+
|
|
481
|
+
Raises:
|
|
482
|
+
ValueError: If model_path is invalid (empty or None)
|
|
483
|
+
"""
|
|
484
|
+
# Validate input and parse path
|
|
485
|
+
repo_id, file_name = _validate_and_parse_input(model_path)
|
|
486
|
+
|
|
487
|
+
# Set up local directory
|
|
488
|
+
if local_dir is None:
|
|
489
|
+
local_dir = os.path.expanduser(DEFAULT_MODEL_SAVING_PATH)
|
|
490
|
+
|
|
491
|
+
local_dir = os.path.abspath(local_dir)
|
|
492
|
+
|
|
493
|
+
# Return False if local directory doesn't exist
|
|
494
|
+
if not os.path.exists(local_dir):
|
|
495
|
+
return False
|
|
496
|
+
|
|
497
|
+
# Get all downloaded models
|
|
498
|
+
downloaded_models = list_downloaded_models(local_dir)
|
|
499
|
+
|
|
500
|
+
# Find the target model
|
|
501
|
+
for model in downloaded_models:
|
|
502
|
+
if model.repo_id == repo_id:
|
|
503
|
+
# If no specific file requested, repository existence is sufficient
|
|
504
|
+
if file_name is None:
|
|
505
|
+
return True
|
|
506
|
+
else:
|
|
507
|
+
# Check specific file existence
|
|
508
|
+
return file_name in model.files
|
|
509
|
+
|
|
510
|
+
return False
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
##########################################################################
|
|
514
|
+
# HuggingFace Downloader Class #
|
|
515
|
+
##########################################################################
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
class HuggingFaceDownloader:
|
|
519
|
+
"""Class to handle downloads from HuggingFace Hub with unified API usage."""
|
|
520
|
+
|
|
521
|
+
def __init__(
|
|
522
|
+
self,
|
|
523
|
+
endpoint: Optional[str] = None,
|
|
524
|
+
token: Union[bool, str, None] = None,
|
|
525
|
+
enable_transfer: bool = True
|
|
526
|
+
):
|
|
527
|
+
"""
|
|
528
|
+
Initialize the downloader with HuggingFace API.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
endpoint: Custom endpoint URL (e.g., "https://hf-mirror.com").
|
|
532
|
+
If None, uses default HuggingFace Hub.
|
|
533
|
+
token: Authentication token for private repositories.
|
|
534
|
+
enable_transfer: Whether to enable HF transfer for faster downloads.
|
|
535
|
+
"""
|
|
536
|
+
# Always create an HfApi instance - either with custom endpoint or default
|
|
537
|
+
self.token = token if isinstance(token, str) else False # False means disable authentication
|
|
538
|
+
self.api = HfApi(endpoint=endpoint, token=self.token) if endpoint else HfApi(token=self.token)
|
|
539
|
+
self.enable_transfer = enable_transfer
|
|
540
|
+
self.original_hf_transfer = None
|
|
541
|
+
self.endpoint = endpoint # Store endpoint for avatar fetching
|
|
542
|
+
|
|
543
|
+
def _create_repo_directory(self, local_dir: str, repo_id: str) -> str:
|
|
544
|
+
"""Create a directory structure for the repository following HF convention."""
|
|
545
|
+
if '/' in repo_id:
|
|
546
|
+
# Standard format: owner/repo
|
|
547
|
+
owner, repo = repo_id.split('/', 1)
|
|
548
|
+
repo_dir = os.path.join(local_dir, owner, repo)
|
|
549
|
+
else:
|
|
550
|
+
# Direct repo name without owner
|
|
551
|
+
repo_dir = os.path.join(local_dir, repo_id)
|
|
552
|
+
|
|
553
|
+
os.makedirs(repo_dir, exist_ok=True)
|
|
554
|
+
return repo_dir
|
|
555
|
+
|
|
556
|
+
def _created_dir_if_not_exists(self, local_dir: Optional[str]) -> str:
|
|
557
|
+
"""Create directory if it doesn't exist and return the expanded path."""
|
|
558
|
+
if local_dir is None:
|
|
559
|
+
local_dir = DEFAULT_MODEL_SAVING_PATH
|
|
560
|
+
|
|
561
|
+
local_dir = os.path.expanduser(local_dir)
|
|
562
|
+
os.makedirs(local_dir, exist_ok=True)
|
|
563
|
+
return local_dir
|
|
564
|
+
|
|
565
|
+
def _get_repo_info_for_progress(
|
|
566
|
+
self,
|
|
567
|
+
repo_id: str,
|
|
568
|
+
file_name: Optional[Union[str, List[str]]] = None
|
|
569
|
+
) -> tuple[int, int]:
|
|
570
|
+
"""Get total repository size and file count for progress tracking."""
|
|
571
|
+
try:
|
|
572
|
+
info = self.api.model_info(repo_id, files_metadata=True, token=self.token)
|
|
573
|
+
|
|
574
|
+
total_size = 0
|
|
575
|
+
file_count = 0
|
|
576
|
+
|
|
577
|
+
if info.siblings:
|
|
578
|
+
for sibling in info.siblings:
|
|
579
|
+
# Handle different file_name types
|
|
580
|
+
if file_name is not None:
|
|
581
|
+
if isinstance(file_name, str):
|
|
582
|
+
# Single file - only count if it matches
|
|
583
|
+
if sibling.rfilename != file_name:
|
|
584
|
+
continue
|
|
585
|
+
elif isinstance(file_name, list):
|
|
586
|
+
# Multiple files - only count if in the list
|
|
587
|
+
if sibling.rfilename not in file_name:
|
|
588
|
+
continue
|
|
589
|
+
|
|
590
|
+
# For all matching files (or all files if file_name is None)
|
|
591
|
+
if hasattr(sibling, 'size') and sibling.size is not None:
|
|
592
|
+
total_size += sibling.size
|
|
593
|
+
file_count += 1
|
|
594
|
+
else:
|
|
595
|
+
# Count files without size info
|
|
596
|
+
file_count += 1
|
|
597
|
+
|
|
598
|
+
return total_size, file_count if file_count > 0 else 1
|
|
599
|
+
except Exception:
|
|
600
|
+
# If we can't get info, return defaults
|
|
601
|
+
return 0, 1
|
|
602
|
+
|
|
603
|
+
def _validate_and_setup_params(
|
|
604
|
+
self,
|
|
605
|
+
repo_id: str,
|
|
606
|
+
file_name: Optional[Union[str, List[str]]]
|
|
607
|
+
) -> tuple[str, Optional[Union[str, List[str]]]]:
|
|
608
|
+
"""Validate and normalize input parameters."""
|
|
609
|
+
if not repo_id:
|
|
610
|
+
raise ValueError("repo_id is required")
|
|
611
|
+
|
|
612
|
+
repo_id = repo_id.strip()
|
|
613
|
+
|
|
614
|
+
# Handle file_name parameter
|
|
615
|
+
if file_name is not None:
|
|
616
|
+
if isinstance(file_name, str):
|
|
617
|
+
file_name = file_name.strip()
|
|
618
|
+
if not file_name:
|
|
619
|
+
file_name = None
|
|
620
|
+
elif isinstance(file_name, list):
|
|
621
|
+
# Filter out empty strings and strip whitespace
|
|
622
|
+
file_name = [f.strip() for f in file_name if f and f.strip()]
|
|
623
|
+
if not file_name:
|
|
624
|
+
file_name = None
|
|
625
|
+
else:
|
|
626
|
+
raise ValueError("file_name must be a string, list of strings, or None")
|
|
627
|
+
|
|
628
|
+
return repo_id, file_name
|
|
629
|
+
|
|
630
|
+
def _setup_progress_tracker(
|
|
631
|
+
self,
|
|
632
|
+
progress_callback: Optional[Callable[[Dict[str, Any]], None]],
|
|
633
|
+
show_progress: bool,
|
|
634
|
+
repo_id: str,
|
|
635
|
+
file_name: Optional[Union[str, List[str]]]
|
|
636
|
+
) -> Optional[DownloadProgressTracker]:
|
|
637
|
+
"""Initialize progress tracker if callback is provided."""
|
|
638
|
+
if not progress_callback:
|
|
639
|
+
return None
|
|
640
|
+
|
|
641
|
+
progress_tracker = DownloadProgressTracker(progress_callback, show_progress)
|
|
642
|
+
# Get repo info for progress tracking - now handles all cases
|
|
643
|
+
total_size, file_count = self._get_repo_info_for_progress(repo_id, file_name)
|
|
644
|
+
progress_tracker.set_repo_info(total_size, file_count)
|
|
645
|
+
return progress_tracker
|
|
646
|
+
|
|
647
|
+
def _setup_hf_transfer_env(self) -> None:
|
|
648
|
+
"""Set up HF transfer environment."""
|
|
649
|
+
self.original_hf_transfer = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER")
|
|
650
|
+
if self.enable_transfer:
|
|
651
|
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
|
652
|
+
|
|
653
|
+
def _cleanup_hf_transfer_env(self) -> None:
|
|
654
|
+
"""Restore original HF transfer environment."""
|
|
655
|
+
if self.original_hf_transfer is not None:
|
|
656
|
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = self.original_hf_transfer
|
|
657
|
+
else:
|
|
658
|
+
os.environ.pop("HF_HUB_ENABLE_HF_TRANSFER", None)
|
|
659
|
+
|
|
660
|
+
def _validate_repository_and_get_info(
|
|
661
|
+
self,
|
|
662
|
+
repo_id: str,
|
|
663
|
+
progress_tracker: Optional[DownloadProgressTracker]
|
|
664
|
+
):
|
|
665
|
+
"""Validate repository exists and get info."""
|
|
666
|
+
try:
|
|
667
|
+
info = self.api.model_info(repo_id, token=self.token)
|
|
668
|
+
return info
|
|
669
|
+
except RepositoryNotFoundError:
|
|
670
|
+
error_msg = f"Repository '{repo_id}' not found. Please check the repository ID."
|
|
671
|
+
if progress_tracker:
|
|
672
|
+
progress_tracker.set_error(error_msg)
|
|
673
|
+
raise RepositoryNotFoundError(error_msg)
|
|
674
|
+
except HfHubHTTPError as e:
|
|
675
|
+
if e.response.status_code == 404:
|
|
676
|
+
error_msg = f"Repository '{repo_id}' not found. Please check the repository ID."
|
|
677
|
+
if progress_tracker:
|
|
678
|
+
progress_tracker.set_error(error_msg)
|
|
679
|
+
raise RepositoryNotFoundError(error_msg)
|
|
680
|
+
else:
|
|
681
|
+
error_msg = f"HTTP error while accessing repository '{repo_id}': {e}"
|
|
682
|
+
if progress_tracker:
|
|
683
|
+
progress_tracker.set_error(error_msg)
|
|
684
|
+
raise HfHubHTTPError(error_msg)
|
|
685
|
+
|
|
686
|
+
def _validate_file_exists_in_repo(
|
|
687
|
+
self,
|
|
688
|
+
file_name: str,
|
|
689
|
+
info,
|
|
690
|
+
repo_id: str,
|
|
691
|
+
progress_tracker: Optional[DownloadProgressTracker]
|
|
692
|
+
) -> None:
|
|
693
|
+
"""Validate that the file exists in the repository."""
|
|
694
|
+
file_exists = False
|
|
695
|
+
if info.siblings:
|
|
696
|
+
for sibling in info.siblings:
|
|
697
|
+
if sibling.rfilename == file_name:
|
|
698
|
+
file_exists = True
|
|
699
|
+
break
|
|
700
|
+
|
|
701
|
+
if not file_exists:
|
|
702
|
+
available_files = [sibling.rfilename for sibling in info.siblings] if info.siblings else []
|
|
703
|
+
error_msg = (
|
|
704
|
+
f"File '{file_name}' not found in repository '{repo_id}'. "
|
|
705
|
+
f"Available files: {available_files[:10]}{'...' if len(available_files) > 10 else ''}"
|
|
706
|
+
)
|
|
707
|
+
if progress_tracker:
|
|
708
|
+
progress_tracker.set_error(error_msg)
|
|
709
|
+
progress_tracker.stop_tracking()
|
|
710
|
+
raise ValueError(error_msg)
|
|
711
|
+
|
|
712
|
+
def _check_file_exists_and_valid(
|
|
713
|
+
self,
|
|
714
|
+
file_path: str,
|
|
715
|
+
expected_size: Optional[int] = None
|
|
716
|
+
) -> bool:
|
|
717
|
+
"""Check if a file exists and is valid (non-empty, correct size if known)."""
|
|
718
|
+
if not os.path.exists(file_path):
|
|
719
|
+
return False
|
|
720
|
+
|
|
721
|
+
# Check file is not empty
|
|
722
|
+
try:
|
|
723
|
+
file_size = os.path.getsize(file_path)
|
|
724
|
+
if file_size == 0:
|
|
725
|
+
return False
|
|
726
|
+
except (OSError, IOError):
|
|
727
|
+
return False
|
|
728
|
+
|
|
729
|
+
# If we have expected size, check it matches
|
|
730
|
+
if expected_size is not None and file_size != expected_size:
|
|
731
|
+
return False
|
|
732
|
+
|
|
733
|
+
# If no expected size, just check that file is not empty
|
|
734
|
+
return os.path.getsize(file_path) > 0
|
|
735
|
+
|
|
736
|
+
def _fetch_and_save_metadata(self, repo_id: str, local_dir: str) -> None:
|
|
737
|
+
"""Fetch model info and save metadata after successful download."""
|
|
738
|
+
try:
|
|
739
|
+
# Fetch model info to get pipeline_tag
|
|
740
|
+
info = self.api.model_info(repo_id, token=self.token)
|
|
741
|
+
pipeline_tag = info.pipeline_tag if hasattr(info, 'pipeline_tag') else None
|
|
742
|
+
|
|
743
|
+
# Get avatar URL
|
|
744
|
+
avatar_url = get_avatar_url_for_repo(repo_id, custom_endpoint=self.endpoint)
|
|
745
|
+
|
|
746
|
+
# Prepare metadata
|
|
747
|
+
metadata = {
|
|
748
|
+
'pipeline_tag': pipeline_tag,
|
|
749
|
+
'download_time': datetime.now().isoformat(),
|
|
750
|
+
'avatar_url': avatar_url
|
|
751
|
+
}
|
|
752
|
+
|
|
753
|
+
# Save metadata to the repository directory
|
|
754
|
+
_save_download_metadata(local_dir, metadata)
|
|
755
|
+
|
|
756
|
+
except Exception:
|
|
757
|
+
# Don't fail the download if metadata fetch fails
|
|
758
|
+
pass
|
|
759
|
+
|
|
760
|
+
def _download_single_file(
|
|
761
|
+
self,
|
|
762
|
+
repo_id: str,
|
|
763
|
+
file_name: str,
|
|
764
|
+
local_dir: str,
|
|
765
|
+
progress_tracker: Optional[DownloadProgressTracker],
|
|
766
|
+
force_download: bool = False
|
|
767
|
+
) -> str:
|
|
768
|
+
"""Download a single file from the repository using HuggingFace Hub API."""
|
|
769
|
+
# Create repo-specific directory for the single file
|
|
770
|
+
file_local_dir = self._create_repo_directory(local_dir, repo_id)
|
|
771
|
+
|
|
772
|
+
# Check if file already exists
|
|
773
|
+
local_file_path = os.path.join(file_local_dir, file_name)
|
|
774
|
+
if not force_download and self._check_file_exists_and_valid(local_file_path):
|
|
775
|
+
print(f"✓ File already exists, skipping: {file_name}")
|
|
776
|
+
# Stop progress tracking
|
|
777
|
+
if progress_tracker:
|
|
778
|
+
progress_tracker.stop_tracking()
|
|
779
|
+
return local_file_path
|
|
780
|
+
|
|
781
|
+
try:
|
|
782
|
+
# Note: hf_hub_download doesn't support tqdm_class parameter
|
|
783
|
+
# Progress tracking works through the global tqdm monkey patching
|
|
784
|
+
downloaded_path = self.api.hf_hub_download(
|
|
785
|
+
repo_id=repo_id,
|
|
786
|
+
filename=file_name,
|
|
787
|
+
local_dir=file_local_dir,
|
|
788
|
+
local_dir_use_symlinks=False,
|
|
789
|
+
token=self.token,
|
|
790
|
+
force_download=force_download
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
# Stop progress tracking
|
|
794
|
+
if progress_tracker:
|
|
795
|
+
progress_tracker.stop_tracking()
|
|
796
|
+
|
|
797
|
+
# Save metadata after successful download
|
|
798
|
+
self._fetch_and_save_metadata(repo_id, file_local_dir)
|
|
799
|
+
|
|
800
|
+
return downloaded_path
|
|
801
|
+
|
|
802
|
+
except HfHubHTTPError as e:
|
|
803
|
+
error_msg = f"Error downloading file '{file_name}': {e}"
|
|
804
|
+
if progress_tracker:
|
|
805
|
+
progress_tracker.set_error(error_msg)
|
|
806
|
+
progress_tracker.stop_tracking()
|
|
807
|
+
if e.response.status_code == 404:
|
|
808
|
+
raise ValueError(f"File '{file_name}' not found in repository '{repo_id}'")
|
|
809
|
+
else:
|
|
810
|
+
raise HfHubHTTPError(error_msg)
|
|
811
|
+
|
|
812
|
+
def _download_entire_repository(
|
|
813
|
+
self,
|
|
814
|
+
repo_id: str,
|
|
815
|
+
local_dir: str,
|
|
816
|
+
progress_tracker: Optional[DownloadProgressTracker],
|
|
817
|
+
force_download: bool = False
|
|
818
|
+
) -> str:
|
|
819
|
+
"""Download the entire repository."""
|
|
820
|
+
# Create a subdirectory for this specific repo
|
|
821
|
+
repo_local_dir = self._create_repo_directory(local_dir, repo_id)
|
|
822
|
+
|
|
823
|
+
# Check if repository already exists (basic check for directory existence)
|
|
824
|
+
if not force_download and os.path.exists(repo_local_dir) and os.listdir(repo_local_dir):
|
|
825
|
+
print(f"✓ Repository already exists, skipping: {repo_id}")
|
|
826
|
+
# Stop progress tracking
|
|
827
|
+
if progress_tracker:
|
|
828
|
+
progress_tracker.stop_tracking()
|
|
829
|
+
return repo_local_dir
|
|
830
|
+
|
|
831
|
+
try:
|
|
832
|
+
download_kwargs = {
|
|
833
|
+
'repo_id': repo_id,
|
|
834
|
+
'local_dir': repo_local_dir,
|
|
835
|
+
'local_dir_use_symlinks': False,
|
|
836
|
+
'token': self.token,
|
|
837
|
+
'force_download': force_download
|
|
838
|
+
}
|
|
839
|
+
|
|
840
|
+
# Add tqdm_class if progress tracking is enabled
|
|
841
|
+
if progress_tracker:
|
|
842
|
+
download_kwargs['tqdm_class'] = CustomProgressTqdm
|
|
843
|
+
|
|
844
|
+
downloaded_path = self.api.snapshot_download(**download_kwargs)
|
|
845
|
+
|
|
846
|
+
# Stop progress tracking
|
|
847
|
+
if progress_tracker:
|
|
848
|
+
progress_tracker.stop_tracking()
|
|
849
|
+
|
|
850
|
+
# Save metadata after successful download
|
|
851
|
+
self._fetch_and_save_metadata(repo_id, repo_local_dir)
|
|
852
|
+
|
|
853
|
+
return downloaded_path
|
|
854
|
+
|
|
855
|
+
except HfHubHTTPError as e:
|
|
856
|
+
error_msg = f"Error downloading repository '{repo_id}': {e}"
|
|
857
|
+
if progress_tracker:
|
|
858
|
+
progress_tracker.set_error(error_msg)
|
|
859
|
+
progress_tracker.stop_tracking()
|
|
860
|
+
raise HfHubHTTPError(error_msg)
|
|
861
|
+
|
|
862
|
+
def _download_multiple_files_from_hf(
|
|
863
|
+
self,
|
|
864
|
+
repo_id: str,
|
|
865
|
+
file_names: List[str],
|
|
866
|
+
local_dir: str,
|
|
867
|
+
progress_tracker: Optional[DownloadProgressTracker],
|
|
868
|
+
force_download: bool = False
|
|
869
|
+
) -> str:
|
|
870
|
+
"""Download multiple specific files from HuggingFace Hub."""
|
|
871
|
+
# Create repo-specific directory
|
|
872
|
+
repo_local_dir = self._create_repo_directory(local_dir, repo_id)
|
|
873
|
+
|
|
874
|
+
# Create overall progress bar for multiple files
|
|
875
|
+
overall_progress = tqdm(
|
|
876
|
+
total=len(file_names),
|
|
877
|
+
unit='file',
|
|
878
|
+
desc=f"Downloading {len(file_names)} files from {repo_id}",
|
|
879
|
+
position=0,
|
|
880
|
+
leave=True
|
|
881
|
+
)
|
|
882
|
+
|
|
883
|
+
try:
|
|
884
|
+
for file_name in file_names:
|
|
885
|
+
overall_progress.set_postfix_str(f"Current: {os.path.basename(file_name)}")
|
|
886
|
+
|
|
887
|
+
# Check if file already exists
|
|
888
|
+
local_file_path = os.path.join(repo_local_dir, file_name)
|
|
889
|
+
if not force_download and self._check_file_exists_and_valid(local_file_path):
|
|
890
|
+
print(f"✓ File already exists, skipping: {file_name}")
|
|
891
|
+
overall_progress.update(1)
|
|
892
|
+
continue
|
|
893
|
+
|
|
894
|
+
# Download each file using hf_hub_download
|
|
895
|
+
self.api.hf_hub_download(
|
|
896
|
+
repo_id=repo_id,
|
|
897
|
+
filename=file_name,
|
|
898
|
+
local_dir=repo_local_dir,
|
|
899
|
+
local_dir_use_symlinks=False,
|
|
900
|
+
token=self.token,
|
|
901
|
+
force_download=force_download
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
overall_progress.update(1)
|
|
905
|
+
|
|
906
|
+
overall_progress.close()
|
|
907
|
+
|
|
908
|
+
# Stop progress tracking
|
|
909
|
+
if progress_tracker:
|
|
910
|
+
progress_tracker.stop_tracking()
|
|
911
|
+
|
|
912
|
+
# Save metadata after successful download
|
|
913
|
+
self._fetch_and_save_metadata(repo_id, repo_local_dir)
|
|
914
|
+
|
|
915
|
+
return repo_local_dir
|
|
916
|
+
|
|
917
|
+
except HfHubHTTPError as e:
|
|
918
|
+
overall_progress.close()
|
|
919
|
+
error_msg = f"Error downloading files from '{repo_id}': {e}"
|
|
920
|
+
if progress_tracker:
|
|
921
|
+
progress_tracker.set_error(error_msg)
|
|
922
|
+
progress_tracker.stop_tracking()
|
|
923
|
+
raise HfHubHTTPError(error_msg)
|
|
924
|
+
except Exception as e:
|
|
925
|
+
overall_progress.close()
|
|
926
|
+
if progress_tracker:
|
|
927
|
+
progress_tracker.set_error(str(e))
|
|
928
|
+
progress_tracker.stop_tracking()
|
|
929
|
+
raise
|
|
930
|
+
|
|
931
|
+
def download(
|
|
932
|
+
self,
|
|
933
|
+
repo_id: str,
|
|
934
|
+
file_name: Optional[Union[str, List[str]]] = None,
|
|
935
|
+
local_dir: Optional[str] = None,
|
|
936
|
+
progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
|
|
937
|
+
show_progress: bool = True,
|
|
938
|
+
force_download: bool = False
|
|
939
|
+
) -> str:
|
|
940
|
+
"""
|
|
941
|
+
Main download method that handles all download scenarios.
|
|
942
|
+
|
|
943
|
+
Args:
|
|
944
|
+
repo_id: Repository ID to download from
|
|
945
|
+
file_name: Optional file name(s) to download
|
|
946
|
+
local_dir: Local directory to save files
|
|
947
|
+
progress_callback: Callback for progress updates
|
|
948
|
+
show_progress: Whether to show progress bar
|
|
949
|
+
force_download: Force re-download even if files exist
|
|
950
|
+
|
|
951
|
+
Returns:
|
|
952
|
+
Path to downloaded file or directory
|
|
953
|
+
"""
|
|
954
|
+
# Validate and normalize parameters
|
|
955
|
+
repo_id, file_name = self._validate_and_setup_params(repo_id, file_name)
|
|
956
|
+
|
|
957
|
+
# Set up local directory
|
|
958
|
+
local_dir = self._created_dir_if_not_exists(local_dir)
|
|
959
|
+
|
|
960
|
+
# Set up progress tracker
|
|
961
|
+
file_name_for_progress = file_name if isinstance(file_name, str) else None
|
|
962
|
+
progress_tracker = self._setup_progress_tracker(
|
|
963
|
+
progress_callback, show_progress, repo_id, file_name_for_progress
|
|
964
|
+
)
|
|
965
|
+
|
|
966
|
+
# Set up HF transfer environment
|
|
967
|
+
self._setup_hf_transfer_env()
|
|
968
|
+
|
|
969
|
+
try:
|
|
970
|
+
# Validate repository and get info
|
|
971
|
+
info = self._validate_repository_and_get_info(repo_id, progress_tracker)
|
|
972
|
+
|
|
973
|
+
# Start progress tracking
|
|
974
|
+
if progress_tracker:
|
|
975
|
+
progress_tracker.start_tracking()
|
|
976
|
+
|
|
977
|
+
# Choose download strategy based on file_name
|
|
978
|
+
if file_name is None:
|
|
979
|
+
# Download entire repository
|
|
980
|
+
return self._download_entire_repository(
|
|
981
|
+
repo_id, local_dir, progress_tracker, force_download
|
|
982
|
+
)
|
|
983
|
+
elif isinstance(file_name, str):
|
|
984
|
+
# Download specific single file
|
|
985
|
+
self._validate_file_exists_in_repo(file_name, info, repo_id, progress_tracker)
|
|
986
|
+
return self._download_single_file(
|
|
987
|
+
repo_id, file_name, local_dir, progress_tracker, force_download
|
|
988
|
+
)
|
|
989
|
+
else: # file_name is a list
|
|
990
|
+
# Download multiple specific files
|
|
991
|
+
# Validate all files exist
|
|
992
|
+
for fname in file_name:
|
|
993
|
+
self._validate_file_exists_in_repo(fname, info, repo_id, progress_tracker)
|
|
994
|
+
|
|
995
|
+
return self._download_multiple_files_from_hf(
|
|
996
|
+
repo_id, file_name, local_dir, progress_tracker, force_download
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
except Exception as e:
|
|
1000
|
+
# Handle any unexpected errors
|
|
1001
|
+
if progress_tracker and progress_tracker.download_status != "error":
|
|
1002
|
+
progress_tracker.set_error(str(e))
|
|
1003
|
+
progress_tracker.stop_tracking()
|
|
1004
|
+
raise
|
|
1005
|
+
|
|
1006
|
+
finally:
|
|
1007
|
+
# Restore original HF transfer setting
|
|
1008
|
+
self._cleanup_hf_transfer_env()
|
|
1009
|
+
|
|
1010
|
+
|
|
1011
|
+
##########################################################################
|
|
1012
|
+
# Public Download Function #
|
|
1013
|
+
##########################################################################
|
|
1014
|
+
|
|
1015
|
+
|
|
1016
|
+
def download_from_huggingface(
|
|
1017
|
+
repo_id: str,
|
|
1018
|
+
file_name: Optional[Union[str, List[str]]] = None,
|
|
1019
|
+
local_dir: Optional[str] = None,
|
|
1020
|
+
enable_transfer: bool = True,
|
|
1021
|
+
progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
|
|
1022
|
+
show_progress: bool = True,
|
|
1023
|
+
token: Union[bool, str, None] = None,
|
|
1024
|
+
custom_endpoint: Optional[str] = None,
|
|
1025
|
+
force_download: bool = False
|
|
1026
|
+
) -> str:
|
|
1027
|
+
"""
|
|
1028
|
+
Download models or files from HuggingFace Hub or custom mirror endpoints.
|
|
1029
|
+
|
|
1030
|
+
Args:
|
|
1031
|
+
repo_id (str): Required. The repository ID to download from (e.g., "microsoft/DialoGPT-medium")
|
|
1032
|
+
file_name (Union[str, List[str]], optional): Single filename or list of filenames to download.
|
|
1033
|
+
If None, downloads entire repo.
|
|
1034
|
+
local_dir (str, optional): Local directory to save files. If None, uses DEFAULT_MODEL_SAVING_PATH.
|
|
1035
|
+
enable_transfer (bool, optional): Whether to enable HF transfer for faster downloads. Default True.
|
|
1036
|
+
progress_callback (Callable, optional): Callback function to receive progress updates.
|
|
1037
|
+
Function receives a dict with progress information.
|
|
1038
|
+
show_progress (bool, optional): Whether to show a unified progress bar in the terminal. Default True.
|
|
1039
|
+
Only works when progress_callback is provided.
|
|
1040
|
+
token (Union[bool, str, None], optional): A token to be used for the download.
|
|
1041
|
+
- If True, the token is read from the HuggingFace config folder.
|
|
1042
|
+
- If a string, it's used as the authentication token.
|
|
1043
|
+
- If None, uses default behavior.
|
|
1044
|
+
custom_endpoint (str, optional): A custom HuggingFace-compatible endpoint URL.
|
|
1045
|
+
Should be ONLY the base endpoint without any paths.
|
|
1046
|
+
Examples:
|
|
1047
|
+
- "https://hf-mirror.com"
|
|
1048
|
+
- "https://huggingface.co" (default)
|
|
1049
|
+
The endpoint will be used to initialize HfApi for all downloads.
|
|
1050
|
+
force_download (bool, optional): If True, download files even if they already exist locally.
|
|
1051
|
+
Default False (skip existing files).
|
|
1052
|
+
|
|
1053
|
+
Returns:
|
|
1054
|
+
str: Path to the downloaded file or directory
|
|
1055
|
+
|
|
1056
|
+
Raises:
|
|
1057
|
+
ValueError: If repo_id is invalid or file_name doesn't exist in the repo
|
|
1058
|
+
RepositoryNotFoundError: If the repository doesn't exist
|
|
1059
|
+
HfHubHTTPError: If there's an HTTP error during download
|
|
1060
|
+
|
|
1061
|
+
Progress Callback Data Format:
|
|
1062
|
+
{
|
|
1063
|
+
'status': str, # 'idle', 'downloading', 'completed', 'error'
|
|
1064
|
+
'error_message': str, # Only present if status is 'error'
|
|
1065
|
+
'progress': {
|
|
1066
|
+
'total_downloaded': int, # Bytes downloaded
|
|
1067
|
+
'total_size': int, # Total bytes to download
|
|
1068
|
+
'percentage': float, # Progress percentage (0-100)
|
|
1069
|
+
'files_active': int, # Number of files currently downloading
|
|
1070
|
+
'files_total': int, # Total number of files
|
|
1071
|
+
'known_total': bool # Whether total size is known
|
|
1072
|
+
},
|
|
1073
|
+
'speed': {
|
|
1074
|
+
'bytes_per_second': float, # Download speed in bytes/sec
|
|
1075
|
+
'formatted': str # Human readable speed (e.g., "1.2 MB/s")
|
|
1076
|
+
},
|
|
1077
|
+
'formatting': {
|
|
1078
|
+
'downloaded': str, # Human readable downloaded size
|
|
1079
|
+
'total_size': str # Human readable total size
|
|
1080
|
+
},
|
|
1081
|
+
'timing': {
|
|
1082
|
+
'elapsed_seconds': float, # Time since download started
|
|
1083
|
+
'eta_seconds': float, # Estimated time remaining
|
|
1084
|
+
'start_time': float # Download start timestamp
|
|
1085
|
+
}
|
|
1086
|
+
}
|
|
1087
|
+
"""
|
|
1088
|
+
# Create downloader instance with custom endpoint if provided
|
|
1089
|
+
downloader = HuggingFaceDownloader(
|
|
1090
|
+
endpoint=custom_endpoint,
|
|
1091
|
+
token=token,
|
|
1092
|
+
enable_transfer=enable_transfer
|
|
1093
|
+
)
|
|
1094
|
+
|
|
1095
|
+
# Use the downloader to perform the download
|
|
1096
|
+
return downloader.download(
|
|
1097
|
+
repo_id=repo_id,
|
|
1098
|
+
file_name=file_name,
|
|
1099
|
+
local_dir=local_dir,
|
|
1100
|
+
progress_callback=progress_callback,
|
|
1101
|
+
show_progress=show_progress,
|
|
1102
|
+
force_download=force_download
|
|
1103
|
+
)
|
|
1104
|
+
|
|
1105
|
+
|
|
1106
|
+
##########################################################################
|
|
1107
|
+
# Auto-download decorator #
|
|
1108
|
+
##########################################################################
|
|
1109
|
+
|
|
1110
|
+
|
|
1111
|
+
def auto_download_model(func: Callable) -> Callable:
|
|
1112
|
+
"""
|
|
1113
|
+
Decorator that automatically downloads models from HuggingFace if they don't exist locally.
|
|
1114
|
+
|
|
1115
|
+
This decorator should be applied to __init__ methods that take a name_or_path parameter.
|
|
1116
|
+
If name_or_path doesn't exist as a local file/directory, it will attempt to download
|
|
1117
|
+
it from HuggingFace Hub using the download_from_huggingface function.
|
|
1118
|
+
|
|
1119
|
+
The name_or_path can be in formats like:
|
|
1120
|
+
- "microsoft/DialoGPT-small" (downloads entire repo)
|
|
1121
|
+
- "microsoft/DialoGPT-small/pytorch_model.bin" (downloads specific file)
|
|
1122
|
+
- "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf" (downloads specific file)
|
|
1123
|
+
|
|
1124
|
+
Optional kwargs that are extracted and passed to download_from_huggingface:
|
|
1125
|
+
- progress_callback: Callback function for download progress updates
|
|
1126
|
+
- token: HuggingFace authentication token for private repositories
|
|
1127
|
+
|
|
1128
|
+
Args:
|
|
1129
|
+
func: The __init__ method to wrap
|
|
1130
|
+
|
|
1131
|
+
Returns:
|
|
1132
|
+
Wrapped function that handles automatic model downloading
|
|
1133
|
+
"""
|
|
1134
|
+
@functools.wraps(func)
|
|
1135
|
+
def wrapper(*args, **kwargs):
|
|
1136
|
+
# Find name_or_path in arguments
|
|
1137
|
+
# Assuming name_or_path is the first argument after self
|
|
1138
|
+
if len(args) >= 2:
|
|
1139
|
+
name_or_path = args[1]
|
|
1140
|
+
args_list = list(args)
|
|
1141
|
+
path_index = 1
|
|
1142
|
+
is_positional = True
|
|
1143
|
+
elif 'name_or_path' in kwargs:
|
|
1144
|
+
name_or_path = kwargs['name_or_path']
|
|
1145
|
+
path_index = None
|
|
1146
|
+
is_positional = False
|
|
1147
|
+
else:
|
|
1148
|
+
# No name_or_path found, call original function
|
|
1149
|
+
return func(*args, **kwargs)
|
|
1150
|
+
|
|
1151
|
+
# Extract progress_callback and token from arguments
|
|
1152
|
+
progress_callback = None
|
|
1153
|
+
if 'progress_callback' in kwargs:
|
|
1154
|
+
progress_callback = kwargs.pop('progress_callback') # Remove from kwargs to avoid passing to original func
|
|
1155
|
+
|
|
1156
|
+
token = None
|
|
1157
|
+
if 'token' in kwargs:
|
|
1158
|
+
token = kwargs.pop('token') # Remove from kwargs to avoid passing to original func
|
|
1159
|
+
|
|
1160
|
+
# Check if name_or_path exists locally (file or directory)
|
|
1161
|
+
if os.path.exists(name_or_path):
|
|
1162
|
+
# Local path exists, use as-is without downloading
|
|
1163
|
+
return func(*args, **kwargs)
|
|
1164
|
+
|
|
1165
|
+
# Model path doesn't exist locally, try to download from HuggingFace
|
|
1166
|
+
try:
|
|
1167
|
+
# Parse name_or_path to extract repo_id and filename
|
|
1168
|
+
repo_id, file_name = _parse_model_path(name_or_path)
|
|
1169
|
+
|
|
1170
|
+
# Download the model
|
|
1171
|
+
downloaded_path = download_from_huggingface(
|
|
1172
|
+
repo_id=repo_id,
|
|
1173
|
+
file_name=file_name,
|
|
1174
|
+
local_dir=None, # Use default cache directory
|
|
1175
|
+
enable_transfer=True,
|
|
1176
|
+
progress_callback=progress_callback, # Use the extracted callback
|
|
1177
|
+
show_progress=True,
|
|
1178
|
+
token=token # Use the extracted token
|
|
1179
|
+
)
|
|
1180
|
+
|
|
1181
|
+
# Replace name_or_path with downloaded path
|
|
1182
|
+
if is_positional:
|
|
1183
|
+
args_list[path_index] = downloaded_path
|
|
1184
|
+
args = tuple(args_list)
|
|
1185
|
+
else:
|
|
1186
|
+
kwargs['name_or_path'] = downloaded_path
|
|
1187
|
+
|
|
1188
|
+
except Exception as e:
|
|
1189
|
+
# Only handle download-related errors
|
|
1190
|
+
raise RuntimeError(f"Could not load model from '{name_or_path}': {e}")
|
|
1191
|
+
|
|
1192
|
+
# Call original function with updated path (outside try-catch to let model creation errors bubble up)
|
|
1193
|
+
return func(*args, **kwargs)
|
|
1194
|
+
|
|
1195
|
+
return wrapper
|