bizyengine 1.2.7__py3-none-any.whl → 1.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bizyengine/misc/llm.py CHANGED
@@ -3,13 +3,14 @@ import json
3
3
 
4
4
  import aiohttp
5
5
  from aiohttp import web
6
+ from bizyengine.core import BizyAirMiscBaseNode, pop_api_key_and_prompt_id
7
+ from bizyengine.core.common import client
6
8
  from bizyengine.core.common.env_var import BIZYAIR_SERVER_ADDRESS
7
9
  from bizyengine.core.image_utils import decode_data, encode_comfy_image, encode_data
8
10
  from server import PromptServer
9
11
 
10
12
  from .utils import (
11
13
  decode_and_deserialize,
12
- get_api_key,
13
14
  get_llm_response,
14
15
  get_vlm_response,
15
16
  send_post_request,
@@ -17,52 +18,7 @@ from .utils import (
17
18
  )
18
19
 
19
20
 
20
- async def fetch_all_models(api_key):
21
- url = f"{BIZYAIR_SERVER_ADDRESS}/llm/models"
22
- headers = {"accept": "application/json", "authorization": f"Bearer {api_key}"}
23
- params = {"type": "text", "sub_type": "chat"}
24
-
25
- try:
26
- async with aiohttp.ClientSession() as session:
27
- async with session.get(
28
- url, headers=headers, params=params, timeout=10
29
- ) as response:
30
- if response.status == 200:
31
- data = await response.json()
32
- all_models = [model["id"] for model in data["data"]]
33
- return all_models
34
- else:
35
- print(f"Error fetching models: HTTP Status {response.status}")
36
- return []
37
- except aiohttp.ClientError as e:
38
- print(f"Error fetching models: {e}")
39
- return []
40
- except asyncio.exceptions.TimeoutError:
41
- print("Request to fetch models timed out")
42
- return []
43
-
44
-
45
- @PromptServer.instance.routes.post("/bizyair/get_silicon_cloud_llm_models")
46
- async def get_silicon_cloud_llm_models_endpoint(request):
47
- data = await request.json()
48
- api_key = data.get("api_key", get_api_key())
49
- all_models = await fetch_all_models(api_key)
50
- llm_models = [model for model in all_models if "vl" not in model.lower()]
51
- llm_models.append("No LLM Enhancement")
52
- return web.json_response(llm_models)
53
-
54
-
55
- @PromptServer.instance.routes.post("/bizyair/get_silicon_cloud_vlm_models")
56
- async def get_silicon_cloud_vlm_models_endpoint(request):
57
- data = await request.json()
58
- api_key = data.get("api_key", get_api_key())
59
- all_models = await fetch_all_models(api_key)
60
- vlm_models = [model for model in all_models if "vl" in model.lower()]
61
- vlm_models.append("No VLM Enhancement")
62
- return web.json_response(vlm_models)
63
-
64
-
65
- class SiliconCloudLLMAPI:
21
+ class SiliconCloudLLMAPI(BizyAirMiscBaseNode):
66
22
  def __init__(self):
67
23
  pass
68
24
 
@@ -93,7 +49,7 @@ class SiliconCloudLLMAPI:
93
49
  "FLOAT",
94
50
  {"default": 0.7, "min": 0.0, "max": 2.0, "step": 0.01},
95
51
  ),
96
- }
52
+ },
97
53
  }
98
54
 
99
55
  RETURN_TYPES = ("STRING",)
@@ -102,23 +58,18 @@ class SiliconCloudLLMAPI:
102
58
  CATEGORY = "☁️BizyAir/AI Assistants"
103
59
 
104
60
  def get_llm_model_response(
105
- self, model, system_prompt, user_prompt, max_tokens, temperature
61
+ self, model, system_prompt, user_prompt, max_tokens, temperature, **kwargs
106
62
  ):
107
63
  if model == "No LLM Enhancement":
108
64
  return {"ui": {"text": (user_prompt,)}, "result": (user_prompt,)}
