xiaoshiai-hub 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xiaoshiai_hub/__init__.py +83 -0
- xiaoshiai_hub/client.py +360 -0
- xiaoshiai_hub/download.py +598 -0
- xiaoshiai_hub/encryption.py +777 -0
- xiaoshiai_hub/exceptions.py +37 -0
- xiaoshiai_hub/types.py +109 -0
- xiaoshiai_hub/upload.py +875 -0
- xiaoshiai_hub-0.1.1.dist-info/METADATA +560 -0
- xiaoshiai_hub-0.1.1.dist-info/RECORD +12 -0
- xiaoshiai_hub-0.1.1.dist-info/WHEEL +5 -0
- xiaoshiai_hub-0.1.1.dist-info/licenses/LICENSE +56 -0
- xiaoshiai_hub-0.1.1.dist-info/top_level.txt +1 -0
xiaoshiai_hub/upload.py
ADDED
|
@@ -0,0 +1,875 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Upload utilities for XiaoShi AI Hub SDK using GitPython
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import shutil
|
|
7
|
+
import json
|
|
8
|
+
import hashlib
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import List, Optional, Union, Callable, Dict, Any
|
|
11
|
+
from urllib.parse import urlparse, urlunparse
|
|
12
|
+
from datetime import datetime, timezone
|
|
13
|
+
|
|
14
|
+
from xiaoshiai_hub.client import DEFAULT_BASE_URL
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
from git import Repo, GitCommandError, InvalidGitRepositoryError
|
|
18
|
+
except ImportError:
|
|
19
|
+
raise ImportError(
|
|
20
|
+
"GitPython is required for upload functionality. "
|
|
21
|
+
"Install it with: pip install gitpython"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from tqdm.auto import tqdm
|
|
26
|
+
except ImportError:
|
|
27
|
+
tqdm = None
|
|
28
|
+
|
|
29
|
+
from .exceptions import HubException, AuthenticationError, RepositoryNotFoundError, EncryptionError
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class UploadError(HubException):
|
|
33
|
+
"""Raised when an upload operation fails."""
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _build_git_url(
|
|
38
|
+
base_url: Optional[str],
|
|
39
|
+
organization: str,
|
|
40
|
+
repo_type: str,
|
|
41
|
+
repo_name: str,
|
|
42
|
+
username: Optional[str] = None,
|
|
43
|
+
password: Optional[str] = None,
|
|
44
|
+
token: Optional[str] = None,
|
|
45
|
+
) -> str:
|
|
46
|
+
"""
|
|
47
|
+
Build Git repository URL with authentication.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
base_url: Base URL of the Hub API
|
|
51
|
+
organization: Organization name
|
|
52
|
+
repo_type: Repository type ("models" or "datasets")
|
|
53
|
+
repo_name: Repository name
|
|
54
|
+
username: Username for authentication
|
|
55
|
+
password: Password for authentication
|
|
56
|
+
token: Token for authentication
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Git repository URL with embedded credentials
|
|
60
|
+
"""
|
|
61
|
+
# Parse base URL to get the host
|
|
62
|
+
base_url = (base_url or DEFAULT_BASE_URL).rstrip('/')
|
|
63
|
+
parsed = urlparse(base_url)
|
|
64
|
+
host = parsed.netloc
|
|
65
|
+
scheme = parsed.scheme or 'https'
|
|
66
|
+
|
|
67
|
+
# Build repository path
|
|
68
|
+
repo_path = f"moha/{organization}/{repo_type}/{repo_name}.git"
|
|
69
|
+
|
|
70
|
+
# Add authentication to URL
|
|
71
|
+
if token:
|
|
72
|
+
# Use token as username with empty password
|
|
73
|
+
netloc = f"oauth2:{token}@{host}"
|
|
74
|
+
elif username and password:
|
|
75
|
+
netloc = f"{username}:{password}@{host}"
|
|
76
|
+
else:
|
|
77
|
+
netloc = host
|
|
78
|
+
|
|
79
|
+
# Construct full URL
|
|
80
|
+
git_url = urlunparse((scheme, netloc, repo_path, '', '', ''))
|
|
81
|
+
|
|
82
|
+
return git_url
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _calculate_file_hash(file_path: Path) -> str:
|
|
86
|
+
"""
|
|
87
|
+
Calculate SHA256 hash of a file.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
file_path: Path to the file
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Hexadecimal hash string
|
|
94
|
+
"""
|
|
95
|
+
sha256_hash = hashlib.sha256()
|
|
96
|
+
with open(file_path, "rb") as f:
|
|
97
|
+
for byte_block in iter(lambda: f.read(4096), b""):
|
|
98
|
+
sha256_hash.update(byte_block)
|
|
99
|
+
return sha256_hash.hexdigest()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _should_encrypt_file(file_path: str, encryption_exclude: List[str]) -> bool:
|
|
103
|
+
"""
|
|
104
|
+
Check if a file should be encrypted based on exclude patterns.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
file_path: Relative path of the file (relative to folder_path)
|
|
108
|
+
encryption_exclude: List of patterns to exclude from encryption
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
True if file should be encrypted, False otherwise
|
|
112
|
+
"""
|
|
113
|
+
import fnmatch
|
|
114
|
+
|
|
115
|
+
if not encryption_exclude:
|
|
116
|
+
return True
|
|
117
|
+
|
|
118
|
+
for pattern in encryption_exclude:
|
|
119
|
+
# Match against the full relative path
|
|
120
|
+
if fnmatch.fnmatch(file_path, pattern):
|
|
121
|
+
return False
|
|
122
|
+
|
|
123
|
+
return True
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _create_encryption_metadata(
|
|
127
|
+
encrypted_files: List[Dict[str, Any]],
|
|
128
|
+
algorithm: str,
|
|
129
|
+
version: str = "1.0"
|
|
130
|
+
) -> Dict[str, Any]:
|
|
131
|
+
"""
|
|
132
|
+
Create encryption metadata structure.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
encrypted_files: List of encrypted file information
|
|
136
|
+
algorithm: Encryption algorithm used
|
|
137
|
+
version: Metadata format version
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Encryption metadata dictionary
|
|
141
|
+
"""
|
|
142
|
+
return {
|
|
143
|
+
"version": version,
|
|
144
|
+
"createAt": datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z'),
|
|
145
|
+
"files": encrypted_files
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _write_encryption_metadata(repo_path: Path, metadata: Dict[str, Any]) -> None:
|
|
150
|
+
"""
|
|
151
|
+
Write encryption metadata to .moha_encryption file.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
repo_path: Path to the repository
|
|
155
|
+
metadata: Encryption metadata dictionary
|
|
156
|
+
"""
|
|
157
|
+
metadata_file = repo_path / ".moha_encryption"
|
|
158
|
+
with open(metadata_file, 'w', encoding='utf-8') as f:
|
|
159
|
+
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
|
160
|
+
print(f"Encryption metadata written to .moha_encryption")
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _count_files_in_directory(directory: str, ignore_patterns: Optional[List[str]] = None) -> int:
|
|
164
|
+
"""
|
|
165
|
+
Count total number of files in a directory.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
directory: Directory path
|
|
169
|
+
ignore_patterns: List of patterns to ignore (e.g., ['.git', '__pycache__'])
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Total number of files
|
|
173
|
+
"""
|
|
174
|
+
import fnmatch
|
|
175
|
+
|
|
176
|
+
count = 0
|
|
177
|
+
|
|
178
|
+
for root, dirs, files in os.walk(directory):
|
|
179
|
+
# Get relative path from directory
|
|
180
|
+
rel_root = os.path.relpath(root, directory)
|
|
181
|
+
|
|
182
|
+
if ignore_patterns:
|
|
183
|
+
dirs[:] = [d for d in dirs if not any(
|
|
184
|
+
fnmatch.fnmatch(d, pattern) for pattern in ignore_patterns
|
|
185
|
+
)]
|
|
186
|
+
|
|
187
|
+
if ignore_patterns:
|
|
188
|
+
filtered_files = []
|
|
189
|
+
for f in files:
|
|
190
|
+
# Construct relative file path
|
|
191
|
+
if rel_root == '.':
|
|
192
|
+
rel_file_path = f
|
|
193
|
+
else:
|
|
194
|
+
rel_file_path = os.path.join(rel_root, f)
|
|
195
|
+
|
|
196
|
+
# Check if file matches any ignore pattern
|
|
197
|
+
should_ignore = any(
|
|
198
|
+
fnmatch.fnmatch(rel_file_path, pattern) for pattern in ignore_patterns
|
|
199
|
+
)
|
|
200
|
+
if not should_ignore:
|
|
201
|
+
filtered_files.append(f)
|
|
202
|
+
files = filtered_files
|
|
203
|
+
|
|
204
|
+
count += len(files)
|
|
205
|
+
|
|
206
|
+
return count
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def upload_folder(
|
|
210
|
+
folder_path: Union[str, Path],
|
|
211
|
+
repo_id: str,
|
|
212
|
+
repo_type: str = "models",
|
|
213
|
+
revision: str = "main",
|
|
214
|
+
commit_message: Optional[str] = None,
|
|
215
|
+
commit_description: Optional[str] = None,
|
|
216
|
+
base_url: Optional[str] = None,
|
|
217
|
+
username: Optional[str] = None,
|
|
218
|
+
password: Optional[str] = None,
|
|
219
|
+
token: Optional[str] = None,
|
|
220
|
+
ignore_patterns: Optional[List[str]] = None, # 上传的时候忽略的文件
|
|
221
|
+
encryption_key: Optional[Union[str, bytes]] = None,
|
|
222
|
+
encryption_exclude: Optional[List[str]] = None, # 加密的时候排除的文件
|
|
223
|
+
encryption_algorithm: Optional[str] = None,
|
|
224
|
+
temp_dir: Optional[Union[str, Path]] = None,
|
|
225
|
+
skip_lfs: Optional[bool] = True,
|
|
226
|
+
) -> str:
|
|
227
|
+
"""
|
|
228
|
+
Upload a folder to a repository using Git.
|
|
229
|
+
|
|
230
|
+
This function clones the repository, copies files from the folder,
|
|
231
|
+
commits the changes, and pushes to the remote repository.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
folder_path: Path to the folder to upload
|
|
235
|
+
repo_id: Repository ID in the format "organization/repo_name"
|
|
236
|
+
repo_type: Type of repository ("models" or "datasets")
|
|
237
|
+
revision: Branch to upload to (default: "main")
|
|
238
|
+
commit_message: Commit message (default: "Upload folder")
|
|
239
|
+
commit_description: Additional commit description
|
|
240
|
+
base_url: Base URL of the Hub API
|
|
241
|
+
username: Username for authentication
|
|
242
|
+
password: Password for authentication
|
|
243
|
+
token: Token for authentication (preferred over username/password)
|
|
244
|
+
ignore_patterns: List of patterns to ignore (e.g., ['.git', '*.pyc', '__pycache__'])
|
|
245
|
+
encryption_key: Encryption key for encrypted repositories (string for symmetric, PEM for asymmetric)
|
|
246
|
+
encryption_exclude: List of file patterns to exclude from encryption (e.g., ['*.txt', 'README.md'])
|
|
247
|
+
encryption_algorithm: Encryption algorithm to use (default: 'aes-256-cbc')
|
|
248
|
+
- Symmetric: 'aes-256-cbc', 'aes-256-gcm'
|
|
249
|
+
- Asymmetric: 'rsa-oaep', 'rsa-pkcs1v15' (requires RSA public key in PEM format)
|
|
250
|
+
temp_dir: Custom temporary directory path for cloning repository (default: system temp directory)
|
|
251
|
+
skip_lfs: Skip LFS files when cloning the repository (default: True)
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
Commit hash of the uploaded changes
|
|
255
|
+
|
|
256
|
+
Raises:
|
|
257
|
+
RepositoryNotFoundError: If the repository does not exist
|
|
258
|
+
EncryptionError: If repository requires encryption but encryption_key is not provided
|
|
259
|
+
|
|
260
|
+
Example:
|
|
261
|
+
>>> commit_hash = upload_folder(
|
|
262
|
+
... folder_path="./my_model",
|
|
263
|
+
... repo_id="demo/my-model",
|
|
264
|
+
... repo_type="models",
|
|
265
|
+
... commit_message="Upload model files",
|
|
266
|
+
... token="your-token",
|
|
267
|
+
... )
|
|
268
|
+
|
|
269
|
+
>>> # Upload to encrypted repository
|
|
270
|
+
>>> commit_hash = upload_folder(
|
|
271
|
+
... folder_path="./my_model",
|
|
272
|
+
... repo_id="demo/encrypted-model",
|
|
273
|
+
... repo_type="models",
|
|
274
|
+
... encryption_key="my-secret-key",
|
|
275
|
+
... encryption_exclude=["README.md", "*.txt"],
|
|
276
|
+
... token="your-token",
|
|
277
|
+
... )
|
|
278
|
+
"""
|
|
279
|
+
import tempfile
|
|
280
|
+
import fnmatch
|
|
281
|
+
from .client import HubClient
|
|
282
|
+
|
|
283
|
+
parts = repo_id.split('/')
|
|
284
|
+
if len(parts) != 2:
|
|
285
|
+
raise ValueError(f"Invalid repo_id format: {repo_id}. Expected 'organization/repo_name'")
|
|
286
|
+
|
|
287
|
+
organization, repo_name = parts
|
|
288
|
+
client = HubClient(
|
|
289
|
+
base_url=base_url,
|
|
290
|
+
username=username,
|
|
291
|
+
password=password,
|
|
292
|
+
token=token,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
try:
|
|
296
|
+
repo_info = client.get_repository_info(
|
|
297
|
+
organization=organization,
|
|
298
|
+
repo_type=repo_type,
|
|
299
|
+
repo_name=repo_name,
|
|
300
|
+
)
|
|
301
|
+
except RepositoryNotFoundError:
|
|
302
|
+
raise RepositoryNotFoundError(
|
|
303
|
+
f"Repository '{repo_id}' does not exist. "
|
|
304
|
+
"Please create the repository before uploading."
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Check if repository requires encryption
|
|
308
|
+
is_encrypted = repo_info.annotations.get('encryption') == 'true'
|
|
309
|
+
if is_encrypted:
|
|
310
|
+
if not encryption_key:
|
|
311
|
+
raise EncryptionError(
|
|
312
|
+
f"Repository '{repo_id}' requires encryption, but no encryption_key was provided. "
|
|
313
|
+
"Please provide an encryption_key parameter."
|
|
314
|
+
)
|
|
315
|
+
print(f"Repository is encrypted. Files will be encrypted before upload.")
|
|
316
|
+
|
|
317
|
+
# Validate folder path
|
|
318
|
+
folder_path = Path(folder_path)
|
|
319
|
+
if not folder_path.exists():
|
|
320
|
+
raise FileNotFoundError(f"Folder not found: {folder_path}")
|
|
321
|
+
if not folder_path.is_dir():
|
|
322
|
+
raise ValueError(f"Path is not a directory: {folder_path}")
|
|
323
|
+
|
|
324
|
+
if ignore_patterns is None:
|
|
325
|
+
ignore_patterns = []
|
|
326
|
+
|
|
327
|
+
if encryption_exclude is None:
|
|
328
|
+
encryption_exclude = []
|
|
329
|
+
|
|
330
|
+
# Build Git URL
|
|
331
|
+
git_url = _build_git_url(
|
|
332
|
+
base_url=base_url,
|
|
333
|
+
organization=organization,
|
|
334
|
+
repo_type=repo_type,
|
|
335
|
+
repo_name=repo_name,
|
|
336
|
+
username=username,
|
|
337
|
+
password=password,
|
|
338
|
+
token=token,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
if temp_dir is not None:
|
|
342
|
+
import uuid
|
|
343
|
+
temp_dir_path = Path(temp_dir)
|
|
344
|
+
temp_dir_path.mkdir(parents=True, exist_ok=True)
|
|
345
|
+
unique_dir = temp_dir_path / f"upload_{uuid.uuid4().hex[:8]}"
|
|
346
|
+
unique_dir.mkdir(parents=True, exist_ok=True)
|
|
347
|
+
temp_context = None
|
|
348
|
+
working_temp_dir = unique_dir
|
|
349
|
+
else:
|
|
350
|
+
temp_context = tempfile.TemporaryDirectory()
|
|
351
|
+
working_temp_dir = Path(temp_context.__enter__())
|
|
352
|
+
unique_dir = None
|
|
353
|
+
|
|
354
|
+
try:
|
|
355
|
+
repo_path = working_temp_dir / "repo"
|
|
356
|
+
|
|
357
|
+
try:
|
|
358
|
+
try:
|
|
359
|
+
env_mapping = None
|
|
360
|
+
if skip_lfs:
|
|
361
|
+
env_mapping={'GIT_LFS_SKIP_SMUDGE': "1"}
|
|
362
|
+
repo = Repo.clone_from(
|
|
363
|
+
git_url,
|
|
364
|
+
repo_path,
|
|
365
|
+
branch=revision,
|
|
366
|
+
depth=1,
|
|
367
|
+
env=env_mapping,
|
|
368
|
+
)
|
|
369
|
+
except GitCommandError as e:
|
|
370
|
+
if "Authentication failed" in str(e) or "authentication" in str(e).lower():
|
|
371
|
+
raise AuthenticationError(f"Authentication failed: {e}")
|
|
372
|
+
raise UploadError(f"Failed to clone repository: {e}")
|
|
373
|
+
|
|
374
|
+
# 过滤掉忽略的文件,需要上传的文件总数
|
|
375
|
+
total_files = _count_files_in_directory(str(folder_path), ignore_patterns)
|
|
376
|
+
print(f"Copying files from {folder_path} to repository...")
|
|
377
|
+
progress_bar = None
|
|
378
|
+
if tqdm is not None and total_files > 0:
|
|
379
|
+
progress_bar = tqdm(
|
|
380
|
+
total=total_files,
|
|
381
|
+
unit='file',
|
|
382
|
+
desc="Copying files",
|
|
383
|
+
leave=True,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
files_copied = 0
|
|
387
|
+
files_to_encrypt = [] # 标记哪些文件需要加密
|
|
388
|
+
|
|
389
|
+
for root, dirs, files in os.walk(folder_path):
|
|
390
|
+
dirs[:] = [d for d in dirs if not any(
|
|
391
|
+
fnmatch.fnmatch(d, pattern) for pattern in ignore_patterns
|
|
392
|
+
)]
|
|
393
|
+
|
|
394
|
+
# 相对路径
|
|
395
|
+
rel_root = os.path.relpath(root, folder_path)
|
|
396
|
+
if rel_root == '.':
|
|
397
|
+
dest_root = repo_path
|
|
398
|
+
else:
|
|
399
|
+
dest_root = repo_path / rel_root
|
|
400
|
+
|
|
401
|
+
# Create destination directory
|
|
402
|
+
dest_root.mkdir(parents=True, exist_ok=True)
|
|
403
|
+
|
|
404
|
+
# Copy files
|
|
405
|
+
for file in files:
|
|
406
|
+
# Construct relative file path (relative to folder_path)
|
|
407
|
+
if rel_root == '.':
|
|
408
|
+
rel_file_path = file
|
|
409
|
+
else:
|
|
410
|
+
rel_file_path = os.path.join(rel_root, file)
|
|
411
|
+
|
|
412
|
+
# Check ignore patterns using full relative path
|
|
413
|
+
if any(fnmatch.fnmatch(rel_file_path, pattern) for pattern in ignore_patterns):
|
|
414
|
+
continue
|
|
415
|
+
|
|
416
|
+
src_file = Path(root) / file
|
|
417
|
+
dest_file = dest_root / file
|
|
418
|
+
|
|
419
|
+
# Copy file
|
|
420
|
+
shutil.copy2(src_file, dest_file)
|
|
421
|
+
files_copied += 1
|
|
422
|
+
|
|
423
|
+
# 把需要加密的文件写入列表 (使用完整相对路径)
|
|
424
|
+
if is_encrypted and _should_encrypt_file(rel_file_path, encryption_exclude):
|
|
425
|
+
files_to_encrypt.append(dest_file)
|
|
426
|
+
|
|
427
|
+
if progress_bar is not None:
|
|
428
|
+
progress_bar.update(1)
|
|
429
|
+
|
|
430
|
+
if progress_bar is not None:
|
|
431
|
+
progress_bar.close()
|
|
432
|
+
|
|
433
|
+
print(f"Copied {files_copied} files")
|
|
434
|
+
|
|
435
|
+
# 处理加密元数据文件
|
|
436
|
+
metadata_file = repo_path / ".moha_encryption"
|
|
437
|
+
|
|
438
|
+
# 加密
|
|
439
|
+
if is_encrypted and files_to_encrypt:
|
|
440
|
+
# Determine encryption algorithm
|
|
441
|
+
from .encryption import EncryptionAlgorithm, encrypt_file as encrypt_file_func
|
|
442
|
+
|
|
443
|
+
if encryption_algorithm is None:
|
|
444
|
+
algorithm = EncryptionAlgorithm.AES_256_CBC
|
|
445
|
+
else:
|
|
446
|
+
try:
|
|
447
|
+
algorithm = EncryptionAlgorithm(encryption_algorithm)
|
|
448
|
+
except ValueError:
|
|
449
|
+
raise EncryptionError(
|
|
450
|
+
f"Invalid encryption algorithm: {encryption_algorithm}. "
|
|
451
|
+
f"Supported algorithms: {', '.join([a.value for a in EncryptionAlgorithm])}"
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
print(f"Encrypting {len(files_to_encrypt)} files using {algorithm.value}...")
|
|
455
|
+
encrypt_progress = None
|
|
456
|
+
if tqdm is not None:
|
|
457
|
+
encrypt_progress = tqdm(
|
|
458
|
+
total=len(files_to_encrypt),
|
|
459
|
+
unit='file',
|
|
460
|
+
desc="Encrypting files",
|
|
461
|
+
leave=True,
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
# 收集加密文件的元数据
|
|
465
|
+
encrypted_files_metadata = []
|
|
466
|
+
|
|
467
|
+
for file_path in files_to_encrypt:
|
|
468
|
+
# 加密文件
|
|
469
|
+
encrypt_file_func(file_path, encryption_key, algorithm)
|
|
470
|
+
|
|
471
|
+
# 计算加密后的文件信息
|
|
472
|
+
encrypted_size = file_path.stat().st_size
|
|
473
|
+
encrypted_hash = _calculate_file_hash(file_path)
|
|
474
|
+
|
|
475
|
+
# 获取相对于 repo_path 的路径
|
|
476
|
+
rel_path = file_path.relative_to(repo_path)
|
|
477
|
+
|
|
478
|
+
# 添加到元数据列表
|
|
479
|
+
encrypted_files_metadata.append({
|
|
480
|
+
"path": str(rel_path),
|
|
481
|
+
"algorithm": algorithm.value,
|
|
482
|
+
"encryptedSize": encrypted_size,
|
|
483
|
+
"encryptedHash": encrypted_hash,
|
|
484
|
+
})
|
|
485
|
+
|
|
486
|
+
if encrypt_progress is not None:
|
|
487
|
+
encrypt_progress.update(1)
|
|
488
|
+
|
|
489
|
+
if encrypt_progress is not None:
|
|
490
|
+
encrypt_progress.close()
|
|
491
|
+
|
|
492
|
+
print(f"Encrypted {len(files_to_encrypt)} files")
|
|
493
|
+
|
|
494
|
+
# 写入加密元数据文件
|
|
495
|
+
encryption_metadata = _create_encryption_metadata(
|
|
496
|
+
encrypted_files=encrypted_files_metadata,
|
|
497
|
+
algorithm=algorithm.value
|
|
498
|
+
)
|
|
499
|
+
_write_encryption_metadata(repo_path, encryption_metadata)
|
|
500
|
+
else:
|
|
501
|
+
# 如果不是加密上传,删除可能存在的加密元数据文件
|
|
502
|
+
if metadata_file.exists():
|
|
503
|
+
metadata_file.unlink()
|
|
504
|
+
print("Removed .moha_encryption file (non-encrypted upload)")
|
|
505
|
+
|
|
506
|
+
repo.git.add(A=True)
|
|
507
|
+
|
|
508
|
+
if not repo.is_dirty() and not repo.untracked_files:
|
|
509
|
+
print("No changes to commit")
|
|
510
|
+
return repo.head.commit.hexsha
|
|
511
|
+
|
|
512
|
+
if commit_message is None:
|
|
513
|
+
commit_message = f"Upload folder from {folder_path.name}"
|
|
514
|
+
|
|
515
|
+
full_message = commit_message
|
|
516
|
+
if commit_description:
|
|
517
|
+
full_message = f"{commit_message}\n\n{commit_description}"
|
|
518
|
+
|
|
519
|
+
# Commit changes
|
|
520
|
+
print(f"Committing changes: {commit_message}")
|
|
521
|
+
commit = repo.index.commit(full_message)
|
|
522
|
+
try:
|
|
523
|
+
origin = repo.remote(name='origin')
|
|
524
|
+
origin.push(refspec=f'{revision}:{revision}')
|
|
525
|
+
except GitCommandError as e:
|
|
526
|
+
if "Authentication failed" in str(e) or "authentication" in str(e).lower():
|
|
527
|
+
raise AuthenticationError(f"Authentication failed during push: {e}")
|
|
528
|
+
raise UploadError(f"Failed to push changes: {e}")
|
|
529
|
+
|
|
530
|
+
print(f"Successfully uploaded to {repo_id}")
|
|
531
|
+
print(f"Commit hash: {commit.hexsha}")
|
|
532
|
+
|
|
533
|
+
return commit.hexsha
|
|
534
|
+
|
|
535
|
+
except (GitCommandError, InvalidGitRepositoryError) as e:
|
|
536
|
+
raise UploadError(f"Git operation failed: {e}")
|
|
537
|
+
finally:
|
|
538
|
+
# Clean up temporary directory
|
|
539
|
+
if temp_context is not None:
|
|
540
|
+
# Clean up system temp directory
|
|
541
|
+
try:
|
|
542
|
+
temp_context.__exit__(None, None, None)
|
|
543
|
+
except Exception:
|
|
544
|
+
pass # Ignore cleanup errors
|
|
545
|
+
elif unique_dir is not None:
|
|
546
|
+
# Clean up custom temp directory
|
|
547
|
+
try:
|
|
548
|
+
import shutil as shutil_cleanup
|
|
549
|
+
shutil_cleanup.rmtree(unique_dir, ignore_errors=True)
|
|
550
|
+
except Exception:
|
|
551
|
+
pass # Ignore cleanup errors
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def upload_file(
|
|
555
|
+
path_or_fileobj: Union[str, Path, bytes],
|
|
556
|
+
path_in_repo: str,
|
|
557
|
+
repo_id: str,
|
|
558
|
+
repo_type: str = "models",
|
|
559
|
+
revision: str = "main",
|
|
560
|
+
commit_message: Optional[str] = None,
|
|
561
|
+
commit_description: Optional[str] = None,
|
|
562
|
+
base_url: Optional[str] = None,
|
|
563
|
+
username: Optional[str] = None,
|
|
564
|
+
password: Optional[str] = None,
|
|
565
|
+
token: Optional[str] = None,
|
|
566
|
+
encryption_key: Optional[Union[str, bytes]] = None,
|
|
567
|
+
encryption_algorithm: Optional[str] = None,
|
|
568
|
+
temp_dir: Optional[Union[str, Path]] = None,
|
|
569
|
+
skip_lfs: Optional[bool] = True, # 当克隆仓库的时候,跳过lfs的大文件下载,只需要下载lfs文件指针
|
|
570
|
+
) -> str:
|
|
571
|
+
"""
|
|
572
|
+
Upload a single file to a repository.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
path_or_fileobj: Path to the file or file content as bytes
|
|
576
|
+
path_in_repo: Path where the file should be stored in the repository
|
|
577
|
+
repo_id: Repository ID in the format "organization/repo_name"
|
|
578
|
+
repo_type: Type of repository ("models" or "datasets")
|
|
579
|
+
revision: Branch to upload to (default: "main")
|
|
580
|
+
commit_message: Commit message
|
|
581
|
+
commit_description: Additional commit description
|
|
582
|
+
base_url: Base URL of the Hub API
|
|
583
|
+
username: Username for authentication
|
|
584
|
+
password: Password for authentication
|
|
585
|
+
token: Token for authentication
|
|
586
|
+
encryption_key: Encryption key for encrypted repositories (string for symmetric, PEM for asymmetric)
|
|
587
|
+
encryption_algorithm: Encryption algorithm to use (default: 'aes-256-cbc')
|
|
588
|
+
- Symmetric: 'aes-256-cbc', 'aes-256-gcm'
|
|
589
|
+
- Asymmetric: 'rsa-oaep', 'rsa-pkcs1v15' (requires RSA public key in PEM format)
|
|
590
|
+
temp_dir: Custom temporary directory path for cloning repository (default: system temp directory)
|
|
591
|
+
skip_lfs: Skip LFS files when cloning the repository (default: True)
|
|
592
|
+
|
|
593
|
+
Returns:
|
|
594
|
+
Commit hash of the uploaded file
|
|
595
|
+
|
|
596
|
+
Raises:
|
|
597
|
+
RepositoryNotFoundError: If the repository does not exist
|
|
598
|
+
EncryptionError: If repository requires encryption but encryption_key is not provided
|
|
599
|
+
|
|
600
|
+
Example:
|
|
601
|
+
>>> commit_hash = upload_file(
|
|
602
|
+
... path_or_fileobj="./config.yaml",
|
|
603
|
+
... path_in_repo="config.yaml",
|
|
604
|
+
... repo_id="demo/my-model",
|
|
605
|
+
... commit_message="Upload config file",
|
|
606
|
+
... token="your-token",
|
|
607
|
+
... )
|
|
608
|
+
|
|
609
|
+
>>> # Upload to encrypted repository
|
|
610
|
+
>>> commit_hash = upload_file(
|
|
611
|
+
... path_or_fileobj="./model.bin",
|
|
612
|
+
... path_in_repo="model.bin",
|
|
613
|
+
... repo_id="demo/encrypted-model",
|
|
614
|
+
... encryption_key="my-secret-key",
|
|
615
|
+
... token="your-token",
|
|
616
|
+
... )
|
|
617
|
+
"""
|
|
618
|
+
import tempfile
|
|
619
|
+
from .client import HubClient
|
|
620
|
+
|
|
621
|
+
# Parse repo_id
|
|
622
|
+
parts = repo_id.split('/')
|
|
623
|
+
if len(parts) != 2:
|
|
624
|
+
raise ValueError(f"Invalid repo_id format: {repo_id}. Expected 'organization/repo_name'")
|
|
625
|
+
|
|
626
|
+
organization, repo_name = parts
|
|
627
|
+
|
|
628
|
+
client = HubClient(
|
|
629
|
+
base_url=base_url,
|
|
630
|
+
username=username,
|
|
631
|
+
password=password,
|
|
632
|
+
token=token,
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
try:
|
|
636
|
+
repo_info = client.get_repository_info(
|
|
637
|
+
organization=organization,
|
|
638
|
+
repo_type=repo_type,
|
|
639
|
+
repo_name=repo_name,
|
|
640
|
+
)
|
|
641
|
+
except RepositoryNotFoundError:
|
|
642
|
+
raise RepositoryNotFoundError(
|
|
643
|
+
f"Repository '{repo_id}' does not exist. "
|
|
644
|
+
"Please create the repository before uploading."
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
# 检查仓库是否是加密仓库
|
|
648
|
+
is_encrypted = repo_info.annotations.get('encryption') == 'true'
|
|
649
|
+
if is_encrypted:
|
|
650
|
+
if not encryption_key:
|
|
651
|
+
raise EncryptionError(
|
|
652
|
+
f"Repository '{repo_id}' requires encryption, but no encryption_key was provided. "
|
|
653
|
+
"Please provide an encryption_key parameter."
|
|
654
|
+
)
|
|
655
|
+
print(f"Repository is encrypted. File will be encrypted before upload.")
|
|
656
|
+
|
|
657
|
+
# 构建 Git URL
|
|
658
|
+
git_url = _build_git_url(
|
|
659
|
+
base_url=base_url,
|
|
660
|
+
organization=organization,
|
|
661
|
+
repo_type=repo_type,
|
|
662
|
+
repo_name=repo_name,
|
|
663
|
+
username=username,
|
|
664
|
+
password=password,
|
|
665
|
+
token=token,
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
# 该临时目录用来存储克隆的仓库
|
|
669
|
+
if temp_dir is not None:
|
|
670
|
+
# 创建子目录
|
|
671
|
+
import uuid
|
|
672
|
+
temp_dir_path = Path(temp_dir)
|
|
673
|
+
temp_dir_path.mkdir(parents=True, exist_ok=True)
|
|
674
|
+
unique_dir = temp_dir_path / f"upload_{uuid.uuid4().hex[:8]}"
|
|
675
|
+
unique_dir.mkdir(parents=True, exist_ok=True)
|
|
676
|
+
temp_context = None
|
|
677
|
+
working_temp_dir = unique_dir
|
|
678
|
+
else:
|
|
679
|
+
# 使用系统的临时目录
|
|
680
|
+
temp_context = tempfile.TemporaryDirectory()
|
|
681
|
+
working_temp_dir = Path(temp_context.__enter__())
|
|
682
|
+
unique_dir = None
|
|
683
|
+
|
|
684
|
+
try:
|
|
685
|
+
repo_path = working_temp_dir / "repo"
|
|
686
|
+
|
|
687
|
+
try:
|
|
688
|
+
# 克隆远程仓库到临时目录,depth=1,只克隆最新的提交,减少数据量
|
|
689
|
+
try:
|
|
690
|
+
env_mapping = None
|
|
691
|
+
if skip_lfs:
|
|
692
|
+
env_mapping={'GIT_LFS_SKIP_SMUDGE': "1"}
|
|
693
|
+
repo = Repo.clone_from(
|
|
694
|
+
git_url,
|
|
695
|
+
repo_path,
|
|
696
|
+
branch=revision,
|
|
697
|
+
depth=1,
|
|
698
|
+
env=env_mapping,
|
|
699
|
+
)
|
|
700
|
+
except GitCommandError as e:
|
|
701
|
+
if "Authentication failed" in str(e) or "authentication" in str(e).lower():
|
|
702
|
+
raise AuthenticationError(f"Authentication failed: {e}")
|
|
703
|
+
raise UploadError(f"Failed to clone repository: {e}")
|
|
704
|
+
|
|
705
|
+
# Prepare file path in repository
|
|
706
|
+
file_path = repo_path / path_in_repo
|
|
707
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
708
|
+
|
|
709
|
+
# Write file content
|
|
710
|
+
if isinstance(path_or_fileobj, bytes):
|
|
711
|
+
# Write bytes directly
|
|
712
|
+
file_path.write_bytes(path_or_fileobj)
|
|
713
|
+
else:
|
|
714
|
+
# Copy from source file
|
|
715
|
+
source_path = Path(path_or_fileobj)
|
|
716
|
+
if not source_path.exists():
|
|
717
|
+
raise FileNotFoundError(f"File not found: {source_path}")
|
|
718
|
+
shutil.copy2(source_path, file_path)
|
|
719
|
+
|
|
720
|
+
print(f"Added file: {path_in_repo}")
|
|
721
|
+
|
|
722
|
+
# 处理加密元数据文件
|
|
723
|
+
metadata_file = repo_path / ".moha_encryption"
|
|
724
|
+
metadata_updated = False
|
|
725
|
+
|
|
726
|
+
# Encrypt file if repository is encrypted
|
|
727
|
+
if is_encrypted:
|
|
728
|
+
# Determine encryption algorithm
|
|
729
|
+
from .encryption import EncryptionAlgorithm, encrypt_file as encrypt_file_func
|
|
730
|
+
|
|
731
|
+
if encryption_algorithm is None:
|
|
732
|
+
algorithm = EncryptionAlgorithm.AES_256_CBC
|
|
733
|
+
else:
|
|
734
|
+
try:
|
|
735
|
+
algorithm = EncryptionAlgorithm(encryption_algorithm)
|
|
736
|
+
except ValueError:
|
|
737
|
+
raise EncryptionError(
|
|
738
|
+
f"Invalid encryption algorithm: {encryption_algorithm}. "
|
|
739
|
+
f"Supported algorithms: {', '.join([a.value for a in EncryptionAlgorithm])}"
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
print(f"Encrypting file: {path_in_repo} using {algorithm.value}")
|
|
743
|
+
encrypt_file_func(file_path, encryption_key, algorithm)
|
|
744
|
+
print(f"File encrypted")
|
|
745
|
+
|
|
746
|
+
# 计算加密后的文件信息
|
|
747
|
+
encrypted_size = file_path.stat().st_size
|
|
748
|
+
encrypted_hash = _calculate_file_hash(file_path)
|
|
749
|
+
|
|
750
|
+
# 读取或创建加密元数据文件
|
|
751
|
+
if metadata_file.exists():
|
|
752
|
+
with open(metadata_file, 'r', encoding='utf-8') as f:
|
|
753
|
+
encryption_metadata = json.load(f)
|
|
754
|
+
else:
|
|
755
|
+
encryption_metadata = {
|
|
756
|
+
"version": "1.0",
|
|
757
|
+
"createAt": datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z'),
|
|
758
|
+
"files": []
|
|
759
|
+
}
|
|
760
|
+
|
|
761
|
+
# 更新或添加文件信息
|
|
762
|
+
file_info = {
|
|
763
|
+
"path": path_in_repo,
|
|
764
|
+
"algorithm": algorithm.value,
|
|
765
|
+
"encryptedSize": encrypted_size,
|
|
766
|
+
"encryptedHash": encrypted_hash,
|
|
767
|
+
}
|
|
768
|
+
|
|
769
|
+
# 检查文件是否已存在,如果存在则更新
|
|
770
|
+
existing_index = None
|
|
771
|
+
for i, f in enumerate(encryption_metadata["files"]):
|
|
772
|
+
if f["path"] == path_in_repo:
|
|
773
|
+
existing_index = i
|
|
774
|
+
break
|
|
775
|
+
|
|
776
|
+
if existing_index is not None:
|
|
777
|
+
encryption_metadata["files"][existing_index] = file_info
|
|
778
|
+
else:
|
|
779
|
+
encryption_metadata["files"].append(file_info)
|
|
780
|
+
|
|
781
|
+
# 写入元数据文件
|
|
782
|
+
_write_encryption_metadata(repo_path, encryption_metadata)
|
|
783
|
+
metadata_updated = True
|
|
784
|
+
else:
|
|
785
|
+
# 如果不是加密上传,从元数据文件中移除该文件(如果存在)
|
|
786
|
+
if metadata_file.exists():
|
|
787
|
+
with open(metadata_file, 'r', encoding='utf-8') as f:
|
|
788
|
+
encryption_metadata = json.load(f)
|
|
789
|
+
|
|
790
|
+
# 查找并移除该文件
|
|
791
|
+
original_count = len(encryption_metadata["files"])
|
|
792
|
+
encryption_metadata["files"] = [
|
|
793
|
+
f for f in encryption_metadata["files"]
|
|
794
|
+
if f["path"] != path_in_repo
|
|
795
|
+
]
|
|
796
|
+
|
|
797
|
+
if len(encryption_metadata["files"]) < original_count:
|
|
798
|
+
# 文件被移除了
|
|
799
|
+
if len(encryption_metadata["files"]) == 0:
|
|
800
|
+
# 如果没有加密文件了,删除元数据文件
|
|
801
|
+
metadata_file.unlink()
|
|
802
|
+
print(f"Removed .moha_encryption file (no encrypted files remaining)")
|
|
803
|
+
else:
|
|
804
|
+
# 更新元数据文件
|
|
805
|
+
_write_encryption_metadata(repo_path, encryption_metadata)
|
|
806
|
+
print(f"Updated .moha_encryption file (removed {path_in_repo})")
|
|
807
|
+
metadata_updated = True
|
|
808
|
+
|
|
809
|
+
# Add file to git
|
|
810
|
+
repo.git.add(path_in_repo)
|
|
811
|
+
|
|
812
|
+
# 如果元数据文件被更新,也添加到 git
|
|
813
|
+
if metadata_updated:
|
|
814
|
+
if metadata_file.exists():
|
|
815
|
+
repo.git.add(".moha_encryption")
|
|
816
|
+
else:
|
|
817
|
+
# 如果文件被删除,确保从 git 中移除
|
|
818
|
+
try:
|
|
819
|
+
repo.git.rm(".moha_encryption")
|
|
820
|
+
except GitCommandError:
|
|
821
|
+
# 文件可能不在 git 中,忽略错误
|
|
822
|
+
pass
|
|
823
|
+
|
|
824
|
+
# Check if there are changes
|
|
825
|
+
if not repo.is_dirty() and not repo.untracked_files:
|
|
826
|
+
print("No changes to commit (file already exists with same content)")
|
|
827
|
+
return repo.head.commit.hexsha
|
|
828
|
+
|
|
829
|
+
# Create commit message
|
|
830
|
+
if commit_message is None:
|
|
831
|
+
commit_message = f"Upload {path_in_repo}"
|
|
832
|
+
|
|
833
|
+
full_message = commit_message
|
|
834
|
+
if commit_description:
|
|
835
|
+
full_message = f"{commit_message}\n\n{commit_description}"
|
|
836
|
+
|
|
837
|
+
# Commit changes
|
|
838
|
+
print(f"Committing changes: {commit_message}")
|
|
839
|
+
commit = repo.index.commit(full_message)
|
|
840
|
+
|
|
841
|
+
# Push to remote
|
|
842
|
+
print(f"Pushing to {revision}...")
|
|
843
|
+
try:
|
|
844
|
+
origin = repo.remote(name='origin')
|
|
845
|
+
origin.push(refspec=f'{revision}:{revision}')
|
|
846
|
+
except GitCommandError as e:
|
|
847
|
+
if "Authentication failed" in str(e) or "authentication" in str(e).lower():
|
|
848
|
+
raise AuthenticationError(f"Authentication failed during push: {e}")
|
|
849
|
+
raise UploadError(f"Failed to push changes: {e}")
|
|
850
|
+
|
|
851
|
+
print(f"Successfully uploaded {path_in_repo} to {repo_id}")
|
|
852
|
+
print(f"Commit hash: {commit.hexsha}")
|
|
853
|
+
|
|
854
|
+
return commit.hexsha
|
|
855
|
+
|
|
856
|
+
except (GitCommandError, InvalidGitRepositoryError) as e:
|
|
857
|
+
raise UploadError(f"Git operation failed: {e}")
|
|
858
|
+
finally:
|
|
859
|
+
# Clean up temporary directory
|
|
860
|
+
if temp_context is not None:
|
|
861
|
+
# Clean up system temp directory
|
|
862
|
+
try:
|
|
863
|
+
temp_context.__exit__(None, None, None)
|
|
864
|
+
except Exception:
|
|
865
|
+
pass # Ignore cleanup errors
|
|
866
|
+
elif unique_dir is not None:
|
|
867
|
+
# Clean up custom temp directory
|
|
868
|
+
try:
|
|
869
|
+
import shutil as shutil_cleanup
|
|
870
|
+
shutil_cleanup.rmtree(unique_dir, ignore_errors=True)
|
|
871
|
+
except Exception:
|
|
872
|
+
pass # Ignore cleanup errors
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
|