xiaoshiai-hub 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xiaoshiai_hub/__init__.py +18 -3
- xiaoshiai_hub/auth.py +163 -0
- xiaoshiai_hub/cli.py +496 -0
- xiaoshiai_hub/client.py +57 -10
- xiaoshiai_hub/download.py +21 -3
- xiaoshiai_hub/envelope_crypto.py +237 -0
- xiaoshiai_hub/upload.py +47 -63
- {xiaoshiai_hub-1.0.0.dist-info → xiaoshiai_hub-1.1.0.dist-info}/METADATA +177 -50
- xiaoshiai_hub-1.1.0.dist-info/RECORD +15 -0
- xiaoshiai_hub-1.1.0.dist-info/entry_points.txt +2 -0
- xiaoshiai_hub-1.0.0.dist-info/RECORD +0 -11
- {xiaoshiai_hub-1.0.0.dist-info → xiaoshiai_hub-1.1.0.dist-info}/WHEEL +0 -0
- {xiaoshiai_hub-1.0.0.dist-info → xiaoshiai_hub-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {xiaoshiai_hub-1.0.0.dist-info → xiaoshiai_hub-1.1.0.dist-info}/top_level.txt +0 -0
xiaoshiai_hub/client.py
CHANGED
|
@@ -4,10 +4,11 @@ XiaoShi AI Hub Client
|
|
|
4
4
|
|
|
5
5
|
import base64
|
|
6
6
|
import os
|
|
7
|
-
from typing import List, Optional
|
|
7
|
+
from typing import Dict, List, Optional
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
import requests
|
|
11
|
+
from xiaoshiai_hub.envelope_crypto import DataKey
|
|
11
12
|
|
|
12
13
|
try:
|
|
13
14
|
from tqdm.auto import tqdm
|
|
@@ -25,10 +26,11 @@ from .types import Repository, Ref, GitContent
|
|
|
25
26
|
# 默认基础 URL,可通过环境变量 MOHA_ENDPOINT 覆盖
|
|
26
27
|
DEFAULT_BASE_URL = os.environ.get(
|
|
27
28
|
"MOHA_ENDPOINT",
|
|
28
|
-
"https://rune-api.develop.xiaoshiai.cn
|
|
29
|
+
"https://rune-api.develop.xiaoshiai.cn"
|
|
29
30
|
)
|
|
30
31
|
|
|
31
32
|
|
|
33
|
+
|
|
32
34
|
class HubClient:
|
|
33
35
|
"""Client for interacting with XiaoShi AI Hub API."""
|
|
34
36
|
|
|
@@ -85,6 +87,30 @@ class HubClient:
|
|
|
85
87
|
return response
|
|
86
88
|
except requests.RequestException as e:
|
|
87
89
|
raise HTTPError(f"Request failed: {str(e)}")
|
|
90
|
+
|
|
91
|
+
def cancel_repository_encrypted(
|
|
92
|
+
self,
|
|
93
|
+
organization: str,
|
|
94
|
+
repo_type: str,
|
|
95
|
+
repo_name: str,
|
|
96
|
+
) -> None:
|
|
97
|
+
"""Cancel repository encrypted flag."""
|
|
98
|
+
url = f"{self.base_url}/moha/v1/organizations/{organization}/{repo_type}/{repo_name}/encryption/cancel"
|
|
99
|
+
response = self._make_request("PUT", url)
|
|
100
|
+
if response.status_code != 200:
|
|
101
|
+
raise HTTPError(f"Failed to cancel repository encrypted flag: {response.text}")
|
|
102
|
+
|
|
103
|
+
def set_repository_encrypted(
|
|
104
|
+
self,
|
|
105
|
+
organization: str,
|
|
106
|
+
repo_type: str,
|
|
107
|
+
repo_name: str,
|
|
108
|
+
) -> None:
|
|
109
|
+
"""Set repository encrypted flag."""
|
|
110
|
+
url = f"{self.base_url}/moha/v1/organizations/{organization}/{repo_type}/{repo_name}/encryption/set"
|
|
111
|
+
response = self._make_request("PUT", url)
|
|
112
|
+
if response.status_code != 200:
|
|
113
|
+
raise HTTPError(f"Failed to set repository encrypted flag: {response.text}")
|
|
88
114
|
|
|
89
115
|
def get_repository_info(
|
|
90
116
|
self,
|
|
@@ -103,17 +129,17 @@ class HubClient:
|
|
|
103
129
|
Returns:
|
|
104
130
|
Repository information
|
|
105
131
|
"""
|
|
106
|
-
url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}"
|
|
132
|
+
url = f"{self.base_url}/moha/v1/organizations/{organization}/{repo_type}/{repo_name}"
|
|
107
133
|
response = self._make_request("GET", url)
|
|
108
134
|
data = response.json()
|
|
109
135
|
|
|
110
136
|
# Parse annotations if present
|
|
111
|
-
annotations = {}
|
|
137
|
+
annotations: Dict[str, str] = {}
|
|
112
138
|
if 'annotations' in data and isinstance(data['annotations'], dict):
|
|
113
139
|
annotations = data['annotations']
|
|
114
140
|
|
|
115
141
|
# Parse metadata if present
|
|
116
|
-
metadata = {}
|
|
142
|
+
metadata: Dict[str, List[str]] = {}
|
|
117
143
|
if 'metadata' in data and isinstance(data['metadata'], dict):
|
|
118
144
|
metadata = data['metadata']
|
|
119
145
|
|
|
@@ -143,7 +169,7 @@ class HubClient:
|
|
|
143
169
|
Returns:
|
|
144
170
|
List of references
|
|
145
171
|
"""
|
|
146
|
-
url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/refs"
|
|
172
|
+
url = f"{self.base_url}/moha/v1/organizations/{organization}/{repo_type}/{repo_name}/refs"
|
|
147
173
|
response = self._make_request("GET", url)
|
|
148
174
|
data = response.json()
|
|
149
175
|
|
|
@@ -204,9 +230,9 @@ class HubClient:
|
|
|
204
230
|
Git content information
|
|
205
231
|
"""
|
|
206
232
|
if path:
|
|
207
|
-
url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/contents/{branch}/{path}"
|
|
233
|
+
url = f"{self.base_url}/moha/v1/organizations/{organization}/{repo_type}/{repo_name}/contents/{branch}/{path}"
|
|
208
234
|
else:
|
|
209
|
-
url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/contents/{branch}"
|
|
235
|
+
url = f"{self.base_url}/moha/v1/organizations/{organization}/{repo_type}/{repo_name}/contents/{branch}"
|
|
210
236
|
|
|
211
237
|
response = self._make_request("GET", url)
|
|
212
238
|
data = response.json()
|
|
@@ -240,7 +266,7 @@ class HubClient:
|
|
|
240
266
|
repo_type: Repository type ("models" or "datasets")
|
|
241
267
|
repo_name: Repository name
|
|
242
268
|
"""
|
|
243
|
-
url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/downloads"
|
|
269
|
+
url = f"{self.base_url}/moha/v1/organizations/{organization}/{repo_type}/{repo_name}/downloads"
|
|
244
270
|
response = self._make_request("POST", url)
|
|
245
271
|
# 检查http code
|
|
246
272
|
if response.status_code != 200:
|
|
@@ -268,7 +294,7 @@ class HubClient:
|
|
|
268
294
|
local_path: Local path to save the file
|
|
269
295
|
show_progress: Whether to show download progress bar
|
|
270
296
|
"""
|
|
271
|
-
url = f"{self.base_url}/v1/organizations/{organization}/{repo_type}/{repo_name}/resolve/{branch}/{file_path}"
|
|
297
|
+
url = f"{self.base_url}/moha/v1/organizations/{organization}/{repo_type}/{repo_name}/resolve/{branch}/{file_path}"
|
|
272
298
|
response = self._make_request("GET", url, stream=True)
|
|
273
299
|
|
|
274
300
|
# Get file size from headers
|
|
@@ -311,3 +337,24 @@ class HubClient:
|
|
|
311
337
|
if progress_bar is not None:
|
|
312
338
|
progress_bar.close()
|
|
313
339
|
|
|
340
|
+
|
|
341
|
+
def generate_data_key(self, password: Optional[str] = None) -> DataKey:
|
|
342
|
+
url = f"{self.base_url}/api/kms/generate-data-key"
|
|
343
|
+
try:
|
|
344
|
+
resp = requests.post(
|
|
345
|
+
url,
|
|
346
|
+
json={"password": password},
|
|
347
|
+
timeout=30
|
|
348
|
+
)
|
|
349
|
+
resp.raise_for_status()
|
|
350
|
+
data = resp.json()
|
|
351
|
+
plaintext_key = base64.b64decode(data["plaintextKey"])
|
|
352
|
+
encrypted_key = data["encryptedKey"]
|
|
353
|
+
|
|
354
|
+
return DataKey(
|
|
355
|
+
plaintext_key=plaintext_key,
|
|
356
|
+
encrypted_key=encrypted_key
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
except requests.RequestException as e:
|
|
360
|
+
raise e
|
xiaoshiai_hub/download.py
CHANGED
|
@@ -7,14 +7,15 @@ import os
|
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
from typing import List, Optional, Union
|
|
9
9
|
|
|
10
|
+
from .exceptions import RepositoryNotFoundError
|
|
11
|
+
|
|
10
12
|
try:
|
|
11
13
|
from tqdm.auto import tqdm
|
|
12
14
|
except ImportError:
|
|
13
15
|
tqdm = None
|
|
14
16
|
|
|
15
|
-
from .client import HubClient
|
|
16
|
-
|
|
17
|
-
from .exceptions import RepositoryNotFoundError
|
|
17
|
+
from .client import HubClient
|
|
18
|
+
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
def _match_pattern(name: str, pattern: str) -> bool:
|
|
@@ -263,6 +264,16 @@ def moha_hub_download(
|
|
|
263
264
|
password=password,
|
|
264
265
|
token=token,
|
|
265
266
|
)
|
|
267
|
+
|
|
268
|
+
try:
|
|
269
|
+
client.get_repository_info(organization, repo_type, repo_name)
|
|
270
|
+
except RepositoryNotFoundError:
|
|
271
|
+
raise RepositoryNotFoundError(
|
|
272
|
+
f"Repository not found: {organization}/{repo_type}/{repo_name}. "
|
|
273
|
+
f"Please create the repository first."
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
|
|
266
277
|
# 获取默认分支
|
|
267
278
|
if revision is None:
|
|
268
279
|
revision = client.get_default_branch(organization, repo_type, repo_name)
|
|
@@ -348,6 +359,13 @@ def snapshot_download(
|
|
|
348
359
|
password=password,
|
|
349
360
|
token=token,
|
|
350
361
|
)
|
|
362
|
+
try:
|
|
363
|
+
client.get_repository_info(organization, repo_type, repo_name)
|
|
364
|
+
except RepositoryNotFoundError:
|
|
365
|
+
raise RepositoryNotFoundError(
|
|
366
|
+
f"Repository not found: {organization}/{repo_type}/{repo_name}. "
|
|
367
|
+
f"Please create the repository first."
|
|
368
|
+
)
|
|
351
369
|
if revision is None:
|
|
352
370
|
revision = client.get_default_branch(organization, repo_type, repo_name)
|
|
353
371
|
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
"""
|
|
2
|
+
信封加密核心模块
|
|
3
|
+
|
|
4
|
+
实现基于 AES-256-CTR 的信封加密/解密逻辑。
|
|
5
|
+
支持随机访问和流式解密,适合大文件部分读取场景。
|
|
6
|
+
|
|
7
|
+
信封加密文件格式:
|
|
8
|
+
- 前 4 字节: 元数据长度 (big-endian)
|
|
9
|
+
- 元数据: JSON 格式,包含 encryptedKey 和 password
|
|
10
|
+
- 16 字节: IV (用于 AES-256-CTR)
|
|
11
|
+
- 剩余部分: 加密后的文件内容
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import io
|
|
15
|
+
import os
|
|
16
|
+
import json
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
import shutil
|
|
20
|
+
import tempfile
|
|
21
|
+
from typing import Optional, Union
|
|
22
|
+
|
|
23
|
+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
24
|
+
from cryptography.hazmat.backends import default_backend
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# IV 长度 (AES-256-CTR)
|
|
28
|
+
IV_SIZE = 16
|
|
29
|
+
# 元数据长度字段大小
|
|
30
|
+
METADATA_LENGTH_SIZE = 4
|
|
31
|
+
CHUNK_SIZE = 64 * 1024 # 64KB chunks
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class DataKey:
|
|
35
|
+
"""数据密钥对象,包含明文密钥和加密后的密钥"""
|
|
36
|
+
plaintext_key: bytes # 明文密钥(32字节,用于 AES-256)
|
|
37
|
+
encrypted_key: str
|
|
38
|
+
|
|
39
|
+
def _raw_open(path, mode="rb"):
|
|
40
|
+
"""
|
|
41
|
+
获取原始的文件句柄,绕过 decrypt_patch 的 hook。
|
|
42
|
+
|
|
43
|
+
使用 io.FileIO 直接访问文件系统,不经过 builtins.open。
|
|
44
|
+
"""
|
|
45
|
+
# io.FileIO 是底层实现,不会被 builtins.open 的 hook 影响
|
|
46
|
+
file_io = io.FileIO(path, mode.replace("b", ""))
|
|
47
|
+
if "b" in mode:
|
|
48
|
+
return io.BufferedReader(file_io) if "r" in mode else io.BufferedWriter(file_io)
|
|
49
|
+
return file_io
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def create_aes_ctr_cipher(key: bytes, iv: bytes):
|
|
53
|
+
"""创建 AES-256-CTR cipher"""
|
|
54
|
+
return Cipher(
|
|
55
|
+
algorithms.AES(key),
|
|
56
|
+
modes.CTR(iv),
|
|
57
|
+
backend=default_backend()
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class EnvelopeMetadata:
|
|
63
|
+
"""加密文件元数据"""
|
|
64
|
+
encrypted_key: str # 加密后的数据密钥(Base64)
|
|
65
|
+
password: str # 用于解密 DEK 的密码
|
|
66
|
+
|
|
67
|
+
def to_bytes(self) -> bytes:
|
|
68
|
+
"""序列化为 JSON 字节"""
|
|
69
|
+
return json.dumps({
|
|
70
|
+
"encryptedKey": self.encrypted_key,
|
|
71
|
+
"password": self.password
|
|
72
|
+
}).encode("utf-8")
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def from_bytes(cls, data: bytes) -> "EnvelopeMetadata":
|
|
76
|
+
"""从 JSON 字节反序列化"""
|
|
77
|
+
obj = json.loads(data.decode("utf-8"))
|
|
78
|
+
return cls(
|
|
79
|
+
encrypted_key=obj["encryptedKey"],
|
|
80
|
+
password=obj["password"]
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _envelope_encrypt_file(
|
|
85
|
+
input_path: Union[str, Path],
|
|
86
|
+
output_path: Union[str, Path],
|
|
87
|
+
password: str,
|
|
88
|
+
data_key: DataKey,
|
|
89
|
+
) -> None:
|
|
90
|
+
"""
|
|
91
|
+
使用信封加密对文件进行加密 (AES-256-CTR)
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
input_path: 明文文件路径
|
|
95
|
+
output_path: 加密文件输出路径
|
|
96
|
+
password: 用于 KMS 密钥派生的密码
|
|
97
|
+
kms_client: KMS 客户端实例(可选,不提供则使用默认客户端)
|
|
98
|
+
"""
|
|
99
|
+
input_path = Path(input_path)
|
|
100
|
+
output_path = Path(output_path)
|
|
101
|
+
|
|
102
|
+
iv = os.urandom(IV_SIZE)
|
|
103
|
+
cipher = create_aes_ctr_cipher(data_key.plaintext_key, iv)
|
|
104
|
+
encryptor = cipher.encryptor()
|
|
105
|
+
with _raw_open(input_path, "rb") as f:
|
|
106
|
+
plaintext = f.read()
|
|
107
|
+
|
|
108
|
+
ciphertext = encryptor.update(plaintext) + encryptor.finalize()
|
|
109
|
+
metadata = EnvelopeMetadata(
|
|
110
|
+
encrypted_key=data_key.encrypted_key,
|
|
111
|
+
password=password
|
|
112
|
+
)
|
|
113
|
+
metadata_bytes = metadata.to_bytes()
|
|
114
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
115
|
+
with _raw_open(output_path, "wb") as f:
|
|
116
|
+
f.write(len(metadata_bytes).to_bytes(METADATA_LENGTH_SIZE, "big"))
|
|
117
|
+
f.write(metadata_bytes)
|
|
118
|
+
f.write(iv)
|
|
119
|
+
f.write(ciphertext)
|
|
120
|
+
del data_key.plaintext_key
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def envelope_enc_file(
|
|
124
|
+
source: Union[Path, str],
|
|
125
|
+
*,
|
|
126
|
+
password: str,
|
|
127
|
+
data_key: DataKey,
|
|
128
|
+
dest: Optional[Union[Path, str]] = None,
|
|
129
|
+
replace: bool = False,
|
|
130
|
+
chunked: bool = True,
|
|
131
|
+
chunk_size: int = CHUNK_SIZE, # 默认 64KB,与 envelope_crypto 一致
|
|
132
|
+
) -> Path:
|
|
133
|
+
"""使用信封加密模式加密单个文件 (AES-256-CTR)。
|
|
134
|
+
|
|
135
|
+
信封加密使用 KMS 服务生成数据密钥,数据密钥用于加密文件内容,
|
|
136
|
+
加密后的数据密钥存储在文件头部。
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
source: 源文件路径
|
|
140
|
+
password: 用于 KMS 密钥派生的密码
|
|
141
|
+
dest: 目标文件路径(默认添加 .encrypted 后缀)
|
|
142
|
+
replace: 是否原地加密(替换原文件)
|
|
143
|
+
chunked: 是否使用流式加密(适用于大文件,减少内存占用)
|
|
144
|
+
chunk_size: 流式加密时每次处理的块大小
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
加密后的文件路径
|
|
148
|
+
"""
|
|
149
|
+
src = Path(source)
|
|
150
|
+
if not src.exists() or not src.is_file():
|
|
151
|
+
raise FileNotFoundError(f"Source file not found: {src}")
|
|
152
|
+
|
|
153
|
+
# 确定输出路径
|
|
154
|
+
if replace:
|
|
155
|
+
dst = _make_temp_path(src.parent, ".encrypted.tmp")
|
|
156
|
+
elif dest:
|
|
157
|
+
dst = Path(dest)
|
|
158
|
+
else:
|
|
159
|
+
dst = src.with_suffix(src.suffix + ".encrypted")
|
|
160
|
+
|
|
161
|
+
dst.parent.mkdir(parents=True, exist_ok=True)
|
|
162
|
+
|
|
163
|
+
if chunked:
|
|
164
|
+
_envelope_encrypt_file_streaming(src, dst, password, data_key, chunk_size)
|
|
165
|
+
else:
|
|
166
|
+
_envelope_encrypt_file(src, dst, password, data_key)
|
|
167
|
+
|
|
168
|
+
# 替换模式:移动加密文件到原位置
|
|
169
|
+
if replace:
|
|
170
|
+
shutil.move(str(src), str(dst))
|
|
171
|
+
return src
|
|
172
|
+
|
|
173
|
+
return dst
|
|
174
|
+
|
|
175
|
+
def _make_temp_path(parent: Path, suffix: str) -> Path:
|
|
176
|
+
"""创建临时文件路径(使用 mkstemp 避免竞态条件)"""
|
|
177
|
+
fd, path = tempfile.mkstemp(suffix=suffix, dir=parent)
|
|
178
|
+
os.close(fd)
|
|
179
|
+
return Path(path)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _envelope_encrypt_file_streaming(
|
|
186
|
+
input_path: Union[str, Path],
|
|
187
|
+
output_path: Union[str, Path],
|
|
188
|
+
password: str,
|
|
189
|
+
data_key: DataKey,
|
|
190
|
+
chunk_size: int = CHUNK_SIZE,
|
|
191
|
+
) -> None:
|
|
192
|
+
"""
|
|
193
|
+
使用信封加密对大文件进行流式加密 (AES-256-CTR)
|
|
194
|
+
|
|
195
|
+
AES-CTR 模式天然支持流式加密,不需要将整个文件加载到内存。
|
|
196
|
+
加密后的文件格式与普通模式完全相同,可以用普通模式解密。
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
input_path: 明文文件路径
|
|
200
|
+
output_path: 加密文件输出路径
|
|
201
|
+
password: 用于 KMS 密钥派生的密码
|
|
202
|
+
kms_client: KMS 客户端实例
|
|
203
|
+
chunk_size: 每次读取的块大小
|
|
204
|
+
"""
|
|
205
|
+
input_path = Path(input_path)
|
|
206
|
+
output_path = Path(output_path)
|
|
207
|
+
iv = os.urandom(IV_SIZE)
|
|
208
|
+
metadata = EnvelopeMetadata(
|
|
209
|
+
encrypted_key=data_key.encrypted_key,
|
|
210
|
+
password=password
|
|
211
|
+
)
|
|
212
|
+
metadata_bytes = metadata.to_bytes()
|
|
213
|
+
cipher = create_aes_ctr_cipher(data_key.plaintext_key, iv)
|
|
214
|
+
encryptor = cipher.encryptor()
|
|
215
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
216
|
+
with _raw_open(input_path, "rb") as fin, _raw_open(output_path, "wb") as fout:
|
|
217
|
+
# 写入元数据
|
|
218
|
+
fout.write(len(metadata_bytes).to_bytes(METADATA_LENGTH_SIZE, "big"))
|
|
219
|
+
fout.write(metadata_bytes)
|
|
220
|
+
# 写入 IV
|
|
221
|
+
fout.write(iv)
|
|
222
|
+
|
|
223
|
+
# 流式加密
|
|
224
|
+
while True:
|
|
225
|
+
chunk = fin.read(chunk_size)
|
|
226
|
+
if not chunk:
|
|
227
|
+
break
|
|
228
|
+
encrypted_chunk = encryptor.update(chunk)
|
|
229
|
+
fout.write(encrypted_chunk)
|
|
230
|
+
|
|
231
|
+
# 完成加密
|
|
232
|
+
final = encryptor.finalize()
|
|
233
|
+
if final:
|
|
234
|
+
fout.write(final)
|
|
235
|
+
|
|
236
|
+
del data_key.plaintext_key
|
|
237
|
+
|
xiaoshiai_hub/upload.py
CHANGED
|
@@ -12,10 +12,10 @@ from typing import List, Optional, Union, Dict
|
|
|
12
12
|
import requests
|
|
13
13
|
|
|
14
14
|
from xiaoshiai_hub.client import DEFAULT_BASE_URL, HubClient
|
|
15
|
+
from xiaoshiai_hub.envelope_crypto import DataKey, envelope_enc_file
|
|
15
16
|
from .exceptions import HubException, AuthenticationError, RepositoryNotFoundError
|
|
16
17
|
|
|
17
|
-
|
|
18
|
-
from key_manager import KeyManager
|
|
18
|
+
|
|
19
19
|
|
|
20
20
|
try:
|
|
21
21
|
from tqdm.auto import tqdm
|
|
@@ -40,7 +40,7 @@ def _build_api_url(
|
|
|
40
40
|
) -> str:
|
|
41
41
|
"""Build API upload URL."""
|
|
42
42
|
base_url = (base_url or DEFAULT_BASE_URL).rstrip('/')
|
|
43
|
-
return f"{base_url}/{organization}/{repo_type}/{repo_name}/api/upload"
|
|
43
|
+
return f"{base_url}/moha/{organization}/{repo_type}/{repo_name}/api/upload"
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
def _create_session(
|
|
@@ -96,7 +96,8 @@ class _TqdmUploadWrapper:
|
|
|
96
96
|
def __enter__(self):
|
|
97
97
|
return self
|
|
98
98
|
|
|
99
|
-
def __exit__(self,
|
|
99
|
+
def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore[no-untyped-def]
|
|
100
|
+
del exc_type, exc_val, exc_tb # Unused but required by context manager protocol
|
|
100
101
|
if self.pbar is not None:
|
|
101
102
|
self.pbar.close()
|
|
102
103
|
self.file_obj.close()
|
|
@@ -123,6 +124,7 @@ def _upload_file_with_progress(
|
|
|
123
124
|
def _encrypt_file_if_needed(
|
|
124
125
|
file_path: Path,
|
|
125
126
|
encryption_password: Optional[str] = None,
|
|
127
|
+
data_key: Optional[DataKey] = None,
|
|
126
128
|
) -> tuple[Optional[Path], Optional[Path]]:
|
|
127
129
|
"""
|
|
128
130
|
Encrypt file if encryption_password is provided, file is large enough, and file extension is encryptable.
|
|
@@ -130,12 +132,13 @@ def _encrypt_file_if_needed(
|
|
|
130
132
|
Args:
|
|
131
133
|
file_path: Path to the file to encrypt
|
|
132
134
|
encryption_password: Password for encryption
|
|
135
|
+
data_key: DataKey for encryption (required if encryption_password is provided)
|
|
133
136
|
|
|
134
137
|
Returns:
|
|
135
138
|
Tuple of (encrypted_file_path, temp_dir_path)
|
|
136
139
|
If no encryption, returns (None, None)
|
|
137
140
|
"""
|
|
138
|
-
if not encryption_password:
|
|
141
|
+
if not encryption_password or not data_key:
|
|
139
142
|
return None, None
|
|
140
143
|
|
|
141
144
|
# Check file size first (only encrypt files >= 5MB)
|
|
@@ -147,21 +150,18 @@ def _encrypt_file_if_needed(
|
|
|
147
150
|
if file_path.suffix.lower() not in ENCRYPTABLE_EXTENSIONS:
|
|
148
151
|
return None, None
|
|
149
152
|
|
|
150
|
-
# Generate encryption key from password
|
|
151
|
-
encryption_key = KeyManager.generate_key(encryption_password)
|
|
152
|
-
|
|
153
|
-
# Create temporary directory for encrypted file
|
|
154
153
|
temp_dir = Path(tempfile.mkdtemp())
|
|
155
154
|
encrypted_file = temp_dir / file_path.name
|
|
156
|
-
manifest_path = temp_dir / "xpai_encryption_manifest.enc"
|
|
157
155
|
|
|
158
156
|
# Encrypt the file
|
|
159
|
-
|
|
157
|
+
envelope_enc_file(
|
|
160
158
|
source=file_path,
|
|
159
|
+
password=encryption_password,
|
|
160
|
+
data_key=data_key,
|
|
161
161
|
dest=encrypted_file,
|
|
162
|
-
|
|
163
|
-
manifest_path=manifest_path,
|
|
162
|
+
chunked=True,
|
|
164
163
|
)
|
|
164
|
+
print(f"Encrypted file path: {encrypted_file}")
|
|
165
165
|
|
|
166
166
|
return encrypted_file, temp_dir
|
|
167
167
|
|
|
@@ -349,15 +349,10 @@ def upload_folder(
|
|
|
349
349
|
if '.git' not in ignore_patterns and '.git/' not in ignore_patterns:
|
|
350
350
|
ignore_patterns.append('.git')
|
|
351
351
|
ignore_patterns.append('.gitattributes')
|
|
352
|
-
|
|
353
|
-
# Prepare encryption if needed
|
|
354
|
-
encryption_key = None
|
|
355
|
-
if encryption_password:
|
|
356
|
-
encryption_key = KeyManager.generate_key(encryption_password)
|
|
357
|
-
|
|
358
352
|
# Create or use temporary directory for encrypted files
|
|
359
353
|
temp_dir_path: Optional[Path] = None
|
|
360
|
-
|
|
354
|
+
data_key: Optional[DataKey] = None
|
|
355
|
+
if encryption_password:
|
|
361
356
|
if temp_dir:
|
|
362
357
|
# Use user-specified temp directory
|
|
363
358
|
temp_dir_path = Path(temp_dir)
|
|
@@ -365,6 +360,8 @@ def upload_folder(
|
|
|
365
360
|
else:
|
|
366
361
|
# Auto-create temp directory
|
|
367
362
|
temp_dir_path = Path(tempfile.mkdtemp())
|
|
363
|
+
# Generate data key for encryption
|
|
364
|
+
data_key = client.generate_data_key(encryption_password)
|
|
368
365
|
|
|
369
366
|
try:
|
|
370
367
|
# Create session and API URL
|
|
@@ -381,6 +378,7 @@ def upload_folder(
|
|
|
381
378
|
files_to_upload = []
|
|
382
379
|
large_files = [] # Files >= 5MB
|
|
383
380
|
small_files = [] # Files < 5MB
|
|
381
|
+
has_encrypted_files = False # Track if any file was encrypted
|
|
384
382
|
for root, dirs, files in os.walk(folder_path):
|
|
385
383
|
# Filter directories
|
|
386
384
|
dirs[:] = [d for d in dirs if not any(
|
|
@@ -405,19 +403,22 @@ def upload_folder(
|
|
|
405
403
|
|
|
406
404
|
# Encrypt file if needed (only for large files with specific extensions)
|
|
407
405
|
actual_file = local_file
|
|
408
|
-
if (
|
|
406
|
+
if (encryption_password and temp_dir_path and data_key and
|
|
409
407
|
file_size >= 5 * 1024 * 1024 and # Only encrypt files >= 5MB
|
|
410
408
|
local_file.suffix.lower() in ENCRYPTABLE_EXTENSIONS):
|
|
411
409
|
# Preserve directory structure in temp dir
|
|
412
410
|
encrypted_file = temp_dir_path / rel_file_path
|
|
413
411
|
encrypted_file.parent.mkdir(parents=True, exist_ok=True)
|
|
414
|
-
|
|
412
|
+
envelope_enc_file(
|
|
415
413
|
source=local_file,
|
|
414
|
+
password=encryption_password,
|
|
415
|
+
data_key=data_key,
|
|
416
416
|
dest=encrypted_file,
|
|
417
|
-
|
|
418
|
-
manifest_path=temp_dir_path / "xpai_encryption_manifest.enc",
|
|
417
|
+
chunked=True,
|
|
419
418
|
)
|
|
419
|
+
print(f"Encrypted file path: {encrypted_file}")
|
|
420
420
|
actual_file = encrypted_file
|
|
421
|
+
has_encrypted_files = True
|
|
421
422
|
|
|
422
423
|
# Determine upload method based on size
|
|
423
424
|
if file_size >= 5 * 1024 * 1024: # 5MB
|
|
@@ -441,24 +442,6 @@ def upload_folder(
|
|
|
441
442
|
})
|
|
442
443
|
small_files.append((rel_file_path, file_size))
|
|
443
444
|
|
|
444
|
-
# Add encryption manifest file if it exists
|
|
445
|
-
if temp_dir_path:
|
|
446
|
-
manifest_file = temp_dir_path / "xpai_encryption_manifest.enc"
|
|
447
|
-
if manifest_file.exists():
|
|
448
|
-
manifest_size = manifest_file.stat().st_size
|
|
449
|
-
# Read manifest file content
|
|
450
|
-
with open(manifest_file, 'rb') as f:
|
|
451
|
-
manifest_content = f.read()
|
|
452
|
-
manifest_b64 = base64.b64encode(manifest_content).decode('utf-8')
|
|
453
|
-
|
|
454
|
-
# Add to files to upload
|
|
455
|
-
files_to_upload.append({
|
|
456
|
-
"path": "xpai_encryption_manifest.enc",
|
|
457
|
-
"content": manifest_b64,
|
|
458
|
-
"size": manifest_size,
|
|
459
|
-
})
|
|
460
|
-
small_files.append(("xpai_encryption_manifest.enc", manifest_size))
|
|
461
|
-
|
|
462
445
|
print(f"Found {len(files_to_upload)} files to upload ({len(large_files)} large files, {len(small_files)} small files)")
|
|
463
446
|
# Show progress for small files upload
|
|
464
447
|
small_files_pbar = None
|
|
@@ -475,7 +458,6 @@ def upload_folder(
|
|
|
475
458
|
# Upload files via API
|
|
476
459
|
try:
|
|
477
460
|
result = _upload_files_via_api(session, api_url, files_to_upload, commit_message, revision)
|
|
478
|
-
|
|
479
461
|
# Update progress bar for small files
|
|
480
462
|
if small_files_pbar is not None:
|
|
481
463
|
for _, size in small_files:
|
|
@@ -499,6 +481,14 @@ def upload_folder(
|
|
|
499
481
|
)
|
|
500
482
|
print(f"Upload completed: {remote_path}")
|
|
501
483
|
|
|
484
|
+
# Set repository encrypted flag if any file was encrypted
|
|
485
|
+
if has_encrypted_files:
|
|
486
|
+
try:
|
|
487
|
+
client.set_repository_encrypted(organization, repo_type, repo_name)
|
|
488
|
+
print(f"Set repository encrypted flag for {repo_id}")
|
|
489
|
+
except Exception as e:
|
|
490
|
+
print(f"Warning: Failed to set repository encrypted flag: {e}")
|
|
491
|
+
|
|
502
492
|
print(f"Successfully uploaded to {repo_id}")
|
|
503
493
|
return result
|
|
504
494
|
finally:
|
|
@@ -565,9 +555,15 @@ def upload_file(
|
|
|
565
555
|
if not path_file.is_file():
|
|
566
556
|
raise ValueError(f"Path is not a file: {path_file}")
|
|
567
557
|
|
|
558
|
+
# Generate data key if encryption is needed
|
|
559
|
+
data_key: Optional[DataKey] = None
|
|
560
|
+
if encryption_password:
|
|
561
|
+
data_key = client.generate_data_key(encryption_password)
|
|
562
|
+
|
|
568
563
|
# Encrypt file if needed
|
|
569
|
-
encrypted_file, temp_dir = _encrypt_file_if_needed(path_file, encryption_password)
|
|
564
|
+
encrypted_file, temp_dir = _encrypt_file_if_needed(path_file, encryption_password, data_key)
|
|
570
565
|
actual_file = encrypted_file if encrypted_file else path_file
|
|
566
|
+
file_was_encrypted = encrypted_file is not None
|
|
571
567
|
|
|
572
568
|
try:
|
|
573
569
|
# Create session and API URL
|
|
@@ -582,15 +578,6 @@ def upload_file(
|
|
|
582
578
|
|
|
583
579
|
# Check file size
|
|
584
580
|
file_size = actual_file.stat().st_size
|
|
585
|
-
|
|
586
|
-
# Check if encryption manifest file exists
|
|
587
|
-
manifest_file = None
|
|
588
|
-
if temp_dir:
|
|
589
|
-
manifest_path = temp_dir / "xpai_encryption_manifest.enc"
|
|
590
|
-
if manifest_path.exists():
|
|
591
|
-
manifest_file = manifest_path
|
|
592
|
-
print(f"Found encryption manifest file: xpai_encryption_manifest.enc")
|
|
593
|
-
|
|
594
581
|
# Upload main file
|
|
595
582
|
if file_size >= 5 * 1024 * 1024: # 5MB
|
|
596
583
|
# Large file
|
|
@@ -599,16 +586,13 @@ def upload_file(
|
|
|
599
586
|
# Small file
|
|
600
587
|
result = _upload_small_file(session, api_url, actual_file, path_in_repo, commit_message, revision)
|
|
601
588
|
|
|
602
|
-
#
|
|
603
|
-
if
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
"
|
|
609
|
-
f"{commit_message} (manifest)",
|
|
610
|
-
revision
|
|
611
|
-
)
|
|
589
|
+
# Set repository encrypted flag if file was encrypted
|
|
590
|
+
if file_was_encrypted:
|
|
591
|
+
try:
|
|
592
|
+
client.set_repository_encrypted(organization, repo_type, repo_name)
|
|
593
|
+
print(f"Set repository encrypted flag for {repo_id}")
|
|
594
|
+
except Exception as e:
|
|
595
|
+
print(f"Warning: Failed to set repository encrypted flag: {e}")
|
|
612
596
|
|
|
613
597
|
print(f"Successfully uploaded {path_in_repo} to {repo_id}")
|
|
614
598
|
return result
|