109
- response = get_llm_response(
110
- model,
111
- system_prompt,
112
- user_prompt,
113
- max_tokens,
114
- temperature,
65
+ ret = get_llm_response(
66
+ model, system_prompt, user_prompt, max_tokens, temperature, **kwargs
115
67
  )
116
- ret = json.loads(response)
117
68
  text = ret["choices"][0]["message"]["content"]
118
69
  return (text,) # if update ui: {"ui": {"text": (text,)}, "result": (text,)}
119
70
 
120
71
 
121
- class SiliconCloudVLMAPI:
72
+ class SiliconCloudVLMAPI(BizyAirMiscBaseNode):
122
73
  def __init__(self):
123
74
  pass
124
75
 
@@ -148,7 +99,7 @@ class SiliconCloudVLMAPI:
148
99
  {"default": 0.7, "min": 0.0, "max": 2.0, "step": 0.01},
149
100
  ),
150
101
  "detail": (["auto", "low", "high"], {"default": "auto"}),
151
- }
102
+ },
152
103
  }
153
104
 
154
105
  RETURN_TYPES = ("STRING",)
@@ -157,7 +108,15 @@ class SiliconCloudVLMAPI:
157
108
  CATEGORY = "☁️BizyAir/AI Assistants"
158
109
 
159
110
  def get_vlm_model_response(
160
- self, model, system_prompt, user_prompt, images, max_tokens, temperature, detail
111
+ self,
112
+ model,
113
+ system_prompt,
114
+ user_prompt,
115
+ images,
116
+ max_tokens,
117
+ temperature,
118
+ detail,
119
+ **kwargs,
161
120
  ):
162
121
  if model == "No VLM Enhancement":
163
122
  return (user_prompt,)
@@ -171,7 +130,7 @@ class SiliconCloudVLMAPI:
171
130
  # 提取所有编码后的图像
172
131
  base64_images = list(encoded_images_dict.values())
173
132
 
174
- response = get_vlm_response(
133
+ ret = get_vlm_response(
175
134
  model,
176
135
  system_prompt,
177
136
  user_prompt,
@@ -179,13 +138,13 @@ class SiliconCloudVLMAPI:
179
138
  max_tokens,
180
139
  temperature,
181
140
  detail,
141
+ **kwargs,
182
142
  )
183
- ret = json.loads(response)
184
143
  text = ret["choices"][0]["message"]["content"]
185
144
  return (text,)
186
145
 
187
146
 
188
- class BizyAirJoyCaption:
147
+ class BizyAirJoyCaption(BizyAirMiscBaseNode):
189
148
  # refer to: https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha
190
149
  API_URL = f"{BIZYAIR_SERVER_ADDRESS}/supernode/joycaption2"
191
150
 
@@ -216,7 +175,7 @@ class BizyAirJoyCaption:
216
175
  "display": "number",
217
176
  },
218
177
  ),
219
- }
178
+ },
220
179
  }
221
180
 
222
181
  RETURN_TYPES = ("STRING",)
@@ -224,8 +183,10 @@ class BizyAirJoyCaption:
224
183
 
225
184
  CATEGORY = "☁️BizyAir/AI Assistants"
226
185
 
227
- def joycaption(self, image, do_sample, temperature, max_tokens):
228
- API_KEY = get_api_key()
186
+ def joycaption(self, image, do_sample, temperature, max_tokens, **kwargs):
187
+ extra_data = pop_api_key_and_prompt_id(kwargs)
188
+ headers = client.headers(api_key=extra_data["api_key"])
189
+
229
190
  SIZE_LIMIT = 1536
230
191
  # device = image.device
231
192
  _, w, h, c = image.shape
@@ -244,17 +205,18 @@ class BizyAirJoyCaption:
244
205
  "name_input": "",
245
206
  "custom_prompt": "A descriptive caption for this image:\n",
246
207
  }
247
- auth = f"Bearer {API_KEY}"
248
- headers = {
249
- "accept": "application/json",
250
- "content-type": "application/json",
251
- "authorization": auth,
252
- }
253
208
  input_image = encode_data(image, disable_image_marker=True)
254
209
  payload["image"] = input_image
255
-
256
- ret: str = send_post_request(self.API_URL, payload=payload, headers=headers)
257
- ret = json.loads(ret)
210
+ if "prompt_id" in extra_data:
211
+ payload["prompt_id"] = extra_data["prompt_id"]
212
+ data = json.dumps(payload).encode("utf-8")
213
+
214
+ ret = client.send_request(
215
+ url=self.API_URL,
216
+ data=data,
217
+ headers=headers,
218
+ callback=None,
219
+ )
258
220
 
