intellif-aihub 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of intellif-aihub might be problematic. Click here for more details.
- aihub/__init__.py +1 -1
- aihub/client.py +19 -11
- aihub/models/eval.py +17 -0
- aihub/models/labelfree.py +35 -13
- aihub/models/model_training_platform.py +234 -0
- aihub/models/quota_schedule_management.py +239 -0
- aihub/models/tag_resource_management.py +50 -0
- aihub/models/task_center.py +10 -10
- aihub/models/user_system.py +262 -0
- aihub/services/artifact.py +29 -8
- aihub/services/dataset_management.py +1 -1
- aihub/services/eval.py +68 -0
- aihub/services/labelfree.py +1 -1
- aihub/services/model_training_platform.py +209 -0
- aihub/services/quota_schedule_management.py +164 -0
- aihub/services/tag_resource_management.py +55 -0
- aihub/services/task_center.py +16 -16
- aihub/services/user_system.py +339 -0
- {intellif_aihub-0.1.2.dist-info → intellif_aihub-0.1.4.dist-info}/METADATA +2 -2
- intellif_aihub-0.1.4.dist-info/RECORD +36 -0
- aihub/models/tag_management.py +0 -21
- aihub/models/user.py +0 -46
- aihub/services/tag_management.py +0 -35
- aihub/services/user.py +0 -47
- intellif_aihub-0.1.2.dist-info/RECORD +0 -32
- {intellif_aihub-0.1.2.dist-info → intellif_aihub-0.1.4.dist-info}/WHEEL +0 -0
- {intellif_aihub-0.1.2.dist-info → intellif_aihub-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {intellif_aihub-0.1.2.dist-info → intellif_aihub-0.1.4.dist-info}/top_level.txt +0 -0
aihub/models/task_center.py
CHANGED
|
@@ -20,18 +20,18 @@ class TaskCenterPriorityEnum(Enum):
|
|
|
20
20
|
class LabelProjectTypeEnum(Enum):
|
|
21
21
|
"""
|
|
22
22
|
任务类型枚举
|
|
23
|
+
1 - 目标检测 2 - 语义分割 3 - 图片分类 4 - 实例分割 5 - 视频标注 6 - 人类偏好文本标注 7- 敏感预料文本标注 8 - 文本标注 9 - 关键点标注
|
|
23
24
|
"""
|
|
24
25
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
IMAGE_RESTORATION = 10
|
|
26
|
+
OBJECT_DETECTION = 1
|
|
27
|
+
SEGMENTATION = 2
|
|
28
|
+
IMAGE_CLASSIFICATION = 3
|
|
29
|
+
INSTANCE_SEGMENTATION = 4
|
|
30
|
+
VIDEO_LABELING = 5
|
|
31
|
+
HUMAN_PREFERENCE_TEXT_LABELING = 6
|
|
32
|
+
SENSITIVE_TEXT_LABELING = 7
|
|
33
|
+
TEXT_LABELING = 8
|
|
34
|
+
KEYPOINT_LABELING = 9
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
class CreateTaskOtherInfo(BaseModel):
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
# !/usr/bin/env python
|
|
2
|
+
# -*-coding:utf-8 -*-
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import List, Optional
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# ======================================================================
|
|
11
|
+
# COMMON
|
|
12
|
+
# ======================================================================
|
|
13
|
+
|
|
14
|
+
class Role(BaseModel):
|
|
15
|
+
id: int
|
|
16
|
+
name: str
|
|
17
|
+
role_type: int = Field(alias="role_type")
|
|
18
|
+
menu_ids: Optional[List[int]] = Field(None, alias="menu_ids")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Menu(BaseModel):
|
|
22
|
+
id: int
|
|
23
|
+
name: str
|
|
24
|
+
parent: int
|
|
25
|
+
auth: str
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TreeMenu(BaseModel):
|
|
29
|
+
id: int
|
|
30
|
+
name: str
|
|
31
|
+
parent: int
|
|
32
|
+
auth: str
|
|
33
|
+
children: Optional[List["TreeMenu"]] = None
|
|
34
|
+
roles: Optional[List[Role]] = None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TagBrief(BaseModel):
|
|
38
|
+
id: int
|
|
39
|
+
name: str
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# ======================================================================
|
|
43
|
+
# ------------------------------- AUTH ---------------------------------
|
|
44
|
+
# ======================================================================
|
|
45
|
+
|
|
46
|
+
class LoginRequest(BaseModel):
|
|
47
|
+
username: str = Field(alias="username")
|
|
48
|
+
password: str = Field(alias="password")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class LoginResponse(BaseModel):
|
|
52
|
+
id: int = Field(alias="id")
|
|
53
|
+
token: str = Field(alias="token")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class SignupRequest(BaseModel):
|
|
57
|
+
username: str = Field(alias="username")
|
|
58
|
+
password: str = Field(alias="password")
|
|
59
|
+
nickname: str = Field(alias="nickname")
|
|
60
|
+
email: str = Field(alias="email")
|
|
61
|
+
role_ids: List[int] = Field(alias="role_ids")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class SignupResponse(BaseModel):
|
|
65
|
+
id: int = Field(alias="id")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# ======================================================================
|
|
69
|
+
# ------------------------------- MENU ---------------------------------
|
|
70
|
+
# ======================================================================
|
|
71
|
+
class ListMenusRequest(BaseModel):
|
|
72
|
+
need_roles: Optional[bool] = Field(None, alias="need_roles")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ListMenusResponse(BaseModel):
|
|
76
|
+
menus: List[TreeMenu] = Field(None, alias="menus")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class CreateMenuRequest(BaseModel):
|
|
80
|
+
name: str = Field(alias="name")
|
|
81
|
+
parent: int = Field(alias="parent")
|
|
82
|
+
auth: str = Field(alias="auth")
|
|
83
|
+
role_ids: Optional[List[int]] = Field(None, alias="role_ids")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class CreateMenuResponse(BaseModel):
|
|
87
|
+
id: int = Field(alias="id")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class UpdateMenuRequest(BaseModel):
|
|
91
|
+
name: Optional[str] = Field(None, alias="name")
|
|
92
|
+
parent: Optional[int] = Field(None, alias="parent")
|
|
93
|
+
auth: str = Field(alias="auth")
|
|
94
|
+
role_ids: Optional[List[int]] = Field(None, alias="role_ids")
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class GetMenuRolesResponse(BaseModel):
|
|
98
|
+
role_ids: List[int] = Field(alias="role_ids")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class SetMenuRolesRequest(BaseModel):
|
|
102
|
+
role_ids: List[int] = Field(alias="role_ids")
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class SearchMenusRequest(BaseModel):
|
|
106
|
+
page_size: int = Field(20, alias="page_size")
|
|
107
|
+
page_num: int = Field(1, alias="page_num")
|
|
108
|
+
name: Optional[str] = None
|
|
109
|
+
parent_ids: Optional[List[int]] = Field(None, alias="parent_ids")
|
|
110
|
+
auth: Optional[str] = None
|
|
111
|
+
menu_ids: Optional[List[int]] = Field(None, alias="menu_ids")
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class SearchMenusResponse(BaseModel):
|
|
115
|
+
total: int
|
|
116
|
+
page_size: int = Field(alias="page_size")
|
|
117
|
+
page_num: int = Field(alias="page_num")
|
|
118
|
+
data: List[Menu]
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# ======================================================================
|
|
122
|
+
# ------------------------------- ROLE ---------------------------------
|
|
123
|
+
# ======================================================================
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class CreateRoleRequest(BaseModel):
|
|
127
|
+
id: Optional[int] = Field(None, alias="id")
|
|
128
|
+
name: str
|
|
129
|
+
role_type: int = Field(alias="role_type")
|
|
130
|
+
menu_ids: Optional[List[int]] = Field(None, alias="menu_ids")
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class CreateRoleResponse(BaseModel):
|
|
134
|
+
id: int
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class UpdateRoleRequest(BaseModel):
|
|
138
|
+
name: Optional[str] = None
|
|
139
|
+
role_type: Optional[int] = Field(None, alias="role_type")
|
|
140
|
+
menu_ids: Optional[List[int]] = Field(None, alias="menu_ids")
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class GetRoleMenusResponse(BaseModel):
|
|
144
|
+
menu_ids: List[int] = Field(alias="menu_ids")
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class SetRoleMenusRequest(BaseModel):
|
|
148
|
+
menu_ids: List[int] = Field(alias="menu_ids")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class ListRolesRequest(BaseModel):
|
|
152
|
+
page_size: int = Field(20, alias="page_size")
|
|
153
|
+
page_num: int = Field(1, alias="page_num")
|
|
154
|
+
role_type: Optional[int] = Field(None, alias="role_type")
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class ListRolesResponse(BaseModel):
|
|
158
|
+
total: int
|
|
159
|
+
page_size: int = Field(alias="page_size")
|
|
160
|
+
page_num: int = Field(alias="page_num")
|
|
161
|
+
data: List[Role]
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class SearchRolesRequest(BaseModel):
|
|
165
|
+
page_size: int = Field(20, alias="page_size")
|
|
166
|
+
page_num: int = Field(1, alias="page_num")
|
|
167
|
+
name: Optional[str] = None
|
|
168
|
+
role_ids: Optional[List[int]] = Field(None, alias="role_ids")
|
|
169
|
+
menu_ids: Optional[List[int]] = Field(None, alias="menu_ids")
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class SearchRolesResponse(BaseModel):
|
|
173
|
+
total: int
|
|
174
|
+
page_size: int = Field(alias="page_size")
|
|
175
|
+
page_num: int = Field(alias="page_num")
|
|
176
|
+
data: List[Role]
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
# ======================================================================
|
|
180
|
+
# ------------------------------- USER ---------------------------------
|
|
181
|
+
# ======================================================================
|
|
182
|
+
|
|
183
|
+
class User(BaseModel):
|
|
184
|
+
id: int
|
|
185
|
+
username: str
|
|
186
|
+
nickname: str
|
|
187
|
+
email: str
|
|
188
|
+
roles: Optional[List[Role]] = Field(None, alias="roles")
|
|
189
|
+
status: int
|
|
190
|
+
tags: Optional[List[TagBrief]] = Field(None, alias="tags")
|
|
191
|
+
created_at: int = Field(alias="created_at")
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class ListUsersRequest(BaseModel):
|
|
195
|
+
page_size: int = Field(20, alias="page_size")
|
|
196
|
+
page_num: int = Field(1, alias="page_num")
|
|
197
|
+
search_key: Optional[str] = Field(None, alias="search_key")
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class ListUsersResponse(BaseModel):
|
|
201
|
+
total: int
|
|
202
|
+
page_size: int = Field(alias="page_size")
|
|
203
|
+
page_num: int = Field(alias="page_num")
|
|
204
|
+
data: List[User]
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class CreateUserRequest(BaseModel):
|
|
208
|
+
id: int
|
|
209
|
+
username: str
|
|
210
|
+
password: str
|
|
211
|
+
nickname: str
|
|
212
|
+
email: str
|
|
213
|
+
role_ids: Optional[List[int]] = Field(None, alias="role_ids")
|
|
214
|
+
created_at: Optional[int] = Field(None, alias="created_at")
|
|
215
|
+
updated_at: Optional[int] = Field(None, alias="updated_at")
|
|
216
|
+
status: Optional[int] = None
|
|
217
|
+
tag_ids: Optional[List[int]] = Field(None, alias="tag_ids")
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class CreateUserResponse(BaseModel):
|
|
221
|
+
id: int
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class UpdateUserRequest(BaseModel):
|
|
225
|
+
username: Optional[str] = None
|
|
226
|
+
nickname: Optional[str] = None
|
|
227
|
+
email: Optional[str] = None
|
|
228
|
+
password: Optional[str] = None
|
|
229
|
+
role_ids: Optional[List[int]] = Field(default_factory=list, alias="role_ids")
|
|
230
|
+
status: Optional[int] = None
|
|
231
|
+
tag_ids: Optional[List[int]] = Field(default_factory=list, alias="tag_ids")
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class SetUserRolesRequest(BaseModel):
|
|
235
|
+
role_ids: List[int] = Field(alias="role_ids")
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class GetUserMenusResponse(BaseModel):
|
|
239
|
+
menus: List[TreeMenu]
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class SearchUsersRequest(BaseModel):
|
|
243
|
+
page_size: int = Field(20, alias="page_size")
|
|
244
|
+
page_num: int = Field(1, alias="page_num")
|
|
245
|
+
username: Optional[str] = None
|
|
246
|
+
nickname: Optional[str] = None
|
|
247
|
+
email: Optional[str] = None
|
|
248
|
+
user_ids: Optional[List[int]] = Field(None, alias="user_ids")
|
|
249
|
+
role_ids: Optional[List[int]] = Field(None, alias="role_ids")
|
|
250
|
+
role_names: Optional[List[str]] = Field(None, alias="role_names")
|
|
251
|
+
status: Optional[int] = None
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class SearchUsersResponse(BaseModel):
|
|
255
|
+
total: int
|
|
256
|
+
page_size: int = Field(alias="page_size")
|
|
257
|
+
page_num: int = Field(alias="page_num")
|
|
258
|
+
data: List[User]
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
# 此行放在文件末尾,否则序列化报错
|
|
262
|
+
TreeMenu.update_forward_refs()
|
aihub/services/artifact.py
CHANGED
|
@@ -52,14 +52,8 @@ class ArtifactService:
|
|
|
52
52
|
"""
|
|
53
53
|
self._http = http
|
|
54
54
|
self._Artifact = _Artifact(http)
|
|
55
|
-
self.sts =
|
|
56
|
-
self.s3_client =
|
|
57
|
-
self.sts.endpoint,
|
|
58
|
-
access_key=self.sts.access_key_id,
|
|
59
|
-
secret_key=self.sts.secret_access_key,
|
|
60
|
-
secure=False,
|
|
61
|
-
session_token=self.sts.session_token,
|
|
62
|
-
)
|
|
55
|
+
self.sts = None
|
|
56
|
+
self.s3_client = None
|
|
63
57
|
|
|
64
58
|
@property
|
|
65
59
|
def _artifact(self) -> _Artifact:
|
|
@@ -142,6 +136,15 @@ class ArtifactService:
|
|
|
142
136
|
APIError: 当API调用失败时抛出
|
|
143
137
|
"""
|
|
144
138
|
logger.info(f"log artifact: {artifact_path},local path: {local_path} ")
|
|
139
|
+
if self.s3_client is None:
|
|
140
|
+
self.sts = self._get_sts()
|
|
141
|
+
self.s3_client = minio.Minio(
|
|
142
|
+
self.sts.endpoint,
|
|
143
|
+
access_key=self.sts.access_key_id,
|
|
144
|
+
secret_key=self.sts.secret_access_key,
|
|
145
|
+
session_token=self.sts.session_token,
|
|
146
|
+
secure=False,
|
|
147
|
+
)
|
|
145
148
|
|
|
146
149
|
# 检查文件是否存在
|
|
147
150
|
if not os.path.exists(local_path):
|
|
@@ -187,6 +190,15 @@ class ArtifactService:
|
|
|
187
190
|
ValueError: 当本地目录不存在时抛出
|
|
188
191
|
APIError: 当API调用失败时抛出
|
|
189
192
|
"""
|
|
193
|
+
if self.s3_client is None:
|
|
194
|
+
self.sts = self._get_sts()
|
|
195
|
+
self.s3_client = minio.Minio(
|
|
196
|
+
self.sts.endpoint,
|
|
197
|
+
access_key=self.sts.access_key_id,
|
|
198
|
+
secret_key=self.sts.secret_access_key,
|
|
199
|
+
session_token=self.sts.session_token,
|
|
200
|
+
secure=False,
|
|
201
|
+
)
|
|
190
202
|
|
|
191
203
|
logger.info(f"log artifact: {artifact_path},local path: {local_dir} ")
|
|
192
204
|
if not os.path.exists(local_dir):
|
|
@@ -217,6 +229,15 @@ class ArtifactService:
|
|
|
217
229
|
Raises:
|
|
218
230
|
APIError: 当API调用失败时抛出
|
|
219
231
|
"""
|
|
232
|
+
if self.s3_client is None:
|
|
233
|
+
self.sts = self._get_sts()
|
|
234
|
+
self.s3_client = minio.Minio(
|
|
235
|
+
self.sts.endpoint,
|
|
236
|
+
access_key=self.sts.access_key_id,
|
|
237
|
+
secret_key=self.sts.secret_access_key,
|
|
238
|
+
session_token=self.sts.session_token,
|
|
239
|
+
secure=False,
|
|
240
|
+
)
|
|
220
241
|
artifacts = self.get_by_run_id(run_id, artifact_path)
|
|
221
242
|
|
|
222
243
|
for artifact_item in artifacts:
|
|
@@ -131,7 +131,7 @@ class DatasetManagementService:
|
|
|
131
131
|
return dataset_id, version_id, version_tag
|
|
132
132
|
|
|
133
133
|
def run_download(
|
|
134
|
-
self,
|
|
134
|
+
self, dataset_version_name: str, local_dir: str, worker: int = 4
|
|
135
135
|
) -> None:
|
|
136
136
|
"""根据数据集版本名称下载对应的数据集文件。
|
|
137
137
|
|
aihub/services/eval.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# !/usr/bin/env python
|
|
2
|
+
# -*-coding:utf-8 -*-
|
|
3
|
+
import httpx
|
|
4
|
+
|
|
5
|
+
from ..exceptions import APIError
|
|
6
|
+
from ..models.common import APIWrapper
|
|
7
|
+
from ..models.eval import CreatEvalReq, CreatEvalResp
|
|
8
|
+
|
|
9
|
+
_BASE = "/eval-platform/api/v1"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class EvalService:
|
|
13
|
+
"""评测服务"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, http: httpx.Client):
|
|
16
|
+
self._http = http
|
|
17
|
+
self._eval = _Eval(http)
|
|
18
|
+
|
|
19
|
+
def create(
|
|
20
|
+
self,
|
|
21
|
+
dataset_version_name: str,
|
|
22
|
+
prediction_artifact_path: str,
|
|
23
|
+
evaled_artifact_path: str,
|
|
24
|
+
report_json: dict,
|
|
25
|
+
run_id,
|
|
26
|
+
) -> int:
|
|
27
|
+
"""创建评测报告
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
run_id (str): RUN ID
|
|
31
|
+
report_json (dict): 报告内容
|
|
32
|
+
evaled_artifact_path: 评测结果制品路径
|
|
33
|
+
prediction_artifact_path: 推理结果制品路径
|
|
34
|
+
dataset_version_name (str): 数据集名称
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
id (int): 评测报告id
|
|
39
|
+
|
|
40
|
+
"""
|
|
41
|
+
from .dataset_management import DatasetManagementService
|
|
42
|
+
|
|
43
|
+
dataset_service = DatasetManagementService(self._http)
|
|
44
|
+
dataset_version = dataset_service.get_dataset_version_by_name(
|
|
45
|
+
dataset_version_name
|
|
46
|
+
)
|
|
47
|
+
payload = CreatEvalReq(
|
|
48
|
+
dataset_id=dataset_version.dataset_id,
|
|
49
|
+
dataset_version_id=dataset_version.id,
|
|
50
|
+
evaled_artifact_path=evaled_artifact_path,
|
|
51
|
+
prediction_artifact_path=prediction_artifact_path,
|
|
52
|
+
report=report_json,
|
|
53
|
+
run_id=run_id,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return self._eval.create(payload)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class _Eval:
|
|
60
|
+
def __init__(self, http: httpx.Client):
|
|
61
|
+
self._http = http
|
|
62
|
+
|
|
63
|
+
def create(self, payload: CreatEvalReq) -> int:
|
|
64
|
+
resp = self._http.post(f"{_BASE}/RUN/", json=payload.model_dump())
|
|
65
|
+
wrapper = APIWrapper[CreatEvalResp].model_validate(resp.json())
|
|
66
|
+
if wrapper.code != 0:
|
|
67
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
68
|
+
return wrapper.data.id
|
aihub/services/labelfree.py
CHANGED
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import httpx
|
|
4
|
+
|
|
5
|
+
from ..exceptions import APIError
|
|
6
|
+
from ..models.common import APIWrapper
|
|
7
|
+
from ..models.model_training_platform import (
|
|
8
|
+
CreateTrainingRequest,
|
|
9
|
+
ListTrainingsRequest,
|
|
10
|
+
ListTrainingsResponse,
|
|
11
|
+
ListTrainingPodsRequest,
|
|
12
|
+
GetTrainingPodLogsNewResponse,
|
|
13
|
+
GetTrainingPodEventsResponse,
|
|
14
|
+
Training,
|
|
15
|
+
CreateTrainingResponse,
|
|
16
|
+
ListTrainingUsersRequest,
|
|
17
|
+
ListTrainingPodsResponse,
|
|
18
|
+
GetTrainingPodSpecResponse,
|
|
19
|
+
ListTrainingContainersRequest,
|
|
20
|
+
ListTrainingUsersResponse,
|
|
21
|
+
ListTrainingContainersResponse,
|
|
22
|
+
Pod,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
_BASE = "/model-training-platform/api/v1"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ModelTrainingPlatformService:
|
|
29
|
+
|
|
30
|
+
def __init__(self, http: httpx.Client):
|
|
31
|
+
self._training = _Training(http)
|
|
32
|
+
|
|
33
|
+
def create_training(self, payload: CreateTrainingRequest) -> int:
|
|
34
|
+
"""创建训练任务
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
payload (CreateTrainingRequest): 创建训练任务参数
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
int: 训练任务ID
|
|
41
|
+
"""
|
|
42
|
+
return self._training.create(payload)
|
|
43
|
+
|
|
44
|
+
def list_trainings(self, payload: ListTrainingsRequest) -> ListTrainingsResponse:
|
|
45
|
+
return self._training.list(payload)
|
|
46
|
+
|
|
47
|
+
def get_training(self, training_id: int) -> Training:
|
|
48
|
+
"""获取训练任务详情
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
training_id (int): 训练任务ID
|
|
52
|
+
Returns:
|
|
53
|
+
Training
|
|
54
|
+
|
|
55
|
+
"""
|
|
56
|
+
return self._training.get(training_id)
|
|
57
|
+
|
|
58
|
+
def stop_training(self, training_id: int) -> None:
|
|
59
|
+
"""停止训练瑞文
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
training_id (int): 训练任务ID
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
|
|
66
|
+
"""
|
|
67
|
+
self._training.stop(training_id)
|
|
68
|
+
|
|
69
|
+
def list_training_pods(
|
|
70
|
+
self, training_id: int, payload: ListTrainingPodsRequest
|
|
71
|
+
) -> ListTrainingPodsResponse:
|
|
72
|
+
return self._training.list_training_pods(training_id, payload)
|
|
73
|
+
|
|
74
|
+
def get_training_pod(self, training_id: int, pod_id: int) -> Pod:
|
|
75
|
+
return self._training.get_training_pod(training_id, pod_id)
|
|
76
|
+
|
|
77
|
+
def get_pod_logs_new(
|
|
78
|
+
self, training_id: int, pod_id: int
|
|
79
|
+
) -> GetTrainingPodLogsNewResponse:
|
|
80
|
+
return self._training.get_training_logs_new(training_id, pod_id)
|
|
81
|
+
|
|
82
|
+
def get_pod_spec(self, training_id: int, pod_id: int) -> GetTrainingPodSpecResponse:
|
|
83
|
+
return self._training.get_training_spec(training_id, pod_id)
|
|
84
|
+
|
|
85
|
+
def get_pod_events(
|
|
86
|
+
self, training_id: int, pod_id: int
|
|
87
|
+
) -> GetTrainingPodEventsResponse:
|
|
88
|
+
return self._training.get_training_events(training_id, pod_id)
|
|
89
|
+
|
|
90
|
+
def list_training_users(
|
|
91
|
+
self, payload: ListTrainingUsersRequest
|
|
92
|
+
) -> ListTrainingUsersResponse:
|
|
93
|
+
return self._training.list_training_users(payload)
|
|
94
|
+
|
|
95
|
+
def list_training_containers(
|
|
96
|
+
self, payload: ListTrainingContainersRequest
|
|
97
|
+
) -> ListTrainingContainersResponse:
|
|
98
|
+
return self._training.list_training_containers(payload)
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def training(self) -> _Training:
|
|
102
|
+
return self._training
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class _Training:
|
|
106
|
+
|
|
107
|
+
def __init__(self, http: httpx.Client):
|
|
108
|
+
self._http = http
|
|
109
|
+
|
|
110
|
+
def create(self, payload: CreateTrainingRequest) -> int:
|
|
111
|
+
resp = self._http.post(
|
|
112
|
+
f"{_BASE}/trainings",
|
|
113
|
+
json=payload.model_dump(by_alias=True, exclude_none=True),
|
|
114
|
+
)
|
|
115
|
+
wrapper = APIWrapper[CreateTrainingResponse].model_validate(resp.json())
|
|
116
|
+
if wrapper.code != 0:
|
|
117
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
118
|
+
return wrapper.data.id
|
|
119
|
+
|
|
120
|
+
def list(self, payload: ListTrainingsRequest) -> ListTrainingsResponse:
|
|
121
|
+
resp = self._http.get(
|
|
122
|
+
f"{_BASE}/trainings",
|
|
123
|
+
params=payload.model_dump(by_alias=True, exclude_none=True),
|
|
124
|
+
)
|
|
125
|
+
wrapper = APIWrapper[ListTrainingsResponse].model_validate(resp.json())
|
|
126
|
+
if wrapper.code != 0:
|
|
127
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
128
|
+
return wrapper.data
|
|
129
|
+
|
|
130
|
+
def get(self, training_id: int) -> Training:
|
|
131
|
+
resp = self._http.get(f"{_BASE}/trainings/{training_id}")
|
|
132
|
+
wrapper = APIWrapper[Training].model_validate(resp.json())
|
|
133
|
+
if wrapper.code != 0:
|
|
134
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
135
|
+
return wrapper.data
|
|
136
|
+
|
|
137
|
+
def stop(self, training_id: int) -> None:
|
|
138
|
+
resp = self._http.post(f"{_BASE}/trainings/{training_id}/stop")
|
|
139
|
+
wrapper = APIWrapper[dict].model_validate(resp.json())
|
|
140
|
+
if wrapper.code != 0:
|
|
141
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
142
|
+
|
|
143
|
+
def list_training_pods(
|
|
144
|
+
self, training_id: int, payload: ListTrainingPodsRequest
|
|
145
|
+
) -> ListTrainingPodsResponse:
|
|
146
|
+
resp = self._http.get(
|
|
147
|
+
f"{_BASE}/trainings/{training_id}/pods",
|
|
148
|
+
params=payload.model_dump(by_alias=True, exclude_none=True),
|
|
149
|
+
)
|
|
150
|
+
wrapper = APIWrapper[ListTrainingPodsResponse].model_validate(resp.json())
|
|
151
|
+
if wrapper.code != 0:
|
|
152
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
153
|
+
return wrapper.data
|
|
154
|
+
|
|
155
|
+
def get_training_pod(self, training_id: int, pod_id: int) -> Pod:
|
|
156
|
+
resp = self._http.get(f"{_BASE}/trainings/{training_id}/pods/{pod_id}")
|
|
157
|
+
wrapper = APIWrapper[Pod].model_validate(resp.json())
|
|
158
|
+
if wrapper.code != 0:
|
|
159
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
160
|
+
return wrapper.data
|
|
161
|
+
|
|
162
|
+
def get_training_logs_new(
|
|
163
|
+
self, training_id: int, pod_id: int
|
|
164
|
+
) -> GetTrainingPodLogsNewResponse:
|
|
165
|
+
resp = self._http.get(f"{_BASE}/trainings/{training_id}/pods/{pod_id}/logs/new")
|
|
166
|
+
wrapper = APIWrapper[GetTrainingPodLogsNewResponse].model_validate(resp.json())
|
|
167
|
+
if wrapper.code != 0:
|
|
168
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
169
|
+
return wrapper.data
|
|
170
|
+
|
|
171
|
+
def get_training_spec(
|
|
172
|
+
self, training_id: int, pod_id: int
|
|
173
|
+
) -> GetTrainingPodSpecResponse:
|
|
174
|
+
resp = self._http.get(f"{_BASE}/trainings/{training_id}/pods/{pod_id}/spec")
|
|
175
|
+
wrapper = APIWrapper[GetTrainingPodSpecResponse].model_validate(resp.json())
|
|
176
|
+
if wrapper.code != 0:
|
|
177
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
178
|
+
return wrapper.data
|
|
179
|
+
|
|
180
|
+
def get_training_events(
|
|
181
|
+
self, training_id: int, pod_id: int
|
|
182
|
+
) -> GetTrainingPodEventsResponse:
|
|
183
|
+
resp = self._http.get(f"{_BASE}/trainings/{training_id}/pods/{pod_id}/events")
|
|
184
|
+
wrapper = APIWrapper[GetTrainingPodEventsResponse].model_validate(resp.json())
|
|
185
|
+
if wrapper.code != 0:
|
|
186
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
187
|
+
return wrapper.data
|
|
188
|
+
|
|
189
|
+
def list_training_users(
|
|
190
|
+
self, payload: ListTrainingUsersRequest
|
|
191
|
+
) -> ListTrainingUsersResponse:
|
|
192
|
+
resp = self._http.get(
|
|
193
|
+
f"{_BASE}/training-users", params=payload.model_dump(by_alias=True)
|
|
194
|
+
)
|
|
195
|
+
wrapper = APIWrapper[ListTrainingUsersResponse].model_validate(resp.json())
|
|
196
|
+
if wrapper.code != 0:
|
|
197
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
198
|
+
return wrapper.data
|
|
199
|
+
|
|
200
|
+
def list_training_containers(
|
|
201
|
+
self, payload: ListTrainingContainersRequest
|
|
202
|
+
) -> ListTrainingContainersResponse:
|
|
203
|
+
resp = self._http.get(
|
|
204
|
+
f"{_BASE}/training-containers", params=payload.model_dump(by_alias=True)
|
|
205
|
+
)
|
|
206
|
+
wrapper = APIWrapper[ListTrainingContainersResponse].model_validate(resp.json())
|
|
207
|
+
if wrapper.code != 0:
|
|
208
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
209
|
+
return wrapper.data
|