bizyengine 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. bizyengine/__init__.py +35 -0
  2. bizyengine/bizy_server/__init__.py +7 -0
  3. bizyengine/bizy_server/api_client.py +763 -0
  4. bizyengine/bizy_server/errno.py +122 -0
  5. bizyengine/bizy_server/error_handler.py +3 -0
  6. bizyengine/bizy_server/execution.py +55 -0
  7. bizyengine/bizy_server/resp.py +24 -0
  8. bizyengine/bizy_server/server.py +898 -0
  9. bizyengine/bizy_server/utils.py +93 -0
  10. bizyengine/bizyair_extras/__init__.py +24 -0
  11. bizyengine/bizyair_extras/nodes_advanced_refluxcontrol.py +62 -0
  12. bizyengine/bizyair_extras/nodes_cogview4.py +31 -0
  13. bizyengine/bizyair_extras/nodes_comfyui_detail_daemon.py +180 -0
  14. bizyengine/bizyair_extras/nodes_comfyui_instantid.py +164 -0
  15. bizyengine/bizyair_extras/nodes_comfyui_layerstyle_advance.py +141 -0
  16. bizyengine/bizyair_extras/nodes_comfyui_pulid_flux.py +88 -0
  17. bizyengine/bizyair_extras/nodes_controlnet.py +50 -0
  18. bizyengine/bizyair_extras/nodes_custom_sampler.py +130 -0
  19. bizyengine/bizyair_extras/nodes_dataset.py +99 -0
  20. bizyengine/bizyair_extras/nodes_differential_diffusion.py +16 -0
  21. bizyengine/bizyair_extras/nodes_flux.py +69 -0
  22. bizyengine/bizyair_extras/nodes_image_utils.py +93 -0
  23. bizyengine/bizyair_extras/nodes_ip2p.py +20 -0
  24. bizyengine/bizyair_extras/nodes_ipadapter_plus/__init__.py +1 -0
  25. bizyengine/bizyair_extras/nodes_ipadapter_plus/nodes_ipadapter_plus.py +1598 -0
  26. bizyengine/bizyair_extras/nodes_janus_pro.py +81 -0
  27. bizyengine/bizyair_extras/nodes_kolors_mz/__init__.py +86 -0
  28. bizyengine/bizyair_extras/nodes_model_advanced.py +62 -0
  29. bizyengine/bizyair_extras/nodes_sd3.py +52 -0
  30. bizyengine/bizyair_extras/nodes_segment_anything.py +256 -0
  31. bizyengine/bizyair_extras/nodes_segment_anything_utils.py +134 -0
  32. bizyengine/bizyair_extras/nodes_testing_utils.py +139 -0
  33. bizyengine/bizyair_extras/nodes_trellis.py +199 -0
  34. bizyengine/bizyair_extras/nodes_ultimatesdupscale.py +137 -0
  35. bizyengine/bizyair_extras/nodes_upscale_model.py +32 -0
  36. bizyengine/bizyair_extras/nodes_wan_video.py +49 -0
  37. bizyengine/bizyair_extras/oauth_callback/main.py +118 -0
  38. bizyengine/core/__init__.py +8 -0
  39. bizyengine/core/commands/__init__.py +1 -0
  40. bizyengine/core/commands/base.py +27 -0
  41. bizyengine/core/commands/invoker.py +4 -0
  42. bizyengine/core/commands/processors/model_hosting_processor.py +0 -0
  43. bizyengine/core/commands/processors/prompt_processor.py +123 -0
  44. bizyengine/core/commands/servers/model_server.py +0 -0
  45. bizyengine/core/commands/servers/prompt_server.py +234 -0
  46. bizyengine/core/common/__init__.py +8 -0
  47. bizyengine/core/common/caching.py +198 -0
  48. bizyengine/core/common/client.py +262 -0
  49. bizyengine/core/common/env_var.py +101 -0
  50. bizyengine/core/common/utils.py +93 -0
  51. bizyengine/core/configs/conf.py +112 -0
  52. bizyengine/core/configs/models.json +101 -0
  53. bizyengine/core/configs/models.yaml +329 -0
  54. bizyengine/core/data_types.py +20 -0
  55. bizyengine/core/image_utils.py +288 -0
  56. bizyengine/core/nodes_base.py +159 -0
  57. bizyengine/core/nodes_io.py +97 -0
  58. bizyengine/core/path_utils/__init__.py +9 -0
  59. bizyengine/core/path_utils/path_manager.py +276 -0
  60. bizyengine/core/path_utils/utils.py +34 -0
  61. bizyengine/misc/__init__.py +0 -0
  62. bizyengine/misc/auth.py +83 -0
  63. bizyengine/misc/llm.py +431 -0
  64. bizyengine/misc/mzkolors.py +93 -0
  65. bizyengine/misc/nodes.py +1208 -0
  66. bizyengine/misc/nodes_controlnet_aux.py +491 -0
  67. bizyengine/misc/nodes_controlnet_union_sdxl.py +171 -0
  68. bizyengine/misc/route_sam.py +60 -0
  69. bizyengine/misc/segment_anything.py +276 -0
  70. bizyengine/misc/supernode.py +182 -0
  71. bizyengine/misc/utils.py +218 -0
  72. bizyengine/version.txt +1 -0
  73. bizyengine-0.4.2.dist-info/METADATA +12 -0
  74. bizyengine-0.4.2.dist-info/RECORD +76 -0
  75. bizyengine-0.4.2.dist-info/WHEEL +5 -0
  76. bizyengine-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,898 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import threading
