xiaoshiai-hub 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +4 -0
- tests/test_download.py +142 -0
- xiaoshiai_hub/__init__.py +41 -0
- xiaoshiai_hub/client.py +259 -0
- xiaoshiai_hub/download.py +464 -0
- xiaoshiai_hub/exceptions.py +32 -0
- xiaoshiai_hub/types.py +92 -0
- xiaoshiai_hub-0.1.0.dist-info/METADATA +321 -0
- xiaoshiai_hub-0.1.0.dist-info/RECORD +12 -0
- xiaoshiai_hub-0.1.0.dist-info/WHEEL +5 -0
- xiaoshiai_hub-0.1.0.dist-info/licenses/LICENSE +56 -0
- xiaoshiai_hub-0.1.0.dist-info/top_level.txt +2 -0
tests/__init__.py
ADDED
tests/test_download.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for download functionality
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import tempfile
|
|
7
|
+
import unittest
|
|
8
|
+
from unittest.mock import Mock, patch, MagicMock
|
|
9
|
+
|
|
10
|
+
from xiaoshiai_hub.download import (
|
|
11
|
+
_match_pattern,
|
|
12
|
+
_should_download_file,
|
|
13
|
+
hf_hub_download,
|
|
14
|
+
snapshot_download,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TestPatternMatching(unittest.TestCase):
|
|
19
|
+
"""Test pattern matching functions."""
|
|
20
|
+
|
|
21
|
+
def test_match_pattern_exact(self):
|
|
22
|
+
"""Test exact pattern matching."""
|
|
23
|
+
self.assertTrue(_match_pattern("config.yaml", "config.yaml"))
|
|
24
|
+
self.assertFalse(_match_pattern("config.yml", "config.yaml"))
|
|
25
|
+
|
|
26
|
+
def test_match_pattern_wildcard(self):
|
|
27
|
+
"""Test wildcard pattern matching."""
|
|
28
|
+
self.assertTrue(_match_pattern("config.yaml", "*.yaml"))
|
|
29
|
+
self.assertTrue(_match_pattern("model.yml", "*.yml"))
|
|
30
|
+
self.assertFalse(_match_pattern("config.txt", "*.yaml"))
|
|
31
|
+
|
|
32
|
+
def test_match_pattern_prefix(self):
|
|
33
|
+
"""Test prefix pattern matching."""
|
|
34
|
+
self.assertTrue(_match_pattern("config.yaml", "config*"))
|
|
35
|
+
self.assertTrue(_match_pattern("config_v2.yaml", "config*"))
|
|
36
|
+
self.assertFalse(_match_pattern("model.yaml", "config*"))
|
|
37
|
+
|
|
38
|
+
def test_should_download_file_no_patterns(self):
|
|
39
|
+
"""Test file download decision with no patterns."""
|
|
40
|
+
self.assertTrue(_should_download_file("config.yaml"))
|
|
41
|
+
self.assertTrue(_should_download_file("model.bin"))
|
|
42
|
+
|
|
43
|
+
def test_should_download_file_allow_patterns(self):
|
|
44
|
+
"""Test file download decision with allow patterns."""
|
|
45
|
+
allow = ["*.yaml", "*.yml"]
|
|
46
|
+
self.assertTrue(_should_download_file("config.yaml", allow_patterns=allow))
|
|
47
|
+
self.assertTrue(_should_download_file("model.yml", allow_patterns=allow))
|
|
48
|
+
self.assertFalse(_should_download_file("model.bin", allow_patterns=allow))
|
|
49
|
+
|
|
50
|
+
def test_should_download_file_ignore_patterns(self):
|
|
51
|
+
"""Test file download decision with ignore patterns."""
|
|
52
|
+
ignore = [".git*", "*.tmp"]
|
|
53
|
+
self.assertFalse(_should_download_file(".gitignore", ignore_patterns=ignore))
|
|
54
|
+
self.assertFalse(_should_download_file("temp.tmp", ignore_patterns=ignore))
|
|
55
|
+
self.assertTrue(_should_download_file("config.yaml", ignore_patterns=ignore))
|
|
56
|
+
|
|
57
|
+
def test_should_download_file_both_patterns(self):
|
|
58
|
+
"""Test file download decision with both allow and ignore patterns."""
|
|
59
|
+
allow = ["*.yaml", "*.yml"]
|
|
60
|
+
ignore = [".git*"]
|
|
61
|
+
|
|
62
|
+
self.assertTrue(_should_download_file("config.yaml", allow, ignore))
|
|
63
|
+
self.assertFalse(_should_download_file(".gitignore", allow, ignore))
|
|
64
|
+
self.assertFalse(_should_download_file("model.bin", allow, ignore))
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class TestDownloadFunctions(unittest.TestCase):
|
|
68
|
+
"""Test download functions."""
|
|
69
|
+
|
|
70
|
+
@patch('xiaoshiai_hub.download.HubClient')
|
|
71
|
+
def test_hf_hub_download(self, mock_client_class):
|
|
72
|
+
"""Test hf_hub_download function."""
|
|
73
|
+
# Setup mock
|
|
74
|
+
mock_client = Mock()
|
|
75
|
+
mock_client_class.return_value = mock_client
|
|
76
|
+
|
|
77
|
+
mock_repo_info = Mock()
|
|
78
|
+
mock_repo_info.default_branch = "main"
|
|
79
|
+
mock_client.get_repository_info.return_value = mock_repo_info
|
|
80
|
+
|
|
81
|
+
# Test download
|
|
82
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
83
|
+
result = hf_hub_download(
|
|
84
|
+
repo_id="demo/demo",
|
|
85
|
+
filename="config.yaml",
|
|
86
|
+
local_dir=tmpdir,
|
|
87
|
+
username="test",
|
|
88
|
+
password="test",
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Verify client was created with correct params
|
|
92
|
+
mock_client_class.assert_called_once()
|
|
93
|
+
|
|
94
|
+
# Verify download was called
|
|
95
|
+
mock_client.download_file.assert_called_once()
|
|
96
|
+
|
|
97
|
+
# Verify result path
|
|
98
|
+
self.assertIn("config.yaml", result)
|
|
99
|
+
|
|
100
|
+
@patch('xiaoshiai_hub.download.HubClient')
|
|
101
|
+
def test_snapshot_download(self, mock_client_class):
|
|
102
|
+
"""Test snapshot_download function."""
|
|
103
|
+
# Setup mock
|
|
104
|
+
mock_client = Mock()
|
|
105
|
+
mock_client_class.return_value = mock_client
|
|
106
|
+
|
|
107
|
+
mock_repo_info = Mock()
|
|
108
|
+
mock_repo_info.default_branch = "main"
|
|
109
|
+
mock_client.get_repository_info.return_value = mock_repo_info
|
|
110
|
+
|
|
111
|
+
# Mock content structure
|
|
112
|
+
mock_file = Mock()
|
|
113
|
+
mock_file.type = "file"
|
|
114
|
+
mock_file.path = "config.yaml"
|
|
115
|
+
|
|
116
|
+
mock_content = Mock()
|
|
117
|
+
mock_content.entries = [mock_file]
|
|
118
|
+
mock_client.get_repository_content.return_value = mock_content
|
|
119
|
+
|
|
120
|
+
# Test download
|
|
121
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
122
|
+
result = snapshot_download(
|
|
123
|
+
repo_id="demo/demo",
|
|
124
|
+
local_dir=tmpdir,
|
|
125
|
+
username="test",
|
|
126
|
+
password="test",
|
|
127
|
+
verbose=False,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Verify client was created
|
|
131
|
+
mock_client_class.assert_called_once()
|
|
132
|
+
|
|
133
|
+
# Verify download was called
|
|
134
|
+
mock_client.download_file.assert_called()
|
|
135
|
+
|
|
136
|
+
# Verify result path
|
|
137
|
+
self.assertEqual(result, tmpdir)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
if __name__ == '__main__':
|
|
141
|
+
unittest.main()
|
|
142
|
+
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""
|
|
2
|
+
XiaoShi AI Hub Python SDK
|
|
3
|
+
|
|
4
|
+
A Python library for interacting with XiaoShi AI Hub repositories.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .client import HubClient, DEFAULT_BASE_URL
|
|
8
|
+
from .download import (
|
|
9
|
+
hf_hub_download,
|
|
10
|
+
snapshot_download,
|
|
11
|
+
)
|
|
12
|
+
from .exceptions import (
|
|
13
|
+
HubException,
|
|
14
|
+
RepositoryNotFoundError,
|
|
15
|
+
FileNotFoundError,
|
|
16
|
+
AuthenticationError,
|
|
17
|
+
)
|
|
18
|
+
from .types import (
|
|
19
|
+
Repository,
|
|
20
|
+
Ref,
|
|
21
|
+
GitContent,
|
|
22
|
+
Commit,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__version__ = "0.2.0"
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"HubClient",
|
|
29
|
+
"DEFAULT_BASE_URL",
|
|
30
|
+
"hf_hub_download",
|
|
31
|
+
"snapshot_download",
|
|
32
|
+
"HubException",
|
|
33
|
+
"RepositoryNotFoundError",
|
|
34
|
+
"FileNotFoundError",
|
|
35
|
+
"AuthenticationError",
|
|
36
|
+
"Repository",
|
|
37
|
+
"Ref",
|
|
38
|
+
"GitContent",
|
|
39
|
+
"Commit",
|
|
40
|
+
]
|
|
41
|
+
|
xiaoshiai_hub/client.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
"""
|
|
2
|
+
XiaoShi AI Hub Client
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
from typing import List, Optional
|
|
9
|
+
from urllib.parse import urljoin
|
|
10
|
+
|
|
11
|
+
import requests
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from tqdm.auto import tqdm
|
|
15
|
+
except ImportError:
|
|
16
|
+
tqdm = None
|
|
17
|
+
|
|
18
|
+
from .exceptions import (
|
|
19
|
+
AuthenticationError,
|
|
20
|
+
HTTPError,
|
|
21
|
+
RepositoryNotFoundError,
|
|
22
|
+
)
|
|
23
|
+
from .types import Repository, Ref, GitContent
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# 默认基础 URL,可通过环境变量 MOHA_ENDPOINT 覆盖
|
|
27
|
+
DEFAULT_BASE_URL = os.environ.get(
|
|
28
|
+
"MOHA_ENDPOINT",
|
|
29
|
+
"https://rune.develop.xiaoshiai.cn/api/moha"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class HubClient:
|
|
34
|
+
"""Client for interacting with XiaoShi AI Hub API."""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
base_url: Optional[str] = None,
|
|
39
|
+
username: Optional[str] = None,
|
|
40
|
+
password: Optional[str] = None,
|
|
41
|
+
token: Optional[str] = None,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Initialize the Hub client.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
base_url: Base URL of the Hub API (default: from MOHA_ENDPOINT env var)
|
|
48
|
+
username: Username for authentication
|
|
49
|
+
password: Password for authentication
|
|
50
|
+
token: Token for authentication (alternative to username/password)
|
|
51
|
+
"""
|
|
52
|
+
self.base_url = (base_url or DEFAULT_BASE_URL).rstrip('/')
|
|
53
|
+
self.username = username
|
|
54
|
+
self.password = password
|
|
55
|
+
self.token = token
|
|
56
|
+
self.session = requests.Session()
|
|
57
|
+
|
|
58
|
+
# Set up authentication
|
|
59
|
+
if token:
|
|
60
|
+
self.session.headers['Authorization'] = f'Bearer {token}'
|
|
61
|
+
elif username and password:
|
|
62
|
+
auth_string = f"{username}:{password}"
|
|
63
|
+
encoded = base64.b64encode(auth_string.encode()).decode()
|
|
64
|
+
self.session.headers['Authorization'] = f'Basic {encoded}'
|
|
65
|
+
|
|
66
|
+
def _make_request(
|
|
67
|
+
self,
|
|
68
|
+
method: str,
|
|
69
|
+
url: str,
|
|
70
|
+
**kwargs
|
|
71
|
+
) -> requests.Response:
|
|
72
|
+
"""Make an HTTP request with error handling."""
|
|
73
|
+
try:
|
|
74
|
+
response = self.session.request(method, url, **kwargs)
|
|
75
|
+
|
|
76
|
+
if response.status_code == 401:
|
|
77
|
+
raise AuthenticationError("Authentication failed")
|
|
78
|
+
elif response.status_code == 404:
|
|
79
|
+
raise RepositoryNotFoundError("Resource not found")
|
|
80
|
+
elif response.status_code >= 400:
|
|
81
|
+
raise HTTPError(
|
|
82
|
+
f"HTTP {response.status_code}: {response.reason}",
|
|
83
|
+
status_code=response.status_code
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return response
|
|
87
|
+
except requests.RequestException as e:
|
|
88
|
+
raise HTTPError(f"Request failed: {str(e)}")
|
|
89
|
+
|
|
90
|
+
def get_repository_info(
|
|
91
|
+
self,
|
|
92
|
+
organization: str,
|
|
93
|
+
repo_type: str,
|
|
94
|
+
repo_name: str,
|
|
95
|
+
) -> Repository:
|
|
96
|
+
"""
|
|
97
|
+
Get repository information.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
organization: Organization name
|
|
101
|
+
repo_type: Repository type ("models" or "datasets")
|
|
102
|
+
repo_name: Repository name
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Repository information
|
|
106
|
+
"""
|
|
107
|
+
url = f"{self.base_url}/organizations/{organization}/{repo_type}/{repo_name}"
|
|
108
|
+
response = self._make_request("GET", url)
|
|
109
|
+
data = response.json()
|
|
110
|
+
|
|
111
|
+
return Repository(
|
|
112
|
+
name=data.get('name', repo_name),
|
|
113
|
+
organization=organization,
|
|
114
|
+
type=repo_type,
|
|
115
|
+
default_branch=data.get('defaultBranch'),
|
|
116
|
+
description=data.get('description'),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def get_repository_refs(
|
|
120
|
+
self,
|
|
121
|
+
organization: str,
|
|
122
|
+
repo_type: str,
|
|
123
|
+
repo_name: str,
|
|
124
|
+
) -> List[Ref]:
|
|
125
|
+
"""
|
|
126
|
+
Get repository references (branches and tags).
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
organization: Organization name
|
|
130
|
+
repo_type: Repository type ("models" or "datasets")
|
|
131
|
+
repo_name: Repository name
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
List of references
|
|
135
|
+
"""
|
|
136
|
+
url = f"{self.base_url}/organizations/{organization}/{repo_type}/{repo_name}/refs"
|
|
137
|
+
response = self._make_request("GET", url)
|
|
138
|
+
data = response.json()
|
|
139
|
+
|
|
140
|
+
refs = []
|
|
141
|
+
for ref_data in data:
|
|
142
|
+
refs.append(Ref(
|
|
143
|
+
name=ref_data.get('name', ''),
|
|
144
|
+
ref=ref_data.get('ref', ''),
|
|
145
|
+
fully_name=ref_data.get('fullyName', ''),
|
|
146
|
+
type=ref_data.get('type', ''),
|
|
147
|
+
hash=ref_data.get('hash', ''),
|
|
148
|
+
is_default=ref_data.get('isDefault', False),
|
|
149
|
+
))
|
|
150
|
+
|
|
151
|
+
return refs
|
|
152
|
+
|
|
153
|
+
def get_repository_content(
|
|
154
|
+
self,
|
|
155
|
+
organization: str,
|
|
156
|
+
repo_type: str,
|
|
157
|
+
repo_name: str,
|
|
158
|
+
branch: str,
|
|
159
|
+
path: str = "",
|
|
160
|
+
) -> GitContent:
|
|
161
|
+
"""
|
|
162
|
+
Get repository content at a specific path.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
organization: Organization name
|
|
166
|
+
repo_type: Repository type ("models" or "datasets")
|
|
167
|
+
repo_name: Repository name
|
|
168
|
+
branch: Branch name
|
|
169
|
+
path: Path within the repository (empty for root)
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Git content information
|
|
173
|
+
"""
|
|
174
|
+
if path:
|
|
175
|
+
url = f"{self.base_url}/organizations/{organization}/{repo_type}/{repo_name}/contents/{branch}/{path}"
|
|
176
|
+
else:
|
|
177
|
+
url = f"{self.base_url}/organizations/{organization}/{repo_type}/{repo_name}/contents/{branch}"
|
|
178
|
+
|
|
179
|
+
response = self._make_request("GET", url)
|
|
180
|
+
data = response.json()
|
|
181
|
+
|
|
182
|
+
return self._parse_git_content(data)
|
|
183
|
+
|
|
184
|
+
def _parse_git_content(self, data: dict) -> GitContent:
|
|
185
|
+
"""Parse GitContent from API response."""
|
|
186
|
+
entries = None
|
|
187
|
+
if 'entries' in data and data['entries']:
|
|
188
|
+
entries = [self._parse_git_content(entry) for entry in data['entries']]
|
|
189
|
+
|
|
190
|
+
return GitContent(
|
|
191
|
+
name=data.get('name', ''),
|
|
192
|
+
path=data.get('path', ''),
|
|
193
|
+
type=data.get('type', 'file'),
|
|
194
|
+
size=data.get('size', 0),
|
|
195
|
+
hash=data.get('hash'),
|
|
196
|
+
content_type=data.get('contentType'),
|
|
197
|
+
content=data.get('content'),
|
|
198
|
+
content_omitted=data.get('contentOmitted', False),
|
|
199
|
+
entries=entries,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def download_file(
|
|
203
|
+
self,
|
|
204
|
+
organization: str,
|
|
205
|
+
repo_type: str,
|
|
206
|
+
repo_name: str,
|
|
207
|
+
branch: str,
|
|
208
|
+
file_path: str,
|
|
209
|
+
local_path: str,
|
|
210
|
+
show_progress: bool = True,
|
|
211
|
+
) -> None:
|
|
212
|
+
"""
|
|
213
|
+
Download a single file from the repository.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
organization: Organization name
|
|
217
|
+
repo_type: Repository type ("models" or "datasets")
|
|
218
|
+
repo_name: Repository name
|
|
219
|
+
branch: Branch name
|
|
220
|
+
file_path: Path to the file in the repository
|
|
221
|
+
local_path: Local path to save the file
|
|
222
|
+
show_progress: Whether to show download progress bar
|
|
223
|
+
"""
|
|
224
|
+
url = f"{self.base_url}/organizations/{organization}/{repo_type}/{repo_name}/resolve/{branch}/{file_path}"
|
|
225
|
+
response = self._make_request("GET", url, stream=True)
|
|
226
|
+
|
|
227
|
+
# Get file size from headers
|
|
228
|
+
total_size = int(response.headers.get('content-length', 0))
|
|
229
|
+
|
|
230
|
+
# Create parent directories if needed
|
|
231
|
+
import os
|
|
232
|
+
os.makedirs(os.path.dirname(local_path) if os.path.dirname(local_path) else '.', exist_ok=True)
|
|
233
|
+
|
|
234
|
+
# Prepare progress bar
|
|
235
|
+
progress_bar = None
|
|
236
|
+
if show_progress and tqdm is not None and total_size > 0:
|
|
237
|
+
# Get filename for display
|
|
238
|
+
filename = os.path.basename(file_path)
|
|
239
|
+
progress_bar = tqdm(
|
|
240
|
+
total=total_size,
|
|
241
|
+
unit='B',
|
|
242
|
+
unit_scale=True,
|
|
243
|
+
unit_divisor=1024,
|
|
244
|
+
desc=filename,
|
|
245
|
+
leave=True,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Write file with progress
|
|
249
|
+
try:
|
|
250
|
+
with open(local_path, 'wb') as f:
|
|
251
|
+
for chunk in response.iter_content(chunk_size=8192):
|
|
252
|
+
if chunk:
|
|
253
|
+
f.write(chunk)
|
|
254
|
+
if progress_bar is not None:
|
|
255
|
+
progress_bar.update(len(chunk))
|
|
256
|
+
finally:
|
|
257
|
+
if progress_bar is not None:
|
|
258
|
+
progress_bar.close()
|
|
259
|
+
|