xiaoshiai-hub 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ """
2
+ Tests for XiaoShi AI Hub Python SDK
3
+ """
4
+
tests/test_download.py ADDED
@@ -0,0 +1,142 @@
1
+ """
2
+ Tests for download functionality
3
+ """
4
+
5
+ import os
6
+ import tempfile
7
+ import unittest
8
+ from unittest.mock import Mock, patch, MagicMock
9
+
10
+ from xiaoshiai_hub.download import (
11
+ _match_pattern,
12
+ _should_download_file,
13
+ hf_hub_download,
14
+ snapshot_download,
15
+ )
16
+
17
+
18
+ class TestPatternMatching(unittest.TestCase):
19
+ """Test pattern matching functions."""
20
+
21
+ def test_match_pattern_exact(self):
22
+ """Test exact pattern matching."""
23
+ self.assertTrue(_match_pattern("config.yaml", "config.yaml"))
24
+ self.assertFalse(_match_pattern("config.yml", "config.yaml"))
25
+
26
+ def test_match_pattern_wildcard(self):
27
+ """Test wildcard pattern matching."""
28
+ self.assertTrue(_match_pattern("config.yaml", "*.yaml"))
29
+ self.assertTrue(_match_pattern("model.yml", "*.yml"))
30
+ self.assertFalse(_match_pattern("config.txt", "*.yaml"))
31
+
32
+ def test_match_pattern_prefix(self):
33
+ """Test prefix pattern matching."""
34
+ self.assertTrue(_match_pattern("config.yaml", "config*"))
35
+ self.assertTrue(_match_pattern("config_v2.yaml", "config*"))
36
+ self.assertFalse(_match_pattern("model.yaml", "config*"))
37
+
38
+ def test_should_download_file_no_patterns(self):
39
+ """Test file download decision with no patterns."""
40
+ self.assertTrue(_should_download_file("config.yaml"))
41
+ self.assertTrue(_should_download_file("model.bin"))
42
+
43
+ def test_should_download_file_allow_patterns(self):
44
+ """Test file download decision with allow patterns."""
45
+ allow = ["*.yaml", "*.yml"]
46
+ self.assertTrue(_should_download_file("config.yaml", allow_patterns=allow))
47
+ self.assertTrue(_should_download_file("model.yml", allow_patterns=allow))
48
+ self.assertFalse(_should_download_file("model.bin", allow_patterns=allow))
49
+
50
+ def test_should_download_file_ignore_patterns(self):
51
+ """Test file download decision with ignore patterns."""
52
+ ignore = [".git*", "*.tmp"]
53
+ self.assertFalse(_should_download_file(".gitignore", ignore_patterns=ignore))
54
+ self.assertFalse(_should_download_file("temp.tmp", ignore_patterns=ignore))
55
+ self.assertTrue(_should_download_file("config.yaml", ignore_patterns=ignore))
56
+
57
+ def test_should_download_file_both_patterns(self):
58
+ """Test file download decision with both allow and ignore patterns."""
59
+ allow = ["*.yaml", "*.yml"]
60
+ ignore = [".git*"]
61
+
62
+ self.assertTrue(_should_download_file("config.yaml", allow, ignore))
63
+ self.assertFalse(_should_download_file(".gitignore", allow, ignore))
64
+ self.assertFalse(_should_download_file("model.bin", allow, ignore))
65
+
66
+
67
+ class TestDownloadFunctions(unittest.TestCase):
68
+ """Test download functions."""
69
+
70
+ @patch('xiaoshiai_hub.download.HubClient')
71
+ def test_hf_hub_download(self, mock_client_class):
72
+ """Test hf_hub_download function."""
73
+ # Setup mock
74
+ mock_client = Mock()
75
+ mock_client_class.return_value = mock_client
76
+
77
+ mock_repo_info = Mock()
78
+ mock_repo_info.default_branch = "main"
79
+ mock_client.get_repository_info.return_value = mock_repo_info
80
+
81
+ # Test download
82
+ with tempfile.TemporaryDirectory() as tmpdir:
83
+ result = hf_hub_download(
84
+ repo_id="demo/demo",
85
+ filename="config.yaml",
86
+ local_dir=tmpdir,
87
+ username="test",
88
+ password="test",
89
+ )
90
+
91
+ # Verify client was created with correct params
92
+ mock_client_class.assert_called_once()
93
+
94
+ # Verify download was called
95
+ mock_client.download_file.assert_called_once()
96
+
97
+ # Verify result path
98
+ self.assertIn("config.yaml", result)
99
+
100
+ @patch('xiaoshiai_hub.download.HubClient')
101
+ def test_snapshot_download(self, mock_client_class):
102
+ """Test snapshot_download function."""
103
+ # Setup mock
104
+ mock_client = Mock()
105
+ mock_client_class.return_value = mock_client
106
+
107
+ mock_repo_info = Mock()
108
+ mock_repo_info.default_branch = "main"
109
+ mock_client.get_repository_info.return_value = mock_repo_info
110
+
111
+ # Mock content structure
112
+ mock_file = Mock()
113
+ mock_file.type = "file"
114
+ mock_file.path = "config.yaml"
115
+
116
+ mock_content = Mock()
117
+ mock_content.entries = [mock_file]
118
+ mock_client.get_repository_content.return_value = mock_content
119
+
120
+ # Test download
121
+ with tempfile.TemporaryDirectory() as tmpdir:
122
+ result = snapshot_download(
123
+ repo_id="demo/demo",
124
+ local_dir=tmpdir,
125
+ username="test",
126
+ password="test",
127
+ verbose=False,
128
+ )
129
+
130
+ # Verify client was created
131
+ mock_client_class.assert_called_once()
132
+
133
+ # Verify download was called
134
+ mock_client.download_file.assert_called()
135
+
136
+ # Verify result path
137
+ self.assertEqual(result, tmpdir)
138
+
139
+
140
+ if __name__ == '__main__':
141
+ unittest.main()
142
+
@@ -0,0 +1,41 @@
1
+ """
2
+ XiaoShi AI Hub Python SDK
3
+
4
+ A Python library for interacting with XiaoShi AI Hub repositories.
5
+ """
6
+
7
+ from .client import HubClient, DEFAULT_BASE_URL
8
+ from .download import (
9
+ hf_hub_download,
10
+ snapshot_download,
11
+ )
12
+ from .exceptions import (
13
+ HubException,
14
+ RepositoryNotFoundError,
15
+ FileNotFoundError,
16
+ AuthenticationError,
17
+ )
18
+ from .types import (
19
+ Repository,
20
+ Ref,
21
+ GitContent,
22
+ Commit,
23
+ )
24
+
25
+ __version__ = "0.2.0"
26
+
27
+ __all__ = [
28
+ "HubClient",
29
+ "DEFAULT_BASE_URL",
30
+ "hf_hub_download",
31
+ "snapshot_download",
32
+ "HubException",
33
+ "RepositoryNotFoundError",
34
+ "FileNotFoundError",
35
+ "AuthenticationError",
36
+ "Repository",
37
+ "Ref",
38
+ "GitContent",
39
+ "Commit",
40
+ ]
41
+
@@ -0,0 +1,259 @@
1
+ """
2
+ XiaoShi AI Hub Client
3
+ """
4
+
5
+ import base64
6
+ import json
7
+ import os
8
+ from typing import List, Optional
9
+ from urllib.parse import urljoin
10
+
11
+ import requests
12
+
13
+ try:
14
+ from tqdm.auto import tqdm
15
+ except ImportError:
16
+ tqdm = None
17
+
18
+ from .exceptions import (
19
+ AuthenticationError,
20
+ HTTPError,
21
+ RepositoryNotFoundError,
22
+ )
23
+ from .types import Repository, Ref, GitContent
24
+
25
+
26
+ # 默认基础 URL,可通过环境变量 MOHA_ENDPOINT 覆盖
27
+ DEFAULT_BASE_URL = os.environ.get(
28
+ "MOHA_ENDPOINT",
29
+ "https://rune.develop.xiaoshiai.cn/api/moha"
30
+ )
31
+
32
+
33
+ class HubClient:
34
+ """Client for interacting with XiaoShi AI Hub API."""
35
+
36
+ def __init__(
37
+ self,
38
+ base_url: Optional[str] = None,
39
+ username: Optional[str] = None,
40
+ password: Optional[str] = None,
41
+ token: Optional[str] = None,
42
+ ):
43
+ """
44
+ Initialize the Hub client.
45
+
46
+ Args:
47
+ base_url: Base URL of the Hub API (default: from MOHA_ENDPOINT env var)
48
+ username: Username for authentication
49
+ password: Password for authentication
50
+ token: Token for authentication (alternative to username/password)
51
+ """
52
+ self.base_url = (base_url or DEFAULT_BASE_URL).rstrip('/')
53
+ self.username = username
54
+ self.password = password
55
+ self.token = token
56
+ self.session = requests.Session()
57
+
58
+ # Set up authentication
59
+ if token:
60
+ self.session.headers['Authorization'] = f'Bearer {token}'
61
+ elif username and password:
62
+ auth_string = f"{username}:{password}"
63
+ encoded = base64.b64encode(auth_string.encode()).decode()
64
+ self.session.headers['Authorization'] = f'Basic {encoded}'
65
+
66
+ def _make_request(
67
+ self,
68
+ method: str,
69
+ url: str,
70
+ **kwargs
71
+ ) -> requests.Response:
72
+ """Make an HTTP request with error handling."""
73
+ try:
74
+ response = self.session.request(method, url, **kwargs)
75
+
76
+ if response.status_code == 401:
77
+ raise AuthenticationError("Authentication failed")
78
+ elif response.status_code == 404:
79
+ raise RepositoryNotFoundError("Resource not found")
80
+ elif response.status_code >= 400:
81
+ raise HTTPError(
82
+ f"HTTP {response.status_code}: {response.reason}",
83
+ status_code=response.status_code
84
+ )
85
+
86
+ return response
87
+ except requests.RequestException as e:
88
+ raise HTTPError(f"Request failed: {str(e)}")
89
+
90
+ def get_repository_info(
91
+ self,
92
+ organization: str,
93
+ repo_type: str,
94
+ repo_name: str,
95
+ ) -> Repository:
96
+ """
97
+ Get repository information.
98
+
99
+ Args:
100
+ organization: Organization name
101
+ repo_type: Repository type ("models" or "datasets")
102
+ repo_name: Repository name
103
+
104
+ Returns:
105
+ Repository information
106
+ """
107
+ url = f"{self.base_url}/organizations/{organization}/{repo_type}/{repo_name}"
108
+ response = self._make_request("GET", url)
109
+ data = response.json()
110
+
111
+ return Repository(
112
+ name=data.get('name', repo_name),
113
+ organization=organization,
114
+ type=repo_type,
115
+ default_branch=data.get('defaultBranch'),
116
+ description=data.get('description'),
117
+ )
118
+
119
+ def get_repository_refs(
120
+ self,
121
+ organization: str,
122
+ repo_type: str,
123
+ repo_name: str,
124
+ ) -> List[Ref]:
125
+ """
126
+ Get repository references (branches and tags).
127
+
128
+ Args:
129
+ organization: Organization name
130
+ repo_type: Repository type ("models" or "datasets")
131
+ repo_name: Repository name
132
+
133
+ Returns:
134
+ List of references
135
+ """
136
+ url = f"{self.base_url}/organizations/{organization}/{repo_type}/{repo_name}/refs"
137
+ response = self._make_request("GET", url)
138
+ data = response.json()
139
+
140
+ refs = []
141
+ for ref_data in data:
142
+ refs.append(Ref(
143
+ name=ref_data.get('name', ''),
144
+ ref=ref_data.get('ref', ''),
145
+ fully_name=ref_data.get('fullyName', ''),
146
+ type=ref_data.get('type', ''),
147
+ hash=ref_data.get('hash', ''),
148
+ is_default=ref_data.get('isDefault', False),
149
+ ))
150
+
151
+ return refs
152
+
153
+ def get_repository_content(
154
+ self,
155
+ organization: str,
156
+ repo_type: str,
157
+ repo_name: str,
158
+ branch: str,
159
+ path: str = "",
160
+ ) -> GitContent:
161
+ """
162
+ Get repository content at a specific path.
163
+
164
+ Args:
165
+ organization: Organization name
166
+ repo_type: Repository type ("models" or "datasets")
167
+ repo_name: Repository name
168
+ branch: Branch name
169
+ path: Path within the repository (empty for root)
170
+
171
+ Returns:
172
+ Git content information
173
+ """
174
+ if path:
175
+ url = f"{self.base_url}/organizations/{organization}/{repo_type}/{repo_name}/contents/{branch}/{path}"
176
+ else:
177
+ url = f"{self.base_url}/organizations/{organization}/{repo_type}/{repo_name}/contents/{branch}"
178
+
179
+ response = self._make_request("GET", url)
180
+ data = response.json()
181
+
182
+ return self._parse_git_content(data)
183
+
184
+ def _parse_git_content(self, data: dict) -> GitContent:
185
+ """Parse GitContent from API response."""
186
+ entries = None
187
+ if 'entries' in data and data['entries']:
188
+ entries = [self._parse_git_content(entry) for entry in data['entries']]
189
+
190
+ return GitContent(
191
+ name=data.get('name', ''),
192
+ path=data.get('path', ''),
193
+ type=data.get('type', 'file'),
194
+ size=data.get('size', 0),
195
+ hash=data.get('hash'),
196
+ content_type=data.get('contentType'),
197
+ content=data.get('content'),
198
+ content_omitted=data.get('contentOmitted', False),
199
+ entries=entries,
200
+ )
201
+
202
+ def download_file(
203
+ self,
204
+ organization: str,
205
+ repo_type: str,
206
+ repo_name: str,
207
+ branch: str,
208
+ file_path: str,
209
+ local_path: str,
210
+ show_progress: bool = True,
211
+ ) -> None:
212
+ """
213
+ Download a single file from the repository.
214
+
215
+ Args:
216
+ organization: Organization name
217
+ repo_type: Repository type ("models" or "datasets")
218
+ repo_name: Repository name
219
+ branch: Branch name
220
+ file_path: Path to the file in the repository
221
+ local_path: Local path to save the file
222
+ show_progress: Whether to show download progress bar
223
+ """
224
+ url = f"{self.base_url}/organizations/{organization}/{repo_type}/{repo_name}/resolve/{branch}/{file_path}"
225
+ response = self._make_request("GET", url, stream=True)
226
+
227
+ # Get file size from headers
228
+ total_size = int(response.headers.get('content-length', 0))
229
+
230
+ # Create parent directories if needed
231
+ import os
232
+ os.makedirs(os.path.dirname(local_path) if os.path.dirname(local_path) else '.', exist_ok=True)
233
+
234
+ # Prepare progress bar
235
+ progress_bar = None
236
+ if show_progress and tqdm is not None and total_size > 0:
237
+ # Get filename for display
238
+ filename = os.path.basename(file_path)
239
+ progress_bar = tqdm(
240
+ total=total_size,
241
+ unit='B',
242
+ unit_scale=True,
243
+ unit_divisor=1024,
244
+ desc=filename,
245
+ leave=True,
246
+ )
247
+
248
+ # Write file with progress
249
+ try:
250
+ with open(local_path, 'wb') as f:
251
+ for chunk in response.iter_content(chunk_size=8192):
252
+ if chunk:
253
+ f.write(chunk)
254
+ if progress_bar is not None:
255
+ progress_bar.update(len(chunk))
256
+ finally:
257
+ if progress_bar is not None:
258
+ progress_bar.close()
259
+