xiaoshiai-hub 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xiaoshiai_hub/download.py CHANGED
@@ -14,6 +14,30 @@ except ImportError:
14
14
 
15
15
  from .client import HubClient, DEFAULT_BASE_URL
16
16
  from .types import GitContent
17
+ from .exceptions import EncryptionError, RepositoryNotFoundError
18
+
19
+
20
+ def _should_decrypt_file(filename: str, decryption_exclude: Optional[List[str]] = None) -> bool:
21
+ """
22
+ Check if a file should be decrypted based on exclude patterns.
23
+
24
+ Args:
25
+ filename: Name of the file
26
+ decryption_exclude: List of patterns to exclude from decryption
27
+
28
+ Returns:
29
+ True if the file should be decrypted
30
+ """
31
+ import fnmatch
32
+
33
+ if not decryption_exclude:
34
+ return True
35
+
36
+ for pattern in decryption_exclude:
37
+ if fnmatch.fnmatch(filename, pattern):
38
+ return False
39
+
40
+ return True
17
41
 
18
42
 
19
43
  def _match_pattern(name: str, pattern: str) -> bool:
@@ -138,6 +162,9 @@ def _download_repository_recursively(
138
162
  ignore_patterns: Optional[List[str]] = None,
139
163
  verbose: bool = True,
140
164
  progress_bar = None,
165
+ encryption_metadata = None,
166
+ decryption_key: Optional[Union[str, bytes]] = None,
167
+ decryption_algorithm: Optional[str] = None,
141
168
  ) -> None:
142
169
  """
143
170
  Recursively download repository contents.
@@ -154,8 +181,11 @@ def _download_repository_recursively(
154
181
  ignore_patterns: Patterns to ignore
155
182
  verbose: Print progress messages
156
183
  progress_bar: Optional tqdm progress bar for overall progress
184
+ encryption_metadata: Encryption metadata from .moha_encryption file
185
+ decryption_key: Key to decrypt files
186
+ decryption_exclude: Patterns to exclude from decryption
187
+ decryption_algorithm: Algorithm to use for decryption
157
188
  """
