bizyengine 1.2.7__py3-none-any.whl → 1.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bizyengine/bizy_server/api_client.py +125 -57
- bizyengine/bizy_server/errno.py +9 -0
- bizyengine/bizy_server/server.py +353 -239
- bizyengine/bizyair_extras/nodes_flux.py +1 -1
- bizyengine/bizyair_extras/nodes_image_utils.py +2 -2
- bizyengine/bizyair_extras/nodes_nunchaku.py +1 -5
- bizyengine/bizyair_extras/nodes_segment_anything.py +1 -0
- bizyengine/bizyair_extras/nodes_trellis.py +1 -1
- bizyengine/bizyair_extras/nodes_ultimatesdupscale.py +1 -1
- bizyengine/bizyair_extras/nodes_wan_i2v.py +1 -1
- bizyengine/core/__init__.py +2 -0
- bizyengine/core/commands/processors/prompt_processor.py +21 -18
- bizyengine/core/commands/servers/prompt_server.py +28 -13
- bizyengine/core/common/client.py +14 -2
- bizyengine/core/common/env_var.py +2 -0
- bizyengine/core/nodes_base.py +85 -7
- bizyengine/core/nodes_io.py +2 -2
- bizyengine/misc/llm.py +48 -85
- bizyengine/misc/mzkolors.py +27 -19
- bizyengine/misc/nodes.py +41 -21
- bizyengine/misc/nodes_controlnet_aux.py +18 -18
- bizyengine/misc/nodes_controlnet_union_sdxl.py +5 -12
- bizyengine/misc/segment_anything.py +29 -25
- bizyengine/misc/supernode.py +36 -30
- bizyengine/misc/utils.py +33 -21
- bizyengine/version.txt +1 -1
- bizyengine-1.2.9.dist-info/METADATA +211 -0
- {bizyengine-1.2.7.dist-info → bizyengine-1.2.9.dist-info}/RECORD +30 -30
- bizyengine-1.2.7.dist-info/METADATA +0 -19
- {bizyengine-1.2.7.dist-info → bizyengine-1.2.9.dist-info}/WHEEL +0 -0
- {bizyengine-1.2.7.dist-info → bizyengine-1.2.9.dist-info}/top_level.txt +0 -0
bizyengine/bizy_server/server.py
CHANGED
|
@@ -10,7 +10,9 @@ import urllib.parse
|
|
|
10
10
|
import uuid
|
|
11
11
|
|
|
12
12
|
import aiohttp
|
|
13
|
+
import execution
|
|
13
14
|
import openai
|
|
15
|
+
from bizyengine.core.common.env_var import BIZYAIR_SERVER_MODE
|
|
14
16
|
from server import PromptServer
|
|
15
17
|
|
|
16
18
|
from .api_client import APIClient
|
|
@@ -30,6 +32,12 @@ MODEL_API = f"{API_PREFIX}/model"
|
|
|
30
32
|
logging.basicConfig(level=logging.DEBUG)
|
|
31
33
|
|
|
32
34
|
|
|
35
|
+
def _get_request_api_key(request_headers):
|
|
36
|
+
if BIZYAIR_SERVER_MODE:
|
|
37
|
+
return request_headers.get("api_key")
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
|
|
33
41
|
class BizyAirServer:
|
|
34
42
|
def __init__(self):
|
|
35
43
|
BizyAirServer.instance = self
|
|
@@ -42,6 +50,8 @@ class BizyAirServer:
|
|
|
42
50
|
self.setup_routes()
|
|
43
51
|
|
|
44
52
|
def setup_routes(self):
|
|
53
|
+
# 以下路径不论本地模式还是服务器模式都要注册
|
|
54
|
+
|
|
45
55
|
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/model_types")
|
|
46
56
|
async def list_model_types(request):
|
|
47
57
|
return OKResponse(types())
|
|
@@ -50,167 +60,51 @@ class BizyAirServer:
|
|
|
50
60
|
async def list_base_model_types(request):
|
|
51
61
|
return OKResponse(base_model_types())
|
|
52
62
|
|
|
53
|
-
@self.prompt_server.routes.
|
|
54
|
-
async def
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
await ws.prepare(request)
|
|
65
|
-
sid = request.rel_url.query.get("clientId", "")
|
|
66
|
-
if sid:
|
|
67
|
-
# Reusing existing session, remove old
|
|
68
|
-
self.sockets.pop(sid, None)
|
|
69
|
-
else:
|
|
70
|
-
sid = uuid.uuid4().hex
|
|
71
|
-
|
|
72
|
-
self.sockets[sid] = ws
|
|
73
|
-
|
|
74
|
-
try:
|
|
75
|
-
# Send initial state to the new client
|
|
76
|
-
await self.send_json(
|
|
77
|
-
event="status", data={"status": "connected"}, sid=sid
|
|
78
|
-
)
|
|
79
|
-
|
|
80
|
-
async for msg in ws:
|
|
81
|
-
if msg.type == aiohttp.WSMsgType.TEXT:
|
|
82
|
-
if msg.data == "ping":
|
|
83
|
-
await ws.send_str("pong")
|
|
84
|
-
if msg.type == aiohttp.WSMsgType.ERROR:
|
|
85
|
-
logging.warning(
|
|
86
|
-
"ws connection closed with exception %s" % ws.exception()
|
|
87
|
-
)
|
|
88
|
-
finally:
|
|
89
|
-
self.sockets.pop(sid, None)
|
|
90
|
-
return ws
|
|
91
|
-
|
|
92
|
-
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/sign")
|
|
93
|
-
async def sign(request):
|
|
94
|
-
sha256sum = request.rel_url.query.get("sha256sum")
|
|
95
|
-
if not is_string_valid(sha256sum):
|
|
96
|
-
return ErrResponse(errnos.EMPTY_SHA256SUM)
|
|
97
|
-
|
|
98
|
-
type = request.rel_url.query.get("type")
|
|
99
|
-
|
|
100
|
-
sign_data, err = await self.api_client.sign(sha256sum, type)
|
|
101
|
-
if err is not None:
|
|
102
|
-
return ErrResponse(err)
|
|
103
|
-
|
|
104
|
-
return OKResponse(sign_data)
|
|
105
|
-
|
|
106
|
-
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/upload_token")
|
|
107
|
-
async def upload_token(request):
|
|
108
|
-
filename = request.rel_url.query.get("filename", "")
|
|
109
|
-
# 校验filename
|
|
110
|
-
if not is_string_valid(filename):
|
|
111
|
-
return ErrResponse(errnos.INVALID_FILENAME)
|
|
112
|
-
|
|
113
|
-
filename = urllib.parse.quote(filename)
|
|
114
|
-
token, err = await self.api_client.get_upload_token(filename=filename)
|
|
115
|
-
if err is not None:
|
|
116
|
-
return ErrResponse(err)
|
|
117
|
-
return OKResponse(token)
|
|
118
|
-
|
|
119
|
-
@self.prompt_server.routes.post(f"/{COMMUNITY_API}/commit_file")
|
|
120
|
-
async def commit_file(request):
|
|
121
|
-
json_data = await request.json()
|
|
122
|
-
|
|
123
|
-
if "sha256sum" not in json_data:
|
|
124
|
-
return ErrResponse(errnos.EMPTY_SHA256SUM)
|
|
125
|
-
sha256sum = json_data.get("sha256sum")
|
|
126
|
-
|
|
127
|
-
if "object_key" not in json_data:
|
|
128
|
-
return ErrResponse(errnos.INVALID_OBJECT_KEY)
|
|
129
|
-
object_key = json_data.get("object_key")
|
|
130
|
-
|
|
131
|
-
if "type" not in json_data:
|
|
132
|
-
return ErrResponse(errnos.INVALID_TYPE)
|
|
133
|
-
type = json_data.get("type")
|
|
134
|
-
|
|
135
|
-
md5_hash = ""
|
|
136
|
-
if "md5_hash" in json_data:
|
|
137
|
-
md5_hash = json_data.get("md5_hash")
|
|
63
|
+
@self.prompt_server.routes.post(f"/{COMMUNITY_API}/datasets/query")
|
|
64
|
+
async def query_my_datasets(request):
|
|
65
|
+
current = int(request.rel_url.query.get("current", "1"))
|
|
66
|
+
page_size = int(request.rel_url.query.get("page_size", "10"))
|
|
67
|
+
keyword = None
|
|
68
|
+
annotated = None
|
|
69
|
+
if request.body_exists:
|
|
70
|
+
json_data = await request.json()
|
|
71
|
+
keyword = json_data["keyword"]
|
|
72
|
+
annotated = json_data["annotated"]
|
|
73
|
+
resp, err = None, None
|
|
138
74
|
|
|
139
|
-
|
|
140
|
-
|
|
75
|
+
# 调用API查询数据集
|
|
76
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
77
|
+
resp, err = await self.api_client.query_datasets(
|
|
78
|
+
current,
|
|
79
|
+
page_size,
|
|
80
|
+
keyword=keyword,
|
|
81
|
+
annotated=annotated,
|
|
82
|
+
request_api_key=request_api_key,
|
|
141
83
|
)
|
|
142
|
-
|
|
143
|
-
if err is not None:
|
|
84
|
+
if err:
|
|
144
85
|
return ErrResponse(err)
|
|
145
86
|
|
|
146
|
-
return OKResponse(
|
|
147
|
-
|
|
148
|
-
@self.prompt_server.routes.post(f"/{COMMUNITY_API}/models")
|
|
149
|
-
async def commit_bizy_model(request):
|
|
150
|
-
sid = request.rel_url.query.get("clientId", "")
|
|
151
|
-
if not is_string_valid(sid):
|
|
152
|
-
return ErrResponse(errnos.INVALID_CLIENT_ID)
|
|
153
|
-
|
|
154
|
-
json_data = await request.json()
|
|
155
|
-
|
|
156
|
-
# 校验name和type
|
|
157
|
-
err = check_str_param(json_data, "name", errnos.INVALID_NAME)
|
|
158
|
-
if err is not None:
|
|
159
|
-
return err
|
|
160
|
-
|
|
161
|
-
if "/" in json_data["name"]:
|
|
162
|
-
return ErrResponse(errnos.INVALID_NAME)
|
|
163
|
-
|
|
164
|
-
err = check_type(json_data)
|
|
165
|
-
if err is not None:
|
|
166
|
-
return err
|
|
167
|
-
|
|
168
|
-
# 校验versions
|
|
169
|
-
if "versions" not in json_data or not isinstance(
|
|
170
|
-
json_data["versions"], list
|
|
171
|
-
):
|
|
172
|
-
return ErrResponse(errnos.INVALID_VERSIONS)
|
|
173
|
-
|
|
174
|
-
versions = json_data["versions"]
|
|
175
|
-
version_names = set()
|
|
176
|
-
|
|
177
|
-
for version in versions:
|
|
178
|
-
# 检查version是否重复
|
|
179
|
-
if version.get("version") in version_names:
|
|
180
|
-
return ErrResponse(errnos.DUPLICATE_VERSION)
|
|
181
|
-
|
|
182
|
-
# 检查version字段是否合法
|
|
183
|
-
if not is_string_valid(version.get("version")) or "/" in version.get(
|
|
184
|
-
"version"
|
|
185
|
-
):
|
|
186
|
-
return ErrResponse(errnos.INVALID_VERSION_NAME)
|
|
87
|
+
return OKResponse(resp)
|
|
187
88
|
|
|
188
|
-
|
|
89
|
+
@self.prompt_server.routes.get(
|
|
90
|
+
f"/{COMMUNITY_API}/datasets/{{dataset_id}}/detail"
|
|
91
|
+
)
|
|
92
|
+
async def get_dataset_detail(request):
|
|
93
|
+
# 获取路径参数中的数据集ID
|
|
94
|
+
dataset_id = int(request.match_info["dataset_id"])
|
|
189
95
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
err = errnos.INVALID_VERSION_FIELD.copy()
|
|
194
|
-
err.message = "Invalid version field: " + field
|
|
195
|
-
return ErrResponse(err)
|
|
96
|
+
# 检查dataset_id是否合法
|
|
97
|
+
if not dataset_id or dataset_id <= 0:
|
|
98
|
+
return ErrResponse(errnos.INVALID_DATASET_ID)
|
|
196
99
|
|
|
197
|
-
# 调用API
|
|
198
|
-
|
|
100
|
+
# 调用API获取数据集详情
|
|
101
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
102
|
+
resp, err = await self.api_client.get_dataset_detail(
|
|
103
|
+
dataset_id, request_api_key=request_api_key
|
|
104
|
+
)
|
|
199
105
|
if err:
|
|
200
106
|
return ErrResponse(err)
|
|
201
107
|
|
|
202
|
-
# print("resp------------------------------->", json_data, resp)
|
|
203
|
-
# 开启线程检查同步状态
|
|
204
|
-
threading.Thread(
|
|
205
|
-
target=self.check_sync_status,
|
|
206
|
-
args=(resp["id"], resp["version_ids"], sid),
|
|
207
|
-
daemon=True,
|
|
208
|
-
).start()
|
|
209
|
-
|
|
210
|
-
# enable refresh for lora
|
|
211
|
-
# TODO: enable refresh for other types
|
|
212
|
-
# bizyengine.core.path_utils.path_manager.enable_refresh_options("loras")
|
|
213
|
-
|
|
214
108
|
return OKResponse(resp)
|
|
215
109
|
|
|
216
110
|
@self.prompt_server.routes.post(f"/{COMMUNITY_API}/models/query")
|
|
@@ -229,6 +123,7 @@ class BizyAirServer:
|
|
|
229
123
|
sort = json_data.get("sort", "")
|
|
230
124
|
resp, err = None, None
|
|
231
125
|
|
|
126
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
232
127
|
if mode in ["my", "my_fork"]:
|
|
233
128
|
# 调用API查询模型
|
|
234
129
|
resp, err = await self.api_client.query_models(
|
|
@@ -239,6 +134,7 @@ class BizyAirServer:
|
|
|
239
134
|
model_types=model_types,
|
|
240
135
|
base_models=base_models,
|
|
241
136
|
sort=sort,
|
|
137
|
+
request_api_key=request_api_key,
|
|
242
138
|
)
|
|
243
139
|
elif mode == "publicity":
|
|
244
140
|
# 调用API查询社区模型
|
|
@@ -249,6 +145,7 @@ class BizyAirServer:
|
|
|
249
145
|
model_types=model_types,
|
|
250
146
|
base_models=base_models,
|
|
251
147
|
sort=sort,
|
|
148
|
+
request_api_key=request_api_key,
|
|
252
149
|
)
|
|
253
150
|
elif mode == "official":
|
|
254
151
|
resp, err = await self.api_client.query_official_models(
|
|
@@ -258,6 +155,7 @@ class BizyAirServer:
|
|
|
258
155
|
model_types=model_types,
|
|
259
156
|
base_models=base_models,
|
|
260
157
|
sort=sort,
|
|
158
|
+
request_api_key=request_api_key,
|
|
261
159
|
)
|
|
262
160
|
if err:
|
|
263
161
|
return ErrResponse(err)
|
|
@@ -276,7 +174,10 @@ class BizyAirServer:
|
|
|
276
174
|
source = request.rel_url.query.get("source", "")
|
|
277
175
|
|
|
278
176
|
# 调用API获取模型详情
|
|
279
|
-
|
|
177
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
178
|
+
resp, err = await self.api_client.get_model_detail(
|
|
179
|
+
model_id, source, request_api_key=request_api_key
|
|
180
|
+
)
|
|
280
181
|
if err:
|
|
281
182
|
return ErrResponse(err)
|
|
282
183
|
|
|
@@ -292,7 +193,10 @@ class BizyAirServer:
|
|
|
292
193
|
return ErrResponse(errnos.INVALID_MODEL_ID)
|
|
293
194
|
|
|
294
195
|
# 调用API删除模型
|
|
295
|
-
|
|
196
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
197
|
+
resp, err = await self.api_client.delete_bizy_model(
|
|
198
|
+
model_id, request_api_key=request_api_key
|
|
199
|
+
)
|
|
296
200
|
if err:
|
|
297
201
|
return ErrResponse(err)
|
|
298
202
|
|
|
@@ -353,8 +257,13 @@ class BizyAirServer:
|
|
|
353
257
|
return ErrResponse(errnos.INVALID_VERSION_FIELD(field))
|
|
354
258
|
|
|
355
259
|
# 调用API更新模型
|
|
260
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
356
261
|
resp, err = await self.api_client.update_model(
|
|
357
|
-
model_id,
|
|
262
|
+
model_id,
|
|
263
|
+
json_data["name"],
|
|
264
|
+
json_data["type"],
|
|
265
|
+
versions,
|
|
266
|
+
request_api_key=request_api_key,
|
|
358
267
|
)
|
|
359
268
|
if err:
|
|
360
269
|
return ErrResponse(err)
|
|
@@ -379,7 +288,10 @@ class BizyAirServer:
|
|
|
379
288
|
return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
|
|
380
289
|
|
|
381
290
|
# 调用API fork模型版本
|
|
382
|
-
|
|
291
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
292
|
+
_, err = await self.api_client.fork_model_version(
|
|
293
|
+
version_id, request_api_key=request_api_key
|
|
294
|
+
)
|
|
383
295
|
if err:
|
|
384
296
|
return ErrResponse(err)
|
|
385
297
|
|
|
@@ -400,7 +312,10 @@ class BizyAirServer:
|
|
|
400
312
|
return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
|
|
401
313
|
|
|
402
314
|
# 调用API fork模型版本
|
|
403
|
-
|
|
315
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
316
|
+
_, err = await self.api_client.unfork_model_version(
|
|
317
|
+
version_id, request_api_key=request_api_key
|
|
318
|
+
)
|
|
404
319
|
if err:
|
|
405
320
|
return ErrResponse(err)
|
|
406
321
|
|
|
@@ -423,8 +338,9 @@ class BizyAirServer:
|
|
|
423
338
|
return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
|
|
424
339
|
|
|
425
340
|
# 调用API like模型版本
|
|
341
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
426
342
|
_, err = await self.api_client.toggle_user_like(
|
|
427
|
-
"model_version", version_id
|
|
343
|
+
"model_version", version_id, request_api_key=request_api_key
|
|
428
344
|
)
|
|
429
345
|
if err:
|
|
430
346
|
return ErrResponse(err)
|
|
@@ -451,8 +367,11 @@ class BizyAirServer:
|
|
|
451
367
|
return ErrResponse(errnos.INVALID_SIGN)
|
|
452
368
|
|
|
453
369
|
# 获取上传凭证
|
|
370
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
454
371
|
url, err = await self.api_client.get_download_url(
|
|
455
|
-
sign=sign,
|
|
372
|
+
sign=sign,
|
|
373
|
+
model_version_id=model_version_id,
|
|
374
|
+
request_api_key=request_api_key,
|
|
456
375
|
)
|
|
457
376
|
if err:
|
|
458
377
|
return ErrResponse(err)
|
|
@@ -466,26 +385,12 @@ class BizyAirServer:
|
|
|
466
385
|
|
|
467
386
|
return OKResponse(json_content)
|
|
468
387
|
|
|
469
|
-
@self.prompt_server.routes.get(f"/{MODEL_HOST_API}" + "/{shareId}/models/files")
|
|
470
|
-
async def list_share_model_files(request):
|
|
471
|
-
shareId = request.match_info["shareId"]
|
|
472
|
-
if not is_string_valid(shareId):
|
|
473
|
-
return ErrResponse("INVALID_SHARE_ID")
|
|
474
|
-
payload = {}
|
|
475
|
-
query_params = ["type", "name", "ext_name"]
|
|
476
|
-
for param in query_params:
|
|
477
|
-
if param in request.rel_url.query and request.rel_url.query[param]:
|
|
478
|
-
payload[param] = request.rel_url.query[param]
|
|
479
|
-
model_files, err = await self.api_client.get_share_model_files(
|
|
480
|
-
shareId=shareId, payload=payload
|
|
481
|
-
)
|
|
482
|
-
if err is not None:
|
|
483
|
-
return ErrResponse(err)
|
|
484
|
-
return OKResponse(model_files)
|
|
485
|
-
|
|
486
388
|
@self.prompt_server.routes.get(f"/{API_PREFIX}/dict")
|
|
487
389
|
async def get_data_dict(request):
|
|
488
|
-
|
|
390
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
391
|
+
data_dict, err = await self.api_client.get_data_dict(
|
|
392
|
+
request_api_key=request_api_key
|
|
393
|
+
)
|
|
489
394
|
if err is not None:
|
|
490
395
|
return ErrResponse(err)
|
|
491
396
|
|
|
@@ -530,7 +435,10 @@ class BizyAirServer:
|
|
|
530
435
|
version_names.add(version.get("version"))
|
|
531
436
|
|
|
532
437
|
# 调用API提交数据集
|
|
533
|
-
|
|
438
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
439
|
+
resp, err = await self.api_client.commit_dataset(
|
|
440
|
+
payload=json_data, request_api_key=request_api_key
|
|
441
|
+
)
|
|
534
442
|
if err:
|
|
535
443
|
return ErrResponse(err)
|
|
536
444
|
|
|
@@ -594,8 +502,9 @@ class BizyAirServer:
|
|
|
594
502
|
version_names.add(version.get("version"))
|
|
595
503
|
|
|
596
504
|
# 调用API更新数据集
|
|
505
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
597
506
|
resp, err = await self.api_client.update_dataset(
|
|
598
|
-
dataset_id, json_data["name"], versions
|
|
507
|
+
dataset_id, json_data["name"], versions, request_api_key=request_api_key
|
|
599
508
|
)
|
|
600
509
|
if err:
|
|
601
510
|
return ErrResponse(err)
|
|
@@ -619,76 +528,289 @@ class BizyAirServer:
|
|
|
619
528
|
return ErrResponse(errnos.INVALID_DATASET_ID)
|
|
620
529
|
|
|
621
530
|
# 调用API删除数据集
|
|
622
|
-
|
|
531
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
532
|
+
resp, err = await self.api_client.delete_dataset(
|
|
533
|
+
dataset_id, request_api_key=request_api_key
|
|
534
|
+
)
|
|
623
535
|
if err:
|
|
624
536
|
return ErrResponse(err)
|
|
625
537
|
|
|
626
538
|
return OKResponse(resp)
|
|
627
539
|
|
|
628
|
-
@self.prompt_server.routes.post(f"/{COMMUNITY_API}/
|
|
629
|
-
async def
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
annotated = None
|
|
634
|
-
if request.body_exists:
|
|
635
|
-
json_data = await request.json()
|
|
636
|
-
keyword = json_data["keyword"]
|
|
637
|
-
annotated = json_data["annotated"]
|
|
638
|
-
resp, err = None, None
|
|
540
|
+
@self.prompt_server.routes.post(f"/{COMMUNITY_API}/share")
|
|
541
|
+
async def create_share(request):
|
|
542
|
+
json_data = await request.json()
|
|
543
|
+
if "biz_id" not in json_data:
|
|
544
|
+
return ErrResponse(errnos.INVALID_SHARE_BIZ_ID)
|
|
639
545
|
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
546
|
+
biz_id = int(json_data["biz_id"])
|
|
547
|
+
if not biz_id or biz_id <= 0:
|
|
548
|
+
return ErrResponse(errnos.INVALID_SHARE_BIZ_ID)
|
|
549
|
+
|
|
550
|
+
if "type" not in json_data:
|
|
551
|
+
return ErrResponse(errnos.INVALID_SHARE_TYPE)
|
|
552
|
+
if not is_string_valid(json_data["type"]) or (
|
|
553
|
+
json_data["type"] != "bizy_model_version"
|
|
554
|
+
):
|
|
555
|
+
return ErrResponse(errnos.INVALID_SHARE_TYPE)
|
|
556
|
+
|
|
557
|
+
# 调用API提交数据集
|
|
558
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
559
|
+
resp, err = await self.api_client.create_share(
|
|
560
|
+
payload=json_data, request_api_key=request_api_key
|
|
646
561
|
)
|
|
647
562
|
if err:
|
|
648
563
|
return ErrResponse(err)
|
|
649
564
|
|
|
650
565
|
return OKResponse(resp)
|
|
651
566
|
|
|
652
|
-
@self.prompt_server.routes.get(
|
|
653
|
-
|
|
654
|
-
)
|
|
655
|
-
async def get_dataset_detail(request):
|
|
567
|
+
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/model_version/{{version_id}}")
|
|
568
|
+
async def get_model_version_detail(request):
|
|
656
569
|
# 获取路径参数中的数据集ID
|
|
657
|
-
|
|
570
|
+
version_id = int(request.match_info["version_id"])
|
|
658
571
|
|
|
659
|
-
# 检查
|
|
660
|
-
if not
|
|
661
|
-
return ErrResponse(errnos.
|
|
572
|
+
# 检查version_id是否合法
|
|
573
|
+
if not version_id or version_id <= 0:
|
|
574
|
+
return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
|
|
662
575
|
|
|
663
576
|
# 调用API获取数据集详情
|
|
664
|
-
|
|
577
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
578
|
+
resp, err = await self.api_client.get_model_version_detail(
|
|
579
|
+
version_id, request_api_key=request_api_key
|
|
580
|
+
)
|
|
665
581
|
if err:
|
|
666
582
|
return ErrResponse(err)
|
|
667
583
|
|
|
668
584
|
return OKResponse(resp)
|
|
669
585
|
|
|
670
|
-
@self.prompt_server.routes.
|
|
671
|
-
async def
|
|
586
|
+
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/sign")
|
|
587
|
+
async def sign(request):
|
|
588
|
+
sha256sum = request.rel_url.query.get("sha256sum")
|
|
589
|
+
if not is_string_valid(sha256sum):
|
|
590
|
+
return ErrResponse(errnos.EMPTY_SHA256SUM)
|
|
591
|
+
|
|
592
|
+
type = request.rel_url.query.get("type")
|
|
593
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
594
|
+
sign_data, err = await self.api_client.sign(
|
|
595
|
+
sha256sum, type, request_api_key=request_api_key
|
|
596
|
+
)
|
|
597
|
+
if err is not None:
|
|
598
|
+
return ErrResponse(err)
|
|
599
|
+
|
|
600
|
+
return OKResponse(sign_data)
|
|
601
|
+
|
|
602
|
+
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/upload_token")
|
|
603
|
+
async def upload_token(request):
|
|
604
|
+
filename = request.rel_url.query.get("filename", "")
|
|
605
|
+
# 校验filename
|
|
606
|
+
if not is_string_valid(filename):
|
|
607
|
+
return ErrResponse(errnos.INVALID_FILENAME)
|
|
608
|
+
|
|
609
|
+
filename = urllib.parse.quote(filename)
|
|
610
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
611
|
+
token, err = await self.api_client.get_upload_token(
|
|
612
|
+
filename=filename, request_api_key=request_api_key
|
|
613
|
+
)
|
|
614
|
+
if err is not None:
|
|
615
|
+
return ErrResponse(err)
|
|
616
|
+
return OKResponse(token)
|
|
617
|
+
|
|
618
|
+
@self.prompt_server.routes.post(f"/{COMMUNITY_API}/commit_file")
|
|
619
|
+
async def commit_file(request):
|
|
672
620
|
json_data = await request.json()
|
|
673
|
-
if "biz_id" not in json_data:
|
|
674
|
-
return ErrResponse(errnos.INVALID_SHARE_BIZ_ID)
|
|
675
621
|
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
622
|
+
if "sha256sum" not in json_data:
|
|
623
|
+
return ErrResponse(errnos.EMPTY_SHA256SUM)
|
|
624
|
+
sha256sum = json_data.get("sha256sum")
|
|
625
|
+
|
|
626
|
+
if "object_key" not in json_data:
|
|
627
|
+
return ErrResponse(errnos.INVALID_OBJECT_KEY)
|
|
628
|
+
object_key = json_data.get("object_key")
|
|
679
629
|
|
|
680
630
|
if "type" not in json_data:
|
|
681
|
-
return ErrResponse(errnos.
|
|
682
|
-
|
|
683
|
-
|
|
631
|
+
return ErrResponse(errnos.INVALID_TYPE)
|
|
632
|
+
type = json_data.get("type")
|
|
633
|
+
|
|
634
|
+
md5_hash = ""
|
|
635
|
+
if "md5_hash" in json_data:
|
|
636
|
+
md5_hash = json_data.get("md5_hash")
|
|
637
|
+
|
|
638
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
639
|
+
commit_data, err = await self.api_client.commit_file(
|
|
640
|
+
signature=sha256sum,
|
|
641
|
+
object_key=object_key,
|
|
642
|
+
md5_hash=md5_hash,
|
|
643
|
+
type=type,
|
|
644
|
+
request_api_key=request_api_key,
|
|
645
|
+
)
|
|
646
|
+
# print("commit_data", commit_data)
|
|
647
|
+
if err is not None:
|
|
648
|
+
return ErrResponse(err)
|
|
649
|
+
|
|
650
|
+
return OKResponse(None)
|
|
651
|
+
|
|
652
|
+
# 由于历史原因,前端请求body里有apikey所以是post
|
|
653
|
+
@self.prompt_server.routes.post(f"/{API_PREFIX}/get_silicon_cloud_llm_models")
|
|
654
|
+
async def get_silicon_cloud_llm_models_endpoint(request):
|
|
655
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
656
|
+
all_models = await self.api_client.fetch_all_llm_models(
|
|
657
|
+
request_api_key=request_api_key
|
|
658
|
+
)
|
|
659
|
+
llm_models = [model for model in all_models if "vl" not in model.lower()]
|
|
660
|
+
llm_models.append("No LLM Enhancement")
|
|
661
|
+
return aiohttp.web.json_response(llm_models)
|
|
662
|
+
|
|
663
|
+
# 由于历史原因,前端请求body里有apikey所以是post
|
|
664
|
+
@self.prompt_server.routes.post(f"/{API_PREFIX}/get_silicon_cloud_vlm_models")
|
|
665
|
+
async def get_silicon_cloud_vlm_models_endpoint(request):
|
|
666
|
+
request_api_key = _get_request_api_key(request.headers)
|
|
667
|
+
all_models = await self.api_client.fetch_all_llm_models(
|
|
668
|
+
request_api_key=request_api_key
|
|
669
|
+
)
|
|
670
|
+
vlm_models = [model for model in all_models if "vl" in model.lower()]
|
|
671
|
+
vlm_models.append("No VLM Enhancement")
|
|
672
|
+
return aiohttp.web.json_response(vlm_models)
|
|
673
|
+
|
|
674
|
+
@self.prompt_server.routes.post(f"/{API_PREFIX}/validate_prompt")
|
|
675
|
+
async def validate_prompt(request):
|
|
676
|
+
json_data = await request.json()
|
|
677
|
+
if not "prompt" in json_data:
|
|
678
|
+
return ErrResponse(errnos.MISSING_PROMPT)
|
|
679
|
+
|
|
680
|
+
valid = execution.validate_prompt(json_data["prompt"])
|
|
681
|
+
if valid[0]:
|
|
682
|
+
return OKResponse(None)
|
|
683
|
+
else:
|
|
684
|
+
err = errnos.INVALID_PROMPT.copy()
|
|
685
|
+
err.data = {"error": valid[1], "node_errors": valid[3]}
|
|
686
|
+
return ErrResponse(err)
|
|
687
|
+
|
|
688
|
+
# 服务器模式下以下路径不会注册
|
|
689
|
+
if BIZYAIR_SERVER_MODE:
|
|
690
|
+
return
|
|
691
|
+
|
|
692
|
+
@self.prompt_server.routes.get(f"/{MODEL_HOST_API}" + "/{shareId}/models/files")
|
|
693
|
+
async def list_share_model_files(request):
|
|
694
|
+
shareId = request.match_info["shareId"]
|
|
695
|
+
if not is_string_valid(shareId):
|
|
696
|
+
return ErrResponse("INVALID_SHARE_ID")
|
|
697
|
+
payload = {}
|
|
698
|
+
query_params = ["type", "name", "ext_name"]
|
|
699
|
+
for param in query_params:
|
|
700
|
+
if param in request.rel_url.query and request.rel_url.query[param]:
|
|
701
|
+
payload[param] = request.rel_url.query[param]
|
|
702
|
+
model_files, err = await self.api_client.get_share_model_files(
|
|
703
|
+
shareId=shareId, payload=payload
|
|
704
|
+
)
|
|
705
|
+
if err is not None:
|
|
706
|
+
return ErrResponse(err)
|
|
707
|
+
return OKResponse(model_files)
|
|
708
|
+
|
|
709
|
+
@self.prompt_server.routes.get(f"/{USER_API}/info")
|
|
710
|
+
async def user_info(request):
|
|
711
|
+
info, err = await self.api_client.user_info()
|
|
712
|
+
if err is not None:
|
|
713
|
+
return ErrResponse(err)
|
|
714
|
+
|
|
715
|
+
return OKResponse(info)
|
|
716
|
+
|
|
717
|
+
@self.prompt_server.routes.get(f"/{API_PREFIX}/ws")
|
|
718
|
+
async def websocket_handler(request):
|
|
719
|
+
ws = aiohttp.web.WebSocketResponse()
|
|
720
|
+
await ws.prepare(request)
|
|
721
|
+
sid = request.rel_url.query.get("clientId", "")
|
|
722
|
+
if sid:
|
|
723
|
+
# Reusing existing session, remove old
|
|
724
|
+
self.sockets.pop(sid, None)
|
|
725
|
+
else:
|
|
726
|
+
sid = uuid.uuid4().hex
|
|
727
|
+
|
|
728
|
+
self.sockets[sid] = ws
|
|
729
|
+
|
|
730
|
+
try:
|
|
731
|
+
# Send initial state to the new client
|
|
732
|
+
await self.send_json(
|
|
733
|
+
event="status", data={"status": "connected"}, sid=sid
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
async for msg in ws:
|
|
737
|
+
if msg.type == aiohttp.WSMsgType.TEXT:
|
|
738
|
+
if msg.data == "ping":
|
|
739
|
+
await ws.send_str("pong")
|
|
740
|
+
if msg.type == aiohttp.WSMsgType.ERROR:
|
|
741
|
+
logging.warning(
|
|
742
|
+
"ws connection closed with exception %s" % ws.exception()
|
|
743
|
+
)
|
|
744
|
+
finally:
|
|
745
|
+
self.sockets.pop(sid, None)
|
|
746
|
+
return ws
|
|
747
|
+
|
|
748
|
+
@self.prompt_server.routes.post(f"/{COMMUNITY_API}/models")
|
|
749
|
+
async def commit_bizy_model(request):
|
|
750
|
+
sid = request.rel_url.query.get("clientId", "")
|
|
751
|
+
if not is_string_valid(sid):
|
|
752
|
+
return ErrResponse(errnos.INVALID_CLIENT_ID)
|
|
753
|
+
|
|
754
|
+
json_data = await request.json()
|
|
755
|
+
|
|
756
|
+
# 校验name和type
|
|
757
|
+
err = check_str_param(json_data, "name", errnos.INVALID_NAME)
|
|
758
|
+
if err is not None:
|
|
759
|
+
return err
|
|
760
|
+
|
|
761
|
+
if "/" in json_data["name"]:
|
|
762
|
+
return ErrResponse(errnos.INVALID_NAME)
|
|
763
|
+
|
|
764
|
+
err = check_type(json_data)
|
|
765
|
+
if err is not None:
|
|
766
|
+
return err
|
|
767
|
+
|
|
768
|
+
# 校验versions
|
|
769
|
+
if "versions" not in json_data or not isinstance(
|
|
770
|
+
json_data["versions"], list
|
|
684
771
|
):
|
|
685
|
-
return ErrResponse(errnos.
|
|
772
|
+
return ErrResponse(errnos.INVALID_VERSIONS)
|
|
686
773
|
|
|
687
|
-
|
|
688
|
-
|
|
774
|
+
versions = json_data["versions"]
|
|
775
|
+
version_names = set()
|
|
776
|
+
|
|
777
|
+
for version in versions:
|
|
778
|
+
# 检查version是否重复
|
|
779
|
+
if version.get("version") in version_names:
|
|
780
|
+
return ErrResponse(errnos.DUPLICATE_VERSION)
|
|
781
|
+
|
|
782
|
+
# 检查version字段是否合法
|
|
783
|
+
if not is_string_valid(version.get("version")) or "/" in version.get(
|
|
784
|
+
"version"
|
|
785
|
+
):
|
|
786
|
+
return ErrResponse(errnos.INVALID_VERSION_NAME)
|
|
787
|
+
|
|
788
|
+
version_names.add(version.get("version"))
|
|
789
|
+
|
|
790
|
+
# 检查base_model, path和sign是否有值
|
|
791
|
+
for field in ["base_model", "path", "sign"]:
|
|
792
|
+
if not is_string_valid(version.get(field)):
|
|
793
|
+
err = errnos.INVALID_VERSION_FIELD.copy()
|
|
794
|
+
err.message = "Invalid version field: " + field
|
|
795
|
+
return ErrResponse(err)
|
|
796
|
+
|
|
797
|
+
# 调用API提交模型
|
|
798
|
+
resp, err = await self.api_client.commit_bizy_model(payload=json_data)
|
|
689
799
|
if err:
|
|
690
800
|
return ErrResponse(err)
|
|
691
801
|
|
|
802
|
+
# print("resp------------------------------->", json_data, resp)
|
|
803
|
+
# 开启线程检查同步状态
|
|
804
|
+
threading.Thread(
|
|
805
|
+
target=self.check_sync_status,
|
|
806
|
+
args=(resp["id"], resp["version_ids"], sid),
|
|
807
|
+
daemon=True,
|
|
808
|
+
).start()
|
|
809
|
+
|
|
810
|
+
# enable refresh for lora
|
|
811
|
+
# TODO: enable refresh for other types
|
|
812
|
+
# bizyengine.core.path_utils.path_manager.enable_refresh_options("loras")
|
|
813
|
+
|
|
692
814
|
return OKResponse(resp)
|
|
693
815
|
|
|
694
816
|
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/share/{{code}}")
|
|
@@ -707,22 +829,6 @@ class BizyAirServer:
|
|
|
707
829
|
|
|
708
830
|
return OKResponse(resp)
|
|
709
831
|
|
|
710
|
-
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/model_version/{{version_id}}")
|
|
711
|
-
async def get_model_version_detail(request):
|
|
712
|
-
# 获取路径参数中的数据集ID
|
|
713
|
-
version_id = int(request.match_info["version_id"])
|
|
714
|
-
|
|
715
|
-
# 检查version_id是否合法
|
|
716
|
-
if not version_id or version_id <= 0:
|
|
717
|
-
return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
|
|
718
|
-
|
|
719
|
-
# 调用API获取数据集详情
|
|
720
|
-
resp, err = await self.api_client.get_model_version_detail(version_id)
|
|
721
|
-
if err:
|
|
722
|
-
return ErrResponse(err)
|
|
723
|
-
|
|
724
|
-
return OKResponse(resp)
|
|
725
|
-
|
|
726
832
|
@self.prompt_server.routes.get(f"/{COMMUNITY_API}/notifications/unread_count")
|
|
727
833
|
async def get_notification_unread_count(request):
|
|
728
834
|
# 获取当前用户的未读消息数量
|
|
@@ -744,7 +850,11 @@ class BizyAirServer:
|
|
|
744
850
|
last_broadcast_id = int(request.rel_url.query.get("last_broadcast_id", "0"))
|
|
745
851
|
|
|
746
852
|
resp, err = await self.api_client.fetch_notifications(
|
|
747
|
-
page_size,
|
|
853
|
+
page_size,
|
|
854
|
+
last_pm_id,
|
|
855
|
+
last_broadcast_id,
|
|
856
|
+
types,
|
|
857
|
+
read_status,
|
|
748
858
|
)
|
|
749
859
|
if err:
|
|
750
860
|
return ErrResponse(err)
|
|
@@ -913,7 +1023,9 @@ class BizyAirServer:
|
|
|
913
1023
|
if not year:
|
|
914
1024
|
return ErrResponse(errnos.INVALID_YEAR_PARAM)
|
|
915
1025
|
|
|
916
|
-
resp, err = await self.api_client.get_year_cost(
|
|
1026
|
+
resp, err = await self.api_client.get_year_cost(
|
|
1027
|
+
year=year, query_api_key=api_key
|
|
1028
|
+
)
|
|
917
1029
|
if err is not None:
|
|
918
1030
|
return ErrResponse(err)
|
|
919
1031
|
|
|
@@ -928,7 +1040,7 @@ class BizyAirServer:
|
|
|
928
1040
|
return ErrResponse(errnos.INVALID_MONTH_PARAM)
|
|
929
1041
|
|
|
930
1042
|
resp, err = await self.api_client.get_month_cost(
|
|
931
|
-
month=month,
|
|
1043
|
+
month=month, query_api_key=api_key
|
|
932
1044
|
)
|
|
933
1045
|
if err is not None:
|
|
934
1046
|
return ErrResponse(err)
|
|
@@ -943,7 +1055,9 @@ class BizyAirServer:
|
|
|
943
1055
|
if not day:
|
|
944
1056
|
return ErrResponse(errnos.INVALID_DAY_PARAM)
|
|
945
1057
|
|
|
946
|
-
resp, err = await self.api_client.get_day_cost(
|
|
1058
|
+
resp, err = await self.api_client.get_day_cost(
|
|
1059
|
+
day=day, query_api_key=api_key
|
|
1060
|
+
)
|
|
947
1061
|
if err is not None:
|
|
948
1062
|
return ErrResponse(err)
|
|
949
1063
|
|
|
@@ -954,7 +1068,7 @@ class BizyAirServer:
|
|
|
954
1068
|
|
|
955
1069
|
api_key = request.rel_url.query.get("api_key", "")
|
|
956
1070
|
|
|
957
|
-
resp, err = await self.api_client.get_recent_cost(
|
|
1071
|
+
resp, err = await self.api_client.get_recent_cost(query_api_key=api_key)
|
|
958
1072
|
if err is not None:
|
|
959
1073
|
return ErrResponse(err)
|
|
960
1074
|
|