intellif-aihub 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of intellif-aihub might be problematic. Click here for more details.
- aihub/__init__.py +1 -1
- aihub/client.py +35 -12
- aihub/models/artifact.py +1 -1
- aihub/models/data_warehouse.py +95 -0
- aihub/models/dataset_management.py +99 -61
- aihub/models/document_center.py +26 -18
- aihub/models/eval.py +20 -11
- aihub/models/labelfree.py +12 -38
- aihub/models/model_center.py +141 -0
- aihub/models/model_training_platform.py +183 -149
- aihub/models/quota_schedule_management.py +201 -150
- aihub/models/tag_resource_management.py +30 -24
- aihub/models/task_center.py +39 -36
- aihub/models/user_system.py +159 -125
- aihub/models/workflow_center.py +461 -0
- aihub/services/artifact.py +22 -15
- aihub/services/data_warehouse.py +97 -0
- aihub/services/dataset_management.py +142 -23
- aihub/services/document_center.py +24 -5
- aihub/services/eval.py +14 -7
- aihub/services/labelfree.py +11 -0
- aihub/services/model_center.py +183 -0
- aihub/services/model_training_platform.py +99 -29
- aihub/services/quota_schedule_management.py +104 -7
- aihub/services/tag_resource_management.py +33 -2
- aihub/services/task_center.py +23 -9
- aihub/services/user_system.py +237 -2
- aihub/services/workflow_center.py +522 -0
- aihub/utils/download.py +19 -3
- {intellif_aihub-0.1.4.dist-info → intellif_aihub-0.1.6.dist-info}/METADATA +3 -3
- intellif_aihub-0.1.6.dist-info/RECORD +42 -0
- intellif_aihub-0.1.4.dist-info/RECORD +0 -36
- {intellif_aihub-0.1.4.dist-info → intellif_aihub-0.1.6.dist-info}/WHEEL +0 -0
- {intellif_aihub-0.1.4.dist-info → intellif_aihub-0.1.6.dist-info}/licenses/LICENSE +0 -0
- {intellif_aihub-0.1.4.dist-info → intellif_aihub-0.1.6.dist-info}/top_level.txt +0 -0
|
@@ -1,28 +1,48 @@
|
|
|
1
|
+
# !/usr/bin/env python
|
|
2
|
+
# -*- coding:utf-8 -*-
|
|
3
|
+
"""数据集管理服务模块
|
|
4
|
+
|
|
5
|
+
本模块围绕 **“数据集生命周期管理”** 提供以下能力:
|
|
6
|
+
|
|
7
|
+
- **创建数据集及其版本**(支持本地上传和服务器现有文件两种方式)
|
|
8
|
+
- **上传文件到对象存储**(大文件自动分片)
|
|
9
|
+
- **查询数据集/数据集版本详情**
|
|
10
|
+
- **按版本名称或ID下载数据集文件**
|
|
11
|
+
"""
|
|
12
|
+
|
|
1
13
|
from __future__ import annotations
|
|
2
14
|
|
|
3
15
|
import mimetypes
|
|
4
16
|
import os
|
|
5
17
|
import pathlib
|
|
18
|
+
import tempfile
|
|
19
|
+
import time
|
|
20
|
+
import uuid
|
|
6
21
|
|
|
7
22
|
import httpx
|
|
23
|
+
from loguru import logger
|
|
8
24
|
|
|
9
25
|
from ..exceptions import APIError
|
|
10
26
|
from ..models.common import APIWrapper
|
|
11
|
-
from ..models.dataset_management import
|
|
12
|
-
|
|
27
|
+
from ..models.dataset_management import (
|
|
28
|
+
CreateDatasetRequest,
|
|
29
|
+
CreateDatasetResponse,
|
|
30
|
+
DatasetDetail,
|
|
31
|
+
CreateDatasetVersionRequest,
|
|
32
|
+
CreateDatasetVersionResponse,
|
|
33
|
+
UploadDatasetVersionRequest,
|
|
34
|
+
DatasetVersionDetail,
|
|
35
|
+
UploadDatasetVersionResponse,
|
|
36
|
+
FileUploadData,
|
|
37
|
+
)
|
|
38
|
+
from ..models.dataset_management import DatasetVersionStatus
|
|
39
|
+
from ..utils.download import dataset_download, zip_dir
|
|
13
40
|
|
|
14
41
|
_BASE = "/dataset-mng/api/v2"
|
|
15
42
|
|
|
16
43
|
|
|
17
44
|
class DatasetManagementService:
|
|
18
|
-
"""数据集管理服务,用于数据集的上传、下载
|
|
19
|
-
|
|
20
|
-
Methods:
|
|
21
|
-
create_dataset_and_version: 创建数据集版本
|
|
22
|
-
run_download: 下载
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
"""
|
|
45
|
+
"""数据集管理服务,用于数据集的上传、下载"""
|
|
26
46
|
|
|
27
47
|
def __init__(self, http: httpx.Client):
|
|
28
48
|
self._dataset = _Dataset(http)
|
|
@@ -31,24 +51,80 @@ class DatasetManagementService:
|
|
|
31
51
|
|
|
32
52
|
# 直接把常用方法抛到一级,调用体验简单
|
|
33
53
|
def create_dataset(self, payload: CreateDatasetRequest) -> int:
|
|
54
|
+
"""创建数据集
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
payload (CreateDatasetRequest): 创建数据集所需信息,如名称、描述、可见性等
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
int: 新建数据集的 ``dataset_id``
|
|
61
|
+
"""
|
|
34
62
|
return self._dataset.create(payload)
|
|
35
63
|
|
|
36
64
|
def get_dataset(self, dataset_id: int) -> DatasetDetail:
|
|
65
|
+
"""获取数据集详情
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
dataset_id (int): 数据集 ID
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
DatasetDetail: 数据集完整信息(含所有版本元数据)
|
|
72
|
+
"""
|
|
37
73
|
return self._dataset.get(dataset_id)
|
|
38
74
|
|
|
39
75
|
def create_dataset_version(self, payload: CreateDatasetVersionRequest) -> int:
|
|
76
|
+
"""创建数据集版本
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
payload (CreateDatasetVersionRequest): 版本元信息
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
int: 新建版本的 ``version_id``。
|
|
83
|
+
"""
|
|
40
84
|
return self._dataset_version.create(payload)
|
|
41
85
|
|
|
42
86
|
def upload_dataset_version(self, payload: UploadDatasetVersionRequest) -> int:
|
|
87
|
+
"""上传数据集版本
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
payload (UploadDatasetVersionRequest): 上传请求,需包含本地文件已上传后的 OSS 路径等信息
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
int: 新建版本的 ``version_id``
|
|
94
|
+
"""
|
|
43
95
|
return self._dataset_version.upload(payload)
|
|
44
96
|
|
|
45
97
|
def get_dataset_version(self, version_id: int) -> DatasetVersionDetail:
|
|
98
|
+
"""获取数据集版本详情
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
version_id (int): 数据集版本 ID
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
DatasetVersionDetail: 版本详细信息
|
|
105
|
+
"""
|
|
46
106
|
return self._dataset_version.get(version_id)
|
|
47
107
|
|
|
48
108
|
def get_dataset_version_by_name(self, version_name: str) -> DatasetVersionDetail:
|
|
109
|
+
"""通过 “数据集名/版本号” 获取版本详情
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
version_name (str): 形如 ``<dataset_name>/V<version>`` 的唯一标识
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
DatasetVersionDetail: 版本详细信息
|
|
116
|
+
"""
|
|
49
117
|
return self._dataset_version.get_by_name(version_name)
|
|
50
118
|
|
|
51
119
|
def upload_file(self, file_path: str) -> FileUploadData:
|
|
120
|
+
"""上传本地文件到对象存储
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
file_path (str): 本地文件绝对路径
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
FileUploadData: 上传结果(包含 OSS 路径、下载 URL 等)
|
|
127
|
+
"""
|
|
52
128
|
return self._upload.upload_file(file_path)
|
|
53
129
|
|
|
54
130
|
# 如果想要访问子对象,也保留属性
|
|
@@ -61,16 +137,17 @@ class DatasetManagementService:
|
|
|
61
137
|
return self._dataset_version
|
|
62
138
|
|
|
63
139
|
def create_dataset_and_version(
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
140
|
+
self,
|
|
141
|
+
*,
|
|
142
|
+
dataset_name: str,
|
|
143
|
+
dataset_description: str = "",
|
|
144
|
+
is_local_upload: bool,
|
|
145
|
+
local_file_path: str | None = None,
|
|
146
|
+
server_file_path: str | None = None,
|
|
147
|
+
version_description: str = "",
|
|
148
|
+
timeout: int = 1_800,
|
|
72
149
|
) -> tuple[int, int, str]:
|
|
73
|
-
"""
|
|
150
|
+
"""创建数据集及其版本,并等待版本状态变为 *Success*。
|
|
74
151
|
|
|
75
152
|
根据参数创建数据集,并根据上传类型(本地或服务器路径)创建对应的数据集版本。
|
|
76
153
|
|
|
@@ -82,6 +159,7 @@ class DatasetManagementService:
|
|
|
82
159
|
local_file_path: 本地文件路径,当 is_local_upload=True 时必须提供。
|
|
83
160
|
server_file_path: 服务器已有文件路径,当 is_local_upload=False 时必须提供。
|
|
84
161
|
version_description: 版本描述,默认为空。
|
|
162
|
+
timeout: 最大等待秒数(默认1800s)。超过后仍未成功则引发 ``TimeoutError``。
|
|
85
163
|
|
|
86
164
|
Returns:
|
|
87
165
|
tuple[int, int, str]: 一个三元组,包含:[数据集 ID,数据集版本 ID, 数据集版本标签(格式为 <dataset_name>/V<version_number>)]
|
|
@@ -100,14 +178,34 @@ class DatasetManagementService:
|
|
|
100
178
|
name=dataset_name,
|
|
101
179
|
description=dataset_description,
|
|
102
180
|
tags=[],
|
|
181
|
+
cover_img=None,
|
|
182
|
+
create_by=None,
|
|
183
|
+
is_private=None,
|
|
184
|
+
access_user_ids=None,
|
|
103
185
|
)
|
|
104
186
|
)
|
|
187
|
+
logger.info(
|
|
188
|
+
f"创建数据集成功,名称为 {dataset_name} ,开始准备创建版本、上传数据"
|
|
189
|
+
)
|
|
105
190
|
|
|
106
191
|
if is_local_upload:
|
|
107
|
-
|
|
192
|
+
# 上传文件,检查是文件夹还是zip
|
|
193
|
+
local_file_path = pathlib.Path(local_file_path)
|
|
194
|
+
if local_file_path.is_dir():
|
|
195
|
+
# 把文件夹打包为一个 zip
|
|
196
|
+
temp_zip_path = (
|
|
197
|
+
pathlib.Path(tempfile.mkdtemp()) / f" {uuid.uuid4().hex}.zip"
|
|
198
|
+
)
|
|
199
|
+
zip_dir(local_file_path, temp_zip_path)
|
|
200
|
+
upload_data = self._upload.upload_file(temp_zip_path)
|
|
201
|
+
os.remove(temp_zip_path)
|
|
202
|
+
else:
|
|
203
|
+
upload_data = self._upload.upload_file(local_file_path)
|
|
204
|
+
|
|
108
205
|
upload_path = upload_data.path
|
|
109
206
|
else:
|
|
110
207
|
upload_path = server_file_path
|
|
208
|
+
logger.info(f"文件上传成功:{local_file_path}")
|
|
111
209
|
|
|
112
210
|
version_id = self._dataset_version.upload(
|
|
113
211
|
UploadDatasetVersionRequest(
|
|
@@ -115,6 +213,7 @@ class DatasetManagementService:
|
|
|
115
213
|
upload_type=upload_type,
|
|
116
214
|
dataset_id=dataset_id,
|
|
117
215
|
description=version_description,
|
|
216
|
+
parent_version_id=0,
|
|
118
217
|
)
|
|
119
218
|
)
|
|
120
219
|
|
|
@@ -127,12 +226,32 @@ class DatasetManagementService:
|
|
|
127
226
|
ver_num = 1
|
|
128
227
|
|
|
129
228
|
version_tag = f"{detail.name}/V{ver_num}"
|
|
229
|
+
logger.info(f"数据集版本创建成功,名称为 {version_tag},开始轮询状态…")
|
|
230
|
+
|
|
231
|
+
start_ts = time.time()
|
|
232
|
+
poll_interval = 10
|
|
233
|
+
|
|
234
|
+
while True:
|
|
235
|
+
ver_detail = self._dataset_version.get(version_id)
|
|
236
|
+
status = ver_detail.status
|
|
237
|
+
|
|
238
|
+
if status == DatasetVersionStatus.Success:
|
|
239
|
+
logger.info("版本状态已成功")
|
|
240
|
+
break
|
|
241
|
+
|
|
242
|
+
if status == DatasetVersionStatus.Fail:
|
|
243
|
+
raise APIError(f"版本构建失败:{ver_detail.message}")
|
|
244
|
+
|
|
245
|
+
elapsed = time.time() - start_ts
|
|
246
|
+
if elapsed > timeout:
|
|
247
|
+
raise TimeoutError(f"等待版本成功超时({timeout}s),当前状态:{status}")
|
|
248
|
+
|
|
249
|
+
logger.debug(f"已等待 {elapsed:.0f}s,继续轮询…")
|
|
250
|
+
time.sleep(poll_interval)
|
|
130
251
|
|
|
131
252
|
return dataset_id, version_id, version_tag
|
|
132
253
|
|
|
133
|
-
def run_download(
|
|
134
|
-
self, dataset_version_name: str, local_dir: str, worker: int = 4
|
|
135
|
-
) -> None:
|
|
254
|
+
def run_download(self, dataset_version_name: str, local_dir: str, worker: int = 4) -> None:
|
|
136
255
|
"""根据数据集版本名称下载对应的数据集文件。
|
|
137
256
|
|
|
138
257
|
Args:
|
|
@@ -1,23 +1,42 @@
|
|
|
1
1
|
# !/usr/bin/env python
|
|
2
|
-
# -*-coding:utf-8 -*-
|
|
2
|
+
# -*- coding:utf-8 -*-
|
|
3
|
+
"""文档中心服务模块
|
|
4
|
+
|
|
5
|
+
本模块围绕 **“文档检索”** 提供以下能力:
|
|
6
|
+
|
|
7
|
+
- **分页查询文档列表**
|
|
8
|
+
"""
|
|
9
|
+
|
|
3
10
|
from __future__ import annotations
|
|
4
11
|
|
|
12
|
+
from typing import List
|
|
13
|
+
|
|
5
14
|
import httpx
|
|
6
15
|
|
|
7
16
|
from ..exceptions import APIError
|
|
8
17
|
from ..models.common import APIWrapper
|
|
9
|
-
from ..models.document_center import
|
|
18
|
+
from ..models.document_center import Document, GetDocumentsResponse
|
|
10
19
|
|
|
11
20
|
_BASE = "/document-center/api/v1"
|
|
12
21
|
|
|
13
22
|
|
|
14
23
|
class DocumentCenterService:
|
|
24
|
+
"""文档中心服务封装"""
|
|
25
|
+
|
|
15
26
|
def __init__(self, http: httpx.Client):
|
|
16
27
|
self._document = _Document(http)
|
|
17
28
|
|
|
18
|
-
def get_documents(
|
|
19
|
-
|
|
20
|
-
|
|
29
|
+
def get_documents(self, page_size: int = 9999, page_num: int = 1, name: str = "") -> List[Document]:
|
|
30
|
+
"""分页查询文档
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
page_size: 每页条数,默认 9999
|
|
34
|
+
page_num: 当前页码,默认第 1 页
|
|
35
|
+
name: 按名字过滤,默认为空
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
List[Document]: 文档对象列表
|
|
39
|
+
"""
|
|
21
40
|
return self._document.get_documents(page_size, page_num, name)
|
|
22
41
|
|
|
23
42
|
@property
|
aihub/services/eval.py
CHANGED
|
@@ -1,10 +1,17 @@
|
|
|
1
1
|
# !/usr/bin/env python
|
|
2
|
-
# -*-coding:utf-8 -*-
|
|
2
|
+
# -*- coding:utf-8 -*-
|
|
3
|
+
"""评测平台服务模块
|
|
4
|
+
|
|
5
|
+
本模块围绕 **“模型评测(Run → Report)”** 提供能力:
|
|
6
|
+
|
|
7
|
+
- **创建评测任务 / 评测报告**
|
|
8
|
+
"""
|
|
9
|
+
|
|
3
10
|
import httpx
|
|
4
11
|
|
|
5
12
|
from ..exceptions import APIError
|
|
6
13
|
from ..models.common import APIWrapper
|
|
7
|
-
from ..models.eval import
|
|
14
|
+
from ..models.eval import CreateEvalReq, CreateEvalResp
|
|
8
15
|
|
|
9
16
|
_BASE = "/eval-platform/api/v1"
|
|
10
17
|
|
|
@@ -44,7 +51,7 @@ class EvalService:
|
|
|
44
51
|
dataset_version = dataset_service.get_dataset_version_by_name(
|
|
45
52
|
dataset_version_name
|
|
46
53
|
)
|
|
47
|
-
payload =
|
|
54
|
+
payload = CreateEvalReq(
|
|
48
55
|
dataset_id=dataset_version.dataset_id,
|
|
49
56
|
dataset_version_id=dataset_version.id,
|
|
50
57
|
evaled_artifact_path=evaled_artifact_path,
|
|
@@ -60,9 +67,9 @@ class _Eval:
|
|
|
60
67
|
def __init__(self, http: httpx.Client):
|
|
61
68
|
self._http = http
|
|
62
69
|
|
|
63
|
-
def create(self, payload:
|
|
64
|
-
resp = self._http.post(f"{_BASE}/
|
|
65
|
-
wrapper = APIWrapper[
|
|
70
|
+
def create(self, payload: CreateEvalReq) -> int:
|
|
71
|
+
resp = self._http.post(f"{_BASE}/run/", json=payload.model_dump())
|
|
72
|
+
wrapper = APIWrapper[CreateEvalResp].model_validate(resp.json())
|
|
66
73
|
if wrapper.code != 0:
|
|
67
74
|
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
68
|
-
return wrapper.data.id
|
|
75
|
+
return wrapper.data.eval_run.id
|
aihub/services/labelfree.py
CHANGED
|
@@ -1,3 +1,12 @@
|
|
|
1
|
+
# !/usr/bin/env python
|
|
2
|
+
# -*- coding:utf-8 -*-
|
|
3
|
+
"""标注服务模块
|
|
4
|
+
|
|
5
|
+
本模块用于对接 **标注平台**,提供以下能力:
|
|
6
|
+
|
|
7
|
+
- **获取指定项目的整体标注 / 审核完成度等统计信息**
|
|
8
|
+
"""
|
|
9
|
+
|
|
1
10
|
from __future__ import annotations
|
|
2
11
|
|
|
3
12
|
import httpx
|
|
@@ -10,6 +19,8 @@ _BASE = "/labelfree/api/v2"
|
|
|
10
19
|
|
|
11
20
|
|
|
12
21
|
class LabelfreeService:
|
|
22
|
+
"""标注服务"""
|
|
23
|
+
|
|
13
24
|
def __init__(self, http: httpx.Client):
|
|
14
25
|
self._project = _Project(http)
|
|
15
26
|
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
# !/usr/bin/env python
|
|
2
|
+
# -*- coding:utf-8 -*-
|
|
3
|
+
"""模型中心服务模块
|
|
4
|
+
|
|
5
|
+
封装与 **Model‑Center** 后端交互的常用能力,主要涉及模型的增、删、改、查,以及模型元数据(类型 / 部署平台 / 量化等级)的查询功能:
|
|
6
|
+
|
|
7
|
+
- **分页查询模型列表**
|
|
8
|
+
- **获取单个模型详情**
|
|
9
|
+
- **新建模型**
|
|
10
|
+
- **编辑模型**
|
|
11
|
+
- **删除模型**
|
|
12
|
+
- **查询模型类型下拉**
|
|
13
|
+
- **查询部署平台下拉**
|
|
14
|
+
- **查询量化等级下拉**
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import httpx
|
|
20
|
+
|
|
21
|
+
from ..exceptions import APIError
|
|
22
|
+
from ..models.common import APIWrapper
|
|
23
|
+
from ..models.model_center import (
|
|
24
|
+
ListModelsRequest,
|
|
25
|
+
ListModelsResponse,
|
|
26
|
+
ListModelTypesRequest,
|
|
27
|
+
ListModelTypesResponse,
|
|
28
|
+
ListDeployPlatformsRequest,
|
|
29
|
+
ListDeployPlatformsResponse,
|
|
30
|
+
ListQuantLevelsRequest,
|
|
31
|
+
ListQuantLevelsResponse,
|
|
32
|
+
CreateModelRequest,
|
|
33
|
+
CreateModelResponse,
|
|
34
|
+
EditModelRequest,
|
|
35
|
+
Model,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
_BASE = "/model-center/api/v1"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ModelCenterService:
|
|
42
|
+
"""模型中心业务封装"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, http: httpx.Client):
|
|
45
|
+
self._model = _Model(http)
|
|
46
|
+
|
|
47
|
+
def list_models(self, payload: ListModelsRequest) -> ListModelsResponse:
|
|
48
|
+
"""分页查询模型列表
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
payload: 查询参数(分页、名称过滤等)
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
ListModelsResponse: 包含分页信息与模型数据
|
|
55
|
+
"""
|
|
56
|
+
return self._model.list(payload)
|
|
57
|
+
|
|
58
|
+
def get_model(self, model_id: int) -> Model:
|
|
59
|
+
"""获取模型详情
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
model_id: 模型 ID
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Model: 模型完整信息
|
|
66
|
+
"""
|
|
67
|
+
return self._model.get(model_id)
|
|
68
|
+
|
|
69
|
+
def create_model(self, payload: CreateModelRequest) -> int:
|
|
70
|
+
"""创建模型
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
payload: 创建模型所需字段
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
int: 后端生成的模型 ID
|
|
77
|
+
"""
|
|
78
|
+
return self._model.create(payload)
|
|
79
|
+
|
|
80
|
+
def edit_model(self, payload: EditModelRequest) -> None:
|
|
81
|
+
"""编辑模型信息
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
payload: 编辑模型所需字段(需包含 id)
|
|
85
|
+
"""
|
|
86
|
+
self._model.edit(payload)
|
|
87
|
+
|
|
88
|
+
def delete_model(self, model_id: int) -> None:
|
|
89
|
+
"""删除模型
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
model_id: 目标模型 ID
|
|
93
|
+
"""
|
|
94
|
+
self._model.delete(model_id)
|
|
95
|
+
|
|
96
|
+
def list_model_types(self) -> ListModelTypesResponse:
|
|
97
|
+
"""查询模型类型列表
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
ListModelTypesResponse: 模型类型集合
|
|
101
|
+
"""
|
|
102
|
+
return self._model.list_types(ListModelTypesRequest())
|
|
103
|
+
|
|
104
|
+
def list_deploy_platforms(self) -> ListDeployPlatformsResponse:
|
|
105
|
+
"""查询可用部署平台列表
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
ListDeployPlatformsResponse: 部署平台集合
|
|
109
|
+
"""
|
|
110
|
+
return self._model.list_platforms(ListDeployPlatformsRequest())
|
|
111
|
+
|
|
112
|
+
def list_quant_levels(self) -> ListQuantLevelsResponse:
|
|
113
|
+
"""查询量化等级列表
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
ListQuantLevelsResponse: 量化等级集合
|
|
117
|
+
"""
|
|
118
|
+
return self._model.list_quant_levels(ListQuantLevelsRequest())
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def model(self) -> _Model:
|
|
122
|
+
return self._model
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class _Model:
|
|
126
|
+
|
|
127
|
+
def __init__(self, http: httpx.Client):
|
|
128
|
+
self._http = http
|
|
129
|
+
|
|
130
|
+
def list(self, payload: ListModelsRequest) -> ListModelsResponse:
|
|
131
|
+
resp = self._http.get(f"{_BASE}/models", params=payload.model_dump(by_alias=True, exclude_none=True))
|
|
132
|
+
wrapper = APIWrapper[ListModelsResponse].model_validate(resp.json())
|
|
133
|
+
if wrapper.code != 0:
|
|
134
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
135
|
+
return wrapper.data
|
|
136
|
+
|
|
137
|
+
def get(self, model_id: int) -> Model:
|
|
138
|
+
resp = self._http.get(f"{_BASE}/models/{model_id}")
|
|
139
|
+
wrapper = APIWrapper[Model].model_validate(resp.json())
|
|
140
|
+
if wrapper.code != 0:
|
|
141
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
142
|
+
return wrapper.data
|
|
143
|
+
|
|
144
|
+
def create(self, payload: CreateModelRequest) -> int:
|
|
145
|
+
resp = self._http.post(f"{_BASE}/models", json=payload.model_dump(by_alias=True, exclude_none=True))
|
|
146
|
+
wrapper = APIWrapper[CreateModelResponse].model_validate(resp.json())
|
|
147
|
+
if wrapper.code != 0:
|
|
148
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
149
|
+
return wrapper.data.id
|
|
150
|
+
|
|
151
|
+
def edit(self, payload: EditModelRequest) -> None:
|
|
152
|
+
resp = self._http.put(f"{_BASE}/models/{payload.id}",
|
|
153
|
+
json=payload.model_dump(by_alias=True, exclude_none=True))
|
|
154
|
+
wrapper = APIWrapper[dict].model_validate(resp.json())
|
|
155
|
+
if wrapper.code != 0:
|
|
156
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
157
|
+
|
|
158
|
+
def delete(self, model_id: int) -> None:
|
|
159
|
+
resp = self._http.delete(f"{_BASE}/models/{model_id}")
|
|
160
|
+
wrapper = APIWrapper[dict].model_validate(resp.json())
|
|
161
|
+
if wrapper.code != 0:
|
|
162
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
163
|
+
|
|
164
|
+
def list_types(self, payload: ListModelTypesRequest) -> ListModelTypesResponse:
|
|
165
|
+
resp = self._http.get(f"{_BASE}/model-types", params=payload.model_dump(by_alias=True))
|
|
166
|
+
wrapper = APIWrapper[ListModelTypesResponse].model_validate(resp.json())
|
|
167
|
+
if wrapper.code != 0:
|
|
168
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
169
|
+
return wrapper.data
|
|
170
|
+
|
|
171
|
+
def list_platforms(self, payload: ListDeployPlatformsRequest) -> ListDeployPlatformsResponse:
|
|
172
|
+
resp = self._http.get(f"{_BASE}/deploy-platforms", params=payload.model_dump(by_alias=True))
|
|
173
|
+
wrapper = APIWrapper[ListDeployPlatformsResponse].model_validate(resp.json())
|
|
174
|
+
if wrapper.code != 0:
|
|
175
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
176
|
+
return wrapper.data
|
|
177
|
+
|
|
178
|
+
def list_quant_levels(self, payload: ListQuantLevelsRequest) -> ListQuantLevelsResponse:
|
|
179
|
+
resp = self._http.get(f"{_BASE}/quant-levels", params=payload.model_dump(by_alias=True))
|
|
180
|
+
wrapper = APIWrapper[ListQuantLevelsResponse].model_validate(resp.json())
|
|
181
|
+
if wrapper.code != 0:
|
|
182
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
183
|
+
return wrapper.data
|