xiaoshiai-hub 1.1.1__py3-none-any.whl → 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xiaoshiai_hub/client.py CHANGED
@@ -8,7 +8,7 @@ from typing import Dict, List, Optional
8
8
 
9
9
 
10
10
  import requests
11
- from xiaoshiai_hub.envelope_crypto import DataKey
11
+ from xiaoshiai_hub.envelope_crypto import DEFAULT_ALGORITHM, Algorithm, DataKey
12
12
 
13
13
  try:
14
14
  from tqdm.auto import tqdm
@@ -110,8 +110,186 @@ class HubClient:
110
110
  url = f"{self.base_url}/moha/v1/organizations/{organization}/{repo_type}/{repo_name}/encryption/set"
111
111
  response = self._make_request("PUT", url)
112
112
  if response.status_code != 200:
113
- raise HTTPError(f"Failed to set repository encrypted flag: {response.text}")
113
+ raise HTTPError(f"Failed to set repository encrypted flag: {response.text}")
114
+
115
+ # 创建分支
116
+ def create_branch(
117
+ self,
118
+ organization: str,
119
+ repo_type: str,
120
+ repo_name: str,
121
+ branch_name: str,
122
+ from_branch: str,
123
+ ) -> None:
124
+ """
125
+ 创建新分支。如果分支已存在,则直接返回。
126
+
127
+ Args:
128
+ organization: 组织名称
129
+ repo_type: 仓库类型 ("models" 或 "datasets")
130
+ repo_name: 仓库名称
131
+ branch_name: 要创建的分支名称
132
+ from_branch: 基于哪个分支创建
133
+ """
134
+ # 先检查分支是否已存在
135
+ refs = self.get_repository_refs(organization, repo_type, repo_name)
136
+ for ref in refs:
137
+ if ref.name == branch_name:
138
+ # 分支已存在,直接返回
139
+ return
140
+
141
+ # 分支不存在,创建新分支
142
+ url = f"{self.base_url}/moha/v1/organizations/{organization}/{repo_type}/{repo_name}/refs/{branch_name}"
143
+ body = {
144
+ "base": from_branch,
145
+ }
146
+ response = self._make_request("POST", url, json=body)
147
+ if response.status_code != 200:
148
+ raise HTTPError(f"Failed to create branch: {response.text}")
149
+
150
+ # 删除分支
151
+ def delete_branch(
152
+ self,
153
+ organization: str,
154
+ repo_type: str,
155
+ repo_name: str,
156
+ branch_name: str,
157
+ ) -> None:
158
+ """
159
+ 删除分支。如果分支不存在,则直接返回。
160
+
161
+ Args:
162
+ organization: 组织名称
163
+ repo_type: 仓库类型 ("models" 或 "datasets")
164
+ repo_name: 仓库名称
165
+ branch_name: 要删除的分支名称
166
+ """
167
+ # 先检查分支是否存在
168
+ refs = self.get_repository_refs(organization, repo_type, repo_name)
169
+ branch_exists = False
170
+ for ref in refs:
171
+ if ref.name == branch_name:
172
+ branch_exists = True
173
+ break
174
+
175
+ if not branch_exists:
176
+ # 分支不存在,直接返回
177
+ return
178
+
179
+ # 分支存在,删除它
180
+ url = f"{self.base_url}/moha/v1/organizations/{organization}/{repo_type}/{repo_name}/refs/{branch_name}"
181
+ response = self._make_request("DELETE", url)
182
+ if response.status_code != 200:
183
+ raise HTTPError(f"Failed to delete branch: {response.text}")
184
+
185
+
186
+ # 更新仓库,先获取仓库信息,然后更新
187
+ def update_repository(
188
+ self,
189
+ organization: str,
190
+ repo_type: str,
191
+ repo_name: str,
192
+ description: Optional[str] = None,
193
+ visibility: str = "internal",
194
+ annotations: Optional[Dict[str, str]] = None,
195
+ metadata: Optional[Dict[str, List[str]]] = None,
196
+ base_model: Optional[List[str]] = None,
197
+ relationship: Optional[str] = None,
198
+ ) -> None:
199
+ """Update repository information."""
200
+ url = f"{self.base_url}/moha/v1/organizations/{organization}/{repo_type}/{repo_name}"
201
+ body: Dict = {
202
+ "name": repo_name,
203
+ "organization": organization,
204
+ "type": repo_type,
205
+ }
206
+ if annotations:
207
+ body["annotations"] = annotations
208
+ if description:
209
+ body["description"] = description
210
+ if visibility:
211
+ body["visibility"] = visibility
212
+ if metadata:
213
+ body["metadata"] = metadata
214
+ if base_model or relationship:
215
+ body["requestgenealogy"] = {}
216
+ if base_model:
217
+ body["requestgenealogy"]["baseModel"] = base_model
218
+ if relationship:
219
+ body["requestgenealogy"]["relationship"] = relationship
220
+ response = self._make_request("PUT", url, json=body)
221
+ if response.status_code != 200:
222
+ raise HTTPError(f"Failed to update repository: {response.text}")
114
223
 
