bizyengine 1.2.7__py3-none-any.whl → 1.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,7 +10,9 @@ import urllib.parse
10
10
  import uuid
11
11
 
12
12
  import aiohttp
13
+ import execution
13
14
  import openai
15
+ from bizyengine.core.common.env_var import BIZYAIR_SERVER_MODE
14
16
  from server import PromptServer
15
17
 
16
18
  from .api_client import APIClient
@@ -30,6 +32,12 @@ MODEL_API = f"{API_PREFIX}/model"
30
32
  logging.basicConfig(level=logging.DEBUG)
31
33
 
32
34
 
35
+ def _get_request_api_key(request_headers):
36
+ if BIZYAIR_SERVER_MODE:
37
+ return request_headers.get("api_key")
38
+ return None
39
+
40
+
33
41
  class BizyAirServer:
34
42
  def __init__(self):
35
43
  BizyAirServer.instance = self
@@ -42,6 +50,8 @@ class BizyAirServer:
42
50
  self.setup_routes()
43
51
 
44
52
  def setup_routes(self):
53
+ # 以下路径不论本地模式还是服务器模式都要注册
54
+
45
55
  @self.prompt_server.routes.get(f"/{COMMUNITY_API}/model_types")
46
56
  async def list_model_types(request):
47
57
  return OKResponse(types())
@@ -50,167 +60,51 @@ class BizyAirServer:
50
60
  async def list_base_model_types(request):
51
61
  return OKResponse(base_model_types())
52
62
 
53
- @self.prompt_server.routes.get(f"/{USER_API}/info")
54
- async def user_info(request):
55
- info, err = await self.api_client.user_info()
56
- if err is not None:
57
- return ErrResponse(err)
58
-
59
- return OKResponse(info)
60
-
61
- @self.prompt_server.routes.get(f"/{API_PREFIX}/ws")
62
- async def websocket_handler(request):
63
- ws = aiohttp.web.WebSocketResponse()
64
- await ws.prepare(request)
65
- sid = request.rel_url.query.get("clientId", "")
66
- if sid:
67
- # Reusing existing session, remove old
68
- self.sockets.pop(sid, None)
69
- else:
70
- sid = uuid.uuid4().hex
71
-
72
- self.sockets[sid] = ws
73
-
74
- try:
75
- # Send initial state to the new client
76
- await self.send_json(
77
- event="status", data={"status": "connected"}, sid=sid
78
- )
79
-
80
- async for msg in ws:
81
- if msg.type == aiohttp.WSMsgType.TEXT:
82
- if msg.data == "ping":
83
- await ws.send_str("pong")
84
- if msg.type == aiohttp.WSMsgType.ERROR:
85
- logging.warning(
86
- "ws connection closed with exception %s" % ws.exception()
87
- )
88
- finally:
89
- self.sockets.pop(sid, None)
90
- return ws
91
-
92
- @self.prompt_server.routes.get(f"/{COMMUNITY_API}/sign")
93
- async def sign(request):
94
- sha256sum = request.rel_url.query.get("sha256sum")
95
- if not is_string_valid(sha256sum):
96
- return ErrResponse(errnos.EMPTY_SHA256SUM)
97
-
98
- type = request.rel_url.query.get("type")
99
-
100
- sign_data, err = await self.api_client.sign(sha256sum, type)
101
- if err is not None:
102
- return ErrResponse(err)
103
-
104
- return OKResponse(sign_data)
105
-
106
- @self.prompt_server.routes.get(f"/{COMMUNITY_API}/upload_token")
107
- async def upload_token(request):
108
- filename = request.rel_url.query.get("filename", "")
109
- # 校验filename
110
- if not is_string_valid(filename):
111
- return ErrResponse(errnos.INVALID_FILENAME)
112
-
113
- filename = urllib.parse.quote(filename)
114
- token, err = await self.api_client.get_upload_token(filename=filename)
115
- if err is not None:
116
- return ErrResponse(err)
117
- return OKResponse(token)
118
-
119
- @self.prompt_server.routes.post(f"/{COMMUNITY_API}/commit_file")
120
- async def commit_file(request):
121
- json_data = await request.json()
122
-
123
- if "sha256sum" not in json_data:
124
- return ErrResponse(errnos.EMPTY_SHA256SUM)
125
- sha256sum = json_data.get("sha256sum")
126
-
127
- if "object_key" not in json_data:
128
- return ErrResponse(errnos.INVALID_OBJECT_KEY)
129
- object_key = json_data.get("object_key")
130
-
131
- if "type" not in json_data:
132
- return ErrResponse(errnos.INVALID_TYPE)
133
- type = json_data.get("type")
134
-
135
- md5_hash = ""
136
- if "md5_hash" in json_data:
137
- md5_hash = json_data.get("md5_hash")
63
+ @self.prompt_server.routes.post(f"/{COMMUNITY_API}/datasets/query")
64
+ async def query_my_datasets(request):
65
+ current = int(request.rel_url.query.get("current", "1"))
66
+ page_size = int(request.rel_url.query.get("page_size", "10"))
67
+ keyword = None
68
+ annotated = None
69
+ if request.body_exists:
70
+ json_data = await request.json()
71
+ keyword = json_data["keyword"]
72
+ annotated = json_data["annotated"]
73
+ resp, err = None, None
138
74
 
