xiaoshiai-hub 0.1.3__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xiaoshiai_hub/__init__.py +2 -27
- xiaoshiai_hub/client.py +11 -58
- xiaoshiai_hub/download.py +20 -199
- xiaoshiai_hub/exceptions.py +1 -5
- xiaoshiai_hub/types.py +0 -16
- xiaoshiai_hub/upload.py +477 -711
- xiaoshiai_hub-1.0.1.dist-info/METADATA +473 -0
- xiaoshiai_hub-1.0.1.dist-info/RECORD +11 -0
- xiaoshiai_hub/encryption.py +0 -777
- xiaoshiai_hub-0.1.3.dist-info/METADATA +0 -560
- xiaoshiai_hub-0.1.3.dist-info/RECORD +0 -12
- {xiaoshiai_hub-0.1.3.dist-info → xiaoshiai_hub-1.0.1.dist-info}/WHEEL +0 -0
- {xiaoshiai_hub-0.1.3.dist-info → xiaoshiai_hub-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {xiaoshiai_hub-0.1.3.dist-info → xiaoshiai_hub-1.0.1.dist-info}/top_level.txt +0 -0
xiaoshiai_hub/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
XiaoShi AI Hub Python SDK
|
|
3
3
|
|
|
4
4
|
A Python library for interacting with XiaoShi AI Hub repositories.
|
|
5
|
+
Encryption is handled by xpai-enc package (https://github.com/poxiaoyun/xpai-enc).
|
|
5
6
|
"""
|
|
6
7
|
|
|
7
8
|
from .client import HubClient, DEFAULT_BASE_URL
|
|
@@ -14,7 +15,6 @@ from .exceptions import (
|
|
|
14
15
|
RepositoryNotFoundError,
|
|
15
16
|
FileNotFoundError,
|
|
16
17
|
AuthenticationError,
|
|
17
|
-
EncryptionError,
|
|
18
18
|
)
|
|
19
19
|
from .types import (
|
|
20
20
|
Repository,
|
|
@@ -30,54 +30,29 @@ try:
|
|
|
30
30
|
upload_folder,
|
|
31
31
|
UploadError,
|
|
32
32
|
)
|
|
33
|
-
_upload_available = True
|
|
34
33
|
except ImportError:
|
|
35
|
-
_upload_available = False
|
|
36
34
|
upload_file = None
|
|
37
35
|
upload_folder = None
|
|
38
36
|
UploadError = None
|
|
39
37
|
|
|
40
|
-
|
|
41
|
-
try:
|
|
42
|
-
from .encryption import (
|
|
43
|
-
EncryptionAlgorithm,
|
|
44
|
-
encrypt_file,
|
|
45
|
-
decrypt_file,
|
|
46
|
-
)
|
|
47
|
-
_encryption_available = True
|
|
48
|
-
except ImportError:
|
|
49
|
-
_encryption_available = False
|
|
50
|
-
EncryptionAlgorithm = None
|
|
51
|
-
encrypt_file = None
|
|
52
|
-
decrypt_file = None
|
|
53
|
-
|
|
54
|
-
__version__ = "0.1.0"
|
|
38
|
+
__version__ = "1.0.1"
|
|
55
39
|
|
|
56
40
|
__all__ = [
|
|
57
41
|
# Client
|
|
58
42
|
"HubClient",
|
|
59
43
|
"DEFAULT_BASE_URL",
|
|
60
|
-
# Download
|
|
61
44
|
"moha_hub_download",
|
|
62
45
|
"snapshot_download",
|
|
63
|
-
# Upload
|
|
64
46
|
"upload_file",
|
|
65
47
|
"upload_folder",
|
|
66
|
-
# Exceptions
|
|
67
48
|
"HubException",
|
|
68
49
|
"RepositoryNotFoundError",
|
|
69
50
|
"FileNotFoundError",
|
|
70
51
|
"AuthenticationError",
|
|
71
|
-
"EncryptionError",
|
|
72
52
|
"UploadError",
|
|
73
|
-
# Types
|
|
74
53
|
"Repository",
|
|
75
54
|
"Ref",
|
|
76
55
|
"GitContent",
|
|
77
56
|
"Commit",
|
|
78
|
-
# Encryption
|
|
79
|
-
"EncryptionAlgorithm",
|
|
80
|
-
"encrypt_file",
|
|
81
|
-
"decrypt_file",
|
|
82
57
|
]
|
|
83
58
|
|
xiaoshiai_hub/client.py
CHANGED
|
@@ -19,13 +19,13 @@ from .exceptions import (
|
|
|
19
19
|
HTTPError,
|
|
20
20
|
RepositoryNotFoundError,
|
|
21
21
|
)
|
|
22
|
-
from .types import
|
|
22
|
+
from .types import Repository, Ref, GitContent
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
# 默认基础 URL,可通过环境变量 MOHA_ENDPOINT 覆盖
|
|
26
26
|
DEFAULT_BASE_URL = os.environ.get(
|
|
27
27
|
"MOHA_ENDPOINT",
|
|
28
|
-
"https://rune-api.develop.xiaoshiai.cn/
|
|
28
|
+
"https://rune-api.develop.xiaoshiai.cn/"
|
|
29
29
|
)
|
|
30
30
|
|
|
31
31
|
|
|
@@ -85,60 +85,6 @@ class HubClient:
|
|
|
85
85
|
return response
|
|
86
86
|
except requests.RequestException as e:
|
|
87
87
|
raise HTTPError(f"Request failed: {str(e)}")
|
|
88
|
-
|
|
89
|
-
def get_moha_encryption(
|
|
90
|
-
self,
|
|
91
|
-
organization: str,
|
|
92
|
-
repo_type: str,
|
|
93
|
-
repo_name: str,
|
|
94
|
-
reference: str,
|
|
95
|
-
) -> Optional[EncryptionMetadata]:
|
|
96
|
-
"""
|
|
97
|
-
Get .moha_encryption file content from the repository.
|
|
98
|
-
|
|
99
|
-
Args:
|
|
100
|
-
organization: Organization name
|
|
101
|
-
repo_type: Repository type ("models" or "datasets")
|
|
102
|
-
repo_name: Repository name
|
|
103
|
-
reference: Branch/tag/commit reference
|
|
104
|
-
|
|
105
|
-
Returns:
|
|
106
|
-
Encryption metadata if the repository has encrypted files, None otherwise
|
|
107
|
-
|
|
108
|
-
Raises:
|
|
109
|
-
AuthenticationError: If authentication fails
|
|
110
|
-
RepositoryNotFoundError: If the repository is not found
|
|
111
|
-
HTTPError: If the request fails
|
|
112
|
-
"""
|
|
113
|
-
url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/encryption/{reference}"
|
|
114
|
-
response = self._make_request("GET", url)
|
|
115
|
-
if response.status_code == 204 or not response.content:
|
|
116
|
-
return None
|
|
117
|
-
try:
|
|
118
|
-
data = response.json()
|
|
119
|
-
except ValueError:
|
|
120
|
-
return None
|
|
121
|
-
if not data:
|
|
122
|
-
return None
|
|
123
|
-
|
|
124
|
-
from datetime import datetime
|
|
125
|
-
files = None
|
|
126
|
-
if 'files' in data and data['files']:
|
|
127
|
-
from .types import FileEncryptionMetadata
|
|
128
|
-
files = [
|
|
129
|
-
FileEncryptionMetadata(
|
|
130
|
-
path=f.get('path', ''),
|
|
131
|
-
algorithm=f.get('algorithm', ''),
|
|
132
|
-
encryptedHash=f.get('encryptedHash', ''),
|
|
133
|
-
encryptedSize=f.get('encryptedSize', 0),
|
|
134
|
-
)
|
|
135
|
-
for f in data['files']
|
|
136
|
-
]
|
|
137
|
-
return EncryptionMetadata(
|
|
138
|
-
version=data.get('version', '1.0'),
|
|
139
|
-
createAt=datetime.fromisoformat(data['createAt'].replace('Z', '+00:00')) if 'createAt' in data else datetime.now(),
|
|
140
|
-
files=files,
|
|
141
|
-
)
|
|
142
88
|
|
|
143
89
|
def get_repository_info(
|
|
144
90
|
self,
|
|
@@ -351,9 +297,16 @@ class HubClient:
|
|
|
351
297
|
with open(local_path, 'wb') as f:
|
|
352
298
|
for chunk in response.iter_content(chunk_size=8192):
|
|
353
299
|
if chunk:
|
|
354
|
-
|
|
300
|
+
# Ensure chunk is bytes for type safety
|
|
301
|
+
if isinstance(chunk, str):
|
|
302
|
+
chunk_bytes = chunk.encode('utf-8')
|
|
303
|
+
elif isinstance(chunk, bytes):
|
|
304
|
+
chunk_bytes = chunk
|
|
305
|
+
else:
|
|
306
|
+
chunk_bytes = bytes(chunk)
|
|
307
|
+
f.write(chunk_bytes)
|
|
355
308
|
if progress_bar is not None:
|
|
356
|
-
progress_bar.update(len(
|
|
309
|
+
progress_bar.update(len(chunk_bytes))
|
|
357
310
|
finally:
|
|
358
311
|
if progress_bar is not None:
|
|
359
312
|
progress_bar.close()
|
xiaoshiai_hub/download.py
CHANGED
|
@@ -14,30 +14,7 @@ except ImportError:
|
|
|
14
14
|
|
|
15
15
|
from .client import HubClient, DEFAULT_BASE_URL
|
|
16
16
|
from .types import GitContent
|
|
17
|
-
from .exceptions import
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def _should_decrypt_file(filename: str, decryption_exclude: Optional[List[str]] = None) -> bool:
|
|
21
|
-
"""
|
|
22
|
-
Check if a file should be decrypted based on exclude patterns.
|
|
23
|
-
|
|
24
|
-
Args:
|
|
25
|
-
filename: Name of the file
|
|
26
|
-
decryption_exclude: List of patterns to exclude from decryption
|
|
27
|
-
|
|
28
|
-
Returns:
|
|
29
|
-
True if the file should be decrypted
|
|
30
|
-
"""
|
|
31
|
-
import fnmatch
|
|
32
|
-
|
|
33
|
-
if not decryption_exclude:
|
|
34
|
-
return True
|
|
35
|
-
|
|
36
|
-
for pattern in decryption_exclude:
|
|
37
|
-
if fnmatch.fnmatch(filename, pattern):
|
|
38
|
-
return False
|
|
39
|
-
|
|
40
|
-
return True
|
|
17
|
+
from .exceptions import RepositoryNotFoundError
|
|
41
18
|
|
|
42
19
|
|
|
43
20
|
def _match_pattern(name: str, pattern: str) -> bool:
|
|
@@ -160,11 +137,7 @@ def _download_repository_recursively(
|
|
|
160
137
|
local_dir: str,
|
|
161
138
|
allow_patterns: Optional[List[str]] = None,
|
|
162
139
|
ignore_patterns: Optional[List[str]] = None,
|
|
163
|
-
verbose: bool = True,
|
|
164
140
|
progress_bar = None,
|
|
165
|
-
encryption_metadata = None,
|
|
166
|
-
decryption_key: Optional[Union[str, bytes]] = None,
|
|
167
|
-
decryption_algorithm: Optional[str] = None,
|
|
168
141
|
) -> None:
|
|
169
142
|
"""
|
|
170
143
|
Recursively download repository contents.
|
|
@@ -179,12 +152,7 @@ def _download_repository_recursively(
|
|
|
179
152
|
local_dir: Local directory to save files
|
|
180
153
|
allow_patterns: Patterns to allow
|
|
181
154
|
ignore_patterns: Patterns to ignore
|
|
182
|
-
verbose: Print progress messages
|
|
183
155
|
progress_bar: Optional tqdm progress bar for overall progress
|
|
184
|
-
encryption_metadata: Encryption metadata from .moha_encryption file
|
|
185
|
-
decryption_key: Key to decrypt files
|
|
186
|
-
decryption_exclude: Patterns to exclude from decryption
|
|
187
|
-
decryption_algorithm: Algorithm to use for decryption
|
|
188
156
|
"""
|
|
189
157
|
content = client.get_repository_content(
|
|
190
158
|
organization=organization,
|
|
@@ -200,9 +168,7 @@ def _download_repository_recursively(
|
|
|
200
168
|
if entry.type == "file":
|
|
201
169
|
# 检查文件是否应该被下载
|
|
202
170
|
if _should_download_file(entry.path, allow_patterns, ignore_patterns):
|
|
203
|
-
|
|
204
|
-
print(f"Downloading file: {entry.path}")
|
|
205
|
-
|
|
171
|
+
print(f"Downloading file: {entry.path}")
|
|
206
172
|
local_path = os.path.join(local_dir, entry.path)
|
|
207
173
|
|
|
208
174
|
# Update progress bar description if available
|
|
@@ -218,54 +184,14 @@ def _download_repository_recursively(
|
|
|
218
184
|
local_path=local_path,
|
|
219
185
|
show_progress=progress_bar is None, # Show individual progress only if no overall progress
|
|
220
186
|
)
|
|
221
|
-
|
|
222
|
-
file_encryption_algorithm = None
|
|
223
|
-
if encryption_metadata and encryption_metadata.files:
|
|
224
|
-
for file_meta in encryption_metadata.files:
|
|
225
|
-
if file_meta.path == entry.path:
|
|
226
|
-
file_is_encrypted = True
|
|
227
|
-
file_encryption_algorithm = file_meta.algorithm
|
|
228
|
-
break
|
|
229
|
-
|
|
230
|
-
if file_is_encrypted:
|
|
231
|
-
if not decryption_key:
|
|
232
|
-
raise EncryptionError(
|
|
233
|
-
f"File '{entry.path}' is encrypted, but no decryption_key was provided. "
|
|
234
|
-
"Please provide a decryption_key parameter."
|
|
235
|
-
)
|
|
236
|
-
if not decryption_algorithm:
|
|
237
|
-
raise EncryptionError(
|
|
238
|
-
f"File '{entry.path}' is encrypted, but no decryption_algorithm was provided. "
|
|
239
|
-
f"The file was encrypted with '{file_encryption_algorithm}'. "
|
|
240
|
-
"Please provide a decryption_algorithm parameter."
|
|
241
|
-
)
|
|
242
|
-
if decryption_algorithm != file_encryption_algorithm:
|
|
243
|
-
raise EncryptionError(
|
|
244
|
-
f"File '{entry.path}' is encrypted with '{file_encryption_algorithm}', "
|
|
245
|
-
f"but decryption_algorithm '{decryption_algorithm}' was provided. "
|
|
246
|
-
"Please use the correct decryption algorithm."
|
|
247
|
-
)
|
|
248
|
-
if verbose and progress_bar is None:
|
|
249
|
-
print(f"Decrypting file: {entry.path}")
|
|
250
|
-
from .encryption import EncryptionAlgorithm, decrypt_file as decrypt_file_func
|
|
251
|
-
try:
|
|
252
|
-
algorithm = EncryptionAlgorithm(decryption_algorithm)
|
|
253
|
-
except ValueError:
|
|
254
|
-
raise EncryptionError(
|
|
255
|
-
f"Invalid decryption algorithm: {decryption_algorithm}. "
|
|
256
|
-
f"Supported algorithms: {', '.join([a.value for a in EncryptionAlgorithm])}"
|
|
257
|
-
)
|
|
258
|
-
decrypt_file_func(Path(local_path), decryption_key, algorithm)
|
|
259
|
-
|
|
187
|
+
|
|
260
188
|
if progress_bar is not None:
|
|
261
189
|
progress_bar.update(1)
|
|
262
190
|
else:
|
|
263
|
-
|
|
264
|
-
print(f"Skipping file: {entry.path}")
|
|
191
|
+
print(f"Skipping file: {entry.path}")
|
|
265
192
|
|
|
266
193
|
elif entry.type == "dir":
|
|
267
|
-
|
|
268
|
-
print(f"Entering directory: {entry.path}")
|
|
194
|
+
print(f"Entering directory: {entry.path}")
|
|
269
195
|
# 递归下载
|
|
270
196
|
_download_repository_recursively(
|
|
271
197
|
client=client,
|
|
@@ -277,15 +203,10 @@ def _download_repository_recursively(
|
|
|
277
203
|
local_dir=local_dir,
|
|
278
204
|
allow_patterns=allow_patterns,
|
|
279
205
|
ignore_patterns=ignore_patterns,
|
|
280
|
-
verbose=verbose,
|
|
281
206
|
progress_bar=progress_bar,
|
|
282
|
-
encryption_metadata=encryption_metadata,
|
|
283
|
-
decryption_key=decryption_key,
|
|
284
|
-
decryption_algorithm=decryption_algorithm,
|
|
285
207
|
)
|
|
286
208
|
else:
|
|
287
|
-
|
|
288
|
-
print(f"Skipping {entry.type}: {entry.path}")
|
|
209
|
+
print(f"Skipping {entry.type}: {entry.path}")
|
|
289
210
|
|
|
290
211
|
|
|
291
212
|
def moha_hub_download(
|
|
@@ -300,49 +221,37 @@ def moha_hub_download(
|
|
|
300
221
|
password: Optional[str] = None,
|
|
301
222
|
token: Optional[str] = None,
|
|
302
223
|
show_progress: bool = True,
|
|
303
|
-
|
|
304
|
-
decryption_algorithm: Optional[str] = None,
|
|
305
|
-
) -> str:
|
|
224
|
+
) -> Union[str, bytes]:
|
|
306
225
|
"""
|
|
307
226
|
Download a single file from a repository.
|
|
308
227
|
|
|
309
228
|
Similar to huggingface_hub.hf_hub_download().
|
|
310
229
|
|
|
230
|
+
Note: This function does not support decryption. If you need to download encrypted files,
|
|
231
|
+
use xpai-enc CLI tool to decrypt them after download.
|
|
232
|
+
|
|
311
233
|
Args:
|
|
312
234
|
repo_id: Repository ID in the format "organization/repo_name"
|
|
313
235
|
filename: Path to the file in the repository
|
|
314
236
|
repo_type: Type of repository ("models" or "datasets")
|
|
315
237
|
revision: Branch/tag/commit to download from (default: main branch)
|
|
316
|
-
cache_dir: Directory to cache downloaded files
|
|
317
238
|
local_dir: Directory to save the file (if not using cache)
|
|
318
239
|
base_url: Base URL of the Hub API (default: from MOHA_ENDPOINT env var)
|
|
319
240
|
username: Username for authentication
|
|
320
241
|
password: Password for authentication
|
|
321
242
|
token: Token for authentication
|
|
322
243
|
show_progress: Whether to show download progress bar
|
|
323
|
-
decryption_key: Key to decrypt the file if repository is encrypted (string for symmetric, PEM for asymmetric)
|
|
324
|
-
decryption_algorithm: Decryption algorithm to use (default: 'aes-256-cbc')
|
|
325
|
-
- Symmetric: 'aes-256-cbc', 'aes-256-gcm'
|
|
326
|
-
- Asymmetric: 'rsa-oaep', 'rsa-pkcs1v15' (requires RSA private key in PEM format)
|
|
327
244
|
|
|
328
245
|
Returns:
|
|
329
|
-
|
|
246
|
+
File path (str) or file content (bytes) based on return_content parameter
|
|
330
247
|
|
|
331
248
|
Example:
|
|
332
|
-
>>> file_path =
|
|
249
|
+
>>> file_path = moha_hub_download(
|
|
333
250
|
... repo_id="demo/demo",
|
|
334
251
|
... filename="data/config.yaml",
|
|
335
252
|
... username="your-username",
|
|
336
253
|
... password="your-password",
|
|
337
254
|
... )
|
|
338
|
-
|
|
339
|
-
>>> # Download from encrypted repository
|
|
340
|
-
>>> file_path = hf_hub_download(
|
|
341
|
-
... repo_id="demo/encrypted-model",
|
|
342
|
-
... filename="model.bin",
|
|
343
|
-
... decryption_key="my-secret-key",
|
|
344
|
-
... token="your-token",
|
|
345
|
-
... )
|
|
346
255
|
"""
|
|
347
256
|
parts = repo_id.split('/')
|
|
348
257
|
if len(parts) != 2:
|
|
@@ -358,46 +267,6 @@ def moha_hub_download(
|
|
|
358
267
|
if revision is None:
|
|
359
268
|
revision = client.get_default_branch(organization, repo_type, repo_name)
|
|
360
269
|
|
|
361
|
-
# 获取加密元数据
|
|
362
|
-
encryption_metadata = client.get_moha_encryption(
|
|
363
|
-
organization=organization,
|
|
364
|
-
repo_type=repo_type,
|
|
365
|
-
repo_name=repo_name,
|
|
366
|
-
reference=revision,
|
|
367
|
-
)
|
|
368
|
-
|
|
369
|
-
# 文件加密标识
|
|
370
|
-
file_is_encrypted = False
|
|
371
|
-
file_encryption_algorithm = None
|
|
372
|
-
# 查询文件是否加密
|
|
373
|
-
if encryption_metadata and encryption_metadata.files:
|
|
374
|
-
for file_meta in encryption_metadata.files:
|
|
375
|
-
if file_meta.path == filename:
|
|
376
|
-
file_is_encrypted = True
|
|
377
|
-
file_encryption_algorithm = file_meta.algorithm
|
|
378
|
-
break
|
|
379
|
-
|
|
380
|
-
# 如果该文件加密了,检查解密参数
|
|
381
|
-
if file_is_encrypted:
|
|
382
|
-
if not decryption_key:
|
|
383
|
-
raise EncryptionError(
|
|
384
|
-
f"File '{filename}' is encrypted, but no decryption_key was provided. "
|
|
385
|
-
"Please provide a decryption_key parameter."
|
|
386
|
-
)
|
|
387
|
-
|
|
388
|
-
if not decryption_algorithm:
|
|
389
|
-
raise EncryptionError(
|
|
390
|
-
f"File '{filename}' is encrypted, but no decryption_algorithm was provided. "
|
|
391
|
-
f"The file was encrypted with '{file_encryption_algorithm}'. "
|
|
392
|
-
"Please provide a decryption_algorithm parameter."
|
|
393
|
-
)
|
|
394
|
-
|
|
395
|
-
if decryption_algorithm != file_encryption_algorithm:
|
|
396
|
-
raise EncryptionError(
|
|
397
|
-
f"File '{filename}' is encrypted with '{file_encryption_algorithm}', "
|
|
398
|
-
f"but decryption_algorithm '{decryption_algorithm}' was provided. "
|
|
399
|
-
"Please use the correct decryption algorithm."
|
|
400
|
-
)
|
|
401
270
|
if local_dir:
|
|
402
271
|
local_path = os.path.join(local_dir, filename)
|
|
403
272
|
else:
|
|
@@ -413,20 +282,6 @@ def moha_hub_download(
|
|
|
413
282
|
local_path=local_path,
|
|
414
283
|
show_progress=show_progress,
|
|
415
284
|
)
|
|
416
|
-
|
|
417
|
-
# 解密文件
|
|
418
|
-
if file_is_encrypted and decryption_key:
|
|
419
|
-
from .encryption import EncryptionAlgorithm, decrypt_file as decrypt_file_func
|
|
420
|
-
try:
|
|
421
|
-
algorithm = EncryptionAlgorithm(decryption_algorithm)
|
|
422
|
-
except ValueError:
|
|
423
|
-
raise EncryptionError(
|
|
424
|
-
f"Invalid decryption algorithm: {decryption_algorithm}. "
|
|
425
|
-
f"Supported algorithms: {', '.join([a.value for a in EncryptionAlgorithm])}"
|
|
426
|
-
)
|
|
427
|
-
|
|
428
|
-
decrypt_file_func(Path(local_path), decryption_key, algorithm)
|
|
429
|
-
|
|
430
285
|
return local_path
|
|
431
286
|
|
|
432
287
|
|
|
@@ -441,16 +296,16 @@ def snapshot_download(
|
|
|
441
296
|
username: Optional[str] = None,
|
|
442
297
|
password: Optional[str] = None,
|
|
443
298
|
token: Optional[str] = None,
|
|
444
|
-
verbose: bool = True,
|
|
445
299
|
show_progress: bool = True,
|
|
446
|
-
decryption_key: Optional[Union[str, bytes]] = None,
|
|
447
|
-
decryption_algorithm: Optional[str] = None,
|
|
448
300
|
) -> str:
|
|
449
301
|
"""
|
|
450
302
|
Download an entire repository snapshot.
|
|
451
303
|
|
|
452
304
|
Similar to huggingface_hub.snapshot_download().
|
|
453
305
|
|
|
306
|
+
Note: This function does not support decryption. If you need to download encrypted repositories,
|
|
307
|
+
use xpai-enc CLI tool to decrypt them after download.
|
|
308
|
+
|
|
454
309
|
Args:
|
|
455
310
|
repo_id: Repository ID in the format "organization/repo_name"
|
|
456
311
|
repo_type: Type of repository ("models" or "datasets")
|
|
@@ -462,12 +317,7 @@ def snapshot_download(
|
|
|
462
317
|
username: Username for authentication
|
|
463
318
|
password: Password for authentication
|
|
464
319
|
token: Token for authentication
|
|
465
|
-
verbose: Print progress messages
|
|
466
320
|
show_progress: Whether to show overall progress bar
|
|
467
|
-
decryption_key: Key to decrypt files if repository is encrypted (string for symmetric, PEM for asymmetric)
|
|
468
|
-
decryption_algorithm: Decryption algorithm to use (default: 'aes-256-cbc')
|
|
469
|
-
- Symmetric: 'aes-256-cbc', 'aes-256-gcm'
|
|
470
|
-
- Asymmetric: 'rsa-oaep', 'rsa-pkcs1v15' (requires RSA private key in PEM format)
|
|
471
321
|
|
|
472
322
|
Returns:
|
|
473
323
|
Path to the downloaded repository
|
|
@@ -481,15 +331,6 @@ def snapshot_download(
|
|
|
481
331
|
... username="your-username",
|
|
482
332
|
... password="your-password",
|
|
483
333
|
... )
|
|
484
|
-
|
|
485
|
-
>>> # Download from encrypted repository
|
|
486
|
-
>>> repo_path = snapshot_download(
|
|
487
|
-
... repo_id="demo/encrypted-model",
|
|
488
|
-
... repo_type="models",
|
|
489
|
-
... decryption_key="my-secret-key",
|
|
490
|
-
... decryption_exclude=["README.md", "*.txt"], # Don't decrypt these files
|
|
491
|
-
... token="your-token",
|
|
492
|
-
... )
|
|
493
334
|
"""
|
|
494
335
|
parts = repo_id.split('/')
|
|
495
336
|
if len(parts) != 2:
|
|
@@ -509,16 +350,6 @@ def snapshot_download(
|
|
|
509
350
|
)
|
|
510
351
|
if revision is None:
|
|
511
352
|
revision = client.get_default_branch(organization, repo_type, repo_name)
|
|
512
|
-
encryption_metadata = client.get_moha_encryption(
|
|
513
|
-
organization=organization,
|
|
514
|
-
repo_type=repo_type,
|
|
515
|
-
repo_name=repo_name,
|
|
516
|
-
reference=revision,
|
|
517
|
-
)
|
|
518
|
-
|
|
519
|
-
if encryption_metadata and encryption_metadata.files:
|
|
520
|
-
if verbose:
|
|
521
|
-
print(f"Repository has {len(encryption_metadata.files)} encrypted file(s). Files will be decrypted after download.")
|
|
522
353
|
|
|
523
354
|
# Determine local directory
|
|
524
355
|
if local_dir:
|
|
@@ -527,16 +358,8 @@ def snapshot_download(
|
|
|
527
358
|
# Default to downloads directory
|
|
528
359
|
download_dir = f"./downloads/{organization}_{repo_type}_{repo_name}"
|
|
529
360
|
|
|
530
|
-
if verbose and not show_progress:
|
|
531
|
-
print(f"Downloading repository: {repo_id}")
|
|
532
|
-
print(f"Repository type: {repo_type}")
|
|
533
|
-
print(f"Revision: {revision}")
|
|
534
|
-
print(f"Destination: {download_dir}")
|
|
535
|
-
|
|
536
361
|
progress_bar = None
|
|
537
362
|
if show_progress and tqdm is not None:
|
|
538
|
-
if verbose:
|
|
539
|
-
print(f"Fetching repository info...")
|
|
540
363
|
# 计算需要下载的文件总数
|
|
541
364
|
total_files = _count_files_to_download(
|
|
542
365
|
client=client,
|
|
@@ -569,17 +392,13 @@ def snapshot_download(
|
|
|
569
392
|
local_dir=download_dir,
|
|
570
393
|
allow_patterns=allow_patterns,
|
|
571
394
|
ignore_patterns=ignore_patterns,
|
|
572
|
-
verbose=verbose,
|
|
573
395
|
progress_bar=progress_bar,
|
|
574
|
-
encryption_metadata=encryption_metadata,
|
|
575
|
-
decryption_key=decryption_key,
|
|
576
|
-
decryption_algorithm=decryption_algorithm,
|
|
577
396
|
)
|
|
578
397
|
finally:
|
|
579
398
|
if progress_bar is not None:
|
|
580
399
|
progress_bar.close()
|
|
581
400
|
|
|
582
|
-
if
|
|
401
|
+
if show_progress:
|
|
583
402
|
print(f"Download completed to: {download_dir}")
|
|
584
403
|
|
|
585
404
|
# Add download count
|
|
@@ -591,8 +410,10 @@ def snapshot_download(
|
|
|
591
410
|
)
|
|
592
411
|
except Exception as e:
|
|
593
412
|
# Don't fail the download if adding count fails
|
|
594
|
-
|
|
595
|
-
print(f"Warning: Failed to add download count: {e}")
|
|
413
|
+
print(f"Warning: Failed to add download count: {e}")
|
|
596
414
|
|
|
597
415
|
return download_dir
|
|
598
416
|
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
|
xiaoshiai_hub/exceptions.py
CHANGED
|
@@ -26,12 +26,8 @@ class AuthenticationError(HubException):
|
|
|
26
26
|
class HTTPError(HubException):
|
|
27
27
|
"""Raised when an HTTP error occurs."""
|
|
28
28
|
|
|
29
|
-
def __init__(self, message: str, status_code: int = None):
|
|
29
|
+
def __init__(self, message: str, status_code: int | None = None):
|
|
30
30
|
super().__init__(message)
|
|
31
31
|
self.status_code = status_code
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
class EncryptionError(HubException):
|
|
35
|
-
"""Raised when encryption is required but not provided or encryption fails."""
|
|
36
|
-
pass
|
|
37
|
-
|
xiaoshiai_hub/types.py
CHANGED
|
@@ -18,22 +18,6 @@ class Repository:
|
|
|
18
18
|
annotations: Dict[str, str] = field(default_factory=dict) # Repository annotations/metadata
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
@dataclass
|
|
22
|
-
class FileEncryptionMetadata:
|
|
23
|
-
"""Encryption metadata for a file."""
|
|
24
|
-
path: str
|
|
25
|
-
algorithm: str
|
|
26
|
-
encryptedHash: str
|
|
27
|
-
encryptedSize: int
|
|
28
|
-
|
|
29
|
-
@dataclass
|
|
30
|
-
class EncryptionMetadata:
|
|
31
|
-
"""Encryption metadata for a repository."""
|
|
32
|
-
version: str
|
|
33
|
-
createAt: datetime
|
|
34
|
-
files: Optional[List[FileEncryptionMetadata]] = None
|
|
35
|
-
|
|
36
|
-
|
|
37
21
|
|
|
38
22
|
@dataclass
|
|
39
23
|
class Signature:
|