xiaoshiai-hub 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,83 @@
1
+ """
2
+ XiaoShi AI Hub Python SDK
3
+
4
+ A Python library for interacting with XiaoShi AI Hub repositories.
5
+ """
6
+
7
+ from .client import HubClient, DEFAULT_BASE_URL
8
+ from .download import (
9
+ moha_hub_download,
10
+ snapshot_download,
11
+ )
12
+ from .exceptions import (
13
+ HubException,
14
+ RepositoryNotFoundError,
15
+ FileNotFoundError,
16
+ AuthenticationError,
17
+ EncryptionError,
18
+ )
19
+ from .types import (
20
+ Repository,
21
+ Ref,
22
+ GitContent,
23
+ Commit,
24
+ )
25
+
26
+ # Upload functionality (requires GitPython)
27
+ try:
28
+ from .upload import (
29
+ upload_file,
30
+ upload_folder,
31
+ UploadError,
32
+ )
33
+ _upload_available = True
34
+ except ImportError:
35
+ _upload_available = False
36
+ upload_file = None
37
+ upload_folder = None
38
+ UploadError = None
39
+
40
+ # Encryption functionality
41
+ try:
42
+ from .encryption import (
43
+ EncryptionAlgorithm,
44
+ encrypt_file,
45
+ decrypt_file,
46
+ )
47
+ _encryption_available = True
48
+ except ImportError:
49
+ _encryption_available = False
50
+ EncryptionAlgorithm = None
51
+ encrypt_file = None
52
+ decrypt_file = None
53
+
54
+ __version__ = "0.1.0"
55
+
56
+ __all__ = [
57
+ # Client
58
+ "HubClient",
59
+ "DEFAULT_BASE_URL",
60
+ # Download
61
+ "moha_hub_download",
62
+ "snapshot_download",
63
+ # Upload
64
+ "upload_file",
65
+ "upload_folder",
66
+ # Exceptions
67
+ "HubException",
68
+ "RepositoryNotFoundError",
69
+ "FileNotFoundError",
70
+ "AuthenticationError",
71
+ "EncryptionError",
72
+ "UploadError",
73
+ # Types
74
+ "Repository",
75
+ "Ref",
76
+ "GitContent",
77
+ "Commit",
78
+ # Encryption
79
+ "EncryptionAlgorithm",
80
+ "encrypt_file",
81
+ "decrypt_file",
82
+ ]
83
+
@@ -0,0 +1,360 @@
1
+ """
2
+ XiaoShi AI Hub Client
3
+ """
4
+
5
+ import base64
6
+ import os
7
+ from typing import List, Optional
8
+
9
+
10
+ import requests
11
+
12
+ try:
13
+ from tqdm.auto import tqdm
14
+ except ImportError:
15
+ tqdm = None
16
+
17
+ from .exceptions import (
18
+ AuthenticationError,
19
+ HTTPError,
20
+ RepositoryNotFoundError,
21
+ )
22
+ from .types import EncryptionMetadata, Repository, Ref, GitContent
23
+
24
+
25
+ # 默认基础 URL,可通过环境变量 MOHA_ENDPOINT 覆盖
26
+ DEFAULT_BASE_URL = os.environ.get(
27
+ "MOHA_ENDPOINT",
28
+ "https://rune-api.develop.xiaoshiai.cn/moha"
29
+ )
30
+
31
+
32
+ class HubClient:
33
+ """Client for interacting with XiaoShi AI Hub API."""
34
+
35
+ def __init__(
36
+ self,
37
+ base_url: Optional[str] = None,
38
+ username: Optional[str] = None,
39
+ password: Optional[str] = None,
40
+ token: Optional[str] = None,
41
+ ):
42
+ """
43
+ Initialize the Hub client.
44
+
45
+ Args:
46
+ base_url: Base URL of the Hub API (default: from MOHA_ENDPOINT env var)
47
+ username: Username for authentication
48
+ password: Password for authentication
49
+ token: Token for authentication (alternative to username/password)
50
+ """
51
+ self.base_url = (base_url or DEFAULT_BASE_URL).rstrip('/')
52
+ self.username = username
53
+ self.password = password
54
+ self.token = token
55
+ self.session = requests.Session()
56
+
57
+ # Set up authentication
58
+ if token:
59
+ self.session.headers['Authorization'] = f'Bearer {token}'
60
+ elif username and password:
61
+ auth_string = f"{username}:{password}"
62
+ encoded = base64.b64encode(auth_string.encode()).decode()
63
+ self.session.headers['Authorization'] = f'Basic {encoded}'
64
+
65
+ def _make_request(
66
+ self,
67
+ method: str,
68
+ url: str,
69
+ **kwargs
70
+ ) -> requests.Response:
71
+ """Make an HTTP request with error handling."""
72
+ try:
73
+ response = self.session.request(method, url, **kwargs)
74
+
75
+ if response.status_code == 401:
76
+ raise AuthenticationError("Authentication failed")
77
+ elif response.status_code == 404:
78
+ raise RepositoryNotFoundError("Resource not found")
79
+ elif response.status_code >= 400:
80
+ raise HTTPError(
81
+ f"HTTP {response.status_code}: {response.reason}",
82
+ status_code=response.status_code
83
+ )
84
+
85
+ return response
86
+ except requests.RequestException as e:
87
+ raise HTTPError(f"Request failed: {str(e)}")
88
+
89
+ def get_moha_encryption(
90
+ self,
91
+ organization: str,
92
+ repo_type: str,
93
+ repo_name: str,
94
+ reference: str,
95
+ ) -> Optional[EncryptionMetadata]:
96
+ """
97
+ Get .moha_encryption file content from the repository.
98
+
99
+ Args:
100
+ organization: Organization name
101
+ repo_type: Repository type ("models" or "datasets")
102
+ repo_name: Repository name
103
+ reference: Branch/tag/commit reference
104
+
105
+ Returns:
106
+ Encryption metadata if the repository has encrypted files, None otherwise
107
+
108
+ Raises:
109
+ AuthenticationError: If authentication fails
110
+ RepositoryNotFoundError: If the repository is not found
111
+ HTTPError: If the request fails
112
+ """
113
+ url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/encryption/{reference}"
114
+ response = self._make_request("GET", url)
115
+ if response.status_code == 204 or not response.content:
116
+ return None
117
+ try:
118
+ data = response.json()
119
+ except ValueError:
120
+ return None
121
+ if not data:
122
+ return None
123
+
124
+ from datetime import datetime
125
+ files = None
126
+ if 'files' in data and data['files']:
127
+ from .types import FileEncryptionMetadata
128
+ files = [
129
+ FileEncryptionMetadata(
130
+ path=f.get('path', ''),
131
+ algorithm=f.get('algorithm', ''),
132
+ encryptedHash=f.get('encryptedHash', ''),
133
+ encryptedSize=f.get('encryptedSize', 0),
134
+ )
135
+ for f in data['files']
136
+ ]
137
+ return EncryptionMetadata(
138
+ version=data.get('version', '1.0'),
139
+ createAt=datetime.fromisoformat(data['createAt'].replace('Z', '+00:00')) if 'createAt' in data else datetime.now(),
140
+ files=files,
141
+ )
142
+
143
+ def get_repository_info(
144
+ self,
145
+ organization: str,
146
+ repo_type: str,
147
+ repo_name: str,
148
+ ) -> Repository:
149
+ """
150
+ Get repository information.
151
+
152
+ Args:
153
+ organization: Organization name
154
+ repo_type: Repository type ("models" or "datasets")
155
+ repo_name: Repository name
156
+
157
+ Returns:
158
+ Repository information
159
+ """
160
+ url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}"
161
+ response = self._make_request("GET", url)
162
+ data = response.json()
163
+
164
+ # Parse annotations if present
165
+ annotations = {}
166
+ if 'annotations' in data and isinstance(data['annotations'], dict):
167
+ annotations = data['annotations']
168
+
169
+ # Parse metadata if present
170
+ metadata = {}
171
+ if 'metadata' in data and isinstance(data['metadata'], dict):
172
+ metadata = data['metadata']
173
+
174
+ return Repository(
175
+ name=data.get('name', repo_name),
176
+ organization=organization,
177
+ type=repo_type,
178
+ description=data.get('description'),
179
+ metadata=metadata,
180
+ annotations=annotations,
181
+ )
182
+
183
+ def get_repository_refs(
184
+ self,
185
+ organization: str,
186
+ repo_type: str,
187
+ repo_name: str,
188
+ ) -> List[Ref]:
189
+ """
190
+ Get repository references (branches and tags).
191
+
192
+ Args:
193
+ organization: Organization name
194
+ repo_type: Repository type ("models" or "datasets")
195
+ repo_name: Repository name
196
+
197
+ Returns:
198
+ List of references
199
+ """
200
+ url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/refs"
201
+ response = self._make_request("GET", url)
202
+ data = response.json()
203
+
204
+ refs = []
205
+ for ref_data in data:
206
+ refs.append(Ref(
207
+ name=ref_data.get('name', ''),
208
+ ref=ref_data.get('ref', ''),
209
+ type=ref_data.get('type', ''),
210
+ hash=ref_data.get('hash', ''),
211
+ is_default=ref_data.get('isDefault', False),
212
+ ))
213
+
214
+ return refs
215
+
216
+ def get_default_branch(
217
+ self,
218
+ organization: str,
219
+ repo_type: str,
220
+ repo_name: str,
221
+ ) -> str:
222
+ """
223
+ Get the default branch name for a repository.
224
+
225
+ Args:
226
+ organization: Organization name
227
+ repo_type: Repository type ("models" or "datasets")
228
+ repo_name: Repository name
229
+
230
+ Returns:
231
+ Default branch name (defaults to "main" if not found)
232
+ """
233
+ refs = self.get_repository_refs(organization, repo_type, repo_name)
234
+ for ref in refs:
235
+ if ref.is_default and ref.type == "branch":
236
+ return ref.name
237
+ return "main"
238
+
239
+ def get_repository_content(
240
+ self,
241
+ organization: str,
242
+ repo_type: str,
243
+ repo_name: str,
244
+ branch: str,
245
+ path: str = "",
246
+ ) -> GitContent:
247
+ """
248
+ Get repository content at a specific path.
249
+
250
+ Args:
251
+ organization: Organization name
252
+ repo_type: Repository type ("models" or "datasets")
253
+ repo_name: Repository name
254
+ branch: Branch name
255
+ path: Path within the repository (empty for root)
256
+
257
+ Returns:
258
+ Git content information
259
+ """
260
+ if path:
261
+ url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/contents/{branch}/{path}"
262
+ else:
263
+ url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/contents/{branch}"
264
+
265
+ response = self._make_request("GET", url)
266
+ data = response.json()
267
+
268
+ return self._parse_git_content(data)
269
+
270
+ def _parse_git_content(self, data: dict) -> GitContent:
271
+ """Parse GitContent from API response."""
272
+ entries = None
273
+ if 'entries' in data and data['entries']:
274
+ entries = [self._parse_git_content(entry) for entry in data['entries']]
275
+
276
+ return GitContent(
277
+ name=data.get('name', ''),
278
+ path=data.get('path', ''),
279
+ type=data.get('type', 'file'),
280
+ size=data.get('size', 0),
281
+ hash=data.get('hash'),
282
+ content_type=data.get('contentType'),
283
+ content=data.get('content'),
284
+ content_omitted=data.get('contentOmitted', False),
285
+ entries=entries,
286
+ )
287
+
288
+ def add_download_count(self, organization: str, repo_type: str, repo_name: str) -> None:
289
+ """
290
+ Add download count for a repository.
291
+
292
+ Args:
293
+ organization: Organization name
294
+ repo_type: Repository type ("models" or "datasets")
295
+ repo_name: Repository name
296
+ """
297
+ url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/downloads"
298
+ response = self._make_request("POST", url)
299
+ # 检查http code
300
+ if response.status_code != 200:
301
+ raise HTTPError(f"Failed to add download count: {response.text}")
302
+
303
+ def download_file(
304
+ self,
305
+ organization: str,
306
+ repo_type: str,
307
+ repo_name: str,
308
+ branch: str,
309
+ file_path: str,
310
+ local_path: str,
311
+ show_progress: bool = True,
312
+ ) -> None:
313
+ """
314
+ Download a single file from the repository.
315
+
316
+ Args:
317
+ organization: Organization name
318
+ repo_type: Repository type ("models" or "datasets")
319
+ repo_name: Repository name
320
+ branch: Branch name
321
+ file_path: Path to the file in the repository
322
+ local_path: Local path to save the file
323
+ show_progress: Whether to show download progress bar
324
+ """
325
+ url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/resolve/{branch}/{file_path}"
326
+ response = self._make_request("GET", url, stream=True)
327
+
328
+ # Get file size from headers
329
+ total_size = int(response.headers.get('content-length', 0))
330
+
331
+ # Create parent directories if needed
332
+ import os
333
+ os.makedirs(os.path.dirname(local_path) if os.path.dirname(local_path) else '.', exist_ok=True)
334
+
335
+ # Prepare progress bar
336
+ progress_bar = None
337
+ if show_progress and tqdm is not None and total_size > 0:
338
+ # Get filename for display
339
+ filename = os.path.basename(file_path)
340
+ progress_bar = tqdm(
341
+ total=total_size,
342
+ unit='B',
343
+ unit_scale=True,
344
+ unit_divisor=1024,
345
+ desc=filename,
346
+ leave=True,
347
+ )
348
+
349
+ # Write file with progress
350
+ try:
351
+ with open(local_path, 'wb') as f:
352
+ for chunk in response.iter_content(chunk_size=8192):
353
+ if chunk:
354
+ f.write(chunk)
355
+ if progress_bar is not None:
356
+ progress_bar.update(len(chunk))
357
+ finally:
358
+ if progress_bar is not None:
359
+ progress_bar.close()
360
+