259
221
  try:
260
222
  if "result" in ret:
@@ -275,7 +237,7 @@ class BizyAirJoyCaption:
275
237
  return (caption,)
276
238
 
277
239
 
278
- class BizyAirJoyCaption2:
240
+ class BizyAirJoyCaption2(BizyAirMiscBaseNode):
279
241
  def __init__(self):
280
242
  pass
281
243
 
@@ -348,7 +310,7 @@ class BizyAirJoyCaption2:
348
310
  "multiline": True,
349
311
  },
350
312
  ),
351
- }
313
+ },
352
314
  }
353
315
 
354
316
  RETURN_TYPES = ("STRING",)
@@ -367,8 +329,11 @@ class BizyAirJoyCaption2:
367
329
  extra_options,
368
330
  name_input,
369
331
  custom_prompt,
332
+ **kwargs,
370
333
  ):
371
- API_KEY = get_api_key()
334
+ extra_data = pop_api_key_and_prompt_id(kwargs)
335
+ headers = client.headers(api_key=extra_data["api_key"])
336
+
372
337
  SIZE_LIMIT = 1536
373
338
  _, w, h, c = image.shape
374
339
  assert (
@@ -386,17 +351,15 @@ class BizyAirJoyCaption2:
386
351
  "name_input": name_input,
387
352
  "custom_prompt": custom_prompt,
388
353
  }
389
- auth = f"Bearer {API_KEY}"
390
- headers = {
391
- "accept": "application/json",
392
- "content-type": "application/json",
393
- "authorization": auth,
394
- }
395
354
  input_image = encode_data(image, disable_image_marker=True)
396
355
  payload["image"] = input_image
356
+ if "prompt_id" in extra_data:
357
+ payload["prompt_id"] = extra_data["prompt_id"]
358
+ data = json.dumps(payload).encode("utf-8")
397
359
 
398
- ret: str = send_post_request(self.API_URL, payload=payload, headers=headers)
399
- ret = json.loads(ret)
360
+ ret: str = client.send_request(
361
+ url=self.API_URL, data=data, headers=headers, callback=None
362
+ )
400
363
 
401
364
  try:
402
365
  if "result" in ret:
@@ -1,23 +1,29 @@
1
+ import json
1
2
  import os
2
3
  import uuid
3
4
 
4
5
  import torch
5
- from bizyengine.core import BizyAirBaseNode, BizyAirNodeIO, create_node_data
6
+ from bizyengine.core import (
7
+ BizyAirBaseNode,
8
+ BizyAirMiscBaseNode,
9
+ BizyAirNodeIO,
10
+ create_node_data,
11
+ )
12
+ from bizyengine.core.common import client
6
13
  from bizyengine.core.common.env_var import BIZYAIR_SERVER_ADDRESS
7
14
  from bizyengine.core.data_types import CONDITIONING
8
15
  from bizyengine.core.image_utils import encode_data
9
16
 
10
17
  from .utils import (
11
18
  decode_and_deserialize,
12
- get_api_key,
13
- send_post_request,
19
+ pop_api_key_and_prompt_id,
14
20
  serialize_and_encode,
15
21
  )
16
22
 
17
23
  CATEGORY_NAME = "☁️BizyAir/Kolors"
18
24
 
19
25
 
20
- class BizyAirMZChatGLM3TextEncode:
26
+ class BizyAirMZChatGLM3TextEncode(BizyAirMiscBaseNode):
21
27
  API_URL = f"{BIZYAIR_SERVER_ADDRESS}/supernode/mzkolorschatglm3"
22
28
 
23
29
  @classmethod
@@ -25,7 +31,7 @@ class BizyAirMZChatGLM3TextEncode:
25
31
  return {
26
32
  "required": {
27
33
  "text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
28
- }
34
+ },
29
35
  }
30
36
 
31
37
  RETURN_TYPES = ("CONDITIONING",)
@@ -33,24 +39,26 @@ class BizyAirMZChatGLM3TextEncode:
33
39
  FUNCTION = "encode"
34
40
  CATEGORY = CATEGORY_NAME
35
41
 
