xiaoshiai-hub 0.1.3__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xiaoshiai_hub/upload.py CHANGED
@@ -1,32 +1,30 @@
1
1
  """
2
- Upload utilities for XiaoShi AI Hub SDK using GitPython
2
+ Upload utilities for XiaoShi AI Hub SDK using HTTP API
3
3
  """
4
4
 
5
- import os
6
- import shutil
7
- import json
5
+ import base64
8
6
  import hashlib
7
+ import os
8
+ import tempfile
9
9
  from pathlib import Path
10
- from typing import List, Optional, Union, Callable, Dict, Any
11
- from urllib.parse import urlparse, urlunparse
12
- from datetime import datetime, timezone
10
+ from typing import List, Optional, Union, Dict
13
11
 
14
- from xiaoshiai_hub.client import DEFAULT_BASE_URL
12
+ import requests
15
13
 
16
- try:
17
- from git import Repo, GitCommandError, InvalidGitRepositoryError
18
- except ImportError:
19
- raise ImportError(
20
- "GitPython is required for upload functionality. "
21
- "Install it with: pip install gitpython"
22
- )
14
+ from xiaoshiai_hub.client import DEFAULT_BASE_URL, HubClient
15
+ from .exceptions import HubException, AuthenticationError, RepositoryNotFoundError
16
+
17
+ from xpai_enc import enc_file
18
+ from key_manager import KeyManager
23
19
 
24
20
  try:
25
21
  from tqdm.auto import tqdm
26
22
  except ImportError:
27
23
  tqdm = None
28
24
 
29
- from .exceptions import HubException, AuthenticationError, RepositoryNotFoundError, EncryptionError
25
+
26
+ # File extensions that require encryption
27
+ ENCRYPTABLE_EXTENSIONS = {".safetensors", ".bin", ".pt", ".pth", ".ckpt"}
30
28
 
31
29
 
32
30
  class UploadError(HubException):
@@ -34,176 +32,253 @@ class UploadError(HubException):
34
32
  pass
35
33
 
36
34
 
