xiaoshiai-hub 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xiaoshiai_hub/__init__.py +83 -0
- xiaoshiai_hub/client.py +360 -0
- xiaoshiai_hub/download.py +598 -0
- xiaoshiai_hub/encryption.py +777 -0
- xiaoshiai_hub/exceptions.py +37 -0
- xiaoshiai_hub/types.py +109 -0
- xiaoshiai_hub/upload.py +875 -0
- xiaoshiai_hub-0.1.1.dist-info/METADATA +560 -0
- xiaoshiai_hub-0.1.1.dist-info/RECORD +12 -0
- xiaoshiai_hub-0.1.1.dist-info/WHEEL +5 -0
- xiaoshiai_hub-0.1.1.dist-info/licenses/LICENSE +56 -0
- xiaoshiai_hub-0.1.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,598 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Download utilities for XiaoShi AI Hub SDK
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import fnmatch
|
|
6
|
+
import os
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import List, Optional, Union
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from tqdm.auto import tqdm
|
|
12
|
+
except ImportError:
|
|
13
|
+
tqdm = None
|
|
14
|
+
|
|
15
|
+
from .client import HubClient, DEFAULT_BASE_URL
|
|
16
|
+
from .types import GitContent
|
|
17
|
+
from .exceptions import EncryptionError, RepositoryNotFoundError
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _should_decrypt_file(filename: str, decryption_exclude: Optional[List[str]] = None) -> bool:
|
|
21
|
+
"""
|
|
22
|
+
Check if a file should be decrypted based on exclude patterns.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
filename: Name of the file
|
|
26
|
+
decryption_exclude: List of patterns to exclude from decryption
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
True if the file should be decrypted
|
|
30
|
+
"""
|
|
31
|
+
import fnmatch
|
|
32
|
+
|
|
33
|
+
if not decryption_exclude:
|
|
34
|
+
return True
|
|
35
|
+
|
|
36
|
+
for pattern in decryption_exclude:
|
|
37
|
+
if fnmatch.fnmatch(filename, pattern):
|
|
38
|
+
return False
|
|
39
|
+
|
|
40
|
+
return True
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _match_pattern(name: str, pattern: str) -> bool:
|
|
44
|
+
"""
|
|
45
|
+
Match a filename against a pattern.
|
|
46
|
+
|
|
47
|
+
Supports wildcards:
|
|
48
|
+
- * matches any characters
|
|
49
|
+
- *.ext matches files with extension
|
|
50
|
+
- prefix* matches files starting with prefix
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
name: Filename to match
|
|
54
|
+
pattern: Pattern to match against
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
True if the name matches the pattern
|
|
58
|
+
"""
|
|
59
|
+
return fnmatch.fnmatch(name, pattern)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _should_download_file(
|
|
63
|
+
file_path: str,
|
|
64
|
+
allow_patterns: Optional[List[str]] = None,
|
|
65
|
+
ignore_patterns: Optional[List[str]] = None,
|
|
66
|
+
) -> bool:
|
|
67
|
+
"""
|
|
68
|
+
Determine if a file should be downloaded based on patterns.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
file_path: Path of the file
|
|
72
|
+
allow_patterns: List of patterns to allow (if None, allow all)
|
|
73
|
+
ignore_patterns: List of patterns to ignore
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
True if the file should be downloaded
|
|
77
|
+
"""
|
|
78
|
+
filename = os.path.basename(file_path)
|
|
79
|
+
|
|
80
|
+
# Check ignore patterns first
|
|
81
|
+
if ignore_patterns:
|
|
82
|
+
for pattern in ignore_patterns:
|
|
83
|
+
if _match_pattern(filename, pattern) or _match_pattern(file_path, pattern):
|
|
84
|
+
return False
|
|
85
|
+
|
|
86
|
+
# If no allow patterns, allow all (except ignored)
|
|
87
|
+
if not allow_patterns:
|
|
88
|
+
return True
|
|
89
|
+
|
|
90
|
+
# Check allow patterns
|
|
91
|
+
for pattern in allow_patterns:
|
|
92
|
+
if _match_pattern(filename, pattern) or _match_pattern(file_path, pattern):
|
|
93
|
+
return True
|
|
94
|
+
|
|
95
|
+
return False
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _count_files_to_download(
|
|
99
|
+
client: HubClient,
|
|
100
|
+
organization: str,
|
|
101
|
+
repo_type: str,
|
|
102
|
+
repo_name: str,
|
|
103
|
+
branch: str,
|
|
104
|
+
path: str,
|
|
105
|
+
allow_patterns: Optional[List[str]] = None,
|
|
106
|
+
ignore_patterns: Optional[List[str]] = None,
|
|
107
|
+
) -> int:
|
|
108
|
+
"""
|
|
109
|
+
Count total number of files to download.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
client: Hub client instance
|
|
113
|
+
organization: Organization name
|
|
114
|
+
repo_type: Repository type
|
|
115
|
+
repo_name: Repository name
|
|
116
|
+
branch: Branch name
|
|
117
|
+
path: Current path in the repository
|
|
118
|
+
allow_patterns: Patterns to allow
|
|
119
|
+
ignore_patterns: Patterns to ignore
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Total number of files to download
|
|
123
|
+
"""
|
|
124
|
+
content = client.get_repository_content(
|
|
125
|
+
organization=organization,
|
|
126
|
+
repo_type=repo_type,
|
|
127
|
+
repo_name=repo_name,
|
|
128
|
+
branch=branch,
|
|
129
|
+
path=path,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
count = 0
|
|
133
|
+
if content.entries:
|
|
134
|
+
for entry in content.entries:
|
|
135
|
+
if entry.type == "file":
|
|
136
|
+
if _should_download_file(entry.path, allow_patterns, ignore_patterns):
|
|
137
|
+
count += 1
|
|
138
|
+
elif entry.type == "dir":
|
|
139
|
+
count += _count_files_to_download(
|
|
140
|
+
client=client,
|
|
141
|
+
organization=organization,
|
|
142
|
+
repo_type=repo_type,
|
|
143
|
+
repo_name=repo_name,
|
|
144
|
+
branch=branch,
|
|
145
|
+
path=entry.path,
|
|
146
|
+
allow_patterns=allow_patterns,
|
|
147
|
+
ignore_patterns=ignore_patterns,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
return count
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _download_repository_recursively(
|
|
154
|
+
client: HubClient,
|
|
155
|
+
organization: str,
|
|
156
|
+
repo_type: str,
|
|
157
|
+
repo_name: str,
|
|
158
|
+
branch: str,
|
|
159
|
+
path: str,
|
|
160
|
+
local_dir: str,
|
|
161
|
+
allow_patterns: Optional[List[str]] = None,
|
|
162
|
+
ignore_patterns: Optional[List[str]] = None,
|
|
163
|
+
verbose: bool = True,
|
|
164
|
+
progress_bar = None,
|
|
165
|
+
encryption_metadata = None,
|
|
166
|
+
decryption_key: Optional[Union[str, bytes]] = None,
|
|
167
|
+
decryption_algorithm: Optional[str] = None,
|
|
168
|
+
) -> None:
|
|
169
|
+
"""
|
|
170
|
+
Recursively download repository contents.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
client: Hub client instance
|
|
174
|
+
organization: Organization name
|
|
175
|
+
repo_type: Repository type
|
|
176
|
+
repo_name: Repository name
|
|
177
|
+
branch: Branch name
|
|
178
|
+
path: Current path in the repository
|
|
179
|
+
local_dir: Local directory to save files
|
|
180
|
+
allow_patterns: Patterns to allow
|
|
181
|
+
ignore_patterns: Patterns to ignore
|
|
182
|
+
verbose: Print progress messages
|
|
183
|
+
progress_bar: Optional tqdm progress bar for overall progress
|
|
184
|
+
encryption_metadata: Encryption metadata from .moha_encryption file
|
|
185
|
+
decryption_key: Key to decrypt files
|
|
186
|
+
decryption_exclude: Patterns to exclude from decryption
|
|
187
|
+
decryption_algorithm: Algorithm to use for decryption
|
|
188
|
+
"""
|
|
189
|
+
content = client.get_repository_content(
|
|
190
|
+
organization=organization,
|
|
191
|
+
repo_type=repo_type,
|
|
192
|
+
repo_name=repo_name,
|
|
193
|
+
branch=branch,
|
|
194
|
+
path=path,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Process entries
|
|
198
|
+
if content.entries:
|
|
199
|
+
for entry in content.entries:
|
|
200
|
+
if entry.type == "file":
|
|
201
|
+
# 检查文件是否应该被下载
|
|
202
|
+
if _should_download_file(entry.path, allow_patterns, ignore_patterns):
|
|
203
|
+
if verbose and progress_bar is None:
|
|
204
|
+
print(f"Downloading file: {entry.path}")
|
|
205
|
+
|
|
206
|
+
local_path = os.path.join(local_dir, entry.path)
|
|
207
|
+
|
|
208
|
+
# Update progress bar description if available
|
|
209
|
+
if progress_bar is not None:
|
|
210
|
+
progress_bar.set_description(f"Downloading {entry.path}")
|
|
211
|
+
|
|
212
|
+
client.download_file(
|
|
213
|
+
organization=organization,
|
|
214
|
+
repo_type=repo_type,
|
|
215
|
+
repo_name=repo_name,
|
|
216
|
+
branch=branch,
|
|
217
|
+
file_path=entry.path,
|
|
218
|
+
local_path=local_path,
|
|
219
|
+
show_progress=progress_bar is None, # Show individual progress only if no overall progress
|
|
220
|
+
)
|
|
221
|
+
file_is_encrypted = False
|
|
222
|
+
file_encryption_algorithm = None
|
|
223
|
+
if encryption_metadata and encryption_metadata.files:
|
|
224
|
+
for file_meta in encryption_metadata.files:
|
|
225
|
+
if file_meta.path == entry.path:
|
|
226
|
+
file_is_encrypted = True
|
|
227
|
+
file_encryption_algorithm = file_meta.algorithm
|
|
228
|
+
break
|
|
229
|
+
|
|
230
|
+
if file_is_encrypted:
|
|
231
|
+
if not decryption_key:
|
|
232
|
+
raise EncryptionError(
|
|
233
|
+
f"File '{entry.path}' is encrypted, but no decryption_key was provided. "
|
|
234
|
+
"Please provide a decryption_key parameter."
|
|
235
|
+
)
|
|
236
|
+
if not decryption_algorithm:
|
|
237
|
+
raise EncryptionError(
|
|
238
|
+
f"File '{entry.path}' is encrypted, but no decryption_algorithm was provided. "
|
|
239
|
+
f"The file was encrypted with '{file_encryption_algorithm}'. "
|
|
240
|
+
"Please provide a decryption_algorithm parameter."
|
|
241
|
+
)
|
|
242
|
+
if decryption_algorithm != file_encryption_algorithm:
|
|
243
|
+
raise EncryptionError(
|
|
244
|
+
f"File '{entry.path}' is encrypted with '{file_encryption_algorithm}', "
|
|
245
|
+
f"but decryption_algorithm '{decryption_algorithm}' was provided. "
|
|
246
|
+
"Please use the correct decryption algorithm."
|
|
247
|
+
)
|
|
248
|
+
if verbose and progress_bar is None:
|
|
249
|
+
print(f"Decrypting file: {entry.path}")
|
|
250
|
+
from .encryption import EncryptionAlgorithm, decrypt_file as decrypt_file_func
|
|
251
|
+
try:
|
|
252
|
+
algorithm = EncryptionAlgorithm(decryption_algorithm)
|
|
253
|
+
except ValueError:
|
|
254
|
+
raise EncryptionError(
|
|
255
|
+
f"Invalid decryption algorithm: {decryption_algorithm}. "
|
|
256
|
+
f"Supported algorithms: {', '.join([a.value for a in EncryptionAlgorithm])}"
|
|
257
|
+
)
|
|
258
|
+
decrypt_file_func(Path(local_path), decryption_key, algorithm)
|
|
259
|
+
|
|
260
|
+
if progress_bar is not None:
|
|
261
|
+
progress_bar.update(1)
|
|
262
|
+
else:
|
|
263
|
+
if verbose and progress_bar is None:
|
|
264
|
+
print(f"Skipping file: {entry.path}")
|
|
265
|
+
|
|
266
|
+
elif entry.type == "dir":
|
|
267
|
+
if verbose and progress_bar is None:
|
|
268
|
+
print(f"Entering directory: {entry.path}")
|
|
269
|
+
# 递归下载
|
|
270
|
+
_download_repository_recursively(
|
|
271
|
+
client=client,
|
|
272
|
+
organization=organization,
|
|
273
|
+
repo_type=repo_type,
|
|
274
|
+
repo_name=repo_name,
|
|
275
|
+
branch=branch,
|
|
276
|
+
path=entry.path,
|
|
277
|
+
local_dir=local_dir,
|
|
278
|
+
allow_patterns=allow_patterns,
|
|
279
|
+
ignore_patterns=ignore_patterns,
|
|
280
|
+
verbose=verbose,
|
|
281
|
+
progress_bar=progress_bar,
|
|
282
|
+
encryption_metadata=encryption_metadata,
|
|
283
|
+
decryption_key=decryption_key,
|
|
284
|
+
decryption_algorithm=decryption_algorithm,
|
|
285
|
+
)
|
|
286
|
+
else:
|
|
287
|
+
if verbose and progress_bar is None:
|
|
288
|
+
print(f"Skipping {entry.type}: {entry.path}")
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def moha_hub_download(
|
|
292
|
+
repo_id: str,
|
|
293
|
+
filename: str,
|
|
294
|
+
*,
|
|
295
|
+
repo_type: str = "models",
|
|
296
|
+
revision: Optional[str] = None,
|
|
297
|
+
local_dir: Optional[Union[str, Path]] = None,
|
|
298
|
+
base_url: Optional[str] = None,
|
|
299
|
+
username: Optional[str] = None,
|
|
300
|
+
password: Optional[str] = None,
|
|
301
|
+
token: Optional[str] = None,
|
|
302
|
+
show_progress: bool = True,
|
|
303
|
+
decryption_key: Optional[Union[str, bytes]] = None,
|
|
304
|
+
decryption_algorithm: Optional[str] = None,
|
|
305
|
+
) -> str:
|
|
306
|
+
"""
|
|
307
|
+
Download a single file from a repository.
|
|
308
|
+
|
|
309
|
+
Similar to huggingface_hub.hf_hub_download().
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
repo_id: Repository ID in the format "organization/repo_name"
|
|
313
|
+
filename: Path to the file in the repository
|
|
314
|
+
repo_type: Type of repository ("models" or "datasets")
|
|
315
|
+
revision: Branch/tag/commit to download from (default: main branch)
|
|
316
|
+
cache_dir: Directory to cache downloaded files
|
|
317
|
+
local_dir: Directory to save the file (if not using cache)
|
|
318
|
+
base_url: Base URL of the Hub API (default: from MOHA_ENDPOINT env var)
|
|
319
|
+
username: Username for authentication
|
|
320
|
+
password: Password for authentication
|
|
321
|
+
token: Token for authentication
|
|
322
|
+
show_progress: Whether to show download progress bar
|
|
323
|
+
decryption_key: Key to decrypt the file if repository is encrypted (string for symmetric, PEM for asymmetric)
|
|
324
|
+
decryption_algorithm: Decryption algorithm to use (default: 'aes-256-cbc')
|
|
325
|
+
- Symmetric: 'aes-256-cbc', 'aes-256-gcm'
|
|
326
|
+
- Asymmetric: 'rsa-oaep', 'rsa-pkcs1v15' (requires RSA private key in PEM format)
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
Path to the downloaded file
|
|
330
|
+
|
|
331
|
+
Example:
|
|
332
|
+
>>> file_path = hf_hub_download(
|
|
333
|
+
... repo_id="demo/demo",
|
|
334
|
+
... filename="data/config.yaml",
|
|
335
|
+
... username="your-username",
|
|
336
|
+
... password="your-password",
|
|
337
|
+
... )
|
|
338
|
+
|
|
339
|
+
>>> # Download from encrypted repository
|
|
340
|
+
>>> file_path = hf_hub_download(
|
|
341
|
+
... repo_id="demo/encrypted-model",
|
|
342
|
+
... filename="model.bin",
|
|
343
|
+
... decryption_key="my-secret-key",
|
|
344
|
+
... token="your-token",
|
|
345
|
+
... )
|
|
346
|
+
"""
|
|
347
|
+
parts = repo_id.split('/')
|
|
348
|
+
if len(parts) != 2:
|
|
349
|
+
raise ValueError(f"Invalid repo_id format: {repo_id}. Expected 'organization/repo_name'")
|
|
350
|
+
organization, repo_name = parts
|
|
351
|
+
client = HubClient(
|
|
352
|
+
base_url=base_url,
|
|
353
|
+
username=username,
|
|
354
|
+
password=password,
|
|
355
|
+
token=token,
|
|
356
|
+
)
|
|
357
|
+
# 获取默认分支
|
|
358
|
+
if revision is None:
|
|
359
|
+
revision = client.get_default_branch(organization, repo_type, repo_name)
|
|
360
|
+
|
|
361
|
+
# 获取加密元数据
|
|
362
|
+
encryption_metadata = client.get_moha_encryption(
|
|
363
|
+
organization=organization,
|
|
364
|
+
repo_type=repo_type,
|
|
365
|
+
repo_name=repo_name,
|
|
366
|
+
reference=revision,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
# 文件加密标识
|
|
370
|
+
file_is_encrypted = False
|
|
371
|
+
file_encryption_algorithm = None
|
|
372
|
+
# 查询文件是否加密
|
|
373
|
+
if encryption_metadata and encryption_metadata.files:
|
|
374
|
+
for file_meta in encryption_metadata.files:
|
|
375
|
+
if file_meta.path == filename:
|
|
376
|
+
file_is_encrypted = True
|
|
377
|
+
file_encryption_algorithm = file_meta.algorithm
|
|
378
|
+
break
|
|
379
|
+
|
|
380
|
+
# 如果该文件加密了,检查解密参数
|
|
381
|
+
if file_is_encrypted:
|
|
382
|
+
if not decryption_key:
|
|
383
|
+
raise EncryptionError(
|
|
384
|
+
f"File '{filename}' is encrypted, but no decryption_key was provided. "
|
|
385
|
+
"Please provide a decryption_key parameter."
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
if not decryption_algorithm:
|
|
389
|
+
raise EncryptionError(
|
|
390
|
+
f"File '{filename}' is encrypted, but no decryption_algorithm was provided. "
|
|
391
|
+
f"The file was encrypted with '{file_encryption_algorithm}'. "
|
|
392
|
+
"Please provide a decryption_algorithm parameter."
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
if decryption_algorithm != file_encryption_algorithm:
|
|
396
|
+
raise EncryptionError(
|
|
397
|
+
f"File '{filename}' is encrypted with '{file_encryption_algorithm}', "
|
|
398
|
+
f"but decryption_algorithm '{decryption_algorithm}' was provided. "
|
|
399
|
+
"Please use the correct decryption algorithm."
|
|
400
|
+
)
|
|
401
|
+
if local_dir:
|
|
402
|
+
local_path = os.path.join(local_dir, filename)
|
|
403
|
+
else:
|
|
404
|
+
local_path = filename
|
|
405
|
+
|
|
406
|
+
# 下载文件
|
|
407
|
+
client.download_file(
|
|
408
|
+
organization=organization,
|
|
409
|
+
repo_type=repo_type,
|
|
410
|
+
repo_name=repo_name,
|
|
411
|
+
branch=revision,
|
|
412
|
+
file_path=filename,
|
|
413
|
+
local_path=local_path,
|
|
414
|
+
show_progress=show_progress,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
# 解密文件
|
|
418
|
+
if file_is_encrypted and decryption_key:
|
|
419
|
+
from .encryption import EncryptionAlgorithm, decrypt_file as decrypt_file_func
|
|
420
|
+
try:
|
|
421
|
+
algorithm = EncryptionAlgorithm(decryption_algorithm)
|
|
422
|
+
except ValueError:
|
|
423
|
+
raise EncryptionError(
|
|
424
|
+
f"Invalid decryption algorithm: {decryption_algorithm}. "
|
|
425
|
+
f"Supported algorithms: {', '.join([a.value for a in EncryptionAlgorithm])}"
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
decrypt_file_func(Path(local_path), decryption_key, algorithm)
|
|
429
|
+
|
|
430
|
+
return local_path
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def snapshot_download(
|
|
434
|
+
repo_id: str,
|
|
435
|
+
repo_type: str = "models",
|
|
436
|
+
revision: Optional[str] = None,
|
|
437
|
+
local_dir: Optional[Union[str, Path]] = None,
|
|
438
|
+
allow_patterns: Optional[Union[List[str], str]] = None,
|
|
439
|
+
ignore_patterns: Optional[Union[List[str], str]] = None,
|
|
440
|
+
base_url: Optional[str] = None,
|
|
441
|
+
username: Optional[str] = None,
|
|
442
|
+
password: Optional[str] = None,
|
|
443
|
+
token: Optional[str] = None,
|
|
444
|
+
verbose: bool = True,
|
|
445
|
+
show_progress: bool = True,
|
|
446
|
+
decryption_key: Optional[Union[str, bytes]] = None,
|
|
447
|
+
decryption_algorithm: Optional[str] = None,
|
|
448
|
+
) -> str:
|
|
449
|
+
"""
|
|
450
|
+
Download an entire repository snapshot.
|
|
451
|
+
|
|
452
|
+
Similar to huggingface_hub.snapshot_download().
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
repo_id: Repository ID in the format "organization/repo_name"
|
|
456
|
+
repo_type: Type of repository ("models" or "datasets")
|
|
457
|
+
revision: Branch/tag/commit to download from (default: main branch)
|
|
458
|
+
local_dir: Directory to save files (if not using cache)
|
|
459
|
+
allow_patterns: Pattern or list of patterns to allow (e.g., "*.yaml", "*.yml")
|
|
460
|
+
ignore_patterns: Pattern or list of patterns to ignore (e.g., ".git*")
|
|
461
|
+
base_url: Base URL of the Hub API (default: from MOHA_ENDPOINT env var)
|
|
462
|
+
username: Username for authentication
|
|
463
|
+
password: Password for authentication
|
|
464
|
+
token: Token for authentication
|
|
465
|
+
verbose: Print progress messages
|
|
466
|
+
show_progress: Whether to show overall progress bar
|
|
467
|
+
decryption_key: Key to decrypt files if repository is encrypted (string for symmetric, PEM for asymmetric)
|
|
468
|
+
decryption_algorithm: Decryption algorithm to use (default: 'aes-256-cbc')
|
|
469
|
+
- Symmetric: 'aes-256-cbc', 'aes-256-gcm'
|
|
470
|
+
- Asymmetric: 'rsa-oaep', 'rsa-pkcs1v15' (requires RSA private key in PEM format)
|
|
471
|
+
|
|
472
|
+
Returns:
|
|
473
|
+
Path to the downloaded repository
|
|
474
|
+
|
|
475
|
+
Example:
|
|
476
|
+
>>> repo_path = snapshot_download(
|
|
477
|
+
... repo_id="demo/demo",
|
|
478
|
+
... repo_type="models",
|
|
479
|
+
... allow_patterns=["*.yaml", "*.yml"],
|
|
480
|
+
... ignore_patterns=[".git*"],
|
|
481
|
+
... username="your-username",
|
|
482
|
+
... password="your-password",
|
|
483
|
+
... )
|
|
484
|
+
|
|
485
|
+
>>> # Download from encrypted repository
|
|
486
|
+
>>> repo_path = snapshot_download(
|
|
487
|
+
... repo_id="demo/encrypted-model",
|
|
488
|
+
... repo_type="models",
|
|
489
|
+
... decryption_key="my-secret-key",
|
|
490
|
+
... decryption_exclude=["README.md", "*.txt"], # Don't decrypt these files
|
|
491
|
+
... token="your-token",
|
|
492
|
+
... )
|
|
493
|
+
"""
|
|
494
|
+
parts = repo_id.split('/')
|
|
495
|
+
if len(parts) != 2:
|
|
496
|
+
raise ValueError(f"Invalid repo_id format: {repo_id}. Expected 'organization/repo_name'")
|
|
497
|
+
|
|
498
|
+
organization, repo_name = parts
|
|
499
|
+
|
|
500
|
+
if isinstance(allow_patterns, str):
|
|
501
|
+
allow_patterns = [allow_patterns]
|
|
502
|
+
if isinstance(ignore_patterns, str):
|
|
503
|
+
ignore_patterns = [ignore_patterns]
|
|
504
|
+
client = HubClient(
|
|
505
|
+
base_url=base_url,
|
|
506
|
+
username=username,
|
|
507
|
+
password=password,
|
|
508
|
+
token=token,
|
|
509
|
+
)
|
|
510
|
+
if revision is None:
|
|
511
|
+
revision = client.get_default_branch(organization, repo_type, repo_name)
|
|
512
|
+
encryption_metadata = client.get_moha_encryption(
|
|
513
|
+
organization=organization,
|
|
514
|
+
repo_type=repo_type,
|
|
515
|
+
repo_name=repo_name,
|
|
516
|
+
reference=revision,
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
if encryption_metadata and encryption_metadata.files:
|
|
520
|
+
if verbose:
|
|
521
|
+
print(f"Repository has {len(encryption_metadata.files)} encrypted file(s). Files will be decrypted after download.")
|
|
522
|
+
|
|
523
|
+
# Determine local directory
|
|
524
|
+
if local_dir:
|
|
525
|
+
download_dir = str(local_dir)
|
|
526
|
+
else:
|
|
527
|
+
# Default to downloads directory
|
|
528
|
+
download_dir = f"./downloads/{organization}_{repo_type}_{repo_name}"
|
|
529
|
+
|
|
530
|
+
if verbose and not show_progress:
|
|
531
|
+
print(f"Downloading repository: {repo_id}")
|
|
532
|
+
print(f"Repository type: {repo_type}")
|
|
533
|
+
print(f"Revision: {revision}")
|
|
534
|
+
print(f"Destination: {download_dir}")
|
|
535
|
+
|
|
536
|
+
progress_bar = None
|
|
537
|
+
if show_progress and tqdm is not None:
|
|
538
|
+
if verbose:
|
|
539
|
+
print(f"Fetching repository info...")
|
|
540
|
+
# 计算需要下载的文件总数
|
|
541
|
+
total_files = _count_files_to_download(
|
|
542
|
+
client=client,
|
|
543
|
+
organization=organization,
|
|
544
|
+
repo_type=repo_type,
|
|
545
|
+
repo_name=repo_name,
|
|
546
|
+
branch=revision,
|
|
547
|
+
path="",
|
|
548
|
+
allow_patterns=allow_patterns,
|
|
549
|
+
ignore_patterns=ignore_patterns,
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
if total_files > 0:
|
|
553
|
+
progress_bar = tqdm(
|
|
554
|
+
total=total_files,
|
|
555
|
+
unit='file',
|
|
556
|
+
desc=f"Downloading {repo_id}",
|
|
557
|
+
leave=True,
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
# 递归下载,不是使用git的方式
|
|
561
|
+
try:
|
|
562
|
+
_download_repository_recursively(
|
|
563
|
+
client=client,
|
|
564
|
+
organization=organization,
|
|
565
|
+
repo_type=repo_type,
|
|
566
|
+
repo_name=repo_name,
|
|
567
|
+
branch=revision,
|
|
568
|
+
path="",
|
|
569
|
+
local_dir=download_dir,
|
|
570
|
+
allow_patterns=allow_patterns,
|
|
571
|
+
ignore_patterns=ignore_patterns,
|
|
572
|
+
verbose=verbose,
|
|
573
|
+
progress_bar=progress_bar,
|
|
574
|
+
encryption_metadata=encryption_metadata,
|
|
575
|
+
decryption_key=decryption_key,
|
|
576
|
+
decryption_algorithm=decryption_algorithm,
|
|
577
|
+
)
|
|
578
|
+
finally:
|
|
579
|
+
if progress_bar is not None:
|
|
580
|
+
progress_bar.close()
|
|
581
|
+
|
|
582
|
+
if verbose and not show_progress:
|
|
583
|
+
print(f"Download completed to: {download_dir}")
|
|
584
|
+
|
|
585
|
+
# Add download count
|
|
586
|
+
try:
|
|
587
|
+
client.add_download_count(
|
|
588
|
+
organization=organization,
|
|
589
|
+
repo_type=repo_type,
|
|
590
|
+
repo_name=repo_name,
|
|
591
|
+
)
|
|
592
|
+
except Exception as e:
|
|
593
|
+
# Don't fail the download if adding count fails
|
|
594
|
+
if verbose:
|
|
595
|
+
print(f"Warning: Failed to add download count: {e}")
|
|
596
|
+
|
|
597
|
+
return download_dir
|
|
598
|
+
|