36
- def encode(self, text):
37
- API_KEY = get_api_key()
42
+ def encode(self, text, **kwargs):
43
+ extra_data = pop_api_key_and_prompt_id(kwargs)
44
+ headers = client.headers(api_key=extra_data["api_key"])
45
+
38
46
  assert len(text) <= 4096, f"the prompt is too long, length: {len(text)}"
39
47
 
40
48
  payload = {
41
49
  "text": text,
42
50
  }
43
- auth = f"Bearer {API_KEY}"
44
- headers = {
45
- "accept": "application/json",
46
- "content-type": "application/json",
47
- "authorization": auth,
48
- }
49
-
50
- response: str = send_post_request(
51
- self.API_URL, payload=payload, headers=headers
51
+ if "prompt_id" in extra_data:
52
+ payload["prompt_id"] = extra_data["prompt_id"]
53
+ data = json.dumps(payload).encode("utf-8")
54
+
55
+ tensors_np = client.send_request(
56
+ url=self.API_URL,
57
+ data=data,
58
+ headers=headers,
59
+ callback=None,
60
+ response_handler=decode_and_deserialize,
52
61
  )
53
- tensors_np = decode_and_deserialize(response)
54
62
 
55
63
  ret_conditioning = []
56
64
  for item in tensors_np:
@@ -69,8 +77,8 @@ class BizyAir_MinusZoneChatGLM3TextEncode(BizyAirMZChatGLM3TextEncode, BizyAirBa
69
77
 
70
78
  FUNCTION = "mz_encode"
71
79
 
72
- def mz_encode(self, text):
73
- out = self.encode(text)[0]
80
+ def mz_encode(self, text, **kwargs):
81
+ out = self.encode(text=text, **kwargs)[0]
74
82
  node_data = create_node_data(
75
83
  class_type="ComfyAirLoadData",
76
84
  inputs={"conditioning": {"relay": out}},
bizyengine/misc/nodes.py CHANGED
@@ -55,10 +55,11 @@ class BizyAir_KSampler(BizyAirBaseNode):
55
55
  }
56
56
 
57
57
  RETURN_TYPES = ("LATENT",)
58
- FUNCTION = "sample"
58
+ # FUNCTION = "sample"
59
59
  RETURN_NAMES = (f"LATENT",)
60
60
  CATEGORY = f"{PREFIX}/sampling"
61
61
 
62
+ # deprecated
62
63
  def sample(
63
64
  self,
64
65
  model,
@@ -127,7 +128,7 @@ class KSamplerAdvanced(BizyAirBaseNode):
127
128
  }
128
129
 
129
130
  RETURN_TYPES = ("LATENT",)
130
- FUNCTION = "sample"
131
+ # FUNCTION = "sample"
131
132
 
132
133
  CATEGORY = "sampling"
133
134
 
@@ -177,7 +178,7 @@ class BizyAir_CheckpointLoaderSimple(BizyAirBaseNode):
177
178
  return False
178
179
  return True
179
180
 
180
- def load_checkpoint(self, ckpt_name, model_version_id=""):
181
+ def load_checkpoint(self, ckpt_name, model_version_id="", **kwargs):
181
182
  if model_version_id != "":
182
183
  # use model version id as lora name
183
184
  ckpt_name = (
@@ -224,7 +225,7 @@ class BizyAir_CLIPTextEncode(BizyAirBaseNode):
224
225
 
225
226
  RETURN_TYPES = (data_types.CONDITIONING,)
226
227
  RETURN_NAMES = ("CONDITIONING",)
227
- FUNCTION = "encode"
228
+ # FUNCTION = "encode"
228
229
 
229
230
  CATEGORY = f"{PREFIX}/conditioning"
230
231
 
@@ -249,7 +250,7 @@ class BizyAir_VAEDecode(BizyAirBaseNode):
249
250
 
250
251
  RETURN_TYPES = ("IMAGE",)
251
252
  RETURN_NAMES = (f"IMAGE",)
252
- FUNCTION = "decode"
253
+ # FUNCTION = "decode"
253
254
 
254
255
  CATEGORY = f"{PREFIX}/latent"
255
256
 
@@ -298,6 +299,7 @@ class BizyAir_LoraLoader(BizyAirBaseNode):
298
299
  RETURN_TYPES = (data_types.MODEL, data_types.CLIP)
299
300
  RETURN_NAMES = ("MODEL", "CLIP")
300
301
 
302
+ # 不能使用default_function
301
303
  FUNCTION = "load_lora"
302
304
  CATEGORY = f"{PREFIX}/loaders"
303
305
 
@@ -309,6 +311,7 @@ class BizyAir_LoraLoader(BizyAirBaseNode):
309
311
  strength_model,
310
312
  strength_clip,
311
313
  model_version_id: str = None,
314
+ **kwargs,
312
315
  ):
313
316
  assigned_id = self.assigned_id
314
317
  new_model: BizyAirNodeIO = model.copy(assigned_id)
@@ -371,7 +374,9 @@ class BizyAir_LoraLoader_Legacy(BizyAirBaseNode):
371
374
 
372
375
  CATEGORY = f"{PREFIX}/loaders"
373
376
 
374
- def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
377
+ def load_lora(
378
+ self, model, clip, lora_name, strength_model, strength_clip, **kwargs
379
+ ):
375
380
  assigned_id = self.assigned_id
376
381
  new_model: BizyAirNodeIO = model.copy(assigned_id)
377
382
  new_clip: BizyAirNodeIO = clip.copy(assigned_id)
@@ -401,9 +406,10 @@ class BizyAir_VAEEncode(BizyAirBaseNode):
401
406
 
402
407
  RETURN_TYPES = ("LATENT",)
403
408
  RETURN_NAMES = (f"LATENT",)
404
- FUNCTION = "encode"
409
+ # FUNCTION = "encode"
405
410
  CATEGORY = f"{PREFIX}/latent"
406
411
 
412
+ # deprecated
407
413
  def encode(self, vae, pixels):
408
414
  new_vae: BizyAirNodeIO = vae.copy(self.assigned_id)
409
415
  new_vae.add_node_data(
@@ -431,9 +437,10 @@ class BizyAir_VAEEncodeForInpaint(BizyAirBaseNode):
431
437
 
432
438
  RETURN_TYPES = (f"LATENT",)
433
439
  RETURN_NAMES = (f"LATENT",)
434
- FUNCTION = "encode"
440
+ # FUNCTION = "encode"
435
441
  CATEGORY = f"{PREFIX}/latent/inpaint"
436
442
 
443
+ # deprecated
437
444
  def encode(self, vae, pixels, mask, grow_mask_by=6):
438
445
  new_vae: BizyAirNodeIO = vae.copy(self.assigned_id)
439
446
  new_vae.add_node_data(
@@ -477,7 +484,7 @@ class BizyAir_ControlNetLoader(BizyAirBaseNode):
477
484
  return True
478
485
  return True
479
486
 
480
- def load_controlnet(self, control_net_name, model_version_id):
487
+ def load_controlnet(self, control_net_name, model_version_id, **kwargs):
481
488
  if model_version_id is not None and model_version_id != "":
482
489
  control_net_name = (
483
490
  f"{config_manager.get_model_version_id_prefix()}{model_version_id}"
@@ -506,11 +513,12 @@ class BizyAir_ControlNetLoader_Legacy(BizyAirBaseNode):
506
513
 
507
514
  RETURN_TYPES = (data_types.CONTROL_NET,)
508
515
  RETURN_NAMES = ("CONTROL_NET",)
516
+ # 似乎不能用default实现
509
517
  FUNCTION = "load_controlnet"
510
518
 
511
519
  CATEGORY = f"{PREFIX}/loaders"
512
520
 
513
- def load_controlnet(self, control_net_name):
521
+ def load_controlnet(self, control_net_name, **kwargs):
514
522
 
515
523
  node_data = create_node_data(
516
524
  class_type="ControlNetLoader",
@@ -550,7 +558,7 @@ class BizyAir_ControlNetApplyAdvanced(BizyAirBaseNode):
550
558
 
551
559
  RETURN_TYPES = (data_types.CONDITIONING, data_types.CONDITIONING)
552
560
  RETURN_NAMES = ("positive", "negative")
553
- FUNCTION = "apply_controlnet"
561
+ # FUNCTION = "apply_controlnet"
554
562
 
555
563
  CATEGORY = "conditioning"
556
564
 
@@ -587,7 +595,7 @@ class BizyAir_ControlNetApply(BizyAirBaseNode):
587
595
 
588
596
  RETURN_TYPES = (data_types.CONDITIONING,)
589
597
  RETURN_NAMES = ("CONDITIONING",)
590
- FUNCTION = "apply_controlnet"
598
+ # FUNCTION = "apply_controlnet"
591
599
 
592
600
  CATEGORY = f"{PREFIX}/conditioning/controlnet"
593
601
 
@@ -629,7 +637,7 @@ class BizyAir_CLIPVisionLoader(BizyAirBaseNode):
629
637
  }
630
638
 
631
639
  RETURN_TYPES = ("CLIP_VISION",)
632
- FUNCTION = "load_clip"
640
+ # FUNCTION = "load_clip"
633
641
 
634
642
  CATEGORY = "loaders"
635
643
 
@@ -689,7 +697,7 @@ class VAELoader(BizyAirBaseNode):
689
697
 
690
698
  RETURN_TYPES = (data_types.VAE,)
691
699
  RETURN_NAMES = ("vae",)
692
- FUNCTION = "load_vae"
700
+ # FUNCTION = "load_vae"
693
701
 
694
702
  CATEGORY = "loaders"
695
703
 
@@ -757,7 +765,7 @@ class UNETLoader(BizyAirBaseNode):
757
765
  }
758
766
 
759
767
  RETURN_TYPES = (data_types.MODEL,)
760
- FUNCTION = "load_unet"
768
+ # FUNCTION = "load_unet"
761
769
 
762
770
  CATEGORY = "advanced/loaders"
763
771
 
@@ -851,7 +859,7 @@ class BasicGuider(BizyAirBaseNode):
851
859
 
852
860
  RETURN_TYPES = ("GUIDER",)
853
861
 
854
- FUNCTION = "get_guider"
862
+ # FUNCTION = "get_guider"
855
863
  CATEGORY = "sampling/custom_sampling/guiders"
856
864
 
857
865
  def get_guider(self, model: BizyAirNodeIO, conditioning):
@@ -885,7 +893,7 @@ class BasicScheduler(BizyAirBaseNode):
885
893
  RETURN_TYPES = ("SIGMAS",)
886
894
  CATEGORY = "sampling/custom_sampling/schedulers"
887
895
 
888
- FUNCTION = "get_sigmas"
896
+ # FUNCTION = "get_sigmas"
889
897
 
890
898
  def get_sigmas(self, **kwargs):
891
899
  new_model: BizyAirNodeIO = kwargs["model"].copy(self.assigned_id)
@@ -931,7 +939,7 @@ class DualCLIPLoader(BizyAirBaseNode):
931
939
  }
932
940
 
933
941
  RETURN_TYPES = (data_types.CLIP,)
934
- FUNCTION = "load_clip"
942
+ # FUNCTION = "load_clip"
935
943
 
936
944
  CATEGORY = "advanced/loaders"
937
945
 
@@ -1004,7 +1012,7 @@ class KSamplerSelect(BizyAirBaseNode):
1004
1012
  RETURN_TYPES = ("SAMPLER",)
1005
1013
  CATEGORY = "sampling/custom_sampling/samplers"
1006
1014
 
1007
- FUNCTION = "get_sampler"
1015
+ # FUNCTION = "get_sampler"
1008
1016
 
1009
1017
  def get_sampler(self, **kwargs):
1010
1018
  node_data = create_node_data(
@@ -1032,7 +1040,7 @@ class RandomNoise(BizyAirBaseNode):
1032
1040
  }
1033
1041
 
1034
1042
  RETURN_TYPES = ("NOISE",)
1035
- FUNCTION = "get_noise"
1043
+ # FUNCTION = "get_noise"
1036
1044
  CATEGORY = "sampling/custom_sampling/noise"
1037
1045
 
1038
1046
  def get_noise(self, noise_seed):
@@ -1064,7 +1072,7 @@ class CLIPSetLastLayer(BizyAirBaseNode):
1064
1072
  }
1065
1073
 
1066
1074
  RETURN_TYPES = (data_types.CLIP,)
1067
- FUNCTION = "set_last_layer"
1075
+ # FUNCTION = "set_last_layer"
1068
1076
 
1069
1077
  CATEGORY = "conditioning"
1070
1078
 
@@ -1141,6 +1149,7 @@ class SharedLoraLoader(BizyAir_LoraLoader_Legacy):
1141
1149
 
1142
1150
  RETURN_TYPES = (data_types.MODEL, data_types.CLIP)
1143
1151
  RETURN_NAMES = ("MODEL", "CLIP")
1152
+ # 似乎不能用default实现
1144
1153
  FUNCTION = "shared_load_lora"
1145
1154
  CATEGORY = f"{PREFIX}/loaders"
1146
1155
  NODE_DISPLAY_NAME = "Shared Lora Loader"
@@ -1422,3 +1431,14 @@ class StyleModelApply(BizyAirBaseNode):
1422
1431
  # FUNCTION = "apply_stylemodel"
1423
1432
 
1424
1433
  CATEGORY = "conditioning/style_model"
1434
+
1435
+
1436
+ # 仅用于使用meta传参,故没有输入输出
1437
+ class PassParameter(BizyAirBaseNode):
1438
+ @classmethod
1439
+ def INPUT_TYPES(s):
1440
+ return {}
1441
+
1442
+ RETURN_TYPES = ()
1443
+
1444
+ CATEGORY = "☁️BizyAir"
@@ -1,22 +1,20 @@
1
+ import json
1
2
  import os
2
3
 
3
4
  import numpy as np
4
5
  import torch
6
+ from bizyengine.core import BizyAirMiscBaseNode, pop_api_key_and_prompt_id
7
+ from bizyengine.core.common import client
5
8
  from bizyengine.core.common.env_var import BIZYAIR_SERVER_ADDRESS
6
9
 
7
- from .utils import (
8
- decode_and_deserialize,
9
- get_api_key,
10
- send_post_request,
11
- serialize_and_encode,
12
- )
10
+ from .utils import decode_and_deserialize, serialize_and_encode
13
11
 
14
12
  # Sync with theoritical limit from Comfy base
15
13
  # https://github.com/comfyanonymous/ComfyUI/blob/eecd69b53a896343775bcb02a4f8349e7442ffd1/nodes.py#L45
16
14
  MAX_RESOLUTION = 1024
17
15
 
18
16
 
19
- class BasePreprocessor:
17
+ class BasePreprocessor(BizyAirMiscBaseNode):
20
18
  def __init_subclass__(cls, **kwargs):
21
19
  super().__init_subclass__(**kwargs)
22
20
  if not hasattr(cls, "model_name"):
@@ -24,27 +22,29 @@ class BasePreprocessor:
24
22
  cls.API_URL = f"{BIZYAIR_SERVER_ADDRESS}{cls.model_name}"
25
23
  cls.CATEGORY = f"☁️BizyAir/{cls.CATEGORY}"
26
24
 
27
- @staticmethod
28
- def get_headers():
29
- return {
30
- "accept": "application/json",
31
- "content-type": "application/json",
32
- "authorization": f"Bearer {get_api_key()}",
33
- }
34
-
35
25
  RETURN_TYPES = ("IMAGE",)
36
26
  FUNCTION = "execute"
37
27
 
38
28
  def execute(self, **kwargs):
29
+ extra_data = pop_api_key_and_prompt_id(kwargs)
30
+ headers = client.headers(api_key=extra_data["api_key"])
31
+
39
32
  compress = True
40
33
  image: torch.Tensor = kwargs.pop("image")
41
34
  device = image.device
42
35
  kwargs["image"] = serialize_and_encode(image, compress)[0]
43
36
  kwargs["is_compress"] = compress
44
- response: str = send_post_request(
45
- self.API_URL, payload=kwargs, headers=self.get_headers()
37
+ if "prompt_id" in extra_data:
38
+ kwargs["prompt_id"] = extra_data["prompt_id"]
39
+ data = json.dumps(kwargs).encode("utf-8")
40
+
41
+ image_np = client.send_request(
42
+ url=self.API_URL,
43
+ data=data,
44
+ headers=headers,
45
+ callback=None,
46
+ response_handler=decode_and_deserialize,
46
47
  )
47
- image_np = decode_and_deserialize(response)
48
48
  image_torch = torch.from_numpy(image_np).to(device)
49
49
  return (image_torch,)
50
50