139
- commit_data, err = await self.api_client.commit_file(
140
- signature=sha256sum, object_key=object_key, md5_hash=md5_hash, type=type
75
+ # 调用API查询数据集
76
+ request_api_key = _get_request_api_key(request.headers)
77
+ resp, err = await self.api_client.query_datasets(
78
+ current,
79
+ page_size,
80
+ keyword=keyword,
81
+ annotated=annotated,
82
+ request_api_key=request_api_key,
141
83
  )
142
- # print("commit_data", commit_data)
143
- if err is not None:
84
+ if err:
144
85
  return ErrResponse(err)
145
86
 
146
- return OKResponse(None)
147
-
148
- @self.prompt_server.routes.post(f"/{COMMUNITY_API}/models")
149
- async def commit_bizy_model(request):
150
- sid = request.rel_url.query.get("clientId", "")
151
- if not is_string_valid(sid):
152
- return ErrResponse(errnos.INVALID_CLIENT_ID)
153
-
154
- json_data = await request.json()
155
-
156
- # 校验name和type
157
- err = check_str_param(json_data, "name", errnos.INVALID_NAME)
158
- if err is not None:
159
- return err
160
-
161
- if "/" in json_data["name"]:
162
- return ErrResponse(errnos.INVALID_NAME)
163
-
164
- err = check_type(json_data)
165
- if err is not None:
166
- return err
167
-
168
- # 校验versions
169
- if "versions" not in json_data or not isinstance(
170
- json_data["versions"], list
171
- ):
172
- return ErrResponse(errnos.INVALID_VERSIONS)
173
-
174
- versions = json_data["versions"]
175
- version_names = set()
176
-
177
- for version in versions:
178
- # 检查version是否重复
179
- if version.get("version") in version_names:
180
- return ErrResponse(errnos.DUPLICATE_VERSION)
181
-
182
- # 检查version字段是否合法
183
- if not is_string_valid(version.get("version")) or "/" in version.get(
184
- "version"
185
- ):
186
- return ErrResponse(errnos.INVALID_VERSION_NAME)
87
+ return OKResponse(resp)
187
88
 
188
- version_names.add(version.get("version"))
89
+ @self.prompt_server.routes.get(
90
+ f"/{COMMUNITY_API}/datasets/{{dataset_id}}/detail"
91
+ )
92
+ async def get_dataset_detail(request):
93
+ # 获取路径参数中的数据集ID
94
+ dataset_id = int(request.match_info["dataset_id"])
189
95
 
190
- # 检查base_model, path和sign是否有值
191
- for field in ["base_model", "path", "sign"]:
192
- if not is_string_valid(version.get(field)):
193
- err = errnos.INVALID_VERSION_FIELD.copy()
194
- err.message = "Invalid version field: " + field
195
- return ErrResponse(err)
96
+ # 检查dataset_id是否合法
97
+ if not dataset_id or dataset_id <= 0:
98
+ return ErrResponse(errnos.INVALID_DATASET_ID)
196
99
 
197
- # 调用API提交模型
198
- resp, err = await self.api_client.commit_bizy_model(payload=json_data)
100
+ # 调用API获取数据集详情
101
+ request_api_key = _get_request_api_key(request.headers)
102
+ resp, err = await self.api_client.get_dataset_detail(
103
+ dataset_id, request_api_key=request_api_key
104
+ )
199
105
  if err:
200
106
  return ErrResponse(err)
201
107
 
202
- # print("resp------------------------------->", json_data, resp)
203
- # 开启线程检查同步状态
204
- threading.Thread(
205
- target=self.check_sync_status,
206
- args=(resp["id"], resp["version_ids"], sid),
207
- daemon=True,
208
- ).start()
209
-
210
- # enable refresh for lora
211
- # TODO: enable refresh for other types
212
- # bizyengine.core.path_utils.path_manager.enable_refresh_options("loras")
213
-
214
108
  return OKResponse(resp)
215
109
 
216
110
  @self.prompt_server.routes.post(f"/{COMMUNITY_API}/models/query")
@@ -229,6 +123,7 @@ class BizyAirServer:
229
123
  sort = json_data.get("sort", "")
230
124
  resp, err = None, None
231
125
 
126
+ request_api_key = _get_request_api_key(request.headers)
232
127
  if mode in ["my", "my_fork"]:
233
128
  # 调用API查询模型
234
129
  resp, err = await self.api_client.query_models(
@@ -239,6 +134,7 @@ class BizyAirServer:
239
134
  model_types=model_types,
240
135
  base_models=base_models,
241
136
  sort=sort,
137
+ request_api_key=request_api_key,
242
138
  )
243
139
  elif mode == "publicity":
244
140
  # 调用API查询社区模型
@@ -249,6 +145,7 @@ class BizyAirServer:
249
145
  model_types=model_types,
250
146
  base_models=base_models,
251
147
  sort=sort,
148
+ request_api_key=request_api_key,
252
149
  )