224
+ # 删除仓库
225
+ def delete_repository(
226
+ self,
227
+ organization: str,
228
+ repo_type: str,
229
+ repo_name: str,
230
+ ) -> None:
231
+ """Delete repository."""
232
+ url = f"{self.base_url}/moha/v1/organizations/{organization}/{repo_type}/{repo_name}"
233
+ response = self._make_request("DELETE", url)
234
+ if response.status_code != 200:
235
+ raise HTTPError(f"Failed to delete repository: {response.text}")
236
+
237
+ # 创建仓库
238
+ def create_repository(
239
+ self,
240
+ organization: str,
241
+ repo_type: str,
242
+ repo_name: str,
243
+ description: Optional[str] = None,
244
+ visibility: str = "internal",
245
+ metadata: Optional[Dict[str, List[str]]] = None,
246
+ base_model: Optional[List[str]] = None,
247
+ relationship: Optional[str] = None,
248
+ ) -> None:
249
+ """
250
+ 创建新仓库。
251
+
252
+ Args:
253
+ organization: 组织名称
254
+ repo_type: 仓库类型 ("models" 或 "datasets")
255
+ repo_name: 仓库名称
256
+ description: 仓库描述
257
+ visibility: 可见性 ("public", "internal", "private")
258
+ metadata: 元数据,包含 license, tasks, languages, tags, frameworks 等
259
+ annotations: 注解
260
+ base_model: 基础模型列表 (如 ["demo/yyyy"])
261
+ relationship: 与基础模型的关系 ("adapter", "finetune", "quantized", "merge", "repackage" 等)
262
+
263
+ Returns:
264
+ 创建的仓库信息
265
+ """
266
+ url = f"{self.base_url}/moha/v1/organizations/{organization}/{repo_type}"
267
+
268
+ # 构建请求体
269
+ body: Dict = {
270
+ "name": repo_name,
271
+ "organization": organization,
272
+ "visibility": visibility,
273
+ }
274
+
275
+ if description:
276
+ body["description"] = description
277
+ if metadata:
278
+ body["metadata"] = metadata
279
+
280
+ # 构建 genealogy(模型谱系)
281
+ if base_model or relationship:
282
+ body["requestgenealogy"] = {}
283
+ if base_model:
284
+ body["requestgenealogy"]["baseModel"] = base_model
285
+ if relationship:
286
+ body["requestgenealogy"]["relationship"] = relationship
287
+
288
+ response = self._make_request("POST", url, json=body)
289
+ if response.status_code != 200:
290
+ raise HTTPError(f"Failed to create repository: {response.text}")
291
+
292
+ # 获取仓库信息
115
293
  def get_repository_info(
116
294
  self,
117
295
  organization: str,
@@ -143,15 +321,25 @@ class HubClient:
143
321
  if 'metadata' in data and isinstance(data['metadata'], dict):
144
322
  metadata = data['metadata']
145
323
 
324
+ # Parse genealogy if present
325
+ genealogy: Optional[Dict] = None
326
+ if 'genealogy' in data and isinstance(data['genealogy'], dict):
327
+ genealogy = data['genealogy']
328
+
146
329
  return Repository(
147
330
  name=data.get('name', repo_name),
148
- organization=organization,
149
- type=repo_type,
331
+ organization=data.get('organization', organization),
332
+ owner=data.get('owner', ''),
333
+ creator=data.get('creator', ''),
334
+ type=data.get('type', repo_type),
335
+ visibility=data.get('visibility', 'internal'),
336
+ genealogy=genealogy,
150
337
  description=data.get('description'),
151
338
  metadata=metadata,
152
339
  annotations=annotations,
153
340
  )
154
341
 
342
+ # 获取仓库的所有分支
155
343
  def get_repository_refs(
156
344
  self,
157
345
  organization: str,
@@ -338,12 +526,18 @@ class HubClient:
338
526
  progress_bar.close()
339
527
 
340
528
 
341
- def generate_data_key(self, password: Optional[str] = None) -> DataKey:
529
+ def generate_data_key(
530
+ self,
531
+ algorithm : Optional[str] = DEFAULT_ALGORITHM,
532
+ password: Optional[str] = None) -> DataKey:
342
533
  url = f"{self.base_url}/api/kms/generate-data-key"
343
534
  try:
344
535
  resp = requests.post(
345
536
  url,
346
- json={"password": password},
537
+ json={
538
+ "algorithm": algorithm,
539
+ "password": password,
540
+ },
347
541
  timeout=30
348
542
  )
349
543
  resp.raise_for_status()
@@ -1,20 +1,25 @@
1
1
  """
2
2
  信封加密核心模块
3
3
 
4
- 实现基于 AES-256-CTR 的信封加密/解密逻辑。
4
+ 实现基于 AES-256-CTR 和 SM4-CTR 的信封加密/解密逻辑。
5
5
  支持随机访问和流式解密,适合大文件部分读取场景。
6
6
 
7
7
  信封加密文件格式:
8
8
  - 前 4 字节: 元数据长度 (big-endian)
9
- - 元数据: JSON 格式,包含 encryptedKey 和 password
10
- - 16 字节: IV (用于 AES-256-CTR)
9
+ - 元数据: JSON 格式,包含 encryptedKey 和 algorithm
10
+ - 16 字节: IV (用于 CTR 模式)
11
11
  - 剩余部分: 加密后的文件内容
12
+
13
+ 支持的算法:
14
+ - AES: 使用 32 字节密钥
15
+ - SM4: 使用 16 字节密钥
12
16
  """
13
17
 
14
18
  import io
15
19
  import os
16
20
  import json
17
21
  from dataclasses import dataclass
22
+ from enum import Enum
18
23
  from pathlib import Path
19
24
  import shutil
20
25
  import tempfile
@@ -24,17 +29,29 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
24
29
  from cryptography.hazmat.backends import default_backend
25
30
 
26
31
 
27
- # IV 长度 (AES-256-CTR)
32
+ # IV 长度 (AES-256-CTR 和 SM4-CTR 都使用 16 字节 IV)
28
33
  IV_SIZE = 16
29
34
  # 元数据长度字段大小
30
35
  METADATA_LENGTH_SIZE = 4
31
36
  CHUNK_SIZE = 64 * 1024 # 64KB chunks
32
37
 
38
+
39
+ class Algorithm(str, Enum):
40
+ """支持的加密算法"""
41
+ AES = "AES"
42
+ SM4 = "SM4"
43
+
44
+
45
+ # 默认算法
46
+ DEFAULT_ALGORITHM = Algorithm.AES
47
+
48
+
33
49
  @dataclass
34
50
  class DataKey:
35
51
  """数据密钥对象,包含明文密钥和加密后的密钥"""
36
- plaintext_key: bytes # 明文密钥(32字节,用于 AES-256
37
- encrypted_key: str
52
+ plaintext_key: bytes # 明文密钥(AES-256: 32字节,SM4: 16字节)
53
+ encrypted_key: str
54
+
38
55
 
39
56
  def _raw_open(path, mode="rb"):
40
57
  """
@@ -49,26 +66,51 @@ def _raw_open(path, mode="rb"):
49
66
  return file_io
50
67
 
51
68
 
52
- def create_aes_ctr_cipher(key: bytes, iv: bytes):
53
- """创建 AES-256-CTR cipher"""
54
- return Cipher(
55
- algorithms.AES(key),
56
- modes.CTR(iv),
57
- backend=default_backend()
58
- )
69
+ def create_cipher(key: bytes, iv: bytes, algorithm: Algorithm = DEFAULT_ALGORITHM):
70
+ """
71
+ 创建加密 cipher
72
+
73
+ Args:
74
+ key: 密钥(AES-256: 32字节,SM4: 16字节)
75
+ iv: 初始化向量(16字节)
76
+ algorithm: 加密算法
77
+
78
+ Returns:
79
+ Cipher 对象
80
+ """
81
+ if algorithm == Algorithm.AES:
82
+ if len(key) != 32:
83
+ raise ValueError(f"AES requires 32-byte key, got {len(key)} bytes")
84
+ return Cipher(
85
+ algorithms.AES(key),
86
+ modes.CTR(iv),
87
+ backend=default_backend()
88
+ )
89
+ elif algorithm == Algorithm.SM4:
90
+ if len(key) != 16:
91
+ raise ValueError(f"SM4 requires 16-byte key, got {len(key)} bytes")
92
+ return Cipher(
93
+ algorithms.SM4(key),
94
+ modes.CTR(iv),
95
+ backend=default_backend()
96
+ )
97
+ else:
98
+ raise ValueError(f"Unsupported algorithm: {algorithm}")
99
+
100
+
59
101
 
60
102
 
61
103
  @dataclass
62
104
  class EnvelopeMetadata:
63
105
  """加密文件元数据"""
64
106
  encrypted_key: str # 加密后的数据密钥(Base64)
65
- password: str # 用于解密 DEK 的密码
107
+ algorithm: str # 加密算法名称
66
108
 
67
109
  def to_bytes(self) -> bytes:
68
110
  """序列化为 JSON 字节"""
69
111
  return json.dumps({
70
112
  "encryptedKey": self.encrypted_key,
71
- "password": self.password
113
+ "algorithm": self.algorithm
72
114
  }).encode("utf-8")
73
115
 
74
116
  @classmethod
@@ -77,30 +119,30 @@ class EnvelopeMetadata:
77
119
  obj = json.loads(data.decode("utf-8"))
78
120
  return cls(
79
121
  encrypted_key=obj["encryptedKey"],
80
- password=obj["password"]
122
+ algorithm=obj.get("algorithm", Algorithm.AES.value) # 兼容旧格式
81
123
  )
82
124
 
83
125
 
84
126
  def _envelope_encrypt_file(
85
127
  input_path: Union[str, Path],
86
128
  output_path: Union[str, Path],
87
- password: str,
88
129
  data_key: DataKey,
130
+ algorithm: Algorithm = DEFAULT_ALGORITHM,
89
131
  ) -> None:
90
132
  """
91
- 使用信封加密对文件进行加密 (AES-256-CTR)
133
+ 使用信封加密对文件进行加密
92
134
 
93
135
  Args:
94
136
  input_path: 明文文件路径
95
137
  output_path: 加密文件输出路径
96
- password: 用于 KMS 密钥派生的密码
97
- kms_client: KMS 客户端实例(可选,不提供则使用默认客户端)
138
+ data_key: 数据密钥对象
139
+ algorithm: 加密算法(默认 AES)
98
140
  """
99
141
  input_path = Path(input_path)
100
142
  output_path = Path(output_path)
101
143
 
102
144
  iv = os.urandom(IV_SIZE)
103
- cipher = create_aes_ctr_cipher(data_key.plaintext_key, iv)
145
+ cipher = create_cipher(data_key.plaintext_key, iv, algorithm)
104
146
  encryptor = cipher.encryptor()
105
147
  with _raw_open(input_path, "rb") as f:
106
148
  plaintext = f.read()
@@ -108,7 +150,7 @@ def _envelope_encrypt_file(
108
150
  ciphertext = encryptor.update(plaintext) + encryptor.finalize()
109
151
  metadata = EnvelopeMetadata(
110
152
  encrypted_key=data_key.encrypted_key,
111
- password=password
153
+ algorithm=algorithm.value
112
154
  )
113
155
  metadata_bytes = metadata.to_bytes()
114
156
  output_path.parent.mkdir(parents=True, exist_ok=True)
@@ -123,25 +165,26 @@ def _envelope_encrypt_file(
123
165
  def envelope_enc_file(
124
166
  source: Union[Path, str],
125
167
  *,
126
- password: str,
127
168
  data_key: DataKey,
128
169
  dest: Optional[Union[Path, str]] = None,
129
170
  replace: bool = False,
130
171
  chunked: bool = True,
131
- chunk_size: int = CHUNK_SIZE, # 默认 64KB,与 envelope_crypto 一致
172
+ chunk_size: int = CHUNK_SIZE,
173
+ algorithm: Algorithm = DEFAULT_ALGORITHM,
132
174
  ) -> Path:
133
- """使用信封加密模式加密单个文件 (AES-256-CTR)。
175
+ """使用信封加密模式加密单个文件。
134
176
 
135
177
  信封加密使用 KMS 服务生成数据密钥,数据密钥用于加密文件内容,
136
178
  加密后的数据密钥存储在文件头部。
137
179
 
138
180
  Args:
139
181
  source: 源文件路径
140
- password: 用于 KMS 密钥派生的密码
182
+ data_key: 数据密钥对象
141
183
  dest: 目标文件路径(默认添加 .encrypted 后缀)
142
184
  replace: 是否原地加密(替换原文件)
143
185
  chunked: 是否使用流式加密(适用于大文件,减少内存占用)
144
186
  chunk_size: 流式加密时每次处理的块大小
187
+ algorithm: 加密算法(默认 AES,可选 SM4)
145
188
 
146
189
  Returns:
147
190
  加密后的文件路径
@@ -161,9 +204,9 @@ def envelope_enc_file(
161
204
  dst.parent.mkdir(parents=True, exist_ok=True)
162
205
 
163
206
  if chunked:
164
- _envelope_encrypt_file_streaming(src, dst, password, data_key, chunk_size)
207
+ _envelope_encrypt_file_streaming(src, dst, data_key, chunk_size, algorithm)
165
208
  else:
166
- _envelope_encrypt_file(src, dst, password, data_key)
209
+ _envelope_encrypt_file(src, dst, data_key, algorithm)
167
210
 
168
211
  # 替换模式:移动加密文件到原位置
169
212
  if replace:
@@ -185,32 +228,32 @@ def _make_temp_path(parent: Path, suffix: str) -> Path:
185
228
  def _envelope_encrypt_file_streaming(
186
229
  input_path: Union[str, Path],
187
230
  output_path: Union[str, Path],
188
- password: str,
189
231
  data_key: DataKey,
190
232
  chunk_size: int = CHUNK_SIZE,
233
+ algorithm: Algorithm = DEFAULT_ALGORITHM,
191
234
  ) -> None:
192
235
  """
193
- 使用信封加密对大文件进行流式加密 (AES-256-CTR)
236
+ 使用信封加密对大文件进行流式加密
194
237
 
195
- AES-CTR 模式天然支持流式加密,不需要将整个文件加载到内存。
238
+ CTR 模式天然支持流式加密,不需要将整个文件加载到内存。
196
239
  加密后的文件格式与普通模式完全相同,可以用普通模式解密。
197
240
 
198
241
  Args:
199
242
  input_path: 明文文件路径
200
243
  output_path: 加密文件输出路径
201
- password: 用于 KMS 密钥派生的密码
202
- kms_client: KMS 客户端实例
244
+ data_key: 数据密钥对象
203
245
  chunk_size: 每次读取的块大小
246
+ algorithm: 加密算法(默认 AES)
204
247
  """
205
248
  input_path = Path(input_path)
206
249
  output_path = Path(output_path)
207
250
  iv = os.urandom(IV_SIZE)
208
251
  metadata = EnvelopeMetadata(
209
252
  encrypted_key=data_key.encrypted_key,
210
- password=password
253
+ algorithm=algorithm.value
211
254
  )
212
255
  metadata_bytes = metadata.to_bytes()
213
- cipher = create_aes_ctr_cipher(data_key.plaintext_key, iv)
256
+ cipher = create_cipher(data_key.plaintext_key, iv, algorithm)
214
257
  encryptor = cipher.encryptor()
215
258
  output_path.parent.mkdir(parents=True, exist_ok=True)
216
259
  with _raw_open(input_path, "rb") as fin, _raw_open(output_path, "wb") as fout:
xiaoshiai_hub/types.py CHANGED
@@ -12,7 +12,11 @@ class Repository:
12
12
  """Repository information."""
13
13
  name: str
14
14
  organization: str
15
+ owner: str
16
+ creator: str
15
17
  type: str # "models" or "datasets"
18
+ visibility: str
19
+ genealogy: Optional[Dict] = None
16
20
  description: Optional[str] = None
17
21
  metadata: Dict[str, List[str]] = field(default_factory=dict) # Repository metadata
18
22
  annotations: Dict[str, str] = field(default_factory=dict) # Repository annotations/metadata