5
+ import time
6
+ import urllib.parse
7
+ import uuid
8
+
9
+ import aiohttp
10
+ from server import PromptServer
11
+
12
+ from .api_client import APIClient
13
+ from .errno import ErrorNo, errnos
14
+ from .error_handler import ErrorHandler
15
+ from .resp import ErrResponse, OKResponse
16
+ from .utils import base_model_types, check_str_param, check_type, is_string_valid, types
17
+
18
+ API_PREFIX = "bizyair"
19
+ COMMUNITY_API = f"{API_PREFIX}/community"
20
+ MODEL_HOST_API = f"{API_PREFIX}/modelhost"
21
+ USER_API = f"{API_PREFIX}/user"
22
+
23
+ logging.basicConfig(level=logging.DEBUG)
24
+
25
+
26
+ class BizyAirServer:
27
+ def __init__(self):
28
+ BizyAirServer.instance = self
29
+ self.api_client = APIClient()
30
+ self.error_handler = ErrorHandler()
31
+ self.prompt_server = PromptServer.instance
32
+ self.sockets = dict()
33
+ self.loop = asyncio.get_event_loop()
34
+
35
+ self.setup_routes()
36
+
37
+ def setup_routes(self):
38
+ @self.prompt_server.routes.get(f"/{COMMUNITY_API}/model_types")
39
+ async def list_model_types(request):
40
+ return OKResponse(types())
41
+
42
+ @self.prompt_server.routes.get(f"/{COMMUNITY_API}/base_model_types")
43
+ async def list_base_model_types(request):
44
+ return OKResponse(base_model_types())
45
+
46
+ @self.prompt_server.routes.get(f"/{USER_API}/info")
47
+ async def user_info(request):
48
+ info, err = await self.api_client.user_info()
49
+ if err is not None:
50
+ return ErrResponse(err)
51
+
52
+ return OKResponse(info)
53
+
54
+ @self.prompt_server.routes.get(f"/{API_PREFIX}/ws")
55
+ async def websocket_handler(request):
56
+ ws = aiohttp.web.WebSocketResponse()
57
+ await ws.prepare(request)
58
+ sid = request.rel_url.query.get("clientId", "")
59
+ if sid:
60
+ # Reusing existing session, remove old
61
+ self.sockets.pop(sid, None)
62
+ else:
63
+ sid = uuid.uuid4().hex
64
+
65
+ self.sockets[sid] = ws
66
+
67
+ try:
68
+ # Send initial state to the new client
69
+ await self.send_json(
70
+ event="status", data={"status": "connected"}, sid=sid
71
+ )
72
+
73
+ async for msg in ws:
74
+ if msg.type == aiohttp.WSMsgType.TEXT:
75
+ if msg.data == "ping":
76
+ await ws.send_str("pong")
77
+ if msg.type == aiohttp.WSMsgType.ERROR:
78
+ logging.warning(
79
+ "ws connection closed with exception %s" % ws.exception()
80
+ )
81
+ finally:
82
+ self.sockets.pop(sid, None)
83
+ return ws
84
+
85
+ @self.prompt_server.routes.get(f"/{COMMUNITY_API}/sign")
86
+ async def sign(request):
87
+ sha256sum = request.rel_url.query.get("sha256sum")
88
+ if not is_string_valid(sha256sum):
89
+ return ErrResponse(errnos.EMPTY_SHA256SUM)
90
+
91
+ type = request.rel_url.query.get("type")
92
+
93
+ sign_data, err = await self.api_client.sign(sha256sum, type)
94
+ if err is not None:
95
+ return ErrResponse(err)
96
+
97
+ return OKResponse(sign_data)
98
+
99
+ @self.prompt_server.routes.get(f"/{COMMUNITY_API}/upload_token")
100
+ async def upload_token(request):
101
+ filename = request.rel_url.query.get("filename", "")
102
+ # 校验filename
103
+ if not is_string_valid(filename):
104
+ return ErrResponse(errnos.INVALID_FILENAME)
105
+
106
+ filename = urllib.parse.quote(filename)
107
+ token, err = await self.api_client.get_upload_token(filename=filename)
108
+ if err is not None:
109
+ return ErrResponse(err)
110
+ return OKResponse(token)
111
+
112
+ @self.prompt_server.routes.post(f"/{COMMUNITY_API}/commit_file")
113
+ async def commit_file(request):
114
+ json_data = await request.json()
115
+
116
+ if "sha256sum" not in json_data:
117
+ return ErrResponse(errnos.EMPTY_SHA256SUM)
118
+ sha256sum = json_data.get("sha256sum")
119
+
120
+ if "object_key" not in json_data:
121
+ return ErrResponse(errnos.INVALID_OBJECT_KEY)
122
+ object_key = json_data.get("object_key")
123
+
124
+ if "type" not in json_data:
125
+ return ErrResponse(errnos.INVALID_TYPE)
126
+ type = json_data.get("type")
127
+
128
+ md5_hash = ""
129
+ if "md5_hash" in json_data:
130
+ md5_hash = json_data.get("md5_hash")
131
+
132
+ commit_data, err = await self.api_client.commit_file(
133
+ signature=sha256sum, object_key=object_key, md5_hash=md5_hash, type=type
134
+ )
135
+ # print("commit_data", commit_data)
136
+ if err is not None:
137
+ return ErrResponse(err)
138
+
139
+ return OKResponse(None)
140
+
141
+ @self.prompt_server.routes.post(f"/{COMMUNITY_API}/models")
142
+ async def commit_bizy_model(request):
143
+ sid = request.rel_url.query.get("clientId", "")
144
+ if not is_string_valid(sid):
145
+ return ErrResponse(errnos.INVALID_CLIENT_ID)
146
+
147
+ json_data = await request.json()
148
+
149
+ # 校验name和type
150
+ err = check_str_param(json_data, "name", errnos.INVALID_NAME)
151
+ if err is not None:
152
+ return err
153
+
154
+ if "/" in json_data["name"]:
155
+ return ErrResponse(errnos.INVALID_NAME)
156
+
157
+ err = check_type(json_data)
158
+ if err is not None:
159
+ return err
160
+
161
+ # 校验versions
162
+ if "versions" not in json_data or not isinstance(
163
+ json_data["versions"], list
164
+ ):
165
+ return ErrResponse(errnos.INVALID_VERSIONS)
166
+
167
+ versions = json_data["versions"]
168
+ version_names = set()
169
+
170
+ for version in versions:
171
+ # 检查version是否重复
172
+ if version.get("version") in version_names:
173
+ return ErrResponse(errnos.DUPLICATE_VERSION)
174
+
175
+ # 检查version字段是否合法
176
+ if not is_string_valid(version.get("version")) or "/" in version.get(
177
+ "version"
178
+ ):
179
+ return ErrResponse(errnos.INVALID_VERSION_NAME)
180
+
181
+ version_names.add(version.get("version"))
182
+
183
+ # 检查base_model, path和sign是否有值
184
+ for field in ["base_model", "path", "sign"]:
185
+ if not is_string_valid(version.get(field)):
186
+ err = errnos.INVALID_VERSION_FIELD.copy()
187
+ err.message = "Invalid version field: " + field
188
+ return ErrResponse(err)
189
+
190
+ # 调用API提交模型
191
+ resp, err = await self.api_client.commit_bizy_model(payload=json_data)
192
+ if err:
193
+ return ErrResponse(err)
194
+
195
+ # print("resp------------------------------->", json_data, resp)
196
+ # 开启线程检查同步状态
197
+ threading.Thread(
198
+ target=self.check_sync_status,
199
+ args=(resp["id"], resp["version_ids"], sid),
200
+ daemon=True,
201
+ ).start()
202
+
203
+ # enable refresh for lora
204
+ # TODO: enable refresh for other types
205
+ # bizyengine.core.path_utils.path_manager.enable_refresh_options("loras")
206
+
207
+ return OKResponse(resp)
208
+
209
+ @self.prompt_server.routes.post(f"/{COMMUNITY_API}/models/query")
210
+ async def query_my_models(request):
211
+ # 获取查询参数
212
+ mode = request.rel_url.query.get("mode", "")
213
+ if not mode or mode not in ["my", "my_fork", "publicity", "official"]:
214
+ return ErrResponse(errnos.INVALID_QUERY_MODE)
215
+
216
+ current = int(request.rel_url.query.get("current", "1"))
217
+ page_size = int(request.rel_url.query.get("page_size", "10"))
218
+ json_data = await request.json()
219
+ keyword = json_data["keyword"]
220
+ model_types = json_data.get("model_types", [])
221
+ base_models = json_data.get("base_models", [])
222
+ sort = json_data.get("sort", "")
223
+ resp, err = None, None
224
+
225
+ if mode in ["my", "my_fork"]:
226
+ # 调用API查询模型
227
+ resp, err = await self.api_client.query_models(
228
+ mode,
229
+ current,
230
+ page_size,
231
+ keyword=keyword,
232
+ model_types=model_types,
233
+ base_models=base_models,
234
+ sort=sort,
235
+ )
236
+ elif mode == "publicity":
237
+ # 调用API查询社区模型
238
+ resp, err = await self.api_client.query_community_models(
239
+ current,
240
+ page_size,
241
+ keyword=keyword,
242
+ model_types=model_types,
243
+ base_models=base_models,
244
+ sort=sort,
245
+ )
246
+ elif mode == "official":
247
+ resp, err = await self.api_client.query_official_models(
248
+ current,
249
+ page_size,
250
+ keyword=keyword,
251
+ model_types=model_types,
252
+ base_models=base_models,
253
+ sort=sort,
254
+ )
255
+ if err:
256
+ return ErrResponse(err)
257
+
258
+ return OKResponse(resp)
259
+
260
+ @self.prompt_server.routes.get(f"/{COMMUNITY_API}/models/{{model_id}}/detail")
261
+ async def get_model_detail(request):
262
+ # 获取路径参数中的模型ID
263
+ model_id = int(request.match_info["model_id"])
264
+
265
+ # 检查model_id是否合法
266
+ if not model_id or model_id <= 0:
267
+ return ErrResponse(errnos.INVALID_MODEL_ID)
268
+
269
+ source = request.rel_url.query.get("source", "")
270
+
271
+ # 调用API获取模型详情
272
+ resp, err = await self.api_client.get_model_detail(model_id, source)
273
+ if err:
274
+ return ErrResponse(err)
275
+
276
+ return OKResponse(resp)
277
+
278
+ @self.prompt_server.routes.delete(f"/{COMMUNITY_API}/models/{{model_id}}")
279
+ async def delete_model(request):
280
+ # 获取路径参数中的模型ID
281
+ model_id = int(request.match_info["model_id"])
282
+
283
+ # 检查model_id是否合法
284
+ if not model_id or model_id <= 0:
285
+ return ErrResponse(errnos.INVALID_MODEL_ID)
286
+
287
+ # 调用API删除模型
288
+ resp, err = await self.api_client.delete_bizy_model(model_id)
289
+ if err:
290
+ return ErrResponse(err)
291
+
292
+ return OKResponse(resp)
293
+
294
+ @self.prompt_server.routes.put(f"/{COMMUNITY_API}/models/{{model_id}}")
295
+ async def update_model(request):
296
+ sid = request.rel_url.query.get("clientId", "")
297
+ if not is_string_valid(sid):
298
+ return ErrResponse(errnos.INVALID_CLIENT_ID)
299
+ # 获取路径参数中的模型ID
300
+ model_id = int(request.match_info["model_id"])
301
+
302
+ # 检查model_id是否合法
303
+ if not model_id or model_id <= 0:
304
+ return ErrResponse(errnos.INVALID_MODEL_ID)
305
+
306
+ # 获取请求体数据
307
+ json_data = await request.json()
308
+
309
+ # 校验name和type
310
+ err = check_str_param(json_data, "name", errnos.INVALID_NAME)
311
+ if err is not None:
312
+ return err
313
+
314
+ if "/" in json_data["name"]:
315
+ return ErrResponse(errnos.INVALID_NAME)
316
+
317
+ err = check_type(json_data)
318
+ if err is not None:
319
+ return err
320
+
321
+ # 校验versions
322
+ if "versions" not in json_data or not isinstance(
323
+ json_data["versions"], list
324
+ ):
325
+ return ErrResponse(errnos.INVALID_VERSIONS)
326
+
327
+ versions = json_data["versions"]
328
+ version_names = set()
329
+
330
+ for version in versions:
331
+ # 检查version是否重复
332
+ if version.get("version") in version_names:
333
+ return ErrResponse(errnos.DUPLICATE_VERSION)
334
+
335
+ # 检查version字段是否合法
336
+ if not is_string_valid(version.get("version")) or "/" in version.get(
337
+ "version"
338
+ ):
339
+ return ErrResponse(errnos.INVALID_VERSION_NAME)
340
+
341
+ version_names.add(version.get("version"))
342
+
343
+ # 检查base_model, path和sign是否有值
344
+ for field in ["base_model", "path", "sign"]:
345
+ if not is_string_valid(version.get(field)):
346
+ return ErrResponse(errnos.INVALID_VERSION_FIELD(field))
347
+
348
+ # 调用API更新模型
349
+ resp, err = await self.api_client.update_model(
350
+ model_id, json_data["name"], json_data["type"], versions
351
+ )
352
+ if err:
353
+ return ErrResponse(err)
354
+
355
+ # 开启线程检查同步状态
356
+ threading.Thread(
357
+ target=self.check_sync_status,
358
+ args=(resp["id"], resp["version_ids"]),
359
+ daemon=True,
360
+ ).start()
361
+
362
+ return OKResponse(None)
363
+
364
+ @self.prompt_server.routes.post(
365
+ f"/{COMMUNITY_API}/models/fork/{{model_version_id}}"
366
+ )
367
+ async def fork_model_version(request):
368
+ try:
369
+ # 获取version_id参数
370
+ version_id = request.match_info["model_version_id"]
371
+ if not version_id:
372
+ return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
373
+
374
+ # 调用API fork模型版本
375
+ _, err = await self.api_client.fork_model_version(version_id)
376
+ if err:
377
+ return ErrResponse(err)
378
+
379
+ return OKResponse(None)
380
+
381
+ except Exception as e:
382
+ print(f"\033[31m[BizyAir]\033[0m Fail to fork model version: {str(e)}")
383
+ return ErrResponse(errnos.FORK_MODEL_VERSION)
384
+
385
+ @self.prompt_server.routes.delete(
386
+ f"/{COMMUNITY_API}/models/fork/{{model_version_id}}"
387
+ )
388
+ async def unfork_model_version(request):
389
+ try:
390
+ # 获取version_id参数
391
+ version_id = request.match_info["model_version_id"]
392
+ if not version_id:
393
+ return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
394
+
395
+ # 调用API fork模型版本
396
+ _, err = await self.api_client.unfork_model_version(version_id)
397
+ if err:
398
+ return ErrResponse(err)
399
+
400
+ return OKResponse(None)
401
+
402
+ except Exception as e:
403
+ print(
404
+ f"\033[31m[BizyAir]\033[0m Fail to unfork model version: {str(e)}"
405
+ )
406
+ return ErrResponse(errnos.FORK_MODEL_VERSION)
407
+
408
+ @self.prompt_server.routes.post(
409
+ f"/{COMMUNITY_API}/models/like/{{model_version_id}}"
410
+ )
411
+ async def like_model_version(request):
412
+ try:
413
+ # 获取version_id参数
414
+ version_id = request.match_info["model_version_id"]
415
+ if not version_id:
416
+ return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
417
+
418
+ # 调用API like模型版本
419
+ _, err = await self.api_client.toggle_user_like(
420
+ "model_version", version_id
421
+ )
422
+ if err:
423
+ return ErrResponse(err)
424
+
425
+ return OKResponse(None)
426
+
427
+ except Exception as e:
428
+ print(
429
+ f"\033[31m[BizyAir]\033[0m Fail to toggle like model version: {str(e)}"
430
+ )
431
+ return ErrResponse(errnos.TOGGLE_USER_LIKE)
432
+
433
+ @self.prompt_server.routes.get(
434
+ f"/{COMMUNITY_API}/models/versions/{{model_version_id}}/workflow_json/{{sign}}"
435
+ )
436
+ async def get_workflow_json(request):
437
+ model_version_id = int(request.match_info["model_version_id"])
438
+ # 检查model_version_id是否合法
439
+ if not model_version_id or model_version_id <= 0:
440
+ return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
441
+
442
+ sign = str(request.match_info["sign"])
443
+ if not sign:
444
+ return ErrResponse(errnos.INVALID_SIGN)
445
+
446
+ # 获取上传凭证
447
+ url, err = await self.api_client.get_download_url(
448
+ sign=sign, model_version_id=model_version_id
449
+ )
450
+ if err:
451
+ return ErrResponse(err)
452
+
453
+ # 请求该url,获取文件内容
454
+ async with aiohttp.ClientSession() as session:
455
+ async with session.get(url) as response:
456
+ if response.status != 200:
457
+ return ErrResponse(errnos.FAILED_TO_FETCH_WORKFLOW_JSON)
458
+ json_content = await response.json()
459
+
460
+ return OKResponse(json_content)
461
+
462
+ @self.prompt_server.routes.get(f"/{MODEL_HOST_API}" + "/{shareId}/models/files")
463
+ async def list_share_model_files(request):
464
+ shareId = request.match_info["shareId"]
465
+ if not is_string_valid(shareId):
466
+ return ErrResponse("INVALID_SHARE_ID")
467
+ payload = {}
468
+ query_params = ["type", "name", "ext_name"]
469
+ for param in query_params:
470
+ if param in request.rel_url.query and request.rel_url.query[param]:
471
+ payload[param] = request.rel_url.query[param]
472
+ model_files, err = await self.api_client.get_share_model_files(
473
+ shareId=shareId, payload=payload
474
+ )
475
+ if err is not None:
476
+ return ErrResponse(err)
477
+ return OKResponse(model_files)
478
+
479
+ @self.prompt_server.routes.get(f"/{API_PREFIX}/dict")
480
+ async def get_data_dict(request):
481
+ data_dict, err = await self.api_client.get_data_dict()
482
+ if err is not None:
483
+ return ErrResponse(err)
484
+
485
+ return OKResponse(data_dict)
486
+
487
+ @self.prompt_server.routes.post(f"/{COMMUNITY_API}/datasets")
488
+ async def commit_dataset(request):
489
+ sid = request.rel_url.query.get("clientId", "")
490
+ if not is_string_valid(sid):
491
+ return ErrResponse(errnos.INVALID_CLIENT_ID)
492
+
493
+ json_data = await request.json()
494
+
495
+ # 校验name和type
496
+ err = check_str_param(json_data, "name", errnos.INVALID_DATASET_NAME)
497
+ if err is not None:
498
+ return err
499
+
500
+ if "/" in json_data["name"]:
501
+ return ErrResponse(errnos.INVALID_DATASET_NAME)
502
+
503
+ # 校验versions
504
+ if "versions" not in json_data or not isinstance(
505
+ json_data["versions"], list
506
+ ):
507
+ return ErrResponse(errnos.INVALID_VERSIONS)
508
+
509
+ versions = json_data["versions"]
510
+ version_names = set()
511
+
512
+ for version in versions:
513
+ # 检查version是否重复
514
+ if version.get("version") in version_names:
515
+ return ErrResponse(errnos.DUPLICATE_VERSION)
516
+
517
+ # 检查version字段是否合法
518
+ if not is_string_valid(version.get("version")) or "/" in version.get(
519
+ "version"
520
+ ):
521
+ return ErrResponse(errnos.INVALID_DATASET_VERSION)
522
+
523
+ version_names.add(version.get("version"))
524
+
525
+ # 调用API提交数据集
526
+ resp, err = await self.api_client.commit_dataset(payload=json_data)
527
+ if err:
528
+ return ErrResponse(err)
529
+
530
+ # print("resp------------------------------->", json_data, resp)
531
+ # 开启线程检查同步状态
532
+ threading.Thread(
533
+ target=self.check_dataset_sync_status,
534
+ args=(resp["id"], resp["version_ids"], sid),
535
+ daemon=True,
536
+ ).start()
537
+
538
+ # enable refresh for lora
539
+ # TODO: enable refresh for other types
540
+ # bizyengine.core.path_utils.path_manager.enable_refresh_options("loras")
541
+
542
+ return OKResponse(resp)
543
+
544
+ @self.prompt_server.routes.put(f"/{COMMUNITY_API}/datasets/{{dataset_id}}")
545
+ async def update_dataset(request):
546
+ sid = request.rel_url.query.get("clientId", "")
547
+ if not is_string_valid(sid):
548
+ return ErrResponse(errnos.INVALID_CLIENT_ID)
549
+ # 获取路径参数中的数据集ID
550
+ dataset_id = int(request.match_info["dataset_id"])
551
+
552
+ # 检查model_id是否合法
553
+ if not dataset_id or dataset_id <= 0:
554
+ return ErrResponse(errnos.INVALID_DATASET_ID)
555
+
556
+ # 获取请求体数据
557
+ json_data = await request.json()
558
+
559
+ # 校验name和type
560
+ err = check_str_param(json_data, "name", errnos.INVALID_DATASET_NAME)
561
+ if err is not None:
562
+ return err
563
+
564
+ if "/" in json_data["name"]:
565
+ return ErrResponse(errnos.INVALID_DATASET_NAME)
566
+
567
+ # 校验versions
568
+ if "versions" not in json_data or not isinstance(
569
+ json_data["versions"], list
570
+ ):
571
+ return ErrResponse(errnos.INVALID_VERSIONS)
572
+
573
+ versions = json_data["versions"]
574
+ version_names = set()
575
+
576
+ for version in versions:
577
+ # 检查version是否重复
578
+ if version.get("version") in version_names:
579
+ return ErrResponse(errnos.DUPLICATE_VERSION)
580
+
581
+ # 检查version字段是否合法
582
+ if not is_string_valid(version.get("version")) or "/" in version.get(
583
+ "version"
584
+ ):
585
+ return ErrResponse(errnos.INVALID_VERSION_NAME)
586
+
587
+ version_names.add(version.get("version"))
588
+
589
+ # 调用API更新数据集
590
+ resp, err = await self.api_client.update_dataset(
591
+ dataset_id, json_data["name"], versions
592
+ )
593
+ if err:
594
+ return ErrResponse(err)
595
+
596
+ # 开启线程检查同步状态
597
+ threading.Thread(
598
+ target=self.check_dataset_sync_status,
599
+ args=(resp["id"], resp["version_ids"]),
600
+ daemon=True,
601
+ ).start()
602
+
603
+ return OKResponse(None)
604
+
605
+ @self.prompt_server.routes.delete(f"/{COMMUNITY_API}/datasets/{{dataset_id}}")
606
+ async def delete_dataset(request):
607
+ # 获取路径参数中的数据集ID
608
+ dataset_id = int(request.match_info["dataset_id"])
609
+
610
+ # 检查model_id是否合法
611
+ if not dataset_id or dataset_id <= 0:
612
+ return ErrResponse(errnos.INVALID_DATASET_ID)
613
+
614
+ # 调用API删除数据集
615
+ resp, err = await self.api_client.delete_dataset(dataset_id)
616
+ if err:
617
+ return ErrResponse(err)
618
+
619
+ return OKResponse(resp)
620
+
621
+ @self.prompt_server.routes.post(f"/{COMMUNITY_API}/datasets/query")
622
+ async def query_my_datasets(request):
623
+ current = int(request.rel_url.query.get("current", "1"))
624
+ page_size = int(request.rel_url.query.get("page_size", "10"))
625
+ keyword = None
626
+ annotated = None
627
+ if request.body_exists:
628
+ json_data = await request.json()
629
+ keyword = json_data["keyword"]
630
+ annotated = json_data["annotated"]
631
+ resp, err = None, None
632
+
633
+ # 调用API查询数据集
634
+ resp, err = await self.api_client.query_datasets(
635
+ current,
636
+ page_size,
637
+ keyword=keyword,
638
+ annotated=annotated,
639
+ )
640
+ if err:
641
+ return ErrResponse(err)
642
+
643
+ return OKResponse(resp)
644
+
645
+ @self.prompt_server.routes.get(
646
+ f"/{COMMUNITY_API}/datasets/{{dataset_id}}/detail"
647
+ )
648
+ async def get_dataset_detail(request):
649
+ # 获取路径参数中的数据集ID
650
+ dataset_id = int(request.match_info["dataset_id"])
651
+
652
+ # 检查dataset_id是否合法
653
+ if not dataset_id or dataset_id <= 0:
654
+ return ErrResponse(errnos.INVALID_DATASET_ID)
655
+
656
+ # 调用API获取数据集详情
657
+ resp, err = await self.api_client.get_dataset_detail(dataset_id)
658
+ if err:
659
+ return ErrResponse(err)
660
+
661
+ return OKResponse(resp)
662
+
663
+ @self.prompt_server.routes.post(f"/{COMMUNITY_API}/share")
664
+ async def create_share(request):
665
+ json_data = await request.json()
666
+ if "biz_id" not in json_data:
667
+ return ErrResponse(errnos.INVALID_SHARE_BIZ_ID)
668
+
669
+ biz_id = int(json_data["biz_id"])
670
+ if not biz_id or biz_id <= 0:
671
+ return ErrResponse(errnos.INVALID_SHARE_BIZ_ID)
672
+
673
+ if "type" not in json_data:
674
+ return ErrResponse(errnos.INVALID_SHARE_TYPE)
675
+ if not is_string_valid(json_data["type"]) or (
676
+ json_data["type"] != "bizy_model_version"
677
+ ):
678
+ return ErrResponse(errnos.INVALID_SHARE_TYPE)
679
+
680
+ # 调用API提交数据集
681
+ resp, err = await self.api_client.create_share(payload=json_data)
682
+ if err:
683
+ return ErrResponse(err)
684
+
685
+ return OKResponse(resp)
686
+
687
+ @self.prompt_server.routes.get(f"/{COMMUNITY_API}/share/{{code}}")
688
+ async def get_share_detail(request):
689
+ # 获取路径参数中的数据集ID
690
+ code = str(request.match_info["code"])
691
+
692
+ # 检查code是否合法
693
+ if not is_string_valid(code):
694
+ return ErrResponse(errnos.INVALID_SHARE_CODE)
695
+
696
+ # 调用API获取数据集详情
697
+ resp, err = await self.api_client.get_share_detail(code)
698
+ if err:
699
+ return ErrResponse(err)
700
+
701
+ return OKResponse(resp)
702
+
703
+ @self.prompt_server.routes.get(f"/{COMMUNITY_API}/model_version/{{version_id}}")
704
+ async def get_model_version_detail(request):
705
+ # 获取路径参数中的数据集ID
706
+ version_id = int(request.match_info["version_id"])
707
+
708
+ # 检查version_id是否合法
709
+ if not version_id or version_id <= 0:
710
+ return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
711
+
712
+ # 调用API获取数据集详情
713
+ resp, err = await self.api_client.get_model_version_detail(version_id)
714
+ if err:
715
+ return ErrResponse(err)
716
+
717
+ return OKResponse(resp)
718
+
719
+ @self.prompt_server.routes.get(f"/{COMMUNITY_API}/notifications/unread_count")
720
+ async def get_notification_unread_count(request):
721
+ # 获取当前用户的未读消息数量
722
+ resp, err = await self.api_client.get_notification_unread_count()
723
+ if err:
724
+ return ErrResponse(err)
725
+ return OKResponse(resp)
726
+
727
+ @self.prompt_server.routes.get(f"/{COMMUNITY_API}/notifications")
728
+ async def fetch_notifications(request):
729
+ # 获取当前用户的消息列表
730
+ typesStr = request.rel_url.query.get("types", None)
731
+ types = None
732
+ if typesStr:
733
+ types = [int(x) for x in typesStr.split(",")]
734
+ read_status = request.rel_url.query.get("read_status", None)
735
+ page_size = int(request.rel_url.query.get("page_size", "10"))
736
+ last_pm_id = int(request.rel_url.query.get("last_pm_id", "0"))
737
+ last_broadcast_id = int(request.rel_url.query.get("last_broadcast_id", "0"))
738
+
739
+ resp, err = await self.api_client.fetch_notifications(
740
+ page_size, last_pm_id, last_broadcast_id, types, read_status
741
+ )
742
+ if err:
743
+ return ErrResponse(err)
744
+ return OKResponse(resp)
745
+
746
+ @self.prompt_server.routes.post(f"/{COMMUNITY_API}/notifications/read_all")
747
+ async def read_all_notifications(request):
748
+ # 将当前用户的所有未读消息标记为已读
749
+ type = int(request.rel_url.query.get("type", "0"))
750
+ resp, err = await self.api_client.read_all_notifications(type)
751
+ if err:
752
+ return ErrResponse(err)
753
+ return OKResponse(resp)
754
+
755
+ @self.prompt_server.routes.post(f"/{COMMUNITY_API}/notifications/read")
756
+ async def read_notifications(request):
757
+ # 将当前用户的未读消息标记为已读
758
+ json_data = await request.json()
759
+ if "ids" not in json_data:
760
+ return ErrResponse(errnos.INVALID_NOTIF_ID)
761
+ ids = json_data.get("ids")
762
+ resp, err = await self.api_client.read_notifications(ids)
763
+ if err:
764
+ return ErrResponse(err)
765
+ return OKResponse(resp)
766
+
767
+ async def send_json(self, event, data, sid=None):
768
+ message = {"type": event, "data": data}
769
+
770
+ if sid is None:
771
+ sockets = list(self.sockets.values())
772
+ for ws in sockets:
773
+ await self.send_socket_catch_exception(ws.send_json, message)
774
+ elif sid in self.sockets:
775
+ await self.send_socket_catch_exception(self.sockets[sid].send_json, message)
776
+
777
+ async def send_error(self, err: ErrorNo, sid=None):
778
+ await self.send_json(
779
+ event="error",
780
+ data={"message": err.message, "code": err.code, "data": err.data},
781
+ sid=sid,
782
+ )
783
+
784
+ async def send_socket_catch_exception(self, function, message):
785
+ try:
786
+ await function(message)
787
+ except (
788
+ aiohttp.ClientError,
789
+ aiohttp.ClientPayloadError,
790
+ ConnectionResetError,
791
+ ) as err:
792
+ logging.warning("send error: {}".format(err))
793
+
794
+ def send_sync(self, event, data, sid=None):
795
+ asyncio.run_coroutine_threadsafe(self.send_json(event, data, sid), self.loop)
796
+
797
+ def send_sync_error(self, err: ErrorNo, sid=None):
798
+ self.send_sync(
799
+ event="error",
800
+ data={"message": err.message, "code": err.code, "data": err.data},
801
+ sid=sid,
802
+ )
803
+
804
+ def check_sync_status(self, bizy_model_id: str, version_ids: list, sid=None):
805
+ removed = []
806
+ while True:
807
+ # 从version_ids中移除removed中的version_id
808
+ version_ids = [
809
+ version_id for version_id in version_ids if version_id not in removed
810
+ ]
811
+ if len(version_ids) == 0:
812
+ return
813
+
814
+ for version_id in version_ids:
815
+ future = asyncio.run_coroutine_threadsafe(
816
+ self.api_client.get_model_version_detail(version_id=version_id),
817
+ self.loop,
818
+ )
819
+
820
+ model_version, err = future.result(timeout=2)
821
+
822
+ if err is not None:
823
+ self.send_sync(
824
+ event="error",
825
+ data={
826
+ "message": err.message,
827
+ "code": err.code,
828
+ "data": {
829
+ "bizy_model_id": bizy_model_id,
830
+ "version_id": version_id,
831
+ },
832
+ },
833
+ sid=sid,
834
+ )
835
+ removed.append(version_id)
836
+ continue
837
+
838
+ if "available" in model_version and model_version["available"]:
839
+ self.send_sync(
840
+ event="synced",
841
+ data={
842
+ "version_id": model_version["id"],
843
+ "version": model_version["version"],
844
+ "model_id": bizy_model_id,
845
+ "model_name": model_version["bizy_model_name"],
846
+ },
847
+ sid=sid,
848
+ )
849
+ removed.append(version_id)
850
+ time.sleep(5)
851
+
852
+ def check_dataset_sync_status(self, dataset_id: str, version_ids: list, sid=None):
853
+ removed = []
854
+ while True:
855
+ # 从version_ids中移除removed中的version_id
856
+ version_ids = [
857
+ version_id for version_id in version_ids if version_id not in removed
858
+ ]
859
+ if len(version_ids) == 0:
860
+ return
861
+
862
+ for version_id in version_ids:
863
+ future = asyncio.run_coroutine_threadsafe(
864
+ self.api_client.get_dataset_version_detail(version_id=version_id),
865
+ self.loop,
866
+ )
867
+
868
+ dataset_version, err = future.result(timeout=2)
869
+
870
+ if err is not None:
871
+ self.send_sync(
872
+ event="error",
873
+ data={
874
+ "message": err.message,
875
+ "code": err.code,
876
+ "data": {
877
+ "dataset_id": dataset_id,
878
+ "version_id": version_id,
879
+ },
880
+ },
881
+ sid=sid,
882
+ )
883
+ removed.append(version_id)
884
+ continue
885
+
886
+ if "available" in dataset_version and dataset_version["available"]:
887
+ self.send_sync(
888
+ event="synced",
889
+ data={
890
+ "version_id": dataset_version["id"],
891
+ "version": dataset_version["version"],
892
+ "dataset_id": dataset_id,
893
+ "dataset_name": dataset_version["dataset_name"],
894
+ },
895
+ sid=sid,
896
+ )
897
+ removed.append(version_id)
898
+ time.sleep(5)