xiaoshiai-hub 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xiaoshiai_hub/__init__.py +83 -0
- xiaoshiai_hub/client.py +360 -0
- xiaoshiai_hub/download.py +598 -0
- xiaoshiai_hub/encryption.py +777 -0
- xiaoshiai_hub/exceptions.py +37 -0
- xiaoshiai_hub/types.py +109 -0
- xiaoshiai_hub/upload.py +875 -0
- xiaoshiai_hub-0.1.1.dist-info/METADATA +560 -0
- xiaoshiai_hub-0.1.1.dist-info/RECORD +12 -0
- xiaoshiai_hub-0.1.1.dist-info/WHEEL +5 -0
- xiaoshiai_hub-0.1.1.dist-info/licenses/LICENSE +56 -0
- xiaoshiai_hub-0.1.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""
|
|
2
|
+
XiaoShi AI Hub Python SDK
|
|
3
|
+
|
|
4
|
+
A Python library for interacting with XiaoShi AI Hub repositories.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .client import HubClient, DEFAULT_BASE_URL
|
|
8
|
+
from .download import (
|
|
9
|
+
moha_hub_download,
|
|
10
|
+
snapshot_download,
|
|
11
|
+
)
|
|
12
|
+
from .exceptions import (
|
|
13
|
+
HubException,
|
|
14
|
+
RepositoryNotFoundError,
|
|
15
|
+
FileNotFoundError,
|
|
16
|
+
AuthenticationError,
|
|
17
|
+
EncryptionError,
|
|
18
|
+
)
|
|
19
|
+
from .types import (
|
|
20
|
+
Repository,
|
|
21
|
+
Ref,
|
|
22
|
+
GitContent,
|
|
23
|
+
Commit,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Upload functionality (requires GitPython)
|
|
27
|
+
try:
|
|
28
|
+
from .upload import (
|
|
29
|
+
upload_file,
|
|
30
|
+
upload_folder,
|
|
31
|
+
UploadError,
|
|
32
|
+
)
|
|
33
|
+
_upload_available = True
|
|
34
|
+
except ImportError:
|
|
35
|
+
_upload_available = False
|
|
36
|
+
upload_file = None
|
|
37
|
+
upload_folder = None
|
|
38
|
+
UploadError = None
|
|
39
|
+
|
|
40
|
+
# Encryption functionality
|
|
41
|
+
try:
|
|
42
|
+
from .encryption import (
|
|
43
|
+
EncryptionAlgorithm,
|
|
44
|
+
encrypt_file,
|
|
45
|
+
decrypt_file,
|
|
46
|
+
)
|
|
47
|
+
_encryption_available = True
|
|
48
|
+
except ImportError:
|
|
49
|
+
_encryption_available = False
|
|
50
|
+
EncryptionAlgorithm = None
|
|
51
|
+
encrypt_file = None
|
|
52
|
+
decrypt_file = None
|
|
53
|
+
|
|
54
|
+
__version__ = "0.1.0"
|
|
55
|
+
|
|
56
|
+
__all__ = [
|
|
57
|
+
# Client
|
|
58
|
+
"HubClient",
|
|
59
|
+
"DEFAULT_BASE_URL",
|
|
60
|
+
# Download
|
|
61
|
+
"moha_hub_download",
|
|
62
|
+
"snapshot_download",
|
|
63
|
+
# Upload
|
|
64
|
+
"upload_file",
|
|
65
|
+
"upload_folder",
|
|
66
|
+
# Exceptions
|
|
67
|
+
"HubException",
|
|
68
|
+
"RepositoryNotFoundError",
|
|
69
|
+
"FileNotFoundError",
|
|
70
|
+
"AuthenticationError",
|
|
71
|
+
"EncryptionError",
|
|
72
|
+
"UploadError",
|
|
73
|
+
# Types
|
|
74
|
+
"Repository",
|
|
75
|
+
"Ref",
|
|
76
|
+
"GitContent",
|
|
77
|
+
"Commit",
|
|
78
|
+
# Encryption
|
|
79
|
+
"EncryptionAlgorithm",
|
|
80
|
+
"encrypt_file",
|
|
81
|
+
"decrypt_file",
|
|
82
|
+
]
|
|
83
|
+
|
xiaoshiai_hub/client.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
"""
|
|
2
|
+
XiaoShi AI Hub Client
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
import os
|
|
7
|
+
from typing import List, Optional
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
import requests
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from tqdm.auto import tqdm
|
|
14
|
+
except ImportError:
|
|
15
|
+
tqdm = None
|
|
16
|
+
|
|
17
|
+
from .exceptions import (
|
|
18
|
+
AuthenticationError,
|
|
19
|
+
HTTPError,
|
|
20
|
+
RepositoryNotFoundError,
|
|
21
|
+
)
|
|
22
|
+
from .types import EncryptionMetadata, Repository, Ref, GitContent
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# 默认基础 URL,可通过环境变量 MOHA_ENDPOINT 覆盖
|
|
26
|
+
DEFAULT_BASE_URL = os.environ.get(
|
|
27
|
+
"MOHA_ENDPOINT",
|
|
28
|
+
"https://rune-api.develop.xiaoshiai.cn/moha"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class HubClient:
|
|
33
|
+
"""Client for interacting with XiaoShi AI Hub API."""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
base_url: Optional[str] = None,
|
|
38
|
+
username: Optional[str] = None,
|
|
39
|
+
password: Optional[str] = None,
|
|
40
|
+
token: Optional[str] = None,
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Initialize the Hub client.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
base_url: Base URL of the Hub API (default: from MOHA_ENDPOINT env var)
|
|
47
|
+
username: Username for authentication
|
|
48
|
+
password: Password for authentication
|
|
49
|
+
token: Token for authentication (alternative to username/password)
|
|
50
|
+
"""
|
|
51
|
+
self.base_url = (base_url or DEFAULT_BASE_URL).rstrip('/')
|
|
52
|
+
self.username = username
|
|
53
|
+
self.password = password
|
|
54
|
+
self.token = token
|
|
55
|
+
self.session = requests.Session()
|
|
56
|
+
|
|
57
|
+
# Set up authentication
|
|
58
|
+
if token:
|
|
59
|
+
self.session.headers['Authorization'] = f'Bearer {token}'
|
|
60
|
+
elif username and password:
|
|
61
|
+
auth_string = f"{username}:{password}"
|
|
62
|
+
encoded = base64.b64encode(auth_string.encode()).decode()
|
|
63
|
+
self.session.headers['Authorization'] = f'Basic {encoded}'
|
|
64
|
+
|
|
65
|
+
def _make_request(
|
|
66
|
+
self,
|
|
67
|
+
method: str,
|
|
68
|
+
url: str,
|
|
69
|
+
**kwargs
|
|
70
|
+
) -> requests.Response:
|
|
71
|
+
"""Make an HTTP request with error handling."""
|
|
72
|
+
try:
|
|
73
|
+
response = self.session.request(method, url, **kwargs)
|
|
74
|
+
|
|
75
|
+
if response.status_code == 401:
|
|
76
|
+
raise AuthenticationError("Authentication failed")
|
|
77
|
+
elif response.status_code == 404:
|
|
78
|
+
raise RepositoryNotFoundError("Resource not found")
|
|
79
|
+
elif response.status_code >= 400:
|
|
80
|
+
raise HTTPError(
|
|
81
|
+
f"HTTP {response.status_code}: {response.reason}",
|
|
82
|
+
status_code=response.status_code
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return response
|
|
86
|
+
except requests.RequestException as e:
|
|
87
|
+
raise HTTPError(f"Request failed: {str(e)}")
|
|
88
|
+
|
|
89
|
+
def get_moha_encryption(
|
|
90
|
+
self,
|
|
91
|
+
organization: str,
|
|
92
|
+
repo_type: str,
|
|
93
|
+
repo_name: str,
|
|
94
|
+
reference: str,
|
|
95
|
+
) -> Optional[EncryptionMetadata]:
|
|
96
|
+
"""
|
|
97
|
+
Get .moha_encryption file content from the repository.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
organization: Organization name
|
|
101
|
+
repo_type: Repository type ("models" or "datasets")
|
|
102
|
+
repo_name: Repository name
|
|
103
|
+
reference: Branch/tag/commit reference
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
Encryption metadata if the repository has encrypted files, None otherwise
|
|
107
|
+
|
|
108
|
+
Raises:
|
|
109
|
+
AuthenticationError: If authentication fails
|
|
110
|
+
RepositoryNotFoundError: If the repository is not found
|
|
111
|
+
HTTPError: If the request fails
|
|
112
|
+
"""
|
|
113
|
+
url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/encryption/{reference}"
|
|
114
|
+
response = self._make_request("GET", url)
|
|
115
|
+
if response.status_code == 204 or not response.content:
|
|
116
|
+
return None
|
|
117
|
+
try:
|
|
118
|
+
data = response.json()
|
|
119
|
+
except ValueError:
|
|
120
|
+
return None
|
|
121
|
+
if not data:
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
from datetime import datetime
|
|
125
|
+
files = None
|
|
126
|
+
if 'files' in data and data['files']:
|
|
127
|
+
from .types import FileEncryptionMetadata
|
|
128
|
+
files = [
|
|
129
|
+
FileEncryptionMetadata(
|
|
130
|
+
path=f.get('path', ''),
|
|
131
|
+
algorithm=f.get('algorithm', ''),
|
|
132
|
+
encryptedHash=f.get('encryptedHash', ''),
|
|
133
|
+
encryptedSize=f.get('encryptedSize', 0),
|
|
134
|
+
)
|
|
135
|
+
for f in data['files']
|
|
136
|
+
]
|
|
137
|
+
return EncryptionMetadata(
|
|
138
|
+
version=data.get('version', '1.0'),
|
|
139
|
+
createAt=datetime.fromisoformat(data['createAt'].replace('Z', '+00:00')) if 'createAt' in data else datetime.now(),
|
|
140
|
+
files=files,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def get_repository_info(
|
|
144
|
+
self,
|
|
145
|
+
organization: str,
|
|
146
|
+
repo_type: str,
|
|
147
|
+
repo_name: str,
|
|
148
|
+
) -> Repository:
|
|
149
|
+
"""
|
|
150
|
+
Get repository information.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
organization: Organization name
|
|
154
|
+
repo_type: Repository type ("models" or "datasets")
|
|
155
|
+
repo_name: Repository name
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Repository information
|
|
159
|
+
"""
|
|
160
|
+
url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}"
|
|
161
|
+
response = self._make_request("GET", url)
|
|
162
|
+
data = response.json()
|
|
163
|
+
|
|
164
|
+
# Parse annotations if present
|
|
165
|
+
annotations = {}
|
|
166
|
+
if 'annotations' in data and isinstance(data['annotations'], dict):
|
|
167
|
+
annotations = data['annotations']
|
|
168
|
+
|
|
169
|
+
# Parse metadata if present
|
|
170
|
+
metadata = {}
|
|
171
|
+
if 'metadata' in data and isinstance(data['metadata'], dict):
|
|
172
|
+
metadata = data['metadata']
|
|
173
|
+
|
|
174
|
+
return Repository(
|
|
175
|
+
name=data.get('name', repo_name),
|
|
176
|
+
organization=organization,
|
|
177
|
+
type=repo_type,
|
|
178
|
+
description=data.get('description'),
|
|
179
|
+
metadata=metadata,
|
|
180
|
+
annotations=annotations,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def get_repository_refs(
|
|
184
|
+
self,
|
|
185
|
+
organization: str,
|
|
186
|
+
repo_type: str,
|
|
187
|
+
repo_name: str,
|
|
188
|
+
) -> List[Ref]:
|
|
189
|
+
"""
|
|
190
|
+
Get repository references (branches and tags).
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
organization: Organization name
|
|
194
|
+
repo_type: Repository type ("models" or "datasets")
|
|
195
|
+
repo_name: Repository name
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
List of references
|
|
199
|
+
"""
|
|
200
|
+
url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/refs"
|
|
201
|
+
response = self._make_request("GET", url)
|
|
202
|
+
data = response.json()
|
|
203
|
+
|
|
204
|
+
refs = []
|
|
205
|
+
for ref_data in data:
|
|
206
|
+
refs.append(Ref(
|
|
207
|
+
name=ref_data.get('name', ''),
|
|
208
|
+
ref=ref_data.get('ref', ''),
|
|
209
|
+
type=ref_data.get('type', ''),
|
|
210
|
+
hash=ref_data.get('hash', ''),
|
|
211
|
+
is_default=ref_data.get('isDefault', False),
|
|
212
|
+
))
|
|
213
|
+
|
|
214
|
+
return refs
|
|
215
|
+
|
|
216
|
+
def get_default_branch(
|
|
217
|
+
self,
|
|
218
|
+
organization: str,
|
|
219
|
+
repo_type: str,
|
|
220
|
+
repo_name: str,
|
|
221
|
+
) -> str:
|
|
222
|
+
"""
|
|
223
|
+
Get the default branch name for a repository.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
organization: Organization name
|
|
227
|
+
repo_type: Repository type ("models" or "datasets")
|
|
228
|
+
repo_name: Repository name
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
Default branch name (defaults to "main" if not found)
|
|
232
|
+
"""
|
|
233
|
+
refs = self.get_repository_refs(organization, repo_type, repo_name)
|
|
234
|
+
for ref in refs:
|
|
235
|
+
if ref.is_default and ref.type == "branch":
|
|
236
|
+
return ref.name
|
|
237
|
+
return "main"
|
|
238
|
+
|
|
239
|
+
def get_repository_content(
|
|
240
|
+
self,
|
|
241
|
+
organization: str,
|
|
242
|
+
repo_type: str,
|
|
243
|
+
repo_name: str,
|
|
244
|
+
branch: str,
|
|
245
|
+
path: str = "",
|
|
246
|
+
) -> GitContent:
|
|
247
|
+
"""
|
|
248
|
+
Get repository content at a specific path.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
organization: Organization name
|
|
252
|
+
repo_type: Repository type ("models" or "datasets")
|
|
253
|
+
repo_name: Repository name
|
|
254
|
+
branch: Branch name
|
|
255
|
+
path: Path within the repository (empty for root)
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
Git content information
|
|
259
|
+
"""
|
|
260
|
+
if path:
|
|
261
|
+
url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/contents/{branch}/{path}"
|
|
262
|
+
else:
|
|
263
|
+
url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/contents/{branch}"
|
|
264
|
+
|
|
265
|
+
response = self._make_request("GET", url)
|
|
266
|
+
data = response.json()
|
|
267
|
+
|
|
268
|
+
return self._parse_git_content(data)
|
|
269
|
+
|
|
270
|
+
def _parse_git_content(self, data: dict) -> GitContent:
|
|
271
|
+
"""Parse GitContent from API response."""
|
|
272
|
+
entries = None
|
|
273
|
+
if 'entries' in data and data['entries']:
|
|
274
|
+
entries = [self._parse_git_content(entry) for entry in data['entries']]
|
|
275
|
+
|
|
276
|
+
return GitContent(
|
|
277
|
+
name=data.get('name', ''),
|
|
278
|
+
path=data.get('path', ''),
|
|
279
|
+
type=data.get('type', 'file'),
|
|
280
|
+
size=data.get('size', 0),
|
|
281
|
+
hash=data.get('hash'),
|
|
282
|
+
content_type=data.get('contentType'),
|
|
283
|
+
content=data.get('content'),
|
|
284
|
+
content_omitted=data.get('contentOmitted', False),
|
|
285
|
+
entries=entries,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
def add_download_count(self, organization: str, repo_type: str, repo_name: str) -> None:
|
|
289
|
+
"""
|
|
290
|
+
Add download count for a repository.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
organization: Organization name
|
|
294
|
+
repo_type: Repository type ("models" or "datasets")
|
|
295
|
+
repo_name: Repository name
|
|
296
|
+
"""
|
|
297
|
+
url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/downloads"
|
|
298
|
+
response = self._make_request("POST", url)
|
|
299
|
+
# 检查http code
|
|
300
|
+
if response.status_code != 200:
|
|
301
|
+
raise HTTPError(f"Failed to add download count: {response.text}")
|
|
302
|
+
|
|
303
|
+
def download_file(
|
|
304
|
+
self,
|
|
305
|
+
organization: str,
|
|
306
|
+
repo_type: str,
|
|
307
|
+
repo_name: str,
|
|
308
|
+
branch: str,
|
|
309
|
+
file_path: str,
|
|
310
|
+
local_path: str,
|
|
311
|
+
show_progress: bool = True,
|
|
312
|
+
) -> None:
|
|
313
|
+
"""
|
|
314
|
+
Download a single file from the repository.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
organization: Organization name
|
|
318
|
+
repo_type: Repository type ("models" or "datasets")
|
|
319
|
+
repo_name: Repository name
|
|
320
|
+
branch: Branch name
|
|
321
|
+
file_path: Path to the file in the repository
|
|
322
|
+
local_path: Local path to save the file
|
|
323
|
+
show_progress: Whether to show download progress bar
|
|
324
|
+
"""
|
|
325
|
+
url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/resolve/{branch}/{file_path}"
|
|
326
|
+
response = self._make_request("GET", url, stream=True)
|
|
327
|
+
|
|
328
|
+
# Get file size from headers
|
|
329
|
+
total_size = int(response.headers.get('content-length', 0))
|
|
330
|
+
|
|
331
|
+
# Create parent directories if needed
|
|
332
|
+
import os
|
|
333
|
+
os.makedirs(os.path.dirname(local_path) if os.path.dirname(local_path) else '.', exist_ok=True)
|
|
334
|
+
|
|
335
|
+
# Prepare progress bar
|
|
336
|
+
progress_bar = None
|
|
337
|
+
if show_progress and tqdm is not None and total_size > 0:
|
|
338
|
+
# Get filename for display
|
|
339
|
+
filename = os.path.basename(file_path)
|
|
340
|
+
progress_bar = tqdm(
|
|
341
|
+
total=total_size,
|
|
342
|
+
unit='B',
|
|
343
|
+
unit_scale=True,
|
|
344
|
+
unit_divisor=1024,
|
|
345
|
+
desc=filename,
|
|
346
|
+
leave=True,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# Write file with progress
|
|
350
|
+
try:
|
|
351
|
+
with open(local_path, 'wb') as f:
|
|
352
|
+
for chunk in response.iter_content(chunk_size=8192):
|
|
353
|
+
if chunk:
|
|
354
|
+
f.write(chunk)
|
|
355
|
+
if progress_bar is not None:
|
|
356
|
+
progress_bar.update(len(chunk))
|
|
357
|
+
finally:
|
|
358
|
+
if progress_bar is not None:
|
|
359
|
+
progress_bar.close()
|
|
360
|
+
|