37
- def _build_git_url(
35
+ def _build_api_url(
38
36
  base_url: Optional[str],
39
37
  organization: str,
40
38
  repo_type: str,
41
39
  repo_name: str,
42
- username: Optional[str] = None,
43
- password: Optional[str] = None,
44
- token: Optional[str] = None,
45
40
  ) -> str:
46
- """
47
- Build Git repository URL with authentication.
48
-
49
- Args:
50
- base_url: Base URL of the Hub API
51
- organization: Organization name
52
- repo_type: Repository type ("models" or "datasets")
53
- repo_name: Repository name
54
- username: Username for authentication
55
- password: Password for authentication
56
- token: Token for authentication
57
-
58
- Returns:
59
- Git repository URL with embedded credentials
60
- """
61
- # Parse base URL to get the host
41
+ """Build API upload URL."""
62
42
  base_url = (base_url or DEFAULT_BASE_URL).rstrip('/')
63
- parsed = urlparse(base_url)
64
- host = parsed.netloc
65
- scheme = parsed.scheme or 'https'
66
-
67
- # Build repository path
68
- repo_path = f"moha/{organization}/{repo_type}/{repo_name}.git"
43
+ return f"{base_url}/{organization}/{repo_type}/{repo_name}/api/upload"
44
+
45
+
46
+ def _create_session(
47
+ token: Optional[str] = None,
48
+ username: Optional[str] = None,
49
+ password: Optional[str] = None
50
+ ) -> requests.Session:
51
+ """Create HTTP session with authentication."""
52
+ session = requests.Session()
69
53
 
70
- # Add authentication to URL
71
54
  if token:
72
- # Use token as username with empty password
73
- netloc = f"oauth2:{token}@{host}"
55
+ session.headers.update({'Authorization': f'Bearer {token}'})
74
56
  elif username and password:
75
- netloc = f"{username}:{password}@{host}"
76
- else:
77
- netloc = host
78
-
79
- # Construct full URL
80
- git_url = urlunparse((scheme, netloc, repo_path, '', '', ''))
57
+ from requests.auth import HTTPBasicAuth
58
+ session.auth = HTTPBasicAuth(username, password) # type: ignore
81
59
 
82
- return git_url
60
+ session.headers.update({'Content-Type': 'application/json'})
61
+ return session
83
62
 
84
63
 
85
- def _calculate_file_hash(file_path: Path) -> str:
86
- """
87
- Calculate SHA256 hash of a file.
88
-
89
- Args:
90
- file_path: Path to the file
91
-
92
- Returns:
93
- Hexadecimal hash string
94
- """
64
+ def _calculate_file_sha256(file_path: Path) -> str:
65
+ """Calculate SHA256 hash of a file."""
95
66
  sha256_hash = hashlib.sha256()
96
- with open(file_path, "rb") as f:
97
- for byte_block in iter(lambda: f.read(4096), b""):
98
- sha256_hash.update(byte_block)
67
+ with open(file_path, 'rb') as f:
68
+ for chunk in iter(lambda: f.read(8192), b''):
69
+ sha256_hash.update(chunk)
99
70
  return sha256_hash.hexdigest()
100
71
 
101
72
 
102
- def _should_encrypt_file(file_path: str, encryption_exclude: List[str]) -> bool:
73
+ class _TqdmUploadWrapper:
74
+ """Wrapper to add tqdm progress bar to file upload."""
75
+
76
+ def __init__(self, file_obj, total_size, desc=None):
77
+ self.file_obj = file_obj
78
+ self.total_size = total_size
79
+ self.pbar = None
80
+ if tqdm:
81
+ self.pbar = tqdm(
82
+ total=total_size,
83
+ unit='B',
84
+ unit_scale=True,
85
+ unit_divisor=1024,
86
+ desc=desc,
87
+ )
88
+
89
+ def read(self, size=-1):
90
+ """Read data and update progress bar."""
91
+ data = self.file_obj.read(size)
92
+ if self.pbar is not None and data:
93
+ self.pbar.update(len(data))
94
+ return data
95
+
96
+ def __enter__(self):
97
+ return self
98
+
99
+ def __exit__(self, _exc_type, _exc_val, _exc_tb):
100
+ if self.pbar is not None:
101
+ self.pbar.close()
102
+ self.file_obj.close()
103
+
104
+
105
+ def _upload_file_with_progress(
106
+ upload_url: str,
107
+ file_path: Path,
108
+ desc: Optional[str] = None,
109
+ ) -> None:
110
+ """Upload file to URL with progress bar."""
111
+ file_size = file_path.stat().st_size
112
+
113
+ with open(file_path, 'rb') as f:
114
+ with _TqdmUploadWrapper(f, file_size, desc=desc) as wrapped_file:
115
+ upload_response = requests.put(
116
+ upload_url,
117
+ data=wrapped_file,
118
+ headers={'Content-Type': 'application/octet-stream'}
119
+ )
120
+ upload_response.raise_for_status()
121
+
122
+
123
+ def _encrypt_file_if_needed(
124
+ file_path: Path,
125
+ encryption_password: Optional[str] = None,
126
+ ) -> tuple[Optional[Path], Optional[Path]]:
103
127
  """
104
- Check if a file should be encrypted based on exclude patterns.
128
+ Encrypt file if encryption_password is provided, file is large enough, and file extension is encryptable.
105
129
 
106
130
  Args:
107
- file_path: Relative path of the file (relative to folder_path)
108
- encryption_exclude: List of patterns to exclude from encryption
131
+ file_path: Path to the file to encrypt
132
+ encryption_password: Password for encryption
109
133
 
110
134
  Returns:
111
- True if file should be encrypted, False otherwise
112
- """
113
- import fnmatch
114
-
115
- if not encryption_exclude:
116
- return True
117
-
118
- for pattern in encryption_exclude:
119
- # Match against the full relative path
120
- if fnmatch.fnmatch(file_path, pattern):
121
- return False
122
-
123
- return True
124
-
125
-
126
- def _create_encryption_metadata(
127
- encrypted_files: List[Dict[str, Any]],
128
- algorithm: str,
129
- version: str = "1.0"
130
- ) -> Dict[str, Any]:
135
+ Tuple of (encrypted_file_path, temp_dir_path)
136
+ If no encryption, returns (None, None)
131
137
  """
132
- Create encryption metadata structure.
133
-
134
- Args:
135
- encrypted_files: List of encrypted file information
136
- algorithm: Encryption algorithm used
137
- version: Metadata format version
138
+ if not encryption_password:
139
+ return None, None
140
+
141
+ # Check file size first (only encrypt files >= 5MB)
142
+ file_size = file_path.stat().st_size
143
+ if file_size < 5 * 1024 * 1024:
144
+ return None, None
145
+
146
+ # Check if file extension requires encryption
147
+ if file_path.suffix.lower() not in ENCRYPTABLE_EXTENSIONS:
148
+ return None, None
149
+
150
+ # Generate encryption key from password
151
+ encryption_key = KeyManager.generate_key(encryption_password)
152
+
153
+ # Create temporary directory for encrypted file
154
+ temp_dir = Path(tempfile.mkdtemp())
155
+ encrypted_file = temp_dir / file_path.name
156
+ manifest_path = temp_dir / "xpai_encryption_manifest.enc"
157
+
158
+ # Encrypt the file
159
+ enc_file(
160
+ source=file_path,
161
+ dest=encrypted_file,
162
+ encryption_key=encryption_key,
163
+ manifest_path=manifest_path,
164
+ )
138
165
 
139
- Returns:
140
- Encryption metadata dictionary
141
- """
142
- return {
143
- "version": version,
144
- "createAt": datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z'),
145
- "files": encrypted_files
166
+ return encrypted_file, temp_dir
167
+
168
+
169
+ def _upload_files_via_api(
170
+ session: requests.Session,
171
+ api_url: str,
172
+ files: List[Dict],
173
+ message: str,
174
+ branch: str,
175
+ author: Optional[Dict] = None,
176
+ ) -> Dict:
177
+ """Upload files via HTTP API."""
178
+ payload = {
179
+ "files": files,
180
+ "message": message,
181
+ "branch": branch,
146
182
  }
183
+ if author:
184
+ payload["author"] = author
185
+
186
+ try:
187
+ response = session.post(api_url, json=payload)
188
+ response.raise_for_status()
189
+ return response.json()
190
+ except requests.exceptions.HTTPError as e:
191
+ if e.response.status_code == 401:
192
+ raise AuthenticationError(f"Authentication failed: {e}")
193
+ elif e.response.status_code == 404:
194
+ raise RepositoryNotFoundError(f"Repository not found: {e}")
195
+ else:
196
+ raise UploadError(f"Upload failed: {e}")
197
+ except requests.exceptions.RequestException as e:
198
+ raise UploadError(f"Request failed: {e}")
199
+
200
+
201
+ def _upload_small_file(
202
+ session: requests.Session,
203
+ api_url: str,
204
+ local_path: Path,
205
+ remote_path: str,
206
+ message: str,
207
+ branch: str,
208
+ ) -> Dict:
209
+ """Upload small file (< 5MB) with base64 encoding."""
210
+ file_size = local_path.stat().st_size
211
+
212
+ # Show progress bar for reading file
213
+ pbar = None
214
+ if tqdm and file_size > 0:
215
+ pbar = tqdm(
216
+ total=file_size,
217
+ unit='B',
218
+ unit_scale=True,
219
+ unit_divisor=1024,
220
+ desc=f"Uploading {local_path.name}",
221
+ )
147
222
 
148
-
149
- def _write_encryption_metadata(repo_path: Path, metadata: Dict[str, Any]) -> None:
150
- """
151
- Write encryption metadata to .moha_encryption file.
152
-
153
- Args:
154
- repo_path: Path to the repository
155
- metadata: Encryption metadata dictionary
156
- """
157
- metadata_file = repo_path / ".moha_encryption"
158
- with open(metadata_file, 'w', encoding='utf-8') as f:
159
- json.dump(metadata, f, indent=2, ensure_ascii=False)
160
- print(f"Encryption metadata written to .moha_encryption")
161
-
162
-
163
- def _count_files_in_directory(directory: str, ignore_patterns: Optional[List[str]] = None) -> int:
164
- """
165
- Count total number of files in a directory.
166
-
167
- Args:
168
- directory: Directory path
169
- ignore_patterns: List of patterns to ignore (e.g., ['.git', '__pycache__'])
170
-
171
- Returns:
172
- Total number of files
173
- """
174
- import fnmatch
175
-
176
- count = 0
177
-
178
- for root, dirs, files in os.walk(directory):
179
- # Get relative path from directory
180
- rel_root = os.path.relpath(root, directory)
181
-
182
- if ignore_patterns:
183
- dirs[:] = [d for d in dirs if not any(
184
- fnmatch.fnmatch(d, pattern) for pattern in ignore_patterns
185
- )]
186
-
187
- if ignore_patterns:
188
- filtered_files = []
189
- for f in files:
190
- # Construct relative file path
191
- if rel_root == '.':
192
- rel_file_path = f
193
- else:
194
- rel_file_path = os.path.join(rel_root, f)
195
-
196
- # Check if file matches any ignore pattern
197
- should_ignore = any(
198
- fnmatch.fnmatch(rel_file_path, pattern) for pattern in ignore_patterns
199
- )
200
- if not should_ignore:
201
- filtered_files.append(f)
202
- files = filtered_files
203
-
204
- count += len(files)
205
-
206
- return count
223
+ try:
224
+ with open(local_path, 'rb') as f:
225
+ content = f.read()
226
+ if pbar is not None:
227
+ pbar.update(file_size)
228
+
229
+ content_b64 = base64.b64encode(content).decode('utf-8')
230
+
231
+ files = [{
232
+ "path": remote_path,
233
+ "content": content_b64,
234
+ "size": len(content),
235
+ }]
236
+
237
+ result = _upload_files_via_api(session, api_url, files, message, branch)
238
+
239
+ if pbar is not None:
240
+ pbar.close()
241
+
242
+ return result
243
+ except Exception:
244
+ if pbar is not None:
245
+ pbar.close()
246
+ raise
247
+
248
+
249
+ def _upload_large_file(
250
+ session: requests.Session,
251
+ api_url: str,
252
+ local_path: Path,
253
+ remote_path: str,
254
+ message: str,
255
+ branch: str,
256
+ ) -> Dict:
257
+ """Upload large file (≥ 5MB) using LFS."""
258
+ file_size = os.path.getsize(local_path)
259
+ sha256 = _calculate_file_sha256(local_path)
260
+
261
+ files = [{
262
+ "path": remote_path,
263
+ "size": file_size,
264
+ "sha256": sha256,
265
+ }]
266
+
267
+ # Step 1: Request upload URL
268
+ result = _upload_files_via_api(session, api_url, files, message, branch)
269
+
270
+ # Step 2: Upload to S3 if needed
271
+ if result.get('needUpload'):
272
+ upload_url = result['uploadUrls'].get(remote_path)
273
+ if upload_url:
274
+ _upload_file_with_progress(
275
+ upload_url,
276
+ local_path,
277
+ desc=f"Uploading {local_path.name}"
278
+ )
279
+ print(f"Upload completed: {remote_path}")
280
+
281
+ return result
207
282
 
208
283
 
209
284
  def upload_folder(
@@ -217,331 +292,222 @@ def upload_folder(
217
292
  username: Optional[str] = None,
218
293
  password: Optional[str] = None,
219
294
  token: Optional[str] = None,
220
- ignore_patterns: Optional[List[str]] = None, # 上传的时候忽略的文件
221
- encryption_key: Optional[Union[str, bytes]] = None,
222
- encryption_exclude: Optional[List[str]] = None, # 加密的时候排除的文件
223
- encryption_algorithm: Optional[str] = None,
295
+ encryption_password: Optional[str] = None,
296
+ ignore_patterns: Optional[List[str]] = None,
224
297
  temp_dir: Optional[Union[str, Path]] = None,
225
- skip_lfs: Optional[bool] = True,
226
- ) -> str:
298
+ ) -> Dict:
227
299
  """
228
- Upload a folder to a repository using Git.
229
-
230
- This function clones the repository, copies files from the folder,
231
- commits the changes, and pushes to the remote repository.
300
+ Upload a folder to a repository using HTTP API.
232
301
 
233
302
  Args:
234
303
  folder_path: Path to the folder to upload
235
304
  repo_id: Repository ID in the format "organization/repo_name"
236
305
  repo_type: Type of repository ("models" or "datasets")
237
306
  revision: Branch to upload to (default: "main")
238
- commit_message: Commit message (default: "Upload folder")
307
+ commit_message: Commit message
239
308
  commit_description: Additional commit description
240
309
  base_url: Base URL of the Hub API
241
310
  username: Username for authentication
242
311
  password: Password for authentication
243
- token: Token for authentication (preferred over username/password)
244
- ignore_patterns: List of patterns to ignore (e.g., ['.git', '*.pyc', '__pycache__'])
245
- encryption_key: Encryption key for encrypted repositories (string for symmetric, PEM for asymmetric)
246
- encryption_exclude: List of file patterns to exclude from encryption (e.g., ['*.txt', 'README.md'])
247
- encryption_algorithm: Encryption algorithm to use (default: 'aes-256-cbc')
248
- - Symmetric: 'aes-256-cbc', 'aes-256-gcm'
249
- - Asymmetric: 'rsa-oaep', 'rsa-pkcs1v15' (requires RSA public key in PEM format)
250
- temp_dir: Custom temporary directory path for cloning repository (default: system temp directory)
251
- skip_lfs: Skip LFS files when cloning the repository (default: True)
312
+ token: Token for authentication (preferred)
313
+ encryption_password: Password for file encryption (optional)
314
+ ignore_patterns: List of patterns to ignore
315
+ temp_dir: Temporary directory for encrypted files (optional, auto-created if not specified)
252
316
 
253
317
  Returns:
254
- Commit hash of the uploaded changes
255
-
256
- Raises:
257
- RepositoryNotFoundError: If the repository does not exist
258
- EncryptionError: If repository requires encryption but encryption_key is not provided
259
-
260
- Example:
261
- >>> commit_hash = upload_folder(
262
- ... folder_path="./my_model",
263
- ... repo_id="demo/my-model",
264
- ... repo_type="models",
265
- ... commit_message="Upload model files",
266
- ... token="your-token",
267
- ... )
268
-
269
- >>> # Upload to encrypted repository
270
- >>> commit_hash = upload_folder(
271
- ... folder_path="./my_model",
272
- ... repo_id="demo/encrypted-model",
273
- ... repo_type="models",
274
- ... encryption_key="my-secret-key",
275
- ... encryption_exclude=["README.md", "*.txt"],
276
- ... token="your-token",
277
- ... )
318
+ Upload response
278
319
  """
279
- import tempfile
280
320
  import fnmatch
281
- from .client import HubClient
321
+ import shutil
282
322
 
323
+ # Parse repo_id
283
324
  parts = repo_id.split('/')
284
325
  if len(parts) != 2:
285
326
  raise ValueError(f"Invalid repo_id format: {repo_id}. Expected 'organization/repo_name'")
286
-
287
327
  organization, repo_name = parts
288
- client = HubClient(
289
- base_url=base_url,
290
- username=username,
291
- password=password,
292
- token=token,
293
- )
328
+
329
+ # Check if repository exists
330
+ client = HubClient(base_url=base_url, username=username, password=password, token=token)
294
331
  try:
295
- client.get_repository_info(
296
- organization=organization,
297
- repo_type=repo_type,
298
- repo_name=repo_name,
299
- )
332
+ client.get_repository_info(organization, repo_type, repo_name)
300
333
  except RepositoryNotFoundError:
301
334
  raise RepositoryNotFoundError(
302
- f"Repository '{repo_id}' does not exist. "
303
- "Please create the repository before uploading."
335
+ f"Repository not found: {organization}/{repo_type}/{repo_name}. "
336
+ f"Please create the repository first."
304
337
  )
305
- is_encrypted = bool(encryption_key)
306
- # Validate folder path
338
+
339
+ # Validate folder
307
340
  folder_path = Path(folder_path)
308
341
  if not folder_path.exists():
309
342
  raise FileNotFoundError(f"Folder not found: {folder_path}")
310
343
  if not folder_path.is_dir():
311
344
  raise ValueError(f"Path is not a directory: {folder_path}")
312
345
 
346
+ # Setup ignore patterns
313
347
  if ignore_patterns is None:
314
348
  ignore_patterns = []
315
-
316
- if encryption_exclude is None:
317
- encryption_exclude = []
318
-
319
- # Build Git URL
320
- git_url = _build_git_url(
321
- base_url=base_url,
322
- organization=organization,
323
- repo_type=repo_type,
324
- repo_name=repo_name,
325
- username=username,
326
- password=password,
327
- token=token,
328
- )
329
-
330
- if temp_dir is not None:
331
- import uuid
332
- temp_dir_path = Path(temp_dir)
333
- temp_dir_path.mkdir(parents=True, exist_ok=True)
334
- unique_dir = temp_dir_path / f"upload_{uuid.uuid4().hex[:8]}"
335
- unique_dir.mkdir(parents=True, exist_ok=True)
336
- temp_context = None
337
- working_temp_dir = unique_dir
338
- else:
339
- temp_context = tempfile.TemporaryDirectory()
340
- working_temp_dir = Path(temp_context.__enter__())
341
- unique_dir = None
349
+ if '.git' not in ignore_patterns and '.git/' not in ignore_patterns:
350
+ ignore_patterns.append('.git')
351
+ ignore_patterns.append('.gitattributes')
352
+
353
+ # Prepare encryption if needed
354
+ encryption_key = None
355
+ if encryption_password:
356
+ encryption_key = KeyManager.generate_key(encryption_password)
357
+
358
+ # Create or use temporary directory for encrypted files
359
+ temp_dir_path: Optional[Path] = None
360
+ if encryption_key:
361
+ if temp_dir:
362
+ # Use user-specified temp directory
363
+ temp_dir_path = Path(temp_dir)
364
+ temp_dir_path.mkdir(parents=True, exist_ok=True)
365
+ else:
366
+ # Auto-create temp directory
367
+ temp_dir_path = Path(tempfile.mkdtemp())
342
368
 
343
369
  try:
344
- repo_path = working_temp_dir / "repo"
345
-
346
- try:
347
- try:
348
- env_mapping = None
349
- if skip_lfs:
350
- env_mapping={'GIT_LFS_SKIP_SMUDGE': "1"}
351
- repo = Repo.clone_from(
352
- git_url,
353
- repo_path,
354
- branch=revision,
355
- depth=1,
356
- env=env_mapping,
357
- )
358
- except GitCommandError as e:
359
- if "Authentication failed" in str(e) or "authentication" in str(e).lower():
360
- raise AuthenticationError(f"Authentication failed: {e}")
361
- raise UploadError(f"Failed to clone repository: {e}")
362
-
363
- # 过滤掉忽略的文件,需要上传的文件总数
364
- total_files = _count_files_in_directory(str(folder_path), ignore_patterns)
365
- print(f"Copying files from {folder_path} to repository...")
366
- progress_bar = None
367
- if tqdm is not None and total_files > 0:
368
- progress_bar = tqdm(
369
- total=total_files,
370
- unit='file',
371
- desc="Copying files",
372
- leave=True,
373
- )
374
-
375
- files_copied = 0
376
- files_to_encrypt = [] # 标记哪些文件需要加密
377
-
378
- for root, dirs, files in os.walk(folder_path):
379
- dirs[:] = [d for d in dirs if not any(
380
- fnmatch.fnmatch(d, pattern) for pattern in ignore_patterns
381
- )]
382
-
383
- # 相对路径
384
- rel_root = os.path.relpath(root, folder_path)
385
- if rel_root == '.':
386
- dest_root = repo_path
387
- else:
388
- dest_root = repo_path / rel_root
389
-
390
- # Create destination directory
391
- dest_root.mkdir(parents=True, exist_ok=True)
392
-
393
- # Copy files
394
- for file in files:
395
- # Construct relative file path (relative to folder_path)
396
- if rel_root == '.':
397
- rel_file_path = file
398
- else:
399
- rel_file_path = os.path.join(rel_root, file)
400
-
401
- # Check ignore patterns using full relative path
402
- if any(fnmatch.fnmatch(rel_file_path, pattern) for pattern in ignore_patterns):
403
- continue
404
-
405
- src_file = Path(root) / file
406
- dest_file = dest_root / file
407
-
408
- # Copy file
409
- shutil.copy2(src_file, dest_file)
410
- files_copied += 1
411
-
412
- # 把需要加密的文件写入列表 (使用完整相对路径)
413
- if is_encrypted and _should_encrypt_file(rel_file_path, encryption_exclude):
414
- files_to_encrypt.append(dest_file)
415
-
416
- if progress_bar is not None:
417
- progress_bar.update(1)
418
-
419
- if progress_bar is not None:
420
- progress_bar.close()
421
-
422
- print(f"Copied {files_copied} files")
423
-
424
- # 处理加密元数据文件
425
- metadata_file = repo_path / ".moha_encryption"
370
+ # Create session and API URL
371
+ session = _create_session(token, username, password)
372
+ api_url = _build_api_url(base_url, organization, repo_type, repo_name)
373
+
374
+ # Prepare commit message
375
+ if commit_message is None:
376
+ commit_message = f"Upload folder from {folder_path.name}"
377
+ if commit_description:
378
+ commit_message = f"{commit_message}\n\n{commit_description}"
379
+
380
+ # Collect all files
381
+ files_to_upload = []
382
+ large_files = [] # Files >= 5MB
383
+ small_files = [] # Files < 5MB
384
+ for root, dirs, files in os.walk(folder_path):
385
+ # Filter directories
386
+ dirs[:] = [d for d in dirs if not any(
387
+ fnmatch.fnmatch(d, pattern) for pattern in ignore_patterns
388
+ )]
426
389
 
427
- # 加密
428
- if is_encrypted and files_to_encrypt:
429
- # Determine encryption algorithm
430
- from .encryption import EncryptionAlgorithm, encrypt_file as encrypt_file_func
390
+ rel_root = os.path.relpath(root, folder_path)
431
391
 
432
- if encryption_algorithm is None:
433
- algorithm = EncryptionAlgorithm.AES_256_CBC
392
+ for file in files:
393
+ # Construct relative file path
394
+ if rel_root == '.':
395
+ rel_file_path = file
434
396
  else:
435
- try:
436
- algorithm = EncryptionAlgorithm(encryption_algorithm)
437
- except ValueError:
438
- raise EncryptionError(
439
- f"Invalid encryption algorithm: {encryption_algorithm}. "
440
- f"Supported algorithms: {', '.join([a.value for a in EncryptionAlgorithm])}"
441
- )
442
-
443
- print(f"Encrypting {len(files_to_encrypt)} files using {algorithm.value}...")
444
- encrypt_progress = None
445
- if tqdm is not None:
446
- encrypt_progress = tqdm(
447
- total=len(files_to_encrypt),
448
- unit='file',
449
- desc="Encrypting files",
450
- leave=True,
397
+ rel_file_path = os.path.join(rel_root, file)
398
+
399
+ # Check if file matches ignore pattern
400
+ if any(fnmatch.fnmatch(rel_file_path, pattern) for pattern in ignore_patterns):
401
+ continue
402
+
403
+ local_file = Path(root) / file
404
+ file_size = local_file.stat().st_size
405
+
406
+ # Encrypt file if needed (only for large files with specific extensions)
407
+ actual_file = local_file
408
+ if (encryption_key and temp_dir_path and
409
+ file_size >= 5 * 1024 * 1024 and # Only encrypt files >= 5MB
410
+ local_file.suffix.lower() in ENCRYPTABLE_EXTENSIONS):
411
+ # Preserve directory structure in temp dir
412
+ encrypted_file = temp_dir_path / rel_file_path
413
+ encrypted_file.parent.mkdir(parents=True, exist_ok=True)
414
+ enc_file(
415
+ source=local_file,
416
+ dest=encrypted_file,
417
+ encryption_key=encryption_key,
418
+ manifest_path=temp_dir_path / "xpai_encryption_manifest.enc",
451
419
  )
452
-
453
- # 收集加密文件的元数据
454
- encrypted_files_metadata = []
455
-
456
- for file_path in files_to_encrypt:
457
- # 加密文件
458
- encrypt_file_func(file_path, encryption_key, algorithm)
459
-
460
- # 计算加密后的文件信息
461
- encrypted_size = file_path.stat().st_size
462
- encrypted_hash = _calculate_file_hash(file_path)
463
-
464
- # 获取相对于 repo_path 的路径
465
- rel_path = file_path.relative_to(repo_path)
466
-
467
- # 添加到元数据列表
468
- encrypted_files_metadata.append({
469
- "path": str(rel_path),
470
- "algorithm": algorithm.value,
471
- "encryptedSize": encrypted_size,
472
- "encryptedHash": encrypted_hash,
420
+ actual_file = encrypted_file
421
+
422
+ # Determine upload method based on size
423
+ if file_size >= 5 * 1024 * 1024: # 5MB
424
+ # Large file - use LFS
425
+ sha256 = _calculate_file_sha256(actual_file)
426
+ files_to_upload.append({
427
+ "path": rel_file_path,
428
+ "size": file_size,
429
+ "sha256": sha256,
430
+ })
431
+ large_files.append((actual_file, rel_file_path))
432
+ else:
433
+ # Small file - use base64
434
+ with open(actual_file, 'rb') as f:
435
+ content = f.read()
436
+ content_b64 = base64.b64encode(content).decode('utf-8')
437
+ files_to_upload.append({
438
+ "path": rel_file_path,
439
+ "content": content_b64,
440
+ "size": file_size,
473
441
  })
442
+ small_files.append((rel_file_path, file_size))
443
+
444
+ # Add encryption manifest file if it exists
445
+ if temp_dir_path:
446
+ manifest_file = temp_dir_path / "xpai_encryption_manifest.enc"
447
+ if manifest_file.exists():
448
+ manifest_size = manifest_file.stat().st_size
449
+ # Read manifest file content
450
+ with open(manifest_file, 'rb') as f:
451
+ manifest_content = f.read()
452
+ manifest_b64 = base64.b64encode(manifest_content).decode('utf-8')
453
+
454
+ # Add to files to upload
455
+ files_to_upload.append({
456
+ "path": "xpai_encryption_manifest.enc",
457
+ "content": manifest_b64,
458
+ "size": manifest_size,
459
+ })
460
+ small_files.append(("xpai_encryption_manifest.enc", manifest_size))
461
+
462
+ print(f"Found {len(files_to_upload)} files to upload ({len(large_files)} large files, {len(small_files)} small files)")
463
+ # Show progress for small files upload
464
+ small_files_pbar = None
465
+ if small_files and tqdm:
466
+ total_small_size = sum(size for _, size in small_files)
467
+ small_files_pbar = tqdm(
468
+ total=total_small_size,
469
+ unit='B',
470
+ unit_scale=True,
471
+ unit_divisor=1024,
472
+ desc="Uploading small files",
473
+ )
474
+
475
+ # Upload files via API
476
+ try:
477
+ result = _upload_files_via_api(session, api_url, files_to_upload, commit_message, revision)
478
+
479
+ # Update progress bar for small files
480
+ if small_files_pbar is not None:
481
+ for _, size in small_files:
482
+ small_files_pbar.update(size)
483
+ small_files_pbar.close()
484
+ except Exception:
485
+ if small_files_pbar is not None:
486
+ small_files_pbar.close()
487
+ raise
488
+
489
+ # Upload large files to S3 if needed
490
+ if result.get('needUpload') and large_files:
491
+ upload_urls = result.get('uploadUrls', {})
492
+ for local_file, remote_path in large_files:
493
+ upload_url = upload_urls.get(remote_path)
494
+ if upload_url:
495
+ _upload_file_with_progress(
496
+ upload_url,
497
+ local_file,
498
+ desc=f"Uploading {local_file.name}"
499
+ )
500
+ print(f"Upload completed: {remote_path}")
474
501
 
475
- if encrypt_progress is not None:
476
- encrypt_progress.update(1)
477
-
478
- if encrypt_progress is not None:
479
- encrypt_progress.close()
480
-
481
- print(f"Encrypted {len(files_to_encrypt)} files")
482
-
483
- # 写入加密元数据文件
484
- encryption_metadata = _create_encryption_metadata(
485
- encrypted_files=encrypted_files_metadata,
486
- algorithm=algorithm.value
487
- )
488
- _write_encryption_metadata(repo_path, encryption_metadata)
489
- else:
490
- # 如果不是加密上传,删除可能存在的加密元数据文件
491
- if metadata_file.exists():
492
- metadata_file.unlink()
493
- print("Removed .moha_encryption file (non-encrypted upload)")
494
-
495
- repo.git.add(A=True)
496
-
497
- if not repo.is_dirty() and not repo.untracked_files:
498
- print("No changes to commit")
499
- return repo.head.commit.hexsha
500
-
501
- if commit_message is None:
502
- commit_message = f"Upload folder from {folder_path.name}"
503
-
504
- full_message = commit_message
505
- if commit_description:
506
- full_message = f"{commit_message}\n\n{commit_description}"
507
-
508
- # Commit changes
509
- print(f"Committing changes: {commit_message}")
510
- commit = repo.index.commit(full_message)
511
- try:
512
- origin = repo.remote(name='origin')
513
- origin.push(refspec=f'{revision}:{revision}')
514
- except GitCommandError as e:
515
- if "Authentication failed" in str(e) or "authentication" in str(e).lower():
516
- raise AuthenticationError(f"Authentication failed during push: {e}")
517
- raise UploadError(f"Failed to push changes: {e}")
518
-
519
- print(f"Successfully uploaded to {repo_id}")
520
- print(f"Commit hash: {commit.hexsha}")
521
-
522
- return commit.hexsha
523
-
524
- except (GitCommandError, InvalidGitRepositoryError) as e:
525
- raise UploadError(f"Git operation failed: {e}")
502
+ print(f"Successfully uploaded to {repo_id}")
503
+ return result
526
504
  finally:
527
- # Clean up temporary directory
528
- if temp_context is not None:
529
- # Clean up system temp directory
530
- try:
531
- temp_context.__exit__(None, None, None)
532
- except Exception:
533
- pass # Ignore cleanup errors
534
- elif unique_dir is not None:
535
- # Clean up custom temp directory
536
- try:
537
- import shutil as shutil_cleanup
538
- shutil_cleanup.rmtree(unique_dir, ignore_errors=True)
539
- except Exception:
540
- pass # Ignore cleanup errors
505
+ if temp_dir_path and temp_dir_path.exists():
506
+ shutil.rmtree(temp_dir_path)
541
507
 
542
508
 
543
509
  def upload_file(
544
- path_or_fileobj: Union[str, Path, bytes],
510
+ path_file: Union[str, Path],
545
511
  path_in_repo: str,
546
512
  repo_id: str,
547
513
  repo_type: str = "models",
@@ -552,17 +518,14 @@ def upload_file(
552
518
  username: Optional[str] = None,
553
519
  password: Optional[str] = None,
554
520
  token: Optional[str] = None,
555
- encryption_key: Optional[Union[str, bytes]] = None,
556
- encryption_algorithm: Optional[str] = None,
557
- temp_dir: Optional[Union[str, Path]] = None,
558
- skip_lfs: Optional[bool] = True, # 当克隆仓库的时候,跳过lfs的大文件下载,只需要下载lfs文件指针
559
- ) -> str:
521
+ encryption_password: Optional[str] = None,
522
+ ) -> Dict:
560
523
  """
561
- Upload a single file to a repository.
524
+ Upload a single file to a repository using HTTP API.
562
525
 
563
526
  Args:
564
- path_or_fileobj: Path to the file or file content as bytes
565
- path_in_repo: Path where the file should be stored in the repository
527
+ path_file: Path to the local file
528
+ path_in_repo: Path in the repository
566
529
  repo_id: Repository ID in the format "organization/repo_name"
567
530
  repo_type: Type of repository ("models" or "datasets")
568
531
  revision: Branch to upload to (default: "main")
@@ -571,283 +534,86 @@ def upload_file(
571
534
  base_url: Base URL of the Hub API
572
535
  username: Username for authentication
573
536
  password: Password for authentication
574
- token: Token for authentication
575
- encryption_key: Encryption key for encrypted repositories (string for symmetric, PEM for asymmetric)
576
- encryption_algorithm: Encryption algorithm to use (default: 'aes-256-cbc')
577
- - Symmetric: 'aes-256-cbc', 'aes-256-gcm'
578
- - Asymmetric: 'rsa-oaep', 'rsa-pkcs1v15' (requires RSA public key in PEM format)
579
- temp_dir: Custom temporary directory path for cloning repository (default: system temp directory)
580
- skip_lfs: Skip LFS files when cloning the repository (default: True)
537
+ token: Token for authentication (preferred)
538
+ encryption_password: Password for file encryption (optional)
581
539
 
582
540
  Returns:
583
- Commit hash of the uploaded file
584
-
585
- Raises:
586
- RepositoryNotFoundError: If the repository does not exist
587
- EncryptionError: If repository requires encryption but encryption_key is not provided
588
-
589
- Example:
590
- >>> commit_hash = upload_file(
591
- ... path_or_fileobj="./config.yaml",
592
- ... path_in_repo="config.yaml",
593
- ... repo_id="demo/my-model",
594
- ... commit_message="Upload config file",
595
- ... token="your-token",
596
- ... )
597
-
598
- >>> # Upload to encrypted repository
599
- >>> commit_hash = upload_file(
600
- ... path_or_fileobj="./model.bin",
601
- ... path_in_repo="model.bin",
602
- ... repo_id="demo/encrypted-model",
603
- ... encryption_key="my-secret-key",
604
- ... token="your-token",
605
- ... )
541
+ Upload response
606
542
  """
607
- import tempfile
608
- from .client import HubClient
543
+ import shutil
609
544
 
610
545
  # Parse repo_id
611
546
  parts = repo_id.split('/')
612
547
  if len(parts) != 2:
613
548
  raise ValueError(f"Invalid repo_id format: {repo_id}. Expected 'organization/repo_name'")
614
-
615
549
  organization, repo_name = parts
616
550
 
617
- client = HubClient(
618
- base_url=base_url,
619
- username=username,
620
- password=password,
621
- token=token,
622
- )
551
+ # Check if repository exists
552
+ client = HubClient(base_url=base_url, username=username, password=password, token=token)
623
553
  try:
624
- client.get_repository_info(
625
- organization=organization,
626
- repo_type=repo_type,
627
- repo_name=repo_name,
628
- )
554
+ client.get_repository_info(organization, repo_type, repo_name)
629
555
  except RepositoryNotFoundError:
630
556
  raise RepositoryNotFoundError(
631
- f"Repository '{repo_id}' does not exist. "
632
- "Please create the repository before uploading."
557
+ f"Repository not found: {organization}/{repo_type}/{repo_name}. "
558
+ f"Please create the repository first."
633
559
  )
634
- is_encrypted = bool(encryption_key)
635
- # 构建 Git URL
636
- git_url = _build_git_url(
637
- base_url=base_url,
638
- organization=organization,
639
- repo_type=repo_type,
640
- repo_name=repo_name,
641
- username=username,
642
- password=password,
643
- token=token,
644
- )
645
-
646
- # 该临时目录用来存储克隆的仓库
647
- if temp_dir is not None:
648
- # 创建子目录
649
- import uuid
650
- temp_dir_path = Path(temp_dir)
651
- temp_dir_path.mkdir(parents=True, exist_ok=True)
652
- unique_dir = temp_dir_path / f"upload_{uuid.uuid4().hex[:8]}"
653
- unique_dir.mkdir(parents=True, exist_ok=True)
654
- temp_context = None
655
- working_temp_dir = unique_dir
656
- else:
657
- # 使用系统的临时目录
658
- temp_context = tempfile.TemporaryDirectory()
659
- working_temp_dir = Path(temp_context.__enter__())
660
- unique_dir = None
661
-
662
- try:
663
- repo_path = working_temp_dir / "repo"
664
560
 
665
- try:
666
- # 克隆远程仓库到临时目录,depth=1,只克隆最新的提交,减少数据量
667
- try:
668
- env_mapping = None
669
- if skip_lfs:
670
- env_mapping={'GIT_LFS_SKIP_SMUDGE': "1"}
671
- repo = Repo.clone_from(
672
- git_url,
673
- repo_path,
674
- branch=revision,
675
- depth=1,
676
- env=env_mapping,
677
- )
678
- except GitCommandError as e:
679
- if "Authentication failed" in str(e) or "authentication" in str(e).lower():
680
- raise AuthenticationError(f"Authentication failed: {e}")
681
- raise UploadError(f"Failed to clone repository: {e}")
682
-
683
- # Prepare file path in repository
684
- file_path = repo_path / path_in_repo
685
- file_path.parent.mkdir(parents=True, exist_ok=True)
686
-
687
- # Write file content
688
- if isinstance(path_or_fileobj, bytes):
689
- # Write bytes directly
690
- file_path.write_bytes(path_or_fileobj)
691
- else:
692
- # Copy from source file
693
- source_path = Path(path_or_fileobj)
694
- if not source_path.exists():
695
- raise FileNotFoundError(f"File not found: {source_path}")
696
- shutil.copy2(source_path, file_path)
697
-
698
- print(f"Added file: {path_in_repo}")
699
-
700
- # 处理加密元数据文件
701
- metadata_file = repo_path / ".moha_encryption"
702
- metadata_updated = False
703
-
704
- # Encrypt file if repository is encrypted
705
- if is_encrypted:
706
- # Determine encryption algorithm
707
- from .encryption import EncryptionAlgorithm, encrypt_file as encrypt_file_func
708
-
709
- if encryption_algorithm is None:
710
- algorithm = EncryptionAlgorithm.AES_256_CBC
711
- else:
712
- try:
713
- algorithm = EncryptionAlgorithm(encryption_algorithm)
714
- except ValueError:
715
- raise EncryptionError(
716
- f"Invalid encryption algorithm: {encryption_algorithm}. "
717
- f"Supported algorithms: {', '.join([a.value for a in EncryptionAlgorithm])}"
718
- )
719
-
720
- print(f"Encrypting file: {path_in_repo} using {algorithm.value}")
721
- encrypt_file_func(file_path, encryption_key, algorithm)
722
- print(f"File encrypted")
723
-
724
- # 计算加密后的文件信息
725
- encrypted_size = file_path.stat().st_size
726
- encrypted_hash = _calculate_file_hash(file_path)
727
-
728
- # 读取或创建加密元数据文件
729
- if metadata_file.exists():
730
- with open(metadata_file, 'r', encoding='utf-8') as f:
731
- encryption_metadata = json.load(f)
732
- else:
733
- encryption_metadata = {
734
- "version": "1.0",
735
- "createAt": datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z'),
736
- "files": []
737
- }
738
-
739
- # 更新或添加文件信息
740
- file_info = {
741
- "path": path_in_repo,
742
- "algorithm": algorithm.value,
743
- "encryptedSize": encrypted_size,
744
- "encryptedHash": encrypted_hash,
745
- }
746
-
747
- # 检查文件是否已存在,如果存在则更新
748
- existing_index = None
749
- for i, f in enumerate(encryption_metadata["files"]):
750
- if f["path"] == path_in_repo:
751
- existing_index = i
752
- break
753
-
754
- if existing_index is not None:
755
- encryption_metadata["files"][existing_index] = file_info
756
- else:
757
- encryption_metadata["files"].append(file_info)
758
-
759
- # 写入元数据文件
760
- _write_encryption_metadata(repo_path, encryption_metadata)
761
- metadata_updated = True
762
- else:
763
- # 如果不是加密上传,从元数据文件中移除该文件(如果存在)
764
- if metadata_file.exists():
765
- with open(metadata_file, 'r', encoding='utf-8') as f:
766
- encryption_metadata = json.load(f)
767
-
768
- # 查找并移除该文件
769
- original_count = len(encryption_metadata["files"])
770
- encryption_metadata["files"] = [
771
- f for f in encryption_metadata["files"]
772
- if f["path"] != path_in_repo
773
- ]
774
-
775
- if len(encryption_metadata["files"]) < original_count:
776
- # 文件被移除了
777
- if len(encryption_metadata["files"]) == 0:
778
- # 如果没有加密文件了,删除元数据文件
779
- metadata_file.unlink()
780
- print(f"Removed .moha_encryption file (no encrypted files remaining)")
781
- else:
782
- # 更新元数据文件
783
- _write_encryption_metadata(repo_path, encryption_metadata)
784
- print(f"Updated .moha_encryption file (removed {path_in_repo})")
785
- metadata_updated = True
786
-
787
- # Add file to git
788
- repo.git.add(path_in_repo)
789
-
790
- # 如果元数据文件被更新,也添加到 git
791
- if metadata_updated:
792
- if metadata_file.exists():
793
- repo.git.add(".moha_encryption")
794
- else:
795
- # 如果文件被删除,确保从 git 中移除
796
- try:
797
- repo.git.rm(".moha_encryption")
798
- except GitCommandError:
799
- # 文件可能不在 git 中,忽略错误
800
- pass
801
-
802
- # Check if there are changes
803
- if not repo.is_dirty() and not repo.untracked_files:
804
- print("No changes to commit (file already exists with same content)")
805
- return repo.head.commit.hexsha
806
-
807
- # Create commit message
808
- if commit_message is None:
809
- commit_message = f"Upload {path_in_repo}"
810
-
811
- full_message = commit_message
812
- if commit_description:
813
- full_message = f"{commit_message}\n\n{commit_description}"
814
-
815
- # Commit changes
816
- print(f"Committing changes: {commit_message}")
817
- commit = repo.index.commit(full_message)
818
-
819
- # Push to remote
820
- print(f"Pushing to {revision}...")
821
- try:
822
- origin = repo.remote(name='origin')
823
- origin.push(refspec=f'{revision}:{revision}')
824
- except GitCommandError as e:
825
- if "Authentication failed" in str(e) or "authentication" in str(e).lower():
826
- raise AuthenticationError(f"Authentication failed during push: {e}")
827
- raise UploadError(f"Failed to push changes: {e}")
828
-
829
- print(f"Successfully uploaded {path_in_repo} to {repo_id}")
830
- print(f"Commit hash: {commit.hexsha}")
831
-
832
- return commit.hexsha
833
-
834
- except (GitCommandError, InvalidGitRepositoryError) as e:
835
- raise UploadError(f"Git operation failed: {e}")
836
- finally:
837
- # Clean up temporary directory
838
- if temp_context is not None:
839
- # Clean up system temp directory
840
- try:
841
- temp_context.__exit__(None, None, None)
842
- except Exception:
843
- pass # Ignore cleanup errors
844
- elif unique_dir is not None:
845
- # Clean up custom temp directory
846
- try:
847
- import shutil as shutil_cleanup
848
- shutil_cleanup.rmtree(unique_dir, ignore_errors=True)
849
- except Exception:
850
- pass # Ignore cleanup errors
561
+ # Validate file
562
+ path_file = Path(path_file)
563
+ if not path_file.exists():
564
+ raise FileNotFoundError(f"File not found: {path_file}")
565
+ if not path_file.is_file():
566
+ raise ValueError(f"Path is not a file: {path_file}")
851
567
 
568
+ # Encrypt file if needed
569
+ encrypted_file, temp_dir = _encrypt_file_if_needed(path_file, encryption_password)
570
+ actual_file = encrypted_file if encrypted_file else path_file
852
571
 
572
+ try:
573
+ # Create session and API URL
574
+ session = _create_session(token, username, password)
575
+ api_url = _build_api_url(base_url, organization, repo_type, repo_name)
576
+
577
+ # Prepare commit message
578
+ if commit_message is None:
579
+ commit_message = f"Upload {path_in_repo}"
580
+ if commit_description:
581
+ commit_message = f"{commit_message}\n\n{commit_description}"
582
+
583
+ # Check file size
584
+ file_size = actual_file.stat().st_size
585
+
586
+ # Check if encryption manifest file exists
587
+ manifest_file = None
588
+ if temp_dir:
589
+ manifest_path = temp_dir / "xpai_encryption_manifest.enc"
590
+ if manifest_path.exists():
591
+ manifest_file = manifest_path
592
+ print(f"Found encryption manifest file: xpai_encryption_manifest.enc")
593
+
594
+ # Upload main file
595
+ if file_size >= 5 * 1024 * 1024: # 5MB
596
+ # Large file
597
+ result = _upload_large_file(session, api_url, actual_file, path_in_repo, commit_message, revision)
598
+ else:
599
+ # Small file
600
+ result = _upload_small_file(session, api_url, actual_file, path_in_repo, commit_message, revision)
601
+
602
+ # Upload manifest file if it exists
603
+ if manifest_file:
604
+ _upload_small_file(
605
+ session,
606
+ api_url,
607
+ manifest_file,
608
+ "xpai_encryption_manifest.enc",
609
+ f"{commit_message} (manifest)",
610
+ revision
611
+ )
612
+
613
+ print(f"Successfully uploaded {path_in_repo} to {repo_id}")
614
+ return result
615
+ finally:
616
+ # Clean up temporary encrypted file
617
+ if temp_dir and temp_dir.exists():
618
+ shutil.rmtree(temp_dir)
853
619