bizyengine 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bizyengine/__init__.py +35 -0
- bizyengine/bizy_server/__init__.py +7 -0
- bizyengine/bizy_server/api_client.py +763 -0
- bizyengine/bizy_server/errno.py +122 -0
- bizyengine/bizy_server/error_handler.py +3 -0
- bizyengine/bizy_server/execution.py +55 -0
- bizyengine/bizy_server/resp.py +24 -0
- bizyengine/bizy_server/server.py +898 -0
- bizyengine/bizy_server/utils.py +93 -0
- bizyengine/bizyair_extras/__init__.py +24 -0
- bizyengine/bizyair_extras/nodes_advanced_refluxcontrol.py +62 -0
- bizyengine/bizyair_extras/nodes_cogview4.py +31 -0
- bizyengine/bizyair_extras/nodes_comfyui_detail_daemon.py +180 -0
- bizyengine/bizyair_extras/nodes_comfyui_instantid.py +164 -0
- bizyengine/bizyair_extras/nodes_comfyui_layerstyle_advance.py +141 -0
- bizyengine/bizyair_extras/nodes_comfyui_pulid_flux.py +88 -0
- bizyengine/bizyair_extras/nodes_controlnet.py +50 -0
- bizyengine/bizyair_extras/nodes_custom_sampler.py +130 -0
- bizyengine/bizyair_extras/nodes_dataset.py +99 -0
- bizyengine/bizyair_extras/nodes_differential_diffusion.py +16 -0
- bizyengine/bizyair_extras/nodes_flux.py +69 -0
- bizyengine/bizyair_extras/nodes_image_utils.py +93 -0
- bizyengine/bizyair_extras/nodes_ip2p.py +20 -0
- bizyengine/bizyair_extras/nodes_ipadapter_plus/__init__.py +1 -0
- bizyengine/bizyair_extras/nodes_ipadapter_plus/nodes_ipadapter_plus.py +1598 -0
- bizyengine/bizyair_extras/nodes_janus_pro.py +81 -0
- bizyengine/bizyair_extras/nodes_kolors_mz/__init__.py +86 -0
- bizyengine/bizyair_extras/nodes_model_advanced.py +62 -0
- bizyengine/bizyair_extras/nodes_sd3.py +52 -0
- bizyengine/bizyair_extras/nodes_segment_anything.py +256 -0
- bizyengine/bizyair_extras/nodes_segment_anything_utils.py +134 -0
- bizyengine/bizyair_extras/nodes_testing_utils.py +139 -0
- bizyengine/bizyair_extras/nodes_trellis.py +199 -0
- bizyengine/bizyair_extras/nodes_ultimatesdupscale.py +137 -0
- bizyengine/bizyair_extras/nodes_upscale_model.py +32 -0
- bizyengine/bizyair_extras/nodes_wan_video.py +49 -0
- bizyengine/bizyair_extras/oauth_callback/main.py +118 -0
- bizyengine/core/__init__.py +8 -0
- bizyengine/core/commands/__init__.py +1 -0
- bizyengine/core/commands/base.py +27 -0
- bizyengine/core/commands/invoker.py +4 -0
- bizyengine/core/commands/processors/model_hosting_processor.py +0 -0
- bizyengine/core/commands/processors/prompt_processor.py +123 -0
- bizyengine/core/commands/servers/model_server.py +0 -0
- bizyengine/core/commands/servers/prompt_server.py +234 -0
- bizyengine/core/common/__init__.py +8 -0
- bizyengine/core/common/caching.py +198 -0
- bizyengine/core/common/client.py +262 -0
- bizyengine/core/common/env_var.py +101 -0
- bizyengine/core/common/utils.py +93 -0
- bizyengine/core/configs/conf.py +112 -0
- bizyengine/core/configs/models.json +101 -0
- bizyengine/core/configs/models.yaml +329 -0
- bizyengine/core/data_types.py +20 -0
- bizyengine/core/image_utils.py +288 -0
- bizyengine/core/nodes_base.py +159 -0
- bizyengine/core/nodes_io.py +97 -0
- bizyengine/core/path_utils/__init__.py +9 -0
- bizyengine/core/path_utils/path_manager.py +276 -0
- bizyengine/core/path_utils/utils.py +34 -0
- bizyengine/misc/__init__.py +0 -0
- bizyengine/misc/auth.py +83 -0
- bizyengine/misc/llm.py +431 -0
- bizyengine/misc/mzkolors.py +93 -0
- bizyengine/misc/nodes.py +1208 -0
- bizyengine/misc/nodes_controlnet_aux.py +491 -0
- bizyengine/misc/nodes_controlnet_union_sdxl.py +171 -0
- bizyengine/misc/route_sam.py +60 -0
- bizyengine/misc/segment_anything.py +276 -0
- bizyengine/misc/supernode.py +182 -0
- bizyengine/misc/utils.py +218 -0
- bizyengine/version.txt +1 -0
- bizyengine-0.4.2.dist-info/METADATA +12 -0
- bizyengine-0.4.2.dist-info/RECORD +76 -0
- bizyengine-0.4.2.dist-info/WHEEL +5 -0
- bizyengine-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,898 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import threading
|
|
5
|
+
import time
|
|
6
|
+
import urllib.parse
|
|
7
|
+
import uuid
|
|
8
|
+
|
|
9
|
+
import aiohttp
|
|
10
|
+
from server import PromptServer
|
|
11
|
+
|
|
12
|
+
from .api_client import APIClient
|
|
13
|
+
from .errno import ErrorNo, errnos
|
|
14
|
+
from .error_handler import ErrorHandler
|
|
15
|
+
from .resp import ErrResponse, OKResponse
|
|
16
|
+
from .utils import base_model_types, check_str_param, check_type, is_string_valid, types
|
|
17
|
+
|
|
18
|
+
API_PREFIX = "bizyair"
|
|
19
|
+
COMMUNITY_API = f"{API_PREFIX}/community"
|
|
20
|
+
MODEL_HOST_API = f"{API_PREFIX}/modelhost"
|
|
21
|
+
USER_API = f"{API_PREFIX}/user"
|
|
22
|
+
|
|
23
|
+
logging.basicConfig(level=logging.DEBUG)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class BizyAirServer:
|
|
27
|
+
def __init__(self):
|
|
28
|
+
BizyAirServer.instance = self
|
|
29
|
+
self.api_client = APIClient()
|
|
30
|
+
self.error_handler = ErrorHandler()
|
|
31
|
+
self.prompt_server = PromptServer.instance
|
|
32
|
+
self.sockets = dict()
|
|
33
|
+
self.loop = asyncio.get_event_loop()
|
|
34
|
+
|
|
35
|
+
self.setup_routes()
|
|
36
|
+
|
|
37
|
+
def setup_routes(self):
|
|
38
|
+
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/model_types")
|
|
39
|
+
async def list_model_types(request):
|
|
40
|
+
return OKResponse(types())
|
|
41
|
+
|
|
42
|
+
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/base_model_types")
|
|
43
|
+
async def list_base_model_types(request):
|
|
44
|
+
return OKResponse(base_model_types())
|
|
45
|
+
|
|
46
|
+
@self.prompt_server.routes.get(f"/{USER_API}/info")
|
|
47
|
+
async def user_info(request):
|
|
48
|
+
info, err = await self.api_client.user_info()
|
|
49
|
+
if err is not None:
|
|
50
|
+
return ErrResponse(err)
|
|
51
|
+
|
|
52
|
+
return OKResponse(info)
|
|
53
|
+
|
|
54
|
+
@self.prompt_server.routes.get(f"/{API_PREFIX}/ws")
|
|
55
|
+
async def websocket_handler(request):
|
|
56
|
+
ws = aiohttp.web.WebSocketResponse()
|
|
57
|
+
await ws.prepare(request)
|
|
58
|
+
sid = request.rel_url.query.get("clientId", "")
|
|
59
|
+
if sid:
|
|
60
|
+
# Reusing existing session, remove old
|
|
61
|
+
self.sockets.pop(sid, None)
|
|
62
|
+
else:
|
|
63
|
+
sid = uuid.uuid4().hex
|
|
64
|
+
|
|
65
|
+
self.sockets[sid] = ws
|
|
66
|
+
|
|
67
|
+
try:
|
|
68
|
+
# Send initial state to the new client
|
|
69
|
+
await self.send_json(
|
|
70
|
+
event="status", data={"status": "connected"}, sid=sid
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
async for msg in ws:
|
|
74
|
+
if msg.type == aiohttp.WSMsgType.TEXT:
|
|
75
|
+
if msg.data == "ping":
|
|
76
|
+
await ws.send_str("pong")
|
|
77
|
+
if msg.type == aiohttp.WSMsgType.ERROR:
|
|
78
|
+
logging.warning(
|
|
79
|
+
"ws connection closed with exception %s" % ws.exception()
|
|
80
|
+
)
|
|
81
|
+
finally:
|
|
82
|
+
self.sockets.pop(sid, None)
|
|
83
|
+
return ws
|
|
84
|
+
|
|
85
|
+
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/sign")
|
|
86
|
+
async def sign(request):
|
|
87
|
+
sha256sum = request.rel_url.query.get("sha256sum")
|
|
88
|
+
if not is_string_valid(sha256sum):
|
|
89
|
+
return ErrResponse(errnos.EMPTY_SHA256SUM)
|
|
90
|
+
|
|
91
|
+
type = request.rel_url.query.get("type")
|
|
92
|
+
|
|
93
|
+
sign_data, err = await self.api_client.sign(sha256sum, type)
|
|
94
|
+
if err is not None:
|
|
95
|
+
return ErrResponse(err)
|
|
96
|
+
|
|
97
|
+
return OKResponse(sign_data)
|
|
98
|
+
|
|
99
|
+
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/upload_token")
|
|
100
|
+
async def upload_token(request):
|
|
101
|
+
filename = request.rel_url.query.get("filename", "")
|
|
102
|
+
# 校验filename
|
|
103
|
+
if not is_string_valid(filename):
|
|
104
|
+
return ErrResponse(errnos.INVALID_FILENAME)
|
|
105
|
+
|
|
106
|
+
filename = urllib.parse.quote(filename)
|
|
107
|
+
token, err = await self.api_client.get_upload_token(filename=filename)
|
|
108
|
+
if err is not None:
|
|
109
|
+
return ErrResponse(err)
|
|
110
|
+
return OKResponse(token)
|
|
111
|
+
|
|
112
|
+
@self.prompt_server.routes.post(f"/{COMMUNITY_API}/commit_file")
|
|
113
|
+
async def commit_file(request):
|
|
114
|
+
json_data = await request.json()
|
|
115
|
+
|
|
116
|
+
if "sha256sum" not in json_data:
|
|
117
|
+
return ErrResponse(errnos.EMPTY_SHA256SUM)
|
|
118
|
+
sha256sum = json_data.get("sha256sum")
|
|
119
|
+
|
|
120
|
+
if "object_key" not in json_data:
|
|
121
|
+
return ErrResponse(errnos.INVALID_OBJECT_KEY)
|
|
122
|
+
object_key = json_data.get("object_key")
|
|
123
|
+
|
|
124
|
+
if "type" not in json_data:
|
|
125
|
+
return ErrResponse(errnos.INVALID_TYPE)
|
|
126
|
+
type = json_data.get("type")
|
|
127
|
+
|
|
128
|
+
md5_hash = ""
|
|
129
|
+
if "md5_hash" in json_data:
|
|
130
|
+
md5_hash = json_data.get("md5_hash")
|
|
131
|
+
|
|
132
|
+
commit_data, err = await self.api_client.commit_file(
|
|
133
|
+
signature=sha256sum, object_key=object_key, md5_hash=md5_hash, type=type
|
|
134
|
+
)
|
|
135
|
+
# print("commit_data", commit_data)
|
|
136
|
+
if err is not None:
|
|
137
|
+
return ErrResponse(err)
|
|
138
|
+
|
|
139
|
+
return OKResponse(None)
|
|
140
|
+
|
|
141
|
+
@self.prompt_server.routes.post(f"/{COMMUNITY_API}/models")
|
|
142
|
+
async def commit_bizy_model(request):
|
|
143
|
+
sid = request.rel_url.query.get("clientId", "")
|
|
144
|
+
if not is_string_valid(sid):
|
|
145
|
+
return ErrResponse(errnos.INVALID_CLIENT_ID)
|
|
146
|
+
|
|
147
|
+
json_data = await request.json()
|
|
148
|
+
|
|
149
|
+
# 校验name和type
|
|
150
|
+
err = check_str_param(json_data, "name", errnos.INVALID_NAME)
|
|
151
|
+
if err is not None:
|
|
152
|
+
return err
|
|
153
|
+
|
|
154
|
+
if "/" in json_data["name"]:
|
|
155
|
+
return ErrResponse(errnos.INVALID_NAME)
|
|
156
|
+
|
|
157
|
+
err = check_type(json_data)
|
|
158
|
+
if err is not None:
|
|
159
|
+
return err
|
|
160
|
+
|
|
161
|
+
# 校验versions
|
|
162
|
+
if "versions" not in json_data or not isinstance(
|
|
163
|
+
json_data["versions"], list
|
|
164
|
+
):
|
|
165
|
+
return ErrResponse(errnos.INVALID_VERSIONS)
|
|
166
|
+
|
|
167
|
+
versions = json_data["versions"]
|
|
168
|
+
version_names = set()
|
|
169
|
+
|
|
170
|
+
for version in versions:
|
|
171
|
+
# 检查version是否重复
|
|
172
|
+
if version.get("version") in version_names:
|
|
173
|
+
return ErrResponse(errnos.DUPLICATE_VERSION)
|
|
174
|
+
|
|
175
|
+
# 检查version字段是否合法
|
|
176
|
+
if not is_string_valid(version.get("version")) or "/" in version.get(
|
|
177
|
+
"version"
|
|
178
|
+
):
|
|
179
|
+
return ErrResponse(errnos.INVALID_VERSION_NAME)
|
|
180
|
+
|
|
181
|
+
version_names.add(version.get("version"))
|
|
182
|
+
|
|
183
|
+
# 检查base_model, path和sign是否有值
|
|
184
|
+
for field in ["base_model", "path", "sign"]:
|
|
185
|
+
if not is_string_valid(version.get(field)):
|
|
186
|
+
err = errnos.INVALID_VERSION_FIELD.copy()
|
|
187
|
+
err.message = "Invalid version field: " + field
|
|
188
|
+
return ErrResponse(err)
|
|
189
|
+
|
|
190
|
+
# 调用API提交模型
|
|
191
|
+
resp, err = await self.api_client.commit_bizy_model(payload=json_data)
|
|
192
|
+
if err:
|
|
193
|
+
return ErrResponse(err)
|
|
194
|
+
|
|
195
|
+
# print("resp------------------------------->", json_data, resp)
|
|
196
|
+
# 开启线程检查同步状态
|
|
197
|
+
threading.Thread(
|
|
198
|
+
target=self.check_sync_status,
|
|
199
|
+
args=(resp["id"], resp["version_ids"], sid),
|
|
200
|
+
daemon=True,
|
|
201
|
+
).start()
|
|
202
|
+
|
|
203
|
+
# enable refresh for lora
|
|
204
|
+
# TODO: enable refresh for other types
|
|
205
|
+
# bizyengine.core.path_utils.path_manager.enable_refresh_options("loras")
|
|
206
|
+
|
|
207
|
+
return OKResponse(resp)
|
|
208
|
+
|
|
209
|
+
@self.prompt_server.routes.post(f"/{COMMUNITY_API}/models/query")
|
|
210
|
+
async def query_my_models(request):
|
|
211
|
+
# 获取查询参数
|
|
212
|
+
mode = request.rel_url.query.get("mode", "")
|
|
213
|
+
if not mode or mode not in ["my", "my_fork", "publicity", "official"]:
|
|
214
|
+
return ErrResponse(errnos.INVALID_QUERY_MODE)
|
|
215
|
+
|
|
216
|
+
current = int(request.rel_url.query.get("current", "1"))
|
|
217
|
+
page_size = int(request.rel_url.query.get("page_size", "10"))
|
|
218
|
+
json_data = await request.json()
|
|
219
|
+
keyword = json_data["keyword"]
|
|
220
|
+
model_types = json_data.get("model_types", [])
|
|
221
|
+
base_models = json_data.get("base_models", [])
|
|
222
|
+
sort = json_data.get("sort", "")
|
|
223
|
+
resp, err = None, None
|
|
224
|
+
|
|
225
|
+
if mode in ["my", "my_fork"]:
|
|
226
|
+
# 调用API查询模型
|
|
227
|
+
resp, err = await self.api_client.query_models(
|
|
228
|
+
mode,
|
|
229
|
+
current,
|
|
230
|
+
page_size,
|
|
231
|
+
keyword=keyword,
|
|
232
|
+
model_types=model_types,
|
|
233
|
+
base_models=base_models,
|
|
234
|
+
sort=sort,
|
|
235
|
+
)
|
|
236
|
+
elif mode == "publicity":
|
|
237
|
+
# 调用API查询社区模型
|
|
238
|
+
resp, err = await self.api_client.query_community_models(
|
|
239
|
+
current,
|
|
240
|
+
page_size,
|
|
241
|
+
keyword=keyword,
|
|
242
|
+
model_types=model_types,
|
|
243
|
+
base_models=base_models,
|
|
244
|
+
sort=sort,
|
|
245
|
+
)
|
|
246
|
+
elif mode == "official":
|
|
247
|
+
resp, err = await self.api_client.query_official_models(
|
|
248
|
+
current,
|
|
249
|
+
page_size,
|
|
250
|
+
keyword=keyword,
|
|
251
|
+
model_types=model_types,
|
|
252
|
+
base_models=base_models,
|
|
253
|
+
sort=sort,
|
|
254
|
+
)
|
|
255
|
+
if err:
|
|
256
|
+
return ErrResponse(err)
|
|
257
|
+
|
|
258
|
+
return OKResponse(resp)
|
|
259
|
+
|
|
260
|
+
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/models/{{model_id}}/detail")
|
|
261
|
+
async def get_model_detail(request):
|
|
262
|
+
# 获取路径参数中的模型ID
|
|
263
|
+
model_id = int(request.match_info["model_id"])
|
|
264
|
+
|
|
265
|
+
# 检查model_id是否合法
|
|
266
|
+
if not model_id or model_id <= 0:
|
|
267
|
+
return ErrResponse(errnos.INVALID_MODEL_ID)
|
|
268
|
+
|
|
269
|
+
source = request.rel_url.query.get("source", "")
|
|
270
|
+
|
|
271
|
+
# 调用API获取模型详情
|
|
272
|
+
resp, err = await self.api_client.get_model_detail(model_id, source)
|
|
273
|
+
if err:
|
|
274
|
+
return ErrResponse(err)
|
|
275
|
+
|
|
276
|
+
return OKResponse(resp)
|
|
277
|
+
|
|
278
|
+
@self.prompt_server.routes.delete(f"/{COMMUNITY_API}/models/{{model_id}}")
|
|
279
|
+
async def delete_model(request):
|
|
280
|
+
# 获取路径参数中的模型ID
|
|
281
|
+
model_id = int(request.match_info["model_id"])
|
|
282
|
+
|
|
283
|
+
# 检查model_id是否合法
|
|
284
|
+
if not model_id or model_id <= 0:
|
|
285
|
+
return ErrResponse(errnos.INVALID_MODEL_ID)
|
|
286
|
+
|
|
287
|
+
# 调用API删除模型
|
|
288
|
+
resp, err = await self.api_client.delete_bizy_model(model_id)
|
|
289
|
+
if err:
|
|
290
|
+
return ErrResponse(err)
|
|
291
|
+
|
|
292
|
+
return OKResponse(resp)
|
|
293
|
+
|
|
294
|
+
@self.prompt_server.routes.put(f"/{COMMUNITY_API}/models/{{model_id}}")
|
|
295
|
+
async def update_model(request):
|
|
296
|
+
sid = request.rel_url.query.get("clientId", "")
|
|
297
|
+
if not is_string_valid(sid):
|
|
298
|
+
return ErrResponse(errnos.INVALID_CLIENT_ID)
|
|
299
|
+
# 获取路径参数中的模型ID
|
|
300
|
+
model_id = int(request.match_info["model_id"])
|
|
301
|
+
|
|
302
|
+
# 检查model_id是否合法
|
|
303
|
+
if not model_id or model_id <= 0:
|
|
304
|
+
return ErrResponse(errnos.INVALID_MODEL_ID)
|
|
305
|
+
|
|
306
|
+
# 获取请求体数据
|
|
307
|
+
json_data = await request.json()
|
|
308
|
+
|
|
309
|
+
# 校验name和type
|
|
310
|
+
err = check_str_param(json_data, "name", errnos.INVALID_NAME)
|
|
311
|
+
if err is not None:
|
|
312
|
+
return err
|
|
313
|
+
|
|
314
|
+
if "/" in json_data["name"]:
|
|
315
|
+
return ErrResponse(errnos.INVALID_NAME)
|
|
316
|
+
|
|
317
|
+
err = check_type(json_data)
|
|
318
|
+
if err is not None:
|
|
319
|
+
return err
|
|
320
|
+
|
|
321
|
+
# 校验versions
|
|
322
|
+
if "versions" not in json_data or not isinstance(
|
|
323
|
+
json_data["versions"], list
|
|
324
|
+
):
|
|
325
|
+
return ErrResponse(errnos.INVALID_VERSIONS)
|
|
326
|
+
|
|
327
|
+
versions = json_data["versions"]
|
|
328
|
+
version_names = set()
|
|
329
|
+
|
|
330
|
+
for version in versions:
|
|
331
|
+
# 检查version是否重复
|
|
332
|
+
if version.get("version") in version_names:
|
|
333
|
+
return ErrResponse(errnos.DUPLICATE_VERSION)
|
|
334
|
+
|
|
335
|
+
# 检查version字段是否合法
|
|
336
|
+
if not is_string_valid(version.get("version")) or "/" in version.get(
|
|
337
|
+
"version"
|
|
338
|
+
):
|
|
339
|
+
return ErrResponse(errnos.INVALID_VERSION_NAME)
|
|
340
|
+
|
|
341
|
+
version_names.add(version.get("version"))
|
|
342
|
+
|
|
343
|
+
# 检查base_model, path和sign是否有值
|
|
344
|
+
for field in ["base_model", "path", "sign"]:
|
|
345
|
+
if not is_string_valid(version.get(field)):
|
|
346
|
+
return ErrResponse(errnos.INVALID_VERSION_FIELD(field))
|
|
347
|
+
|
|
348
|
+
# 调用API更新模型
|
|
349
|
+
resp, err = await self.api_client.update_model(
|
|
350
|
+
model_id, json_data["name"], json_data["type"], versions
|
|
351
|
+
)
|
|
352
|
+
if err:
|
|
353
|
+
return ErrResponse(err)
|
|
354
|
+
|
|
355
|
+
# 开启线程检查同步状态
|
|
356
|
+
threading.Thread(
|
|
357
|
+
target=self.check_sync_status,
|
|
358
|
+
args=(resp["id"], resp["version_ids"]),
|
|
359
|
+
daemon=True,
|
|
360
|
+
).start()
|
|
361
|
+
|
|
362
|
+
return OKResponse(None)
|
|
363
|
+
|
|
364
|
+
@self.prompt_server.routes.post(
|
|
365
|
+
f"/{COMMUNITY_API}/models/fork/{{model_version_id}}"
|
|
366
|
+
)
|
|
367
|
+
async def fork_model_version(request):
|
|
368
|
+
try:
|
|
369
|
+
# 获取version_id参数
|
|
370
|
+
version_id = request.match_info["model_version_id"]
|
|
371
|
+
if not version_id:
|
|
372
|
+
return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
|
|
373
|
+
|
|
374
|
+
# 调用API fork模型版本
|
|
375
|
+
_, err = await self.api_client.fork_model_version(version_id)
|
|
376
|
+
if err:
|
|
377
|
+
return ErrResponse(err)
|
|
378
|
+
|
|
379
|
+
return OKResponse(None)
|
|
380
|
+
|
|
381
|
+
except Exception as e:
|
|
382
|
+
print(f"\033[31m[BizyAir]\033[0m Fail to fork model version: {str(e)}")
|
|
383
|
+
return ErrResponse(errnos.FORK_MODEL_VERSION)
|
|
384
|
+
|
|
385
|
+
@self.prompt_server.routes.delete(
|
|
386
|
+
f"/{COMMUNITY_API}/models/fork/{{model_version_id}}"
|
|
387
|
+
)
|
|
388
|
+
async def unfork_model_version(request):
|
|
389
|
+
try:
|
|
390
|
+
# 获取version_id参数
|
|
391
|
+
version_id = request.match_info["model_version_id"]
|
|
392
|
+
if not version_id:
|
|
393
|
+
return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
|
|
394
|
+
|
|
395
|
+
# 调用API fork模型版本
|
|
396
|
+
_, err = await self.api_client.unfork_model_version(version_id)
|
|
397
|
+
if err:
|
|
398
|
+
return ErrResponse(err)
|
|
399
|
+
|
|
400
|
+
return OKResponse(None)
|
|
401
|
+
|
|
402
|
+
except Exception as e:
|
|
403
|
+
print(
|
|
404
|
+
f"\033[31m[BizyAir]\033[0m Fail to unfork model version: {str(e)}"
|
|
405
|
+
)
|
|
406
|
+
return ErrResponse(errnos.FORK_MODEL_VERSION)
|
|
407
|
+
|
|
408
|
+
@self.prompt_server.routes.post(
|
|
409
|
+
f"/{COMMUNITY_API}/models/like/{{model_version_id}}"
|
|
410
|
+
)
|
|
411
|
+
async def like_model_version(request):
|
|
412
|
+
try:
|
|
413
|
+
# 获取version_id参数
|
|
414
|
+
version_id = request.match_info["model_version_id"]
|
|
415
|
+
if not version_id:
|
|
416
|
+
return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
|
|
417
|
+
|
|
418
|
+
# 调用API like模型版本
|
|
419
|
+
_, err = await self.api_client.toggle_user_like(
|
|
420
|
+
"model_version", version_id
|
|
421
|
+
)
|
|
422
|
+
if err:
|
|
423
|
+
return ErrResponse(err)
|
|
424
|
+
|
|
425
|
+
return OKResponse(None)
|
|
426
|
+
|
|
427
|
+
except Exception as e:
|
|
428
|
+
print(
|
|
429
|
+
f"\033[31m[BizyAir]\033[0m Fail to toggle like model version: {str(e)}"
|
|
430
|
+
)
|
|
431
|
+
return ErrResponse(errnos.TOGGLE_USER_LIKE)
|
|
432
|
+
|
|
433
|
+
@self.prompt_server.routes.get(
|
|
434
|
+
f"/{COMMUNITY_API}/models/versions/{{model_version_id}}/workflow_json/{{sign}}"
|
|
435
|
+
)
|
|
436
|
+
async def get_workflow_json(request):
|
|
437
|
+
model_version_id = int(request.match_info["model_version_id"])
|
|
438
|
+
# 检查model_version_id是否合法
|
|
439
|
+
if not model_version_id or model_version_id <= 0:
|
|
440
|
+
return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
|
|
441
|
+
|
|
442
|
+
sign = str(request.match_info["sign"])
|
|
443
|
+
if not sign:
|
|
444
|
+
return ErrResponse(errnos.INVALID_SIGN)
|
|
445
|
+
|
|
446
|
+
# 获取上传凭证
|
|
447
|
+
url, err = await self.api_client.get_download_url(
|
|
448
|
+
sign=sign, model_version_id=model_version_id
|
|
449
|
+
)
|
|
450
|
+
if err:
|
|
451
|
+
return ErrResponse(err)
|
|
452
|
+
|
|
453
|
+
# 请求该url,获取文件内容
|
|
454
|
+
async with aiohttp.ClientSession() as session:
|
|
455
|
+
async with session.get(url) as response:
|
|
456
|
+
if response.status != 200:
|
|
457
|
+
return ErrResponse(errnos.FAILED_TO_FETCH_WORKFLOW_JSON)
|
|
458
|
+
json_content = await response.json()
|
|
459
|
+
|
|
460
|
+
return OKResponse(json_content)
|
|
461
|
+
|
|
462
|
+
@self.prompt_server.routes.get(f"/{MODEL_HOST_API}" + "/{shareId}/models/files")
|
|
463
|
+
async def list_share_model_files(request):
|
|
464
|
+
shareId = request.match_info["shareId"]
|
|
465
|
+
if not is_string_valid(shareId):
|
|
466
|
+
return ErrResponse("INVALID_SHARE_ID")
|
|
467
|
+
payload = {}
|
|
468
|
+
query_params = ["type", "name", "ext_name"]
|
|
469
|
+
for param in query_params:
|
|
470
|
+
if param in request.rel_url.query and request.rel_url.query[param]:
|
|
471
|
+
payload[param] = request.rel_url.query[param]
|
|
472
|
+
model_files, err = await self.api_client.get_share_model_files(
|
|
473
|
+
shareId=shareId, payload=payload
|
|
474
|
+
)
|
|
475
|
+
if err is not None:
|
|
476
|
+
return ErrResponse(err)
|
|
477
|
+
return OKResponse(model_files)
|
|
478
|
+
|
|
479
|
+
@self.prompt_server.routes.get(f"/{API_PREFIX}/dict")
|
|
480
|
+
async def get_data_dict(request):
|
|
481
|
+
data_dict, err = await self.api_client.get_data_dict()
|
|
482
|
+
if err is not None:
|
|
483
|
+
return ErrResponse(err)
|
|
484
|
+
|
|
485
|
+
return OKResponse(data_dict)
|
|
486
|
+
|
|
487
|
+
@self.prompt_server.routes.post(f"/{COMMUNITY_API}/datasets")
|
|
488
|
+
async def commit_dataset(request):
|
|
489
|
+
sid = request.rel_url.query.get("clientId", "")
|
|
490
|
+
if not is_string_valid(sid):
|
|
491
|
+
return ErrResponse(errnos.INVALID_CLIENT_ID)
|
|
492
|
+
|
|
493
|
+
json_data = await request.json()
|
|
494
|
+
|
|
495
|
+
# 校验name和type
|
|
496
|
+
err = check_str_param(json_data, "name", errnos.INVALID_DATASET_NAME)
|
|
497
|
+
if err is not None:
|
|
498
|
+
return err
|
|
499
|
+
|
|
500
|
+
if "/" in json_data["name"]:
|
|
501
|
+
return ErrResponse(errnos.INVALID_DATASET_NAME)
|
|
502
|
+
|
|
503
|
+
# 校验versions
|
|
504
|
+
if "versions" not in json_data or not isinstance(
|
|
505
|
+
json_data["versions"], list
|
|
506
|
+
):
|
|
507
|
+
return ErrResponse(errnos.INVALID_VERSIONS)
|
|
508
|
+
|
|
509
|
+
versions = json_data["versions"]
|
|
510
|
+
version_names = set()
|
|
511
|
+
|
|
512
|
+
for version in versions:
|
|
513
|
+
# 检查version是否重复
|
|
514
|
+
if version.get("version") in version_names:
|
|
515
|
+
return ErrResponse(errnos.DUPLICATE_VERSION)
|
|
516
|
+
|
|
517
|
+
# 检查version字段是否合法
|
|
518
|
+
if not is_string_valid(version.get("version")) or "/" in version.get(
|
|
519
|
+
"version"
|
|
520
|
+
):
|
|
521
|
+
return ErrResponse(errnos.INVALID_DATASET_VERSION)
|
|
522
|
+
|
|
523
|
+
version_names.add(version.get("version"))
|
|
524
|
+
|
|
525
|
+
# 调用API提交数据集
|
|
526
|
+
resp, err = await self.api_client.commit_dataset(payload=json_data)
|
|
527
|
+
if err:
|
|
528
|
+
return ErrResponse(err)
|
|
529
|
+
|
|
530
|
+
# print("resp------------------------------->", json_data, resp)
|
|
531
|
+
# 开启线程检查同步状态
|
|
532
|
+
threading.Thread(
|
|
533
|
+
target=self.check_dataset_sync_status,
|
|
534
|
+
args=(resp["id"], resp["version_ids"], sid),
|
|
535
|
+
daemon=True,
|
|
536
|
+
).start()
|
|
537
|
+
|
|
538
|
+
# enable refresh for lora
|
|
539
|
+
# TODO: enable refresh for other types
|
|
540
|
+
# bizyengine.core.path_utils.path_manager.enable_refresh_options("loras")
|
|
541
|
+
|
|
542
|
+
return OKResponse(resp)
|
|
543
|
+
|
|
544
|
+
@self.prompt_server.routes.put(f"/{COMMUNITY_API}/datasets/{{dataset_id}}")
|
|
545
|
+
async def update_dataset(request):
|
|
546
|
+
sid = request.rel_url.query.get("clientId", "")
|
|
547
|
+
if not is_string_valid(sid):
|
|
548
|
+
return ErrResponse(errnos.INVALID_CLIENT_ID)
|
|
549
|
+
# 获取路径参数中的数据集ID
|
|
550
|
+
dataset_id = int(request.match_info["dataset_id"])
|
|
551
|
+
|
|
552
|
+
# 检查model_id是否合法
|
|
553
|
+
if not dataset_id or dataset_id <= 0:
|
|
554
|
+
return ErrResponse(errnos.INVALID_DATASET_ID)
|
|
555
|
+
|
|
556
|
+
# 获取请求体数据
|
|
557
|
+
json_data = await request.json()
|
|
558
|
+
|
|
559
|
+
# 校验name和type
|
|
560
|
+
err = check_str_param(json_data, "name", errnos.INVALID_DATASET_NAME)
|
|
561
|
+
if err is not None:
|
|
562
|
+
return err
|
|
563
|
+
|
|
564
|
+
if "/" in json_data["name"]:
|
|
565
|
+
return ErrResponse(errnos.INVALID_DATASET_NAME)
|
|
566
|
+
|
|
567
|
+
# 校验versions
|
|
568
|
+
if "versions" not in json_data or not isinstance(
|
|
569
|
+
json_data["versions"], list
|
|
570
|
+
):
|
|
571
|
+
return ErrResponse(errnos.INVALID_VERSIONS)
|
|
572
|
+
|
|
573
|
+
versions = json_data["versions"]
|
|
574
|
+
version_names = set()
|
|
575
|
+
|
|
576
|
+
for version in versions:
|
|
577
|
+
# 检查version是否重复
|
|
578
|
+
if version.get("version") in version_names:
|
|
579
|
+
return ErrResponse(errnos.DUPLICATE_VERSION)
|
|
580
|
+
|
|
581
|
+
# 检查version字段是否合法
|
|
582
|
+
if not is_string_valid(version.get("version")) or "/" in version.get(
|
|
583
|
+
"version"
|
|
584
|
+
):
|
|
585
|
+
return ErrResponse(errnos.INVALID_VERSION_NAME)
|
|
586
|
+
|
|
587
|
+
version_names.add(version.get("version"))
|
|
588
|
+
|
|
589
|
+
# 调用API更新数据集
|
|
590
|
+
resp, err = await self.api_client.update_dataset(
|
|
591
|
+
dataset_id, json_data["name"], versions
|
|
592
|
+
)
|
|
593
|
+
if err:
|
|
594
|
+
return ErrResponse(err)
|
|
595
|
+
|
|
596
|
+
# 开启线程检查同步状态
|
|
597
|
+
threading.Thread(
|
|
598
|
+
target=self.check_dataset_sync_status,
|
|
599
|
+
args=(resp["id"], resp["version_ids"]),
|
|
600
|
+
daemon=True,
|
|
601
|
+
).start()
|
|
602
|
+
|
|
603
|
+
return OKResponse(None)
|
|
604
|
+
|
|
605
|
+
@self.prompt_server.routes.delete(f"/{COMMUNITY_API}/datasets/{{dataset_id}}")
|
|
606
|
+
async def delete_dataset(request):
|
|
607
|
+
# 获取路径参数中的数据集ID
|
|
608
|
+
dataset_id = int(request.match_info["dataset_id"])
|
|
609
|
+
|
|
610
|
+
# 检查model_id是否合法
|
|
611
|
+
if not dataset_id or dataset_id <= 0:
|
|
612
|
+
return ErrResponse(errnos.INVALID_DATASET_ID)
|
|
613
|
+
|
|
614
|
+
# 调用API删除数据集
|
|
615
|
+
resp, err = await self.api_client.delete_dataset(dataset_id)
|
|
616
|
+
if err:
|
|
617
|
+
return ErrResponse(err)
|
|
618
|
+
|
|
619
|
+
return OKResponse(resp)
|
|
620
|
+
|
|
621
|
+
@self.prompt_server.routes.post(f"/{COMMUNITY_API}/datasets/query")
|
|
622
|
+
async def query_my_datasets(request):
|
|
623
|
+
current = int(request.rel_url.query.get("current", "1"))
|
|
624
|
+
page_size = int(request.rel_url.query.get("page_size", "10"))
|
|
625
|
+
keyword = None
|
|
626
|
+
annotated = None
|
|
627
|
+
if request.body_exists:
|
|
628
|
+
json_data = await request.json()
|
|
629
|
+
keyword = json_data["keyword"]
|
|
630
|
+
annotated = json_data["annotated"]
|
|
631
|
+
resp, err = None, None
|
|
632
|
+
|
|
633
|
+
# 调用API查询数据集
|
|
634
|
+
resp, err = await self.api_client.query_datasets(
|
|
635
|
+
current,
|
|
636
|
+
page_size,
|
|
637
|
+
keyword=keyword,
|
|
638
|
+
annotated=annotated,
|
|
639
|
+
)
|
|
640
|
+
if err:
|
|
641
|
+
return ErrResponse(err)
|
|
642
|
+
|
|
643
|
+
return OKResponse(resp)
|
|
644
|
+
|
|
645
|
+
@self.prompt_server.routes.get(
|
|
646
|
+
f"/{COMMUNITY_API}/datasets/{{dataset_id}}/detail"
|
|
647
|
+
)
|
|
648
|
+
async def get_dataset_detail(request):
|
|
649
|
+
# 获取路径参数中的数据集ID
|
|
650
|
+
dataset_id = int(request.match_info["dataset_id"])
|
|
651
|
+
|
|
652
|
+
# 检查dataset_id是否合法
|
|
653
|
+
if not dataset_id or dataset_id <= 0:
|
|
654
|
+
return ErrResponse(errnos.INVALID_DATASET_ID)
|
|
655
|
+
|
|
656
|
+
# 调用API获取数据集详情
|
|
657
|
+
resp, err = await self.api_client.get_dataset_detail(dataset_id)
|
|
658
|
+
if err:
|
|
659
|
+
return ErrResponse(err)
|
|
660
|
+
|
|
661
|
+
return OKResponse(resp)
|
|
662
|
+
|
|
663
|
+
@self.prompt_server.routes.post(f"/{COMMUNITY_API}/share")
|
|
664
|
+
async def create_share(request):
|
|
665
|
+
json_data = await request.json()
|
|
666
|
+
if "biz_id" not in json_data:
|
|
667
|
+
return ErrResponse(errnos.INVALID_SHARE_BIZ_ID)
|
|
668
|
+
|
|
669
|
+
biz_id = int(json_data["biz_id"])
|
|
670
|
+
if not biz_id or biz_id <= 0:
|
|
671
|
+
return ErrResponse(errnos.INVALID_SHARE_BIZ_ID)
|
|
672
|
+
|
|
673
|
+
if "type" not in json_data:
|
|
674
|
+
return ErrResponse(errnos.INVALID_SHARE_TYPE)
|
|
675
|
+
if not is_string_valid(json_data["type"]) or (
|
|
676
|
+
json_data["type"] != "bizy_model_version"
|
|
677
|
+
):
|
|
678
|
+
return ErrResponse(errnos.INVALID_SHARE_TYPE)
|
|
679
|
+
|
|
680
|
+
# 调用API提交数据集
|
|
681
|
+
resp, err = await self.api_client.create_share(payload=json_data)
|
|
682
|
+
if err:
|
|
683
|
+
return ErrResponse(err)
|
|
684
|
+
|
|
685
|
+
return OKResponse(resp)
|
|
686
|
+
|
|
687
|
+
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/share/{{code}}")
|
|
688
|
+
async def get_share_detail(request):
|
|
689
|
+
# 获取路径参数中的数据集ID
|
|
690
|
+
code = str(request.match_info["code"])
|
|
691
|
+
|
|
692
|
+
# 检查code是否合法
|
|
693
|
+
if not is_string_valid(code):
|
|
694
|
+
return ErrResponse(errnos.INVALID_SHARE_CODE)
|
|
695
|
+
|
|
696
|
+
# 调用API获取数据集详情
|
|
697
|
+
resp, err = await self.api_client.get_share_detail(code)
|
|
698
|
+
if err:
|
|
699
|
+
return ErrResponse(err)
|
|
700
|
+
|
|
701
|
+
return OKResponse(resp)
|
|
702
|
+
|
|
703
|
+
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/model_version/{{version_id}}")
|
|
704
|
+
async def get_model_version_detail(request):
|
|
705
|
+
# 获取路径参数中的数据集ID
|
|
706
|
+
version_id = int(request.match_info["version_id"])
|
|
707
|
+
|
|
708
|
+
# 检查version_id是否合法
|
|
709
|
+
if not version_id or version_id <= 0:
|
|
710
|
+
return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
|
|
711
|
+
|
|
712
|
+
# 调用API获取数据集详情
|
|
713
|
+
resp, err = await self.api_client.get_model_version_detail(version_id)
|
|
714
|
+
if err:
|
|
715
|
+
return ErrResponse(err)
|
|
716
|
+
|
|
717
|
+
return OKResponse(resp)
|
|
718
|
+
|
|
719
|
+
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/notifications/unread_count")
|
|
720
|
+
async def get_notification_unread_count(request):
|
|
721
|
+
# 获取当前用户的未读消息数量
|
|
722
|
+
resp, err = await self.api_client.get_notification_unread_count()
|
|
723
|
+
if err:
|
|
724
|
+
return ErrResponse(err)
|
|
725
|
+
return OKResponse(resp)
|
|
726
|
+
|
|
727
|
+
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/notifications")
|
|
728
|
+
async def fetch_notifications(request):
|
|
729
|
+
# 获取当前用户的消息列表
|
|
730
|
+
typesStr = request.rel_url.query.get("types", None)
|
|
731
|
+
types = None
|
|
732
|
+
if typesStr:
|
|
733
|
+
types = [int(x) for x in typesStr.split(",")]
|
|
734
|
+
read_status = request.rel_url.query.get("read_status", None)
|
|
735
|
+
page_size = int(request.rel_url.query.get("page_size", "10"))
|
|
736
|
+
last_pm_id = int(request.rel_url.query.get("last_pm_id", "0"))
|
|
737
|
+
last_broadcast_id = int(request.rel_url.query.get("last_broadcast_id", "0"))
|
|
738
|
+
|
|
739
|
+
resp, err = await self.api_client.fetch_notifications(
|
|
740
|
+
page_size, last_pm_id, last_broadcast_id, types, read_status
|
|
741
|
+
)
|
|
742
|
+
if err:
|
|
743
|
+
return ErrResponse(err)
|
|
744
|
+
return OKResponse(resp)
|
|
745
|
+
|
|
746
|
+
@self.prompt_server.routes.post(f"/{COMMUNITY_API}/notifications/read_all")
|
|
747
|
+
async def read_all_notifications(request):
|
|
748
|
+
# 将当前用户的所有未读消息标记为已读
|
|
749
|
+
type = int(request.rel_url.query.get("type", "0"))
|
|
750
|
+
resp, err = await self.api_client.read_all_notifications(type)
|
|
751
|
+
if err:
|
|
752
|
+
return ErrResponse(err)
|
|
753
|
+
return OKResponse(resp)
|
|
754
|
+
|
|
755
|
+
@self.prompt_server.routes.post(f"/{COMMUNITY_API}/notifications/read")
|
|
756
|
+
async def read_notifications(request):
|
|
757
|
+
# 将当前用户的未读消息标记为已读
|
|
758
|
+
json_data = await request.json()
|
|
759
|
+
if "ids" not in json_data:
|
|
760
|
+
return ErrResponse(errnos.INVALID_NOTIF_ID)
|
|
761
|
+
ids = json_data.get("ids")
|
|
762
|
+
resp, err = await self.api_client.read_notifications(ids)
|
|
763
|
+
if err:
|
|
764
|
+
return ErrResponse(err)
|
|
765
|
+
return OKResponse(resp)
|
|
766
|
+
|
|
767
|
+
async def send_json(self, event, data, sid=None):
|
|
768
|
+
message = {"type": event, "data": data}
|
|
769
|
+
|
|
770
|
+
if sid is None:
|
|
771
|
+
sockets = list(self.sockets.values())
|
|
772
|
+
for ws in sockets:
|
|
773
|
+
await self.send_socket_catch_exception(ws.send_json, message)
|
|
774
|
+
elif sid in self.sockets:
|
|
775
|
+
await self.send_socket_catch_exception(self.sockets[sid].send_json, message)
|
|
776
|
+
|
|
777
|
+
async def send_error(self, err: ErrorNo, sid=None):
|
|
778
|
+
await self.send_json(
|
|
779
|
+
event="error",
|
|
780
|
+
data={"message": err.message, "code": err.code, "data": err.data},
|
|
781
|
+
sid=sid,
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
async def send_socket_catch_exception(self, function, message):
|
|
785
|
+
try:
|
|
786
|
+
await function(message)
|
|
787
|
+
except (
|
|
788
|
+
aiohttp.ClientError,
|
|
789
|
+
aiohttp.ClientPayloadError,
|
|
790
|
+
ConnectionResetError,
|
|
791
|
+
) as err:
|
|
792
|
+
logging.warning("send error: {}".format(err))
|
|
793
|
+
|
|
794
|
+
def send_sync(self, event, data, sid=None):
|
|
795
|
+
asyncio.run_coroutine_threadsafe(self.send_json(event, data, sid), self.loop)
|
|
796
|
+
|
|
797
|
+
def send_sync_error(self, err: ErrorNo, sid=None):
|
|
798
|
+
self.send_sync(
|
|
799
|
+
event="error",
|
|
800
|
+
data={"message": err.message, "code": err.code, "data": err.data},
|
|
801
|
+
sid=sid,
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
def check_sync_status(self, bizy_model_id: str, version_ids: list, sid=None):
|
|
805
|
+
removed = []
|
|
806
|
+
while True:
|
|
807
|
+
# 从version_ids中移除removed中的version_id
|
|
808
|
+
version_ids = [
|
|
809
|
+
version_id for version_id in version_ids if version_id not in removed
|
|
810
|
+
]
|
|
811
|
+
if len(version_ids) == 0:
|
|
812
|
+
return
|
|
813
|
+
|
|
814
|
+
for version_id in version_ids:
|
|
815
|
+
future = asyncio.run_coroutine_threadsafe(
|
|
816
|
+
self.api_client.get_model_version_detail(version_id=version_id),
|
|
817
|
+
self.loop,
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
model_version, err = future.result(timeout=2)
|
|
821
|
+
|
|
822
|
+
if err is not None:
|
|
823
|
+
self.send_sync(
|
|
824
|
+
event="error",
|
|
825
|
+
data={
|
|
826
|
+
"message": err.message,
|
|
827
|
+
"code": err.code,
|
|
828
|
+
"data": {
|
|
829
|
+
"bizy_model_id": bizy_model_id,
|
|
830
|
+
"version_id": version_id,
|
|
831
|
+
},
|
|
832
|
+
},
|
|
833
|
+
sid=sid,
|
|
834
|
+
)
|
|
835
|
+
removed.append(version_id)
|
|
836
|
+
continue
|
|
837
|
+
|
|
838
|
+
if "available" in model_version and model_version["available"]:
|
|
839
|
+
self.send_sync(
|
|
840
|
+
event="synced",
|
|
841
|
+
data={
|
|
842
|
+
"version_id": model_version["id"],
|
|
843
|
+
"version": model_version["version"],
|
|
844
|
+
"model_id": bizy_model_id,
|
|
845
|
+
"model_name": model_version["bizy_model_name"],
|
|
846
|
+
},
|
|
847
|
+
sid=sid,
|
|
848
|
+
)
|
|
849
|
+
removed.append(version_id)
|
|
850
|
+
time.sleep(5)
|
|
851
|
+
|
|
852
|
+
def check_dataset_sync_status(self, dataset_id: str, version_ids: list, sid=None):
|
|
853
|
+
removed = []
|
|
854
|
+
while True:
|
|
855
|
+
# 从version_ids中移除removed中的version_id
|
|
856
|
+
version_ids = [
|
|
857
|
+
version_id for version_id in version_ids if version_id not in removed
|
|
858
|
+
]
|
|
859
|
+
if len(version_ids) == 0:
|
|
860
|
+
return
|
|
861
|
+
|
|
862
|
+
for version_id in version_ids:
|
|
863
|
+
future = asyncio.run_coroutine_threadsafe(
|
|
864
|
+
self.api_client.get_dataset_version_detail(version_id=version_id),
|
|
865
|
+
self.loop,
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
dataset_version, err = future.result(timeout=2)
|
|
869
|
+
|
|
870
|
+
if err is not None:
|
|
871
|
+
self.send_sync(
|
|
872
|
+
event="error",
|
|
873
|
+
data={
|
|
874
|
+
"message": err.message,
|
|
875
|
+
"code": err.code,
|
|
876
|
+
"data": {
|
|
877
|
+
"dataset_id": dataset_id,
|
|
878
|
+
"version_id": version_id,
|
|
879
|
+
},
|
|
880
|
+
},
|
|
881
|
+
sid=sid,
|
|
882
|
+
)
|
|
883
|
+
removed.append(version_id)
|
|
884
|
+
continue
|
|
885
|
+
|
|
886
|
+
if "available" in dataset_version and dataset_version["available"]:
|
|
887
|
+
self.send_sync(
|
|
888
|
+
event="synced",
|
|
889
|
+
data={
|
|
890
|
+
"version_id": dataset_version["id"],
|
|
891
|
+
"version": dataset_version["version"],
|
|
892
|
+
"dataset_id": dataset_id,
|
|
893
|
+
"dataset_name": dataset_version["dataset_name"],
|
|
894
|
+
},
|
|
895
|
+
sid=sid,
|
|
896
|
+
)
|
|
897
|
+
removed.append(version_id)
|
|
898
|
+
time.sleep(5)
|