xiaoshiai-hub 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,598 @@
1
+ """
2
+ Download utilities for XiaoShi AI Hub SDK
3
+ """
4
+
5
+ import fnmatch
6
+ import os
7
+ from pathlib import Path
8
+ from typing import List, Optional, Union
9
+
10
+ try:
11
+ from tqdm.auto import tqdm
12
+ except ImportError:
13
+ tqdm = None
14
+
15
+ from .client import HubClient, DEFAULT_BASE_URL
16
+ from .types import GitContent
17
+ from .exceptions import EncryptionError, RepositoryNotFoundError
18
+
19
+
20
+ def _should_decrypt_file(filename: str, decryption_exclude: Optional[List[str]] = None) -> bool:
21
+ """
22
+ Check if a file should be decrypted based on exclude patterns.
23
+
24
+ Args:
25
+ filename: Name of the file
26
+ decryption_exclude: List of patterns to exclude from decryption
27
+
28
+ Returns:
29
+ True if the file should be decrypted
30
+ """
31
+ import fnmatch
32
+
33
+ if not decryption_exclude:
34
+ return True
35
+
36
+ for pattern in decryption_exclude:
37
+ if fnmatch.fnmatch(filename, pattern):
38
+ return False
39
+
40
+ return True
41
+
42
+
43
+ def _match_pattern(name: str, pattern: str) -> bool:
44
+ """
45
+ Match a filename against a pattern.
46
+
47
+ Supports wildcards:
48
+ - * matches any characters
49
+ - *.ext matches files with extension
50
+ - prefix* matches files starting with prefix
51
+
52
+ Args:
53
+ name: Filename to match
54
+ pattern: Pattern to match against
55
+
56
+ Returns:
57
+ True if the name matches the pattern
58
+ """
59
+ return fnmatch.fnmatch(name, pattern)
60
+
61
+
62
+ def _should_download_file(
63
+ file_path: str,
64
+ allow_patterns: Optional[List[str]] = None,
65
+ ignore_patterns: Optional[List[str]] = None,
66
+ ) -> bool:
67
+ """
68
+ Determine if a file should be downloaded based on patterns.
69
+
70
+ Args:
71
+ file_path: Path of the file
72
+ allow_patterns: List of patterns to allow (if None, allow all)
73
+ ignore_patterns: List of patterns to ignore
74
+
75
+ Returns:
76
+ True if the file should be downloaded
77
+ """
78
+ filename = os.path.basename(file_path)
79
+
80
+ # Check ignore patterns first
81
+ if ignore_patterns:
82
+ for pattern in ignore_patterns:
83
+ if _match_pattern(filename, pattern) or _match_pattern(file_path, pattern):
84
+ return False
85
+
86
+ # If no allow patterns, allow all (except ignored)
87
+ if not allow_patterns:
88
+ return True
89
+
90
+ # Check allow patterns
91
+ for pattern in allow_patterns:
92
+ if _match_pattern(filename, pattern) or _match_pattern(file_path, pattern):
93
+ return True
94
+
95
+ return False
96
+
97
+
98
+ def _count_files_to_download(
99
+ client: HubClient,
100
+ organization: str,
101
+ repo_type: str,
102
+ repo_name: str,
103
+ branch: str,
104
+ path: str,
105
+ allow_patterns: Optional[List[str]] = None,
106
+ ignore_patterns: Optional[List[str]] = None,
107
+ ) -> int:
108
+ """
109
+ Count total number of files to download.
110
+
111
+ Args:
112
+ client: Hub client instance
113
+ organization: Organization name
114
+ repo_type: Repository type
115
+ repo_name: Repository name
116
+ branch: Branch name
117
+ path: Current path in the repository
118
+ allow_patterns: Patterns to allow
119
+ ignore_patterns: Patterns to ignore
120
+
121
+ Returns:
122
+ Total number of files to download
123
+ """
124
+ content = client.get_repository_content(
125
+ organization=organization,
126
+ repo_type=repo_type,
127
+ repo_name=repo_name,
128
+ branch=branch,
129
+ path=path,
130
+ )
131
+
132
+ count = 0
133
+ if content.entries:
134
+ for entry in content.entries:
135
+ if entry.type == "file":
136
+ if _should_download_file(entry.path, allow_patterns, ignore_patterns):
137
+ count += 1
138
+ elif entry.type == "dir":
139
+ count += _count_files_to_download(
140
+ client=client,
141
+ organization=organization,
142
+ repo_type=repo_type,
143
+ repo_name=repo_name,
144
+ branch=branch,
145
+ path=entry.path,
146
+ allow_patterns=allow_patterns,
147
+ ignore_patterns=ignore_patterns,
148
+ )
149
+
150
+ return count
151
+
152
+
153
+ def _download_repository_recursively(
154
+ client: HubClient,
155
+ organization: str,
156
+ repo_type: str,
157
+ repo_name: str,
158
+ branch: str,
159
+ path: str,
160
+ local_dir: str,
161
+ allow_patterns: Optional[List[str]] = None,
162
+ ignore_patterns: Optional[List[str]] = None,
163
+ verbose: bool = True,
164
+ progress_bar = None,
165
+ encryption_metadata = None,
166
+ decryption_key: Optional[Union[str, bytes]] = None,
167
+ decryption_algorithm: Optional[str] = None,
168
+ ) -> None:
169
+ """
170
+ Recursively download repository contents.
171
+
172
+ Args:
173
+ client: Hub client instance
174
+ organization: Organization name
175
+ repo_type: Repository type
176
+ repo_name: Repository name
177
+ branch: Branch name
178
+ path: Current path in the repository
179
+ local_dir: Local directory to save files
180
+ allow_patterns: Patterns to allow
181
+ ignore_patterns: Patterns to ignore
182
+ verbose: Print progress messages
183
+ progress_bar: Optional tqdm progress bar for overall progress
184
+ encryption_metadata: Encryption metadata from .moha_encryption file
185
+ decryption_key: Key to decrypt files
186
+ decryption_exclude: Patterns to exclude from decryption
187
+ decryption_algorithm: Algorithm to use for decryption
188
+ """
189
+ content = client.get_repository_content(
190
+ organization=organization,
191
+ repo_type=repo_type,
192
+ repo_name=repo_name,
193
+ branch=branch,
194
+ path=path,
195
+ )
196
+
197
+ # Process entries
198
+ if content.entries:
199
+ for entry in content.entries:
200
+ if entry.type == "file":
201
+ # 检查文件是否应该被下载
202
+ if _should_download_file(entry.path, allow_patterns, ignore_patterns):
203
+ if verbose and progress_bar is None:
204
+ print(f"Downloading file: {entry.path}")
205
+
206
+ local_path = os.path.join(local_dir, entry.path)
207
+
208
+ # Update progress bar description if available
209
+ if progress_bar is not None:
210
+ progress_bar.set_description(f"Downloading {entry.path}")
211
+
212
+ client.download_file(
213
+ organization=organization,
214
+ repo_type=repo_type,
215
+ repo_name=repo_name,
216
+ branch=branch,
217
+ file_path=entry.path,
218
+ local_path=local_path,
219
+ show_progress=progress_bar is None, # Show individual progress only if no overall progress
220
+ )
221
+ file_is_encrypted = False
222
+ file_encryption_algorithm = None
223
+ if encryption_metadata and encryption_metadata.files:
224
+ for file_meta in encryption_metadata.files:
225
+ if file_meta.path == entry.path:
226
+ file_is_encrypted = True
227
+ file_encryption_algorithm = file_meta.algorithm
228
+ break
229
+
230
+ if file_is_encrypted:
231
+ if not decryption_key:
232
+ raise EncryptionError(
233
+ f"File '{entry.path}' is encrypted, but no decryption_key was provided. "
234
+ "Please provide a decryption_key parameter."
235
+ )
236
+ if not decryption_algorithm:
237
+ raise EncryptionError(
238
+ f"File '{entry.path}' is encrypted, but no decryption_algorithm was provided. "
239
+ f"The file was encrypted with '{file_encryption_algorithm}'. "
240
+ "Please provide a decryption_algorithm parameter."
241
+ )
242
+ if decryption_algorithm != file_encryption_algorithm:
243
+ raise EncryptionError(
244
+ f"File '{entry.path}' is encrypted with '{file_encryption_algorithm}', "
245
+ f"but decryption_algorithm '{decryption_algorithm}' was provided. "
246
+ "Please use the correct decryption algorithm."
247
+ )
248
+ if verbose and progress_bar is None:
249
+ print(f"Decrypting file: {entry.path}")
250
+ from .encryption import EncryptionAlgorithm, decrypt_file as decrypt_file_func
251
+ try:
252
+ algorithm = EncryptionAlgorithm(decryption_algorithm)
253
+ except ValueError:
254
+ raise EncryptionError(
255
+ f"Invalid decryption algorithm: {decryption_algorithm}. "
256
+ f"Supported algorithms: {', '.join([a.value for a in EncryptionAlgorithm])}"
257
+ )
258
+ decrypt_file_func(Path(local_path), decryption_key, algorithm)
259
+
260
+ if progress_bar is not None:
261
+ progress_bar.update(1)
262
+ else:
263
+ if verbose and progress_bar is None:
264
+ print(f"Skipping file: {entry.path}")
265
+
266
+ elif entry.type == "dir":
267
+ if verbose and progress_bar is None:
268
+ print(f"Entering directory: {entry.path}")
269
+ # 递归下载
270
+ _download_repository_recursively(
271
+ client=client,
272
+ organization=organization,
273
+ repo_type=repo_type,
274
+ repo_name=repo_name,
275
+ branch=branch,
276
+ path=entry.path,
277
+ local_dir=local_dir,
278
+ allow_patterns=allow_patterns,
279
+ ignore_patterns=ignore_patterns,
280
+ verbose=verbose,
281
+ progress_bar=progress_bar,
282
+ encryption_metadata=encryption_metadata,
283
+ decryption_key=decryption_key,
284
+ decryption_algorithm=decryption_algorithm,
285
+ )
286
+ else:
287
+ if verbose and progress_bar is None:
288
+ print(f"Skipping {entry.type}: {entry.path}")
289
+
290
+
291
+ def moha_hub_download(
292
+ repo_id: str,
293
+ filename: str,
294
+ *,
295
+ repo_type: str = "models",
296
+ revision: Optional[str] = None,
297
+ local_dir: Optional[Union[str, Path]] = None,
298
+ base_url: Optional[str] = None,
299
+ username: Optional[str] = None,
300
+ password: Optional[str] = None,
301
+ token: Optional[str] = None,
302
+ show_progress: bool = True,
303
+ decryption_key: Optional[Union[str, bytes]] = None,
304
+ decryption_algorithm: Optional[str] = None,
305
+ ) -> str:
306
+ """
307
+ Download a single file from a repository.
308
+
309
+ Similar to huggingface_hub.hf_hub_download().
310
+
311
+ Args:
312
+ repo_id: Repository ID in the format "organization/repo_name"
313
+ filename: Path to the file in the repository
314
+ repo_type: Type of repository ("models" or "datasets")
315
+ revision: Branch/tag/commit to download from (default: main branch)
316
+ cache_dir: Directory to cache downloaded files
317
+ local_dir: Directory to save the file (if not using cache)
318
+ base_url: Base URL of the Hub API (default: from MOHA_ENDPOINT env var)
319
+ username: Username for authentication
320
+ password: Password for authentication
321
+ token: Token for authentication
322
+ show_progress: Whether to show download progress bar
323
+ decryption_key: Key to decrypt the file if repository is encrypted (string for symmetric, PEM for asymmetric)
324
+ decryption_algorithm: Decryption algorithm to use (default: 'aes-256-cbc')
325
+ - Symmetric: 'aes-256-cbc', 'aes-256-gcm'
326
+ - Asymmetric: 'rsa-oaep', 'rsa-pkcs1v15' (requires RSA private key in PEM format)
327
+
328
+ Returns:
329
+ Path to the downloaded file
330
+
331
+ Example:
332
+ >>> file_path = hf_hub_download(
333
+ ... repo_id="demo/demo",
334
+ ... filename="data/config.yaml",
335
+ ... username="your-username",
336
+ ... password="your-password",
337
+ ... )
338
+
339
+ >>> # Download from encrypted repository
340
+ >>> file_path = hf_hub_download(
341
+ ... repo_id="demo/encrypted-model",
342
+ ... filename="model.bin",
343
+ ... decryption_key="my-secret-key",
344
+ ... token="your-token",
345
+ ... )
346
+ """
347
+ parts = repo_id.split('/')
348
+ if len(parts) != 2:
349
+ raise ValueError(f"Invalid repo_id format: {repo_id}. Expected 'organization/repo_name'")
350
+ organization, repo_name = parts
351
+ client = HubClient(
352
+ base_url=base_url,
353
+ username=username,
354
+ password=password,
355
+ token=token,
356
+ )
357
+ # 获取默认分支
358
+ if revision is None:
359
+ revision = client.get_default_branch(organization, repo_type, repo_name)
360
+
361
+ # 获取加密元数据
362
+ encryption_metadata = client.get_moha_encryption(
363
+ organization=organization,
364
+ repo_type=repo_type,
365
+ repo_name=repo_name,
366
+ reference=revision,
367
+ )
368
+
369
+ # 文件加密标识
370
+ file_is_encrypted = False
371
+ file_encryption_algorithm = None
372
+ # 查询文件是否加密
373
+ if encryption_metadata and encryption_metadata.files:
374
+ for file_meta in encryption_metadata.files:
375
+ if file_meta.path == filename:
376
+ file_is_encrypted = True
377
+ file_encryption_algorithm = file_meta.algorithm
378
+ break
379
+
380
+ # 如果该文件加密了,检查解密参数
381
+ if file_is_encrypted:
382
+ if not decryption_key:
383
+ raise EncryptionError(
384
+ f"File '{filename}' is encrypted, but no decryption_key was provided. "
385
+ "Please provide a decryption_key parameter."
386
+ )
387
+
388
+ if not decryption_algorithm:
389
+ raise EncryptionError(
390
+ f"File '{filename}' is encrypted, but no decryption_algorithm was provided. "
391
+ f"The file was encrypted with '{file_encryption_algorithm}'. "
392
+ "Please provide a decryption_algorithm parameter."
393
+ )
394
+
395
+ if decryption_algorithm != file_encryption_algorithm:
396
+ raise EncryptionError(
397
+ f"File '{filename}' is encrypted with '{file_encryption_algorithm}', "
398
+ f"but decryption_algorithm '{decryption_algorithm}' was provided. "
399
+ "Please use the correct decryption algorithm."
400
+ )
401
+ if local_dir:
402
+ local_path = os.path.join(local_dir, filename)
403
+ else:
404
+ local_path = filename
405
+
406
+ # 下载文件
407
+ client.download_file(
408
+ organization=organization,
409
+ repo_type=repo_type,
410
+ repo_name=repo_name,
411
+ branch=revision,
412
+ file_path=filename,
413
+ local_path=local_path,
414
+ show_progress=show_progress,
415
+ )
416
+
417
+ # 解密文件
418
+ if file_is_encrypted and decryption_key:
419
+ from .encryption import EncryptionAlgorithm, decrypt_file as decrypt_file_func
420
+ try:
421
+ algorithm = EncryptionAlgorithm(decryption_algorithm)
422
+ except ValueError:
423
+ raise EncryptionError(
424
+ f"Invalid decryption algorithm: {decryption_algorithm}. "
425
+ f"Supported algorithms: {', '.join([a.value for a in EncryptionAlgorithm])}"
426
+ )
427
+
428
+ decrypt_file_func(Path(local_path), decryption_key, algorithm)
429
+
430
+ return local_path
431
+
432
+
433
+ def snapshot_download(
434
+ repo_id: str,
435
+ repo_type: str = "models",
436
+ revision: Optional[str] = None,
437
+ local_dir: Optional[Union[str, Path]] = None,
438
+ allow_patterns: Optional[Union[List[str], str]] = None,
439
+ ignore_patterns: Optional[Union[List[str], str]] = None,
440
+ base_url: Optional[str] = None,
441
+ username: Optional[str] = None,
442
+ password: Optional[str] = None,
443
+ token: Optional[str] = None,
444
+ verbose: bool = True,
445
+ show_progress: bool = True,
446
+ decryption_key: Optional[Union[str, bytes]] = None,
447
+ decryption_algorithm: Optional[str] = None,
448
+ ) -> str:
449
+ """
450
+ Download an entire repository snapshot.
451
+
452
+ Similar to huggingface_hub.snapshot_download().
453
+
454
+ Args:
455
+ repo_id: Repository ID in the format "organization/repo_name"
456
+ repo_type: Type of repository ("models" or "datasets")
457
+ revision: Branch/tag/commit to download from (default: main branch)
458
+ local_dir: Directory to save files (if not using cache)
459
+ allow_patterns: Pattern or list of patterns to allow (e.g., "*.yaml", "*.yml")
460
+ ignore_patterns: Pattern or list of patterns to ignore (e.g., ".git*")
461
+ base_url: Base URL of the Hub API (default: from MOHA_ENDPOINT env var)
462
+ username: Username for authentication
463
+ password: Password for authentication
464
+ token: Token for authentication
465
+ verbose: Print progress messages
466
+ show_progress: Whether to show overall progress bar
467
+ decryption_key: Key to decrypt files if repository is encrypted (string for symmetric, PEM for asymmetric)
468
+ decryption_algorithm: Decryption algorithm to use (default: 'aes-256-cbc')
469
+ - Symmetric: 'aes-256-cbc', 'aes-256-gcm'
470
+ - Asymmetric: 'rsa-oaep', 'rsa-pkcs1v15' (requires RSA private key in PEM format)
471
+
472
+ Returns:
473
+ Path to the downloaded repository
474
+
475
+ Example:
476
+ >>> repo_path = snapshot_download(
477
+ ... repo_id="demo/demo",
478
+ ... repo_type="models",
479
+ ... allow_patterns=["*.yaml", "*.yml"],
480
+ ... ignore_patterns=[".git*"],
481
+ ... username="your-username",
482
+ ... password="your-password",
483
+ ... )
484
+
485
+ >>> # Download from encrypted repository
486
+ >>> repo_path = snapshot_download(
487
+ ... repo_id="demo/encrypted-model",
488
+ ... repo_type="models",
489
+ ... decryption_key="my-secret-key",
490
+ ... decryption_exclude=["README.md", "*.txt"], # Don't decrypt these files
491
+ ... token="your-token",
492
+ ... )
493
+ """
494
+ parts = repo_id.split('/')
495
+ if len(parts) != 2:
496
+ raise ValueError(f"Invalid repo_id format: {repo_id}. Expected 'organization/repo_name'")
497
+
498
+ organization, repo_name = parts
499
+
500
+ if isinstance(allow_patterns, str):
501
+ allow_patterns = [allow_patterns]
502
+ if isinstance(ignore_patterns, str):
503
+ ignore_patterns = [ignore_patterns]
504
+ client = HubClient(
505
+ base_url=base_url,
506
+ username=username,
507
+ password=password,
508
+ token=token,
509
+ )
510
+ if revision is None:
511
+ revision = client.get_default_branch(organization, repo_type, repo_name)
512
+ encryption_metadata = client.get_moha_encryption(
513
+ organization=organization,
514
+ repo_type=repo_type,
515
+ repo_name=repo_name,
516
+ reference=revision,
517
+ )
518
+
519
+ if encryption_metadata and encryption_metadata.files:
520
+ if verbose:
521
+ print(f"Repository has {len(encryption_metadata.files)} encrypted file(s). Files will be decrypted after download.")
522
+
523
+ # Determine local directory
524
+ if local_dir:
525
+ download_dir = str(local_dir)
526
+ else:
527
+ # Default to downloads directory
528
+ download_dir = f"./downloads/{organization}_{repo_type}_{repo_name}"
529
+
530
+ if verbose and not show_progress:
531
+ print(f"Downloading repository: {repo_id}")
532
+ print(f"Repository type: {repo_type}")
533
+ print(f"Revision: {revision}")
534
+ print(f"Destination: {download_dir}")
535
+
536
+ progress_bar = None
537
+ if show_progress and tqdm is not None:
538
+ if verbose:
539
+ print(f"Fetching repository info...")
540
+ # 计算需要下载的文件总数
541
+ total_files = _count_files_to_download(
542
+ client=client,
543
+ organization=organization,
544
+ repo_type=repo_type,
545
+ repo_name=repo_name,
546
+ branch=revision,
547
+ path="",
548
+ allow_patterns=allow_patterns,
549
+ ignore_patterns=ignore_patterns,
550
+ )
551
+
552
+ if total_files > 0:
553
+ progress_bar = tqdm(
554
+ total=total_files,
555
+ unit='file',
556
+ desc=f"Downloading {repo_id}",
557
+ leave=True,
558
+ )
559
+
560
+ # 递归下载,不是使用git的方式
561
+ try:
562
+ _download_repository_recursively(
563
+ client=client,
564
+ organization=organization,
565
+ repo_type=repo_type,
566
+ repo_name=repo_name,
567
+ branch=revision,
568
+ path="",
569
+ local_dir=download_dir,
570
+ allow_patterns=allow_patterns,
571
+ ignore_patterns=ignore_patterns,
572
+ verbose=verbose,
573
+ progress_bar=progress_bar,
574
+ encryption_metadata=encryption_metadata,
575
+ decryption_key=decryption_key,
576
+ decryption_algorithm=decryption_algorithm,
577
+ )
578
+ finally:
579
+ if progress_bar is not None:
580
+ progress_bar.close()
581
+
582
+ if verbose and not show_progress:
583
+ print(f"Download completed to: {download_dir}")
584
+
585
+ # Add download count
586
+ try:
587
+ client.add_download_count(
588
+ organization=organization,
589
+ repo_type=repo_type,
590
+ repo_name=repo_name,
591
+ )
592
+ except Exception as e:
593
+ # Don't fail the download if adding count fails
594
+ if verbose:
595
+ print(f"Warning: Failed to add download count: {e}")
596
+
597
+ return download_dir
598
+