253
150
  elif mode == "official":
254
151
  resp, err = await self.api_client.query_official_models(
@@ -258,6 +155,7 @@ class BizyAirServer:
258
155
  model_types=model_types,
259
156
  base_models=base_models,
260
157
  sort=sort,
158
+ request_api_key=request_api_key,
261
159
  )
262
160
  if err:
263
161
  return ErrResponse(err)
@@ -276,7 +174,10 @@ class BizyAirServer:
276
174
  source = request.rel_url.query.get("source", "")
277
175
 
278
176
  # 调用API获取模型详情
279
- resp, err = await self.api_client.get_model_detail(model_id, source)
177
+ request_api_key = _get_request_api_key(request.headers)
178
+ resp, err = await self.api_client.get_model_detail(
179
+ model_id, source, request_api_key=request_api_key
180
+ )
280
181
  if err:
281
182
  return ErrResponse(err)
282
183
 
@@ -292,7 +193,10 @@ class BizyAirServer:
292
193
  return ErrResponse(errnos.INVALID_MODEL_ID)
293
194
 
294
195
  # 调用API删除模型
295
- resp, err = await self.api_client.delete_bizy_model(model_id)
196
+ request_api_key = _get_request_api_key(request.headers)
197
+ resp, err = await self.api_client.delete_bizy_model(
198
+ model_id, request_api_key=request_api_key
199
+ )
296
200
  if err:
297
201
  return ErrResponse(err)
298
202
 
@@ -353,8 +257,13 @@ class BizyAirServer:
353
257
  return ErrResponse(errnos.INVALID_VERSION_FIELD(field))
354
258
 
355
259
  # 调用API更新模型
260
+ request_api_key = _get_request_api_key(request.headers)
356
261
  resp, err = await self.api_client.update_model(
357
- model_id, json_data["name"], json_data["type"], versions
262
+ model_id,
263
+ json_data["name"],
264
+ json_data["type"],
265
+ versions,
266
+ request_api_key=request_api_key,
358
267
  )
359
268
  if err:
360
269
  return ErrResponse(err)
@@ -379,7 +288,10 @@ class BizyAirServer:
379
288
  return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
380
289
 
381
290
  # 调用API fork模型版本
382
- _, err = await self.api_client.fork_model_version(version_id)
291
+ request_api_key = _get_request_api_key(request.headers)
292
+ _, err = await self.api_client.fork_model_version(
293
+ version_id, request_api_key=request_api_key
294
+ )
383
295
  if err:
384
296
  return ErrResponse(err)
385
297
 
@@ -400,7 +312,10 @@ class BizyAirServer:
400
312
  return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
401
313
 
402
314
  # 调用API fork模型版本