158
- # Get content at current path
159
189
  content = client.get_repository_content(
160
190
  organization=organization,
161
191
  repo_type=repo_type,
@@ -168,7 +198,7 @@ def _download_repository_recursively(
168
198
  if content.entries:
169
199
  for entry in content.entries:
170
200
  if entry.type == "file":
171
- # Check if file should be downloaded
201
+ # 检查文件是否应该被下载
172
202
  if _should_download_file(entry.path, allow_patterns, ignore_patterns):
173
203
  if verbose and progress_bar is None:
174
204
  print(f"Downloading file: {entry.path}")
@@ -188,8 +218,45 @@ def _download_repository_recursively(
188
218
  local_path=local_path,
189
219
  show_progress=progress_bar is None, # Show individual progress only if no overall progress
190
220
  )
191
-
192
- # Update overall progress
221
+ file_is_encrypted = False
222
+ file_encryption_algorithm = None
223
+ if encryption_metadata and encryption_metadata.files:
224
+ for file_meta in encryption_metadata.files:
225
+ if file_meta.path == entry.path:
226
+ file_is_encrypted = True
227
+ file_encryption_algorithm = file_meta.algorithm
228
+ break
229
+
230
+ if file_is_encrypted:
231
+ if not decryption_key:
232
+ raise EncryptionError(
233
+ f"File '{entry.path}' is encrypted, but no decryption_key was provided. "
234
+ "Please provide a decryption_key parameter."
235
+ )
236
+ if not decryption_algorithm:
237
+ raise EncryptionError(
238
+ f"File '{entry.path}' is encrypted, but no decryption_algorithm was provided. "
239
+ f"The file was encrypted with '{file_encryption_algorithm}'. "
240
+ "Please provide a decryption_algorithm parameter."
241
+ )
242
+ if decryption_algorithm != file_encryption_algorithm:
243
+ raise EncryptionError(
244
+ f"File '{entry.path}' is encrypted with '{file_encryption_algorithm}', "
245
+ f"but decryption_algorithm '{decryption_algorithm}' was provided. "
246
+ "Please use the correct decryption algorithm."
247
+ )
248
+ if verbose and progress_bar is None:
249
+ print(f"Decrypting file: {entry.path}")
250
+ from .encryption import EncryptionAlgorithm, decrypt_file as decrypt_file_func
251
+ try:
252
+ algorithm = EncryptionAlgorithm(decryption_algorithm)
253
+ except ValueError:
254
+ raise EncryptionError(
255
+ f"Invalid decryption algorithm: {decryption_algorithm}. "
256
+ f"Supported algorithms: {', '.join([a.value for a in EncryptionAlgorithm])}"
257
+ )
258
+ decrypt_file_func(Path(local_path), decryption_key, algorithm)
259
+
193
260
  if progress_bar is not None:
194
261
  progress_bar.update(1)
195
262
  else:
@@ -199,8 +266,7 @@ def _download_repository_recursively(
199
266
  elif entry.type == "dir":
200
267
  if verbose and progress_bar is None:
201
268
  print(f"Entering directory: {entry.path}")
202
-
203
- # Recursively download directory contents
269
+ # 递归下载
204
270
  _download_repository_recursively(
205
271
  client=client,
206
272
  organization=organization,
@@ -213,26 +279,29 @@ def _download_repository_recursively(
213
279
  ignore_patterns=ignore_patterns,
214
280
  verbose=verbose,
215
281
  progress_bar=progress_bar,
282
+ encryption_metadata=encryption_metadata,
283
+ decryption_key=decryption_key,
284
+ decryption_algorithm=decryption_algorithm,
216
285
  )
217
-
218
286
  else:
219
287
  if verbose and progress_bar is None:
220
288
  print(f"Skipping {entry.type}: {entry.path}")
221
289
 
222
290
 
223
- def hf_hub_download(
291
+ def moha_hub_download(
224
292
  repo_id: str,
225
293
  filename: str,
226
294
  *,
227
295
  repo_type: str = "models",
228
296
  revision: Optional[str] = None,
229
- cache_dir: Optional[Union[str, Path]] = None,
230
297
  local_dir: Optional[Union[str, Path]] = None,
231
298
  base_url: Optional[str] = None,
232
299
  username: Optional[str] = None,
233
300
  password: Optional[str] = None,
234
301
  token: Optional[str] = None,
235
302
  show_progress: bool = True,
303
+ decryption_key: Optional[Union[str, bytes]] = None,
304
+ decryption_algorithm: Optional[str] = None,
236
305
  ) -> str:
237
306
  """
238
307
  Download a single file from a repository.
@@ -251,6 +320,10 @@ def hf_hub_download(
251
320
  password: Password for authentication
252
321
  token: Token for authentication
253
322
  show_progress: Whether to show download progress bar
323
+ decryption_key: Key to decrypt the file if repository is encrypted (string for symmetric, PEM for asymmetric)
324
+ decryption_algorithm: Decryption algorithm to use (default: 'aes-256-cbc')
325
+ - Symmetric: 'aes-256-cbc', 'aes-256-gcm'
326
+ - Asymmetric: 'rsa-oaep', 'rsa-pkcs1v15' (requires RSA private key in PEM format)
254
327
 
255
328
  Returns:
256
329
  Path to the downloaded file
@@ -258,49 +331,79 @@ def hf_hub_download(
258
331
  Example:
259
332
  >>> file_path = hf_hub_download(
260
333
  ... repo_id="demo/demo",
261
- ... filename="config.yaml",
334
+ ... filename="data/config.yaml",
262
335
  ... username="your-username",
263
336
  ... password="your-password",
264
337
  ... )
338
+
339
+ >>> # Download from encrypted repository
340
+ >>> file_path = hf_hub_download(
341
+ ... repo_id="demo/encrypted-model",
342
+ ... filename="model.bin",
343
+ ... decryption_key="my-secret-key",
344
+ ... token="your-token",
345
+ ... )
265
346
  """
266
- # Parse repo_id
267
347
  parts = repo_id.split('/')
268
348
  if len(parts) != 2:
269
349
  raise ValueError(f"Invalid repo_id format: {repo_id}. Expected 'organization/repo_name'")
270
-
271
350
  organization, repo_name = parts
272
-
273
- # Create client
274
351
  client = HubClient(
275
352
  base_url=base_url,
276
353
  username=username,
277
354
  password=password,
278
355
  token=token,
279
356
  )
280
-
281
- # Get repository info to determine branch
357
+ # 获取默认分支
282
358
  if revision is None:
283
- repo_info = client.get_repository_info(organization, repo_type, repo_name)
284
- revision = repo_info.default_branch or "main"
359
+ revision = client.get_default_branch(organization, repo_type, repo_name)
360
+
361
+ # 获取加密元数据
362
+ encryption_metadata = client.get_moha_encryption(
363
+ organization=organization,
364
+ repo_type=repo_type,
365
+ repo_name=repo_name,
366
+ reference=revision,
367
+ )
368
+
369
+ # 文件加密标识
370
+ file_is_encrypted = False
371
+ file_encryption_algorithm = None
372
+ # 查询文件是否加密
373
+ if encryption_metadata and encryption_metadata.files:
374
+ for file_meta in encryption_metadata.files:
375
+ if file_meta.path == filename:
376
+ file_is_encrypted = True
377
+ file_encryption_algorithm = file_meta.algorithm
378
+ break
379
+
380
+ # 如果该文件加密了,检查解密参数
381
+ if file_is_encrypted:
382
+ if not decryption_key:
383
+ raise EncryptionError(
384
+ f"File '{filename}' is encrypted, but no decryption_key was provided. "
385
+ "Please provide a decryption_key parameter."
386
+ )
387
+
388
+ if not decryption_algorithm:
389
+ raise EncryptionError(
390
+ f"File '{filename}' is encrypted, but no decryption_algorithm was provided. "
391
+ f"The file was encrypted with '{file_encryption_algorithm}'. "
392
+ "Please provide a decryption_algorithm parameter."
393
+ )
285
394
 
286
- # Determine local path
395
+ if decryption_algorithm != file_encryption_algorithm:
396
+ raise EncryptionError(
397
+ f"File '{filename}' is encrypted with '{file_encryption_algorithm}', "
398
+ f"but decryption_algorithm '{decryption_algorithm}' was provided. "
399
+ "Please use the correct decryption algorithm."
400
+ )
287
401
  if local_dir:
288
402
  local_path = os.path.join(local_dir, filename)
289
- elif cache_dir:
290
- # Create cache structure similar to huggingface_hub
291
- cache_path = os.path.join(
292
- cache_dir,
293
- f"{repo_type}--{organization}--{repo_name}",
294
- "snapshots",
295
- revision,
296
- filename,
297
- )
298
- local_path = cache_path
299
403
  else:
300
- # Default to current directory
301
404
  local_path = filename
302
405
 
303
- # Download file
406
+ # 下载文件
304
407
  client.download_file(
305
408
  organization=organization,
306
409
  repo_type=repo_type,
@@ -311,15 +414,26 @@ def hf_hub_download(
311
414
  show_progress=show_progress,
312
415
  )
313
416
 
417
+ # 解密文件
418
+ if file_is_encrypted and decryption_key:
419
+ from .encryption import EncryptionAlgorithm, decrypt_file as decrypt_file_func
420
+ try:
421
+ algorithm = EncryptionAlgorithm(decryption_algorithm)
422
+ except ValueError:
423
+ raise EncryptionError(
424
+ f"Invalid decryption algorithm: {decryption_algorithm}. "
425
+ f"Supported algorithms: {', '.join([a.value for a in EncryptionAlgorithm])}"
426
+ )
427
+
428
+ decrypt_file_func(Path(local_path), decryption_key, algorithm)
429
+
314
430
  return local_path
315
431
 
316
432
 
317
433
  def snapshot_download(
318
434
  repo_id: str,
319
- *,
320
435
  repo_type: str = "models",
321
436
  revision: Optional[str] = None,
322
- cache_dir: Optional[Union[str, Path]] = None,
323
437
  local_dir: Optional[Union[str, Path]] = None,
324
438
  allow_patterns: Optional[Union[List[str], str]] = None,
325
439
  ignore_patterns: Optional[Union[List[str], str]] = None,
@@ -329,6 +443,8 @@ def snapshot_download(
329
443
  token: Optional[str] = None,
330
444
  verbose: bool = True,
331
445
  show_progress: bool = True,
446
+ decryption_key: Optional[Union[str, bytes]] = None,
447
+ decryption_algorithm: Optional[str] = None,
332
448
  ) -> str:
333
449
  """
334
450
  Download an entire repository snapshot.
@@ -339,7 +455,6 @@ def snapshot_download(
339
455
  repo_id: Repository ID in the format "organization/repo_name"
340
456
  repo_type: Type of repository ("models" or "datasets")
341
457
  revision: Branch/tag/commit to download from (default: main branch)
342
- cache_dir: Directory to cache downloaded files
343
458
  local_dir: Directory to save files (if not using cache)
344
459
  allow_patterns: Pattern or list of patterns to allow (e.g., "*.yaml", "*.yml")
345
460
  ignore_patterns: Pattern or list of patterns to ignore (e.g., ".git*")
@@ -349,6 +464,10 @@ def snapshot_download(
349
464
  token: Token for authentication
350
465
  verbose: Print progress messages
351
466
  show_progress: Whether to show overall progress bar
467
+ decryption_key: Key to decrypt files if repository is encrypted (string for symmetric, PEM for asymmetric)
468
+ decryption_algorithm: Decryption algorithm to use (default: 'aes-256-cbc')
469
+ - Symmetric: 'aes-256-cbc', 'aes-256-gcm'
470
+ - Asymmetric: 'rsa-oaep', 'rsa-pkcs1v15' (requires RSA private key in PEM format)
352
471
 
353
472
  Returns:
354
473
  Path to the downloaded repository
@@ -362,46 +481,48 @@ def snapshot_download(
362
481
  ... username="your-username",
363
482
  ... password="your-password",
364
483
  ... )
484
+
485
+ >>> # Download from encrypted repository
486
+ >>> repo_path = snapshot_download(
487
+ ... repo_id="demo/encrypted-model",
488
+ ... repo_type="models",
489
+ ... decryption_key="my-secret-key",
490
+ ... decryption_exclude=["README.md", "*.txt"], # Don't decrypt these files
491
+ ... token="your-token",
492
+ ... )
365
493
  """
366
- # Parse repo_id
367
494
  parts = repo_id.split('/')
368
495
  if len(parts) != 2:
369
496
  raise ValueError(f"Invalid repo_id format: {repo_id}. Expected 'organization/repo_name'")
370
497
 
371
498
  organization, repo_name = parts
372
499
 
373
- # Normalize patterns to lists
374
500
  if isinstance(allow_patterns, str):
375
501
  allow_patterns = [allow_patterns]
376
502
  if isinstance(ignore_patterns, str):
377
503
  ignore_patterns = [ignore_patterns]
378
-
379
- # Create client
380
504
  client = HubClient(
381
505
  base_url=base_url,
382
506
  username=username,
383
507
  password=password,
384
508
  token=token,
385
509
  )
386
-
387
- # Get repository info
388
- repo_info = client.get_repository_info(organization, repo_type, repo_name)
389
-
390
- # Determine revision
391
510
  if revision is None:
392
- revision = repo_info.default_branch or "main"
511
+ revision = client.get_default_branch(organization, repo_type, repo_name)
512
+ encryption_metadata = client.get_moha_encryption(
513
+ organization=organization,
514
+ repo_type=repo_type,
515
+ repo_name=repo_name,
516
+ reference=revision,
517
+ )
518
+
519
+ if encryption_metadata and encryption_metadata.files:
520
+ if verbose:
521
+ print(f"Repository has {len(encryption_metadata.files)} encrypted file(s). Files will be decrypted after download.")
393
522
 
394
523
  # Determine local directory
395
524
  if local_dir:
396
525
  download_dir = str(local_dir)
397
- elif cache_dir:
398
- # Create cache structure
399
- download_dir = os.path.join(
400
- cache_dir,
401
- f"{repo_type}--{organization}--{repo_name}",
402
- "snapshots",
403
- revision,
404
- )
405
526
  else:
406
527
  # Default to downloads directory
407
528
  download_dir = f"./downloads/{organization}_{repo_type}_{repo_name}"
@@ -412,13 +533,11 @@ def snapshot_download(
412
533
  print(f"Revision: {revision}")
413
534
  print(f"Destination: {download_dir}")
414
535
 
415
- # Create progress bar if requested
416
536
  progress_bar = None
417
537
  if show_progress and tqdm is not None:
418
- # Count total files to download
419
538
  if verbose:
420
539
  print(f"Fetching repository info...")
421
-
540
+ # 计算需要下载的文件总数
422
541
  total_files = _count_files_to_download(
423
542
  client=client,
424
543
  organization=organization,
@@ -438,7 +557,7 @@ def snapshot_download(
438
557
  leave=True,
439
558
  )
440
559
 
441
- # Download recursively
560
+ # 递归下载,不是使用git的方式
442
561
  try:
443
562
  _download_repository_recursively(
444
563
  client=client,
@@ -452,6 +571,9 @@ def snapshot_download(
452
571
  ignore_patterns=ignore_patterns,
453
572
  verbose=verbose,
454
573
  progress_bar=progress_bar,
574
+ encryption_metadata=encryption_metadata,
575
+ decryption_key=decryption_key,
576
+ decryption_algorithm=decryption_algorithm,
455
577
  )
456
578
  finally:
457
579
  if progress_bar is not None:
@@ -460,5 +582,17 @@ def snapshot_download(
460
582
  if verbose and not show_progress:
461
583
  print(f"Download completed to: {download_dir}")
462
584
 
585
+ # Add download count
586
+ try:
587
+ client.add_download_count(
588
+ organization=organization,
589
+ repo_type=repo_type,
590
+ repo_name=repo_name,
591
+ )
592
+ except Exception as e:
593
+ # Don't fail the download if adding count fails
594
+ if verbose:
595
+ print(f"Warning: Failed to add download count: {e}")
596
+
463
597
  return download_dir
464
598