xiaoshiai-hub 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,464 @@
1
+ """
2
+ Download utilities for XiaoShi AI Hub SDK
3
+ """
4
+
5
+ import fnmatch
6
+ import os
7
+ from pathlib import Path
8
+ from typing import List, Optional, Union
9
+
10
+ try:
11
+ from tqdm.auto import tqdm
12
+ except ImportError:
13
+ tqdm = None
14
+
15
+ from .client import HubClient, DEFAULT_BASE_URL
16
+ from .types import GitContent
17
+
18
+
19
+ def _match_pattern(name: str, pattern: str) -> bool:
20
+ """
21
+ Match a filename against a pattern.
22
+
23
+ Supports wildcards:
24
+ - * matches any characters
25
+ - *.ext matches files with extension
26
+ - prefix* matches files starting with prefix
27
+
28
+ Args:
29
+ name: Filename to match
30
+ pattern: Pattern to match against
31
+
32
+ Returns:
33
+ True if the name matches the pattern
34
+ """
35
+ return fnmatch.fnmatch(name, pattern)
36
+
37
+
38
+ def _should_download_file(
39
+ file_path: str,
40
+ allow_patterns: Optional[List[str]] = None,
41
+ ignore_patterns: Optional[List[str]] = None,
42
+ ) -> bool:
43
+ """
44
+ Determine if a file should be downloaded based on patterns.
45
+
46
+ Args:
47
+ file_path: Path of the file
48
+ allow_patterns: List of patterns to allow (if None, allow all)
49
+ ignore_patterns: List of patterns to ignore
50
+
51
+ Returns:
52
+ True if the file should be downloaded
53
+ """
54
+ filename = os.path.basename(file_path)
55
+
56
+ # Check ignore patterns first
57
+ if ignore_patterns:
58
+ for pattern in ignore_patterns:
59
+ if _match_pattern(filename, pattern) or _match_pattern(file_path, pattern):
60
+ return False
61
+
62
+ # If no allow patterns, allow all (except ignored)
63
+ if not allow_patterns:
64
+ return True
65
+
66
+ # Check allow patterns
67
+ for pattern in allow_patterns:
68
+ if _match_pattern(filename, pattern) or _match_pattern(file_path, pattern):
69
+ return True
70
+
71
+ return False
72
+
73
+
74
+ def _count_files_to_download(
75
+ client: HubClient,
76
+ organization: str,
77
+ repo_type: str,
78
+ repo_name: str,
79
+ branch: str,
80
+ path: str,
81
+ allow_patterns: Optional[List[str]] = None,
82
+ ignore_patterns: Optional[List[str]] = None,
83
+ ) -> int:
84
+ """
85
+ Count total number of files to download.
86
+
87
+ Args:
88
+ client: Hub client instance
89
+ organization: Organization name
90
+ repo_type: Repository type
91
+ repo_name: Repository name
92
+ branch: Branch name
93
+ path: Current path in the repository
94
+ allow_patterns: Patterns to allow
95
+ ignore_patterns: Patterns to ignore
96
+
97
+ Returns:
98
+ Total number of files to download
99
+ """
100
+ content = client.get_repository_content(
101
+ organization=organization,
102
+ repo_type=repo_type,
103
+ repo_name=repo_name,
104
+ branch=branch,
105
+ path=path,
106
+ )
107
+
108
+ count = 0
109
+ if content.entries:
110
+ for entry in content.entries:
111
+ if entry.type == "file":
112
+ if _should_download_file(entry.path, allow_patterns, ignore_patterns):
113
+ count += 1
114
+ elif entry.type == "dir":
115
+ count += _count_files_to_download(
116
+ client=client,
117
+ organization=organization,
118
+ repo_type=repo_type,
119
+ repo_name=repo_name,
120
+ branch=branch,
121
+ path=entry.path,
122
+ allow_patterns=allow_patterns,
123
+ ignore_patterns=ignore_patterns,
124
+ )
125
+
126
+ return count
127
+
128
+
129
+ def _download_repository_recursively(
130
+ client: HubClient,
131
+ organization: str,
132
+ repo_type: str,
133
+ repo_name: str,
134
+ branch: str,
135
+ path: str,
136
+ local_dir: str,
137
+ allow_patterns: Optional[List[str]] = None,
138
+ ignore_patterns: Optional[List[str]] = None,
139
+ verbose: bool = True,
140
+ progress_bar = None,
141
+ ) -> None:
142
+ """
143
+ Recursively download repository contents.
144
+
145
+ Args:
146
+ client: Hub client instance
147
+ organization: Organization name
148
+ repo_type: Repository type
149
+ repo_name: Repository name
150
+ branch: Branch name
151
+ path: Current path in the repository
152
+ local_dir: Local directory to save files
153
+ allow_patterns: Patterns to allow
154
+ ignore_patterns: Patterns to ignore
155
+ verbose: Print progress messages
156
+ progress_bar: Optional tqdm progress bar for overall progress
157
+ """
158
+ # Get content at current path
159
+ content = client.get_repository_content(
160
+ organization=organization,
161
+ repo_type=repo_type,
162
+ repo_name=repo_name,
163
+ branch=branch,
164
+ path=path,
165
+ )
166
+
167
+ # Process entries
168
+ if content.entries:
169
+ for entry in content.entries:
170
+ if entry.type == "file":
171
+ # Check if file should be downloaded
172
+ if _should_download_file(entry.path, allow_patterns, ignore_patterns):
173
+ if verbose and progress_bar is None:
174
+ print(f"Downloading file: {entry.path}")
175
+
176
+ local_path = os.path.join(local_dir, entry.path)
177
+
178
+ # Update progress bar description if available
179
+ if progress_bar is not None:
180
+ progress_bar.set_description(f"Downloading {entry.path}")
181
+
182
+ client.download_file(
183
+ organization=organization,
184
+ repo_type=repo_type,
185
+ repo_name=repo_name,
186
+ branch=branch,
187
+ file_path=entry.path,
188
+ local_path=local_path,
189
+ show_progress=progress_bar is None, # Show individual progress only if no overall progress
190
+ )
191
+
192
+ # Update overall progress
193
+ if progress_bar is not None:
194
+ progress_bar.update(1)
195
+ else:
196
+ if verbose and progress_bar is None:
197
+ print(f"Skipping file: {entry.path}")
198
+
199
+ elif entry.type == "dir":
200
+ if verbose and progress_bar is None:
201
+ print(f"Entering directory: {entry.path}")
202
+
203
+ # Recursively download directory contents
204
+ _download_repository_recursively(
205
+ client=client,
206
+ organization=organization,
207
+ repo_type=repo_type,
208
+ repo_name=repo_name,
209
+ branch=branch,
210
+ path=entry.path,
211
+ local_dir=local_dir,
212
+ allow_patterns=allow_patterns,
213
+ ignore_patterns=ignore_patterns,
214
+ verbose=verbose,
215
+ progress_bar=progress_bar,
216
+ )
217
+
218
+ else:
219
+ if verbose and progress_bar is None:
220
+ print(f"Skipping {entry.type}: {entry.path}")
221
+
222
+
223
+ def hf_hub_download(
224
+ repo_id: str,
225
+ filename: str,
226
+ *,
227
+ repo_type: str = "models",
228
+ revision: Optional[str] = None,
229
+ cache_dir: Optional[Union[str, Path]] = None,
230
+ local_dir: Optional[Union[str, Path]] = None,
231
+ base_url: Optional[str] = None,
232
+ username: Optional[str] = None,
233
+ password: Optional[str] = None,
234
+ token: Optional[str] = None,
235
+ show_progress: bool = True,
236
+ ) -> str:
237
+ """
238
+ Download a single file from a repository.
239
+
240
+ Similar to huggingface_hub.hf_hub_download().
241
+
242
+ Args:
243
+ repo_id: Repository ID in the format "organization/repo_name"
244
+ filename: Path to the file in the repository
245
+ repo_type: Type of repository ("models" or "datasets")
246
+ revision: Branch/tag/commit to download from (default: main branch)
247
+ cache_dir: Directory to cache downloaded files
248
+ local_dir: Directory to save the file (if not using cache)
249
+ base_url: Base URL of the Hub API (default: from MOHA_ENDPOINT env var)
250
+ username: Username for authentication
251
+ password: Password for authentication
252
+ token: Token for authentication
253
+ show_progress: Whether to show download progress bar
254
+
255
+ Returns:
256
+ Path to the downloaded file
257
+
258
+ Example:
259
+ >>> file_path = hf_hub_download(
260
+ ... repo_id="demo/demo",
261
+ ... filename="config.yaml",
262
+ ... username="your-username",
263
+ ... password="your-password",
264
+ ... )
265
+ """
266
+ # Parse repo_id
267
+ parts = repo_id.split('/')
268
+ if len(parts) != 2:
269
+ raise ValueError(f"Invalid repo_id format: {repo_id}. Expected 'organization/repo_name'")
270
+
271
+ organization, repo_name = parts
272
+
273
+ # Create client
274
+ client = HubClient(
275
+ base_url=base_url,
276
+ username=username,
277
+ password=password,
278
+ token=token,
279
+ )
280
+
281
+ # Get repository info to determine branch
282
+ if revision is None:
283
+ repo_info = client.get_repository_info(organization, repo_type, repo_name)
284
+ revision = repo_info.default_branch or "main"
285
+
286
+ # Determine local path
287
+ if local_dir:
288
+ local_path = os.path.join(local_dir, filename)
289
+ elif cache_dir:
290
+ # Create cache structure similar to huggingface_hub
291
+ cache_path = os.path.join(
292
+ cache_dir,
293
+ f"{repo_type}--{organization}--{repo_name}",
294
+ "snapshots",
295
+ revision,
296
+ filename,
297
+ )
298
+ local_path = cache_path
299
+ else:
300
+ # Default to current directory
301
+ local_path = filename
302
+
303
+ # Download file
304
+ client.download_file(
305
+ organization=organization,
306
+ repo_type=repo_type,
307
+ repo_name=repo_name,
308
+ branch=revision,
309
+ file_path=filename,
310
+ local_path=local_path,
311
+ show_progress=show_progress,
312
+ )
313
+
314
+ return local_path
315
+
316
+
317
+ def snapshot_download(
318
+ repo_id: str,
319
+ *,
320
+ repo_type: str = "models",
321
+ revision: Optional[str] = None,
322
+ cache_dir: Optional[Union[str, Path]] = None,
323
+ local_dir: Optional[Union[str, Path]] = None,
324
+ allow_patterns: Optional[Union[List[str], str]] = None,
325
+ ignore_patterns: Optional[Union[List[str], str]] = None,
326
+ base_url: Optional[str] = None,
327
+ username: Optional[str] = None,
328
+ password: Optional[str] = None,
329
+ token: Optional[str] = None,
330
+ verbose: bool = True,
331
+ show_progress: bool = True,
332
+ ) -> str:
333
+ """
334
+ Download an entire repository snapshot.
335
+
336
+ Similar to huggingface_hub.snapshot_download().
337
+
338
+ Args:
339
+ repo_id: Repository ID in the format "organization/repo_name"
340
+ repo_type: Type of repository ("models" or "datasets")
341
+ revision: Branch/tag/commit to download from (default: main branch)
342
+ cache_dir: Directory to cache downloaded files
343
+ local_dir: Directory to save files (if not using cache)
344
+ allow_patterns: Pattern or list of patterns to allow (e.g., "*.yaml", "*.yml")
345
+ ignore_patterns: Pattern or list of patterns to ignore (e.g., ".git*")
346
+ base_url: Base URL of the Hub API (default: from MOHA_ENDPOINT env var)
347
+ username: Username for authentication
348
+ password: Password for authentication
349
+ token: Token for authentication
350
+ verbose: Print progress messages
351
+ show_progress: Whether to show overall progress bar
352
+
353
+ Returns:
354
+ Path to the downloaded repository
355
+
356
+ Example:
357
+ >>> repo_path = snapshot_download(
358
+ ... repo_id="demo/demo",
359
+ ... repo_type="models",
360
+ ... allow_patterns=["*.yaml", "*.yml"],
361
+ ... ignore_patterns=[".git*"],
362
+ ... username="your-username",
363
+ ... password="your-password",
364
+ ... )
365
+ """
366
+ # Parse repo_id
367
+ parts = repo_id.split('/')
368
+ if len(parts) != 2:
369
+ raise ValueError(f"Invalid repo_id format: {repo_id}. Expected 'organization/repo_name'")
370
+
371
+ organization, repo_name = parts
372
+
373
+ # Normalize patterns to lists
374
+ if isinstance(allow_patterns, str):
375
+ allow_patterns = [allow_patterns]
376
+ if isinstance(ignore_patterns, str):
377
+ ignore_patterns = [ignore_patterns]
378
+
379
+ # Create client
380
+ client = HubClient(
381
+ base_url=base_url,
382
+ username=username,
383
+ password=password,
384
+ token=token,
385
+ )
386
+
387
+ # Get repository info
388
+ repo_info = client.get_repository_info(organization, repo_type, repo_name)
389
+
390
+ # Determine revision
391
+ if revision is None:
392
+ revision = repo_info.default_branch or "main"
393
+
394
+ # Determine local directory
395
+ if local_dir:
396
+ download_dir = str(local_dir)
397
+ elif cache_dir:
398
+ # Create cache structure
399
+ download_dir = os.path.join(
400
+ cache_dir,
401
+ f"{repo_type}--{organization}--{repo_name}",
402
+ "snapshots",
403
+ revision,
404
+ )
405
+ else:
406
+ # Default to downloads directory
407
+ download_dir = f"./downloads/{organization}_{repo_type}_{repo_name}"
408
+
409
+ if verbose and not show_progress:
410
+ print(f"Downloading repository: {repo_id}")
411
+ print(f"Repository type: {repo_type}")
412
+ print(f"Revision: {revision}")
413
+ print(f"Destination: {download_dir}")
414
+
415
+ # Create progress bar if requested
416
+ progress_bar = None
417
+ if show_progress and tqdm is not None:
418
+ # Count total files to download
419
+ if verbose:
420
+ print(f"Fetching repository info...")
421
+
422
+ total_files = _count_files_to_download(
423
+ client=client,
424
+ organization=organization,
425
+ repo_type=repo_type,
426
+ repo_name=repo_name,
427
+ branch=revision,
428
+ path="",
429
+ allow_patterns=allow_patterns,
430
+ ignore_patterns=ignore_patterns,
431
+ )
432
+
433
+ if total_files > 0:
434
+ progress_bar = tqdm(
435
+ total=total_files,
436
+ unit='file',
437
+ desc=f"Downloading {repo_id}",
438
+ leave=True,
439
+ )
440
+
441
+ # Download recursively
442
+ try:
443
+ _download_repository_recursively(
444
+ client=client,
445
+ organization=organization,
446
+ repo_type=repo_type,
447
+ repo_name=repo_name,
448
+ branch=revision,
449
+ path="",
450
+ local_dir=download_dir,
451
+ allow_patterns=allow_patterns,
452
+ ignore_patterns=ignore_patterns,
453
+ verbose=verbose,
454
+ progress_bar=progress_bar,
455
+ )
456
+ finally:
457
+ if progress_bar is not None:
458
+ progress_bar.close()
459
+
460
+ if verbose and not show_progress:
461
+ print(f"Download completed to: {download_dir}")
462
+
463
+ return download_dir
464
+
@@ -0,0 +1,32 @@
1
+ """
2
+ Exceptions for XiaoShi AI Hub SDK
3
+ """
4
+
5
+
6
+ class HubException(Exception):
7
+ """Base exception for all Hub-related errors."""
8
+ pass
9
+
10
+
11
+ class RepositoryNotFoundError(HubException):
12
+ """Raised when a repository is not found."""
13
+ pass
14
+
15
+
16
+ class FileNotFoundError(HubException):
17
+ """Raised when a file is not found in the repository."""
18
+ pass
19
+
20
+
21
+ class AuthenticationError(HubException):
22
+ """Raised when authentication fails."""
23
+ pass
24
+
25
+
26
+ class HTTPError(HubException):
27
+ """Raised when an HTTP error occurs."""
28
+
29
+ def __init__(self, message: str, status_code: int = None):
30
+ super().__init__(message)
31
+ self.status_code = status_code
32
+
xiaoshiai_hub/types.py ADDED
@@ -0,0 +1,92 @@
1
+ """
2
+ Type definitions for XiaoShi AI Hub SDK
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from datetime import datetime
7
+ from typing import List, Optional, Literal
8
+
9
+
10
+ @dataclass
11
+ class Repository:
12
+ """Repository information."""
13
+ name: str
14
+ organization: str
15
+ type: str # "models" or "datasets"
16
+ default_branch: Optional[str] = None
17
+ description: Optional[str] = None
18
+
19
+
20
+ @dataclass
21
+ class Signature:
22
+ """Git signature (author/committer)."""
23
+ name: str
24
+ email: str
25
+ when: datetime
26
+
27
+
28
+ @dataclass
29
+ class Commit:
30
+ """Git commit information."""
31
+ hash: str
32
+ message: str
33
+ author: Signature
34
+ committer: Signature
35
+ timestamp: datetime
36
+ tree_hash: Optional[str] = None
37
+ parents: Optional[List[str]] = None
38
+
39
+
40
+ @dataclass
41
+ class Ref:
42
+ """Git reference (branch/tag)."""
43
+ name: str
44
+ ref: str
45
+ fully_name: str
46
+ type: str # "branch" or "tag"
47
+ hash: str
48
+ is_default: bool = False
49
+ last_commit: Optional[Commit] = None
50
+
51
+
52
+ @dataclass
53
+ class GitLFSMeta:
54
+ """Git LFS metadata."""
55
+ oid: str
56
+ size: int
57
+
58
+
59
+ @dataclass
60
+ class Symlink:
61
+ """Symlink information."""
62
+ target: str
63
+
64
+
65
+ @dataclass
66
+ class Submodule:
67
+ """Submodule information."""
68
+ url: str
69
+ path: str
70
+ branch: str
71
+
72
+
73
+ FileType = Literal["file", "dir", "symlink", "submodule"]
74
+
75
+
76
+ @dataclass
77
+ class GitContent:
78
+ """Git content (file or directory)."""
79
+ name: str
80
+ path: str
81
+ type: FileType
82
+ size: int = 0
83
+ hash: Optional[str] = None
84
+ content_type: Optional[str] = None
85
+ content: Optional[str] = None
86
+ content_omitted: bool = False
87
+ last_commit: Optional[Commit] = None
88
+ symlink: Optional[Symlink] = None
89
+ submodule: Optional[Submodule] = None
90
+ lfs: Optional[GitLFSMeta] = None
91
+ entries: Optional[List['GitContent']] = None
92
+