403
- _, err = await self.api_client.unfork_model_version(version_id)
315
+ request_api_key = _get_request_api_key(request.headers)
316
+ _, err = await self.api_client.unfork_model_version(
317
+ version_id, request_api_key=request_api_key
318
+ )
404
319
  if err:
405
320
  return ErrResponse(err)
406
321
 
@@ -423,8 +338,9 @@ class BizyAirServer:
423
338
  return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
424
339
 
425
340
  # 调用API like模型版本
341
+ request_api_key = _get_request_api_key(request.headers)
426
342
  _, err = await self.api_client.toggle_user_like(
427
- "model_version", version_id
343
+ "model_version", version_id, request_api_key=request_api_key
428
344
  )
429
345
  if err:
430
346
  return ErrResponse(err)
@@ -451,8 +367,11 @@ class BizyAirServer:
451
367
  return ErrResponse(errnos.INVALID_SIGN)
452
368
 
453
369
  # 获取上传凭证
370
+ request_api_key = _get_request_api_key(request.headers)
454
371
  url, err = await self.api_client.get_download_url(
455
- sign=sign, model_version_id=model_version_id
372
+ sign=sign,
373
+ model_version_id=model_version_id,
374
+ request_api_key=request_api_key,
456
375
  )
457
376
  if err:
458
377
  return ErrResponse(err)
@@ -466,26 +385,12 @@ class BizyAirServer:
466
385
 
467
386
  return OKResponse(json_content)
468
387
 
469
- @self.prompt_server.routes.get(f"/{MODEL_HOST_API}" + "/{shareId}/models/files")
470
- async def list_share_model_files(request):
471
- shareId = request.match_info["shareId"]
472
- if not is_string_valid(shareId):
473
- return ErrResponse("INVALID_SHARE_ID")
474
- payload = {}
475
- query_params = ["type", "name", "ext_name"]
476
- for param in query_params:
477
- if param in request.rel_url.query and request.rel_url.query[param]:
478
- payload[param] = request.rel_url.query[param]
479
- model_files, err = await self.api_client.get_share_model_files(
480
- shareId=shareId, payload=payload
481
- )
482
- if err is not None:
483
- return ErrResponse(err)
484
- return OKResponse(model_files)
485
-
486
388
  @self.prompt_server.routes.get(f"/{API_PREFIX}/dict")
487
389
  async def get_data_dict(request):
488
- data_dict, err = await self.api_client.get_data_dict()
390
+ request_api_key = _get_request_api_key(request.headers)
391
+ data_dict, err = await self.api_client.get_data_dict(
392
+ request_api_key=request_api_key
393
+ )
489
394
  if err is not None:
490
395
  return ErrResponse(err)
491
396
 
@@ -530,7 +435,10 @@ class BizyAirServer:
530
435
  version_names.add(version.get("version"))
531
436
 
532
437
  # 调用API提交数据集
533
- resp, err = await self.api_client.commit_dataset(payload=json_data)
438
+ request_api_key = _get_request_api_key(request.headers)
439
+ resp, err = await self.api_client.commit_dataset(
440
+ payload=json_data, request_api_key=request_api_key
441
+ )
534
442
  if err:
535
443
  return ErrResponse(err)
536
444
 
@@ -594,8 +502,9 @@ class BizyAirServer:
594
502
  version_names.add(version.get("version"))
595
503
 
596
504
  # 调用API更新数据集
505
+ request_api_key = _get_request_api_key(request.headers)
597
506
  resp, err = await self.api_client.update_dataset(
598
- dataset_id, json_data["name"], versions
507
+ dataset_id, json_data["name"], versions, request_api_key=request_api_key
599
508
  )
600
509
  if err:
601
510
  return ErrResponse(err)
@@ -619,76 +528,289 @@ class BizyAirServer:
619
528
  return ErrResponse(errnos.INVALID_DATASET_ID)
620
529
 
621
530
  # 调用API删除数据集
622
- resp, err = await self.api_client.delete_dataset(dataset_id)
531
+ request_api_key = _get_request_api_key(request.headers)
532
+ resp, err = await self.api_client.delete_dataset(
533
+ dataset_id, request_api_key=request_api_key
534
+ )
623
535
  if err:
624
536
  return ErrResponse(err)
625
537
 
626
538
  return OKResponse(resp)
627
539
 
628
- @self.prompt_server.routes.post(f"/{COMMUNITY_API}/datasets/query")
629
- async def query_my_datasets(request):
630
- current = int(request.rel_url.query.get("current", "1"))
631
- page_size = int(request.rel_url.query.get("page_size", "10"))
632
- keyword = None
633
- annotated = None
634
- if request.body_exists:
635
- json_data = await request.json()
636
- keyword = json_data["keyword"]
637
- annotated = json_data["annotated"]
638
- resp, err = None, None
540
+ @self.prompt_server.routes.post(f"/{COMMUNITY_API}/share")
541
+ async def create_share(request):
542
+ json_data = await request.json()
543
+ if "biz_id" not in json_data:
544
+ return ErrResponse(errnos.INVALID_SHARE_BIZ_ID)
639
545
 
640
- # 调用API查询数据集
641
- resp, err = await self.api_client.query_datasets(
642
- current,
643
- page_size,
644
- keyword=keyword,
645
- annotated=annotated,
546
+ biz_id = int(json_data["biz_id"])
547
+ if not biz_id or biz_id <= 0:
548
+ return ErrResponse(errnos.INVALID_SHARE_BIZ_ID)
549
+
550
+ if "type" not in json_data:
551
+ return ErrResponse(errnos.INVALID_SHARE_TYPE)
552
+ if not is_string_valid(json_data["type"]) or (
553
+ json_data["type"] != "bizy_model_version"
554
+ ):
555
+ return ErrResponse(errnos.INVALID_SHARE_TYPE)
556
+
557
+ # 调用API提交数据集
558
+ request_api_key = _get_request_api_key(request.headers)
559
+ resp, err = await self.api_client.create_share(
560
+ payload=json_data, request_api_key=request_api_key
646
561
  )
647
562
  if err:
648
563
  return ErrResponse(err)
649
564
 
650
565
  return OKResponse(resp)
651
566
 
652
- @self.prompt_server.routes.get(
653
- f"/{COMMUNITY_API}/datasets/{{dataset_id}}/detail"
654
- )
655
- async def get_dataset_detail(request):
567
+ @self.prompt_server.routes.get(f"/{COMMUNITY_API}/model_version/{{version_id}}")
568
+ async def get_model_version_detail(request):
656
569
  # 获取路径参数中的数据集ID
657
- dataset_id = int(request.match_info["dataset_id"])
570
+ version_id = int(request.match_info["version_id"])
658
571
 
659
- # 检查dataset_id是否合法
660
- if not dataset_id or dataset_id <= 0:
661
- return ErrResponse(errnos.INVALID_DATASET_ID)
572
+ # 检查version_id是否合法
573
+ if not version_id or version_id <= 0:
574
+ return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
662
575
 
663
576
  # 调用API获取数据集详情
664
- resp, err = await self.api_client.get_dataset_detail(dataset_id)
577
+ request_api_key = _get_request_api_key(request.headers)
578
+ resp, err = await self.api_client.get_model_version_detail(
579
+ version_id, request_api_key=request_api_key
580
+ )
665
581
  if err:
666
582
  return ErrResponse(err)
667
583
 
668
584
  return OKResponse(resp)
669
585
 
670
- @self.prompt_server.routes.post(f"/{COMMUNITY_API}/share")
671
- async def create_share(request):
586
+ @self.prompt_server.routes.get(f"/{COMMUNITY_API}/sign")
587
+ async def sign(request):
588
+ sha256sum = request.rel_url.query.get("sha256sum")
589
+ if not is_string_valid(sha256sum):
590
+ return ErrResponse(errnos.EMPTY_SHA256SUM)
591
+
592
+ type = request.rel_url.query.get("type")
593
+ request_api_key = _get_request_api_key(request.headers)
594
+ sign_data, err = await self.api_client.sign(
595
+ sha256sum, type, request_api_key=request_api_key
596
+ )
597
+ if err is not None:
598
+ return ErrResponse(err)
599
+
600
+ return OKResponse(sign_data)
601
+
602
+ @self.prompt_server.routes.get(f"/{COMMUNITY_API}/upload_token")
603
+ async def upload_token(request):
604
+ filename = request.rel_url.query.get("filename", "")
605
+ # 校验filename
606
+ if not is_string_valid(filename):
607
+ return ErrResponse(errnos.INVALID_FILENAME)
608
+
609
+ filename = urllib.parse.quote(filename)
610
+ request_api_key = _get_request_api_key(request.headers)
611
+ token, err = await self.api_client.get_upload_token(
612
+ filename=filename, request_api_key=request_api_key
613
+ )
614
+ if err is not None:
615
+ return ErrResponse(err)
616
+ return OKResponse(token)
617
+
618
+ @self.prompt_server.routes.post(f"/{COMMUNITY_API}/commit_file")
619
+ async def commit_file(request):
672
620
  json_data = await request.json()
673
- if "biz_id" not in json_data:
674
- return ErrResponse(errnos.INVALID_SHARE_BIZ_ID)
675
621
 
676
- biz_id = int(json_data["biz_id"])
677
- if not biz_id or biz_id <= 0:
678
- return ErrResponse(errnos.INVALID_SHARE_BIZ_ID)
622
+ if "sha256sum" not in json_data:
623
+ return ErrResponse(errnos.EMPTY_SHA256SUM)
624
+ sha256sum = json_data.get("sha256sum")
625
+
626
+ if "object_key" not in json_data:
627
+ return ErrResponse(errnos.INVALID_OBJECT_KEY)
628
+ object_key = json_data.get("object_key")
679
629
 
680
630
  if "type" not in json_data:
681
- return ErrResponse(errnos.INVALID_SHARE_TYPE)
682
- if not is_string_valid(json_data["type"]) or (
683
- json_data["type"] != "bizy_model_version"
631
+ return ErrResponse(errnos.INVALID_TYPE)
632
+ type = json_data.get("type")
633
+
634
+ md5_hash = ""
635
+ if "md5_hash" in json_data:
636
+ md5_hash = json_data.get("md5_hash")
637
+
638
+ request_api_key = _get_request_api_key(request.headers)
639
+ commit_data, err = await self.api_client.commit_file(
640
+ signature=sha256sum,
641
+ object_key=object_key,
642
+ md5_hash=md5_hash,
643
+ type=type,
644
+ request_api_key=request_api_key,
645
+ )
646
+ # print("commit_data", commit_data)
647
+ if err is not None:
648
+ return ErrResponse(err)
649
+
650
+ return OKResponse(None)
651
+
652
+ # 由于历史原因,前端请求body里有apikey所以是post
653
+ @self.prompt_server.routes.post(f"/{API_PREFIX}/get_silicon_cloud_llm_models")
654
+ async def get_silicon_cloud_llm_models_endpoint(request):
655
+ request_api_key = _get_request_api_key(request.headers)
656
+ all_models = await self.api_client.fetch_all_llm_models(
657
+ request_api_key=request_api_key
658
+ )
659
+ llm_models = [model for model in all_models if "vl" not in model.lower()]
660
+ llm_models.append("No LLM Enhancement")
661
+ return aiohttp.web.json_response(llm_models)
662
+
663
+ # 由于历史原因,前端请求body里有apikey所以是post
664
+ @self.prompt_server.routes.post(f"/{API_PREFIX}/get_silicon_cloud_vlm_models")
665
+ async def get_silicon_cloud_vlm_models_endpoint(request):
666
+ request_api_key = _get_request_api_key(request.headers)
667
+ all_models = await self.api_client.fetch_all_llm_models(
668
+ request_api_key=request_api_key
669
+ )
670
+ vlm_models = [model for model in all_models if "vl" in model.lower()]
671
+ vlm_models.append("No VLM Enhancement")
672
+ return aiohttp.web.json_response(vlm_models)
673
+
674
+ @self.prompt_server.routes.post(f"/{API_PREFIX}/validate_prompt")
675
+ async def validate_prompt(request):
676
+ json_data = await request.json()
677
+ if not "prompt" in json_data:
678
+ return ErrResponse(errnos.MISSING_PROMPT)
679
+
680
+ valid = execution.validate_prompt(json_data["prompt"])
681
+ if valid[0]:
682
+ return OKResponse(None)
683
+ else:
684
+ err = errnos.INVALID_PROMPT.copy()
685
+ err.data = {"error": valid[1], "node_errors": valid[3]}
686
+ return ErrResponse(err)
687
+
688
+ # 服务器模式下以下路径不会注册
689
+ if BIZYAIR_SERVER_MODE:
690
+ return
691
+
692
+ @self.prompt_server.routes.get(f"/{MODEL_HOST_API}" + "/{shareId}/models/files")
693
+ async def list_share_model_files(request):
694
+ shareId = request.match_info["shareId"]
695
+ if not is_string_valid(shareId):
696
+ return ErrResponse("INVALID_SHARE_ID")
697
+ payload = {}
698
+ query_params = ["type", "name", "ext_name"]
699
+ for param in query_params:
700
+ if param in request.rel_url.query and request.rel_url.query[param]:
701
+ payload[param] = request.rel_url.query[param]
702
+ model_files, err = await self.api_client.get_share_model_files(
703
+ shareId=shareId, payload=payload
704
+ )
705
+ if err is not None:
706
+ return ErrResponse(err)
707
+ return OKResponse(model_files)
708
+
709
+ @self.prompt_server.routes.get(f"/{USER_API}/info")
710
+ async def user_info(request):
711
+ info, err = await self.api_client.user_info()
712
+ if err is not None:
713
+ return ErrResponse(err)
714
+
715
+ return OKResponse(info)
716
+
717
+ @self.prompt_server.routes.get(f"/{API_PREFIX}/ws")
718
+ async def websocket_handler(request):
719
+ ws = aiohttp.web.WebSocketResponse()
720
+ await ws.prepare(request)
721
+ sid = request.rel_url.query.get("clientId", "")
722
+ if sid:
723
+ # Reusing existing session, remove old
724
+ self.sockets.pop(sid, None)
725
+ else:
726
+ sid = uuid.uuid4().hex
727
+
728
+ self.sockets[sid] = ws
729
+
730
+ try:
731
+ # Send initial state to the new client
732
+ await self.send_json(
733
+ event="status", data={"status": "connected"}, sid=sid
734
+ )
735
+
736
+ async for msg in ws:
737
+ if msg.type == aiohttp.WSMsgType.TEXT:
738
+ if msg.data == "ping":
739
+ await ws.send_str("pong")
740
+ if msg.type == aiohttp.WSMsgType.ERROR:
741
+ logging.warning(
742
+ "ws connection closed with exception %s" % ws.exception()
743
+ )
744
+ finally:
745
+ self.sockets.pop(sid, None)
746
+ return ws
747
+
748
+ @self.prompt_server.routes.post(f"/{COMMUNITY_API}/models")
749
+ async def commit_bizy_model(request):
750
+ sid = request.rel_url.query.get("clientId", "")
751
+ if not is_string_valid(sid):
752
+ return ErrResponse(errnos.INVALID_CLIENT_ID)
753
+
754
+ json_data = await request.json()
755
+
756
+ # 校验name和type
757
+ err = check_str_param(json_data, "name", errnos.INVALID_NAME)
758
+ if err is not None:
759
+ return err
760
+
761
+ if "/" in json_data["name"]:
762
+ return ErrResponse(errnos.INVALID_NAME)
763
+
764
+ err = check_type(json_data)
765
+ if err is not None:
766
+ return err
767
+
768
+ # 校验versions
769
+ if "versions" not in json_data or not isinstance(
770
+ json_data["versions"], list
684
771
  ):
685
- return ErrResponse(errnos.INVALID_SHARE_TYPE)
772
+ return ErrResponse(errnos.INVALID_VERSIONS)
686
773
 
687
- # 调用API提交数据集
688
- resp, err = await self.api_client.create_share(payload=json_data)
774
+ versions = json_data["versions"]
775
+ version_names = set()
776
+
777
+ for version in versions:
778
+ # 检查version是否重复
779
+ if version.get("version") in version_names:
780
+ return ErrResponse(errnos.DUPLICATE_VERSION)
781
+
782
+ # 检查version字段是否合法
783
+ if not is_string_valid(version.get("version")) or "/" in version.get(
784
+ "version"
785
+ ):
786
+ return ErrResponse(errnos.INVALID_VERSION_NAME)
787
+
788
+ version_names.add(version.get("version"))
789
+
790
+ # 检查base_model, path和sign是否有值
791
+ for field in ["base_model", "path", "sign"]:
792
+ if not is_string_valid(version.get(field)):
793
+ err = errnos.INVALID_VERSION_FIELD.copy()
794
+ err.message = "Invalid version field: " + field
795
+ return ErrResponse(err)
796
+
797
+ # 调用API提交模型
798
+ resp, err = await self.api_client.commit_bizy_model(payload=json_data)
689
799
  if err:
690
800
  return ErrResponse(err)
691
801
 
802
+ # print("resp------------------------------->", json_data, resp)
803
+ # 开启线程检查同步状态
804
+ threading.Thread(
805
+ target=self.check_sync_status,
806
+ args=(resp["id"], resp["version_ids"], sid),
807
+ daemon=True,
808
+ ).start()
809
+
810
+ # enable refresh for lora
811
+ # TODO: enable refresh for other types
812
+ # bizyengine.core.path_utils.path_manager.enable_refresh_options("loras")
813
+
692
814
  return OKResponse(resp)
693
815
 
694
816
  @self.prompt_server.routes.get(f"/{COMMUNITY_API}/share/{{code}}")
@@ -707,22 +829,6 @@ class BizyAirServer:
707
829
 
708
830
  return OKResponse(resp)
709
831
 
710
- @self.prompt_server.routes.get(f"/{COMMUNITY_API}/model_version/{{version_id}}")
711
- async def get_model_version_detail(request):
712
- # 获取路径参数中的数据集ID
713
- version_id = int(request.match_info["version_id"])
714
-
715
- # 检查version_id是否合法
716
- if not version_id or version_id <= 0:
717
- return ErrResponse(errnos.INVALID_MODEL_VERSION_ID)
718
-
719
- # 调用API获取数据集详情
720
- resp, err = await self.api_client.get_model_version_detail(version_id)
721
- if err:
722
- return ErrResponse(err)
723
-
724
- return OKResponse(resp)
725
-
726
832
  @self.prompt_server.routes.get(f"/{COMMUNITY_API}/notifications/unread_count")
727
833
  async def get_notification_unread_count(request):
728
834
  # 获取当前用户的未读消息数量
@@ -744,7 +850,11 @@ class BizyAirServer:
744
850
  last_broadcast_id = int(request.rel_url.query.get("last_broadcast_id", "0"))
745
851
 
746
852
  resp, err = await self.api_client.fetch_notifications(
747
- page_size, last_pm_id, last_broadcast_id, types, read_status
853
+ page_size,
854
+ last_pm_id,
855
+ last_broadcast_id,
856
+ types,
857
+ read_status,
748
858
  )
749
859
  if err:
750
860
  return ErrResponse(err)
@@ -913,7 +1023,9 @@ class BizyAirServer:
913
1023
  if not year:
914
1024
  return ErrResponse(errnos.INVALID_YEAR_PARAM)
915
1025
 
916
- resp, err = await self.api_client.get_year_cost(year=year, api_key=api_key)
1026
+ resp, err = await self.api_client.get_year_cost(
1027
+ year=year, query_api_key=api_key
1028
+ )
917
1029
  if err is not None:
918
1030
  return ErrResponse(err)
919
1031
 
@@ -928,7 +1040,7 @@ class BizyAirServer:
928
1040
  return ErrResponse(errnos.INVALID_MONTH_PARAM)
929
1041
 
930
1042
  resp, err = await self.api_client.get_month_cost(
931
- month=month, api_key=api_key
1043
+ month=month, query_api_key=api_key
932
1044
  )
933
1045
  if err is not None:
934
1046
  return ErrResponse(err)
@@ -943,7 +1055,9 @@ class BizyAirServer:
943
1055
  if not day:
944
1056
  return ErrResponse(errnos.INVALID_DAY_PARAM)
945
1057
 
946
- resp, err = await self.api_client.get_day_cost(day=day, api_key=api_key)
1058
+ resp, err = await self.api_client.get_day_cost(
1059
+ day=day, query_api_key=api_key
1060
+ )
947
1061
  if err is not None:
948
1062
  return ErrResponse(err)
949
1063
 
@@ -954,7 +1068,7 @@ class BizyAirServer:
954
1068
 
955
1069
  api_key = request.rel_url.query.get("api_key", "")
956
1070
 
957
- resp, err = await self.api_client.get_recent_cost(api_key=api_key)
1071
+ resp, err = await self.api_client.get_recent_cost(query_api_key=api_key)
958
1072
  if err is not None:
959
1073
  return ErrResponse(err)
960
1074