bizyengine 1.2.6__py3-none-any.whl → 1.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. bizyengine/bizy_server/api_client.py +125 -57
  2. bizyengine/bizy_server/errno.py +9 -0
  3. bizyengine/bizy_server/server.py +353 -239
  4. bizyengine/bizyair_extras/__init__.py +1 -0
  5. bizyengine/bizyair_extras/nodes_flux.py +1 -1
  6. bizyengine/bizyair_extras/nodes_image_utils.py +2 -2
  7. bizyengine/bizyair_extras/nodes_nunchaku.py +1 -5
  8. bizyengine/bizyair_extras/nodes_segment_anything.py +1 -0
  9. bizyengine/bizyair_extras/nodes_trellis.py +1 -1
  10. bizyengine/bizyair_extras/nodes_ultimatesdupscale.py +1 -1
  11. bizyengine/bizyair_extras/nodes_wan_i2v.py +222 -0
  12. bizyengine/core/__init__.py +2 -0
  13. bizyengine/core/commands/processors/prompt_processor.py +21 -18
  14. bizyengine/core/commands/servers/prompt_server.py +28 -13
  15. bizyengine/core/common/client.py +14 -2
  16. bizyengine/core/common/env_var.py +2 -0
  17. bizyengine/core/nodes_base.py +85 -7
  18. bizyengine/core/nodes_io.py +2 -2
  19. bizyengine/misc/llm.py +48 -85
  20. bizyengine/misc/mzkolors.py +27 -19
  21. bizyengine/misc/nodes.py +41 -21
  22. bizyengine/misc/nodes_controlnet_aux.py +18 -18
  23. bizyengine/misc/nodes_controlnet_union_sdxl.py +5 -12
  24. bizyengine/misc/segment_anything.py +29 -25
  25. bizyengine/misc/supernode.py +36 -30
  26. bizyengine/misc/utils.py +33 -21
  27. bizyengine/version.txt +1 -1
  28. bizyengine-1.2.8.dist-info/METADATA +211 -0
  29. {bizyengine-1.2.6.dist-info → bizyengine-1.2.8.dist-info}/RECORD +31 -30
  30. {bizyengine-1.2.6.dist-info → bizyengine-1.2.8.dist-info}/WHEEL +1 -1
  31. bizyengine-1.2.6.dist-info/METADATA +0 -19
  32. {bizyengine-1.2.6.dist-info → bizyengine-1.2.8.dist-info}/top_level.txt +0 -0
bizyengine/misc/nodes.py CHANGED
@@ -55,10 +55,11 @@ class BizyAir_KSampler(BizyAirBaseNode):
55
55
  }
56
56
 
57
57
  RETURN_TYPES = ("LATENT",)
58
- FUNCTION = "sample"
58
+ # FUNCTION = "sample"
59
59
  RETURN_NAMES = (f"LATENT",)
60
60
  CATEGORY = f"{PREFIX}/sampling"
61
61
 
62
+ # deprecated
62
63
  def sample(
63
64
  self,
64
65
  model,
@@ -127,7 +128,7 @@ class KSamplerAdvanced(BizyAirBaseNode):
127
128
  }
128
129
 
129
130
  RETURN_TYPES = ("LATENT",)
130
- FUNCTION = "sample"
131
+ # FUNCTION = "sample"
131
132
 
132
133
  CATEGORY = "sampling"
133
134
 
@@ -159,7 +160,7 @@ class BizyAir_CheckpointLoaderSimple(BizyAirBaseNode):
159
160
  }
160
161
 
161
162
  RETURN_TYPES = (data_types.MODEL, data_types.CLIP, data_types.VAE)
162
- FUNCTION = "load_checkpoint"
163
+ # FUNCTION = "load_checkpoint"
163
164
  CATEGORY = f"{PREFIX}/loaders"
164
165
  RETURN_NAMES = (
165
166
  f"model",
@@ -224,7 +225,7 @@ class BizyAir_CLIPTextEncode(BizyAirBaseNode):
224
225
 
225
226
  RETURN_TYPES = (data_types.CONDITIONING,)
226
227
  RETURN_NAMES = ("CONDITIONING",)
227
- FUNCTION = "encode"
228
+ # FUNCTION = "encode"
228
229
 
229
230
  CATEGORY = f"{PREFIX}/conditioning"
230
231
 
@@ -249,7 +250,7 @@ class BizyAir_VAEDecode(BizyAirBaseNode):
249
250
 
250
251
  RETURN_TYPES = ("IMAGE",)
251
252
  RETURN_NAMES = (f"IMAGE",)
252
- FUNCTION = "decode"
253
+ # FUNCTION = "decode"
253
254
 
254
255
  CATEGORY = f"{PREFIX}/latent"
255
256
 
@@ -298,6 +299,7 @@ class BizyAir_LoraLoader(BizyAirBaseNode):
298
299
  RETURN_TYPES = (data_types.MODEL, data_types.CLIP)
299
300
  RETURN_NAMES = ("MODEL", "CLIP")
300
301
 
302
+ # 不能使用default_function
301
303
  FUNCTION = "load_lora"
302
304
  CATEGORY = f"{PREFIX}/loaders"
303
305
 
@@ -309,6 +311,7 @@ class BizyAir_LoraLoader(BizyAirBaseNode):
309
311
  strength_model,
310
312
  strength_clip,
311
313
  model_version_id: str = None,
314
+ **kwargs,
312
315
  ):
313
316
  assigned_id = self.assigned_id
314
317
  new_model: BizyAirNodeIO = model.copy(assigned_id)
@@ -371,7 +374,9 @@ class BizyAir_LoraLoader_Legacy(BizyAirBaseNode):
371
374
 
372
375
  CATEGORY = f"{PREFIX}/loaders"
373
376
 
374
- def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
377
+ def load_lora(
378
+ self, model, clip, lora_name, strength_model, strength_clip, **kwargs
379
+ ):
375
380
  assigned_id = self.assigned_id
376
381
  new_model: BizyAirNodeIO = model.copy(assigned_id)
377
382
  new_clip: BizyAirNodeIO = clip.copy(assigned_id)
@@ -401,9 +406,10 @@ class BizyAir_VAEEncode(BizyAirBaseNode):
401
406
 
402
407
  RETURN_TYPES = ("LATENT",)
403
408
  RETURN_NAMES = (f"LATENT",)
404
- FUNCTION = "encode"
409
+ # FUNCTION = "encode"
405
410
  CATEGORY = f"{PREFIX}/latent"
406
411
 
412
+ # deprecated
407
413
  def encode(self, vae, pixels):
408
414
  new_vae: BizyAirNodeIO = vae.copy(self.assigned_id)
409
415
  new_vae.add_node_data(
@@ -431,9 +437,10 @@ class BizyAir_VAEEncodeForInpaint(BizyAirBaseNode):
431
437
 
432
438
  RETURN_TYPES = (f"LATENT",)
433
439
  RETURN_NAMES = (f"LATENT",)
434
- FUNCTION = "encode"
440
+ # FUNCTION = "encode"
435
441
  CATEGORY = f"{PREFIX}/latent/inpaint"
436
442
 
443
+ # deprecated
437
444
  def encode(self, vae, pixels, mask, grow_mask_by=6):
438
445
  new_vae: BizyAirNodeIO = vae.copy(self.assigned_id)
439
446
  new_vae.add_node_data(
@@ -465,7 +472,7 @@ class BizyAir_ControlNetLoader(BizyAirBaseNode):
465
472
 
466
473
  RETURN_TYPES = (data_types.CONTROL_NET,)
467
474
  RETURN_NAMES = ("CONTROL_NET",)
468
- FUNCTION = "load_controlnet"
475
+ # FUNCTION = "load_controlnet"
469
476
 
470
477
  CATEGORY = f"{PREFIX}/loaders"
471
478
 
@@ -506,11 +513,12 @@ class BizyAir_ControlNetLoader_Legacy(BizyAirBaseNode):
506
513
 
507
514
  RETURN_TYPES = (data_types.CONTROL_NET,)
508
515
  RETURN_NAMES = ("CONTROL_NET",)
516
+ # 似乎不能用default实现
509
517
  FUNCTION = "load_controlnet"
510
518
 
511
519
  CATEGORY = f"{PREFIX}/loaders"
512
520
 
513
- def load_controlnet(self, control_net_name):
521
+ def load_controlnet(self, control_net_name, **kwargs):
514
522
 
515
523
  node_data = create_node_data(
516
524
  class_type="ControlNetLoader",
@@ -550,7 +558,7 @@ class BizyAir_ControlNetApplyAdvanced(BizyAirBaseNode):
550
558
 
551
559
  RETURN_TYPES = (data_types.CONDITIONING, data_types.CONDITIONING)
552
560
  RETURN_NAMES = ("positive", "negative")
553
- FUNCTION = "apply_controlnet"
561
+ # FUNCTION = "apply_controlnet"
554
562
 
555
563
  CATEGORY = "conditioning"
556
564
 
@@ -587,7 +595,7 @@ class BizyAir_ControlNetApply(BizyAirBaseNode):
587
595
 
588
596
  RETURN_TYPES = (data_types.CONDITIONING,)
589
597
  RETURN_NAMES = ("CONDITIONING",)
590
- FUNCTION = "apply_controlnet"
598
+ # FUNCTION = "apply_controlnet"
591
599
 
592
600
  CATEGORY = f"{PREFIX}/conditioning/controlnet"
593
601
 
@@ -629,7 +637,7 @@ class BizyAir_CLIPVisionLoader(BizyAirBaseNode):
629
637
  }
630
638
 
631
639
  RETURN_TYPES = ("CLIP_VISION",)
632
- FUNCTION = "load_clip"
640
+ # FUNCTION = "load_clip"
633
641
 
634
642
  CATEGORY = "loaders"
635
643
 
@@ -689,7 +697,7 @@ class VAELoader(BizyAirBaseNode):
689
697
 
690
698
  RETURN_TYPES = (data_types.VAE,)
691
699
  RETURN_NAMES = ("vae",)
692
- FUNCTION = "load_vae"
700
+ # FUNCTION = "load_vae"
693
701
 
694
702
  CATEGORY = "loaders"
695
703
 
@@ -757,7 +765,7 @@ class UNETLoader(BizyAirBaseNode):
757
765
  }
758
766
 
759
767
  RETURN_TYPES = (data_types.MODEL,)
760
- FUNCTION = "load_unet"
768
+ # FUNCTION = "load_unet"
761
769
 
762
770
  CATEGORY = "advanced/loaders"
763
771
 
@@ -851,7 +859,7 @@ class BasicGuider(BizyAirBaseNode):
851
859
 
852
860
  RETURN_TYPES = ("GUIDER",)
853
861
 
854
- FUNCTION = "get_guider"
862
+ # FUNCTION = "get_guider"
855
863
  CATEGORY = "sampling/custom_sampling/guiders"
856
864
 
857
865
  def get_guider(self, model: BizyAirNodeIO, conditioning):
@@ -885,7 +893,7 @@ class BasicScheduler(BizyAirBaseNode):
885
893
  RETURN_TYPES = ("SIGMAS",)
886
894
  CATEGORY = "sampling/custom_sampling/schedulers"
887
895
 
888
- FUNCTION = "get_sigmas"
896
+ # FUNCTION = "get_sigmas"
889
897
 
890
898
  def get_sigmas(self, **kwargs):
891
899
  new_model: BizyAirNodeIO = kwargs["model"].copy(self.assigned_id)
@@ -931,7 +939,7 @@ class DualCLIPLoader(BizyAirBaseNode):
931
939
  }
932
940
 
933
941
  RETURN_TYPES = (data_types.CLIP,)
934
- FUNCTION = "load_clip"
942
+ # FUNCTION = "load_clip"
935
943
 
936
944
  CATEGORY = "advanced/loaders"
937
945
 
@@ -1004,7 +1012,7 @@ class KSamplerSelect(BizyAirBaseNode):
1004
1012
  RETURN_TYPES = ("SAMPLER",)
1005
1013
  CATEGORY = "sampling/custom_sampling/samplers"
1006
1014
 
1007
- FUNCTION = "get_sampler"
1015
+ # FUNCTION = "get_sampler"
1008
1016
 
1009
1017
  def get_sampler(self, **kwargs):
1010
1018
  node_data = create_node_data(
@@ -1032,7 +1040,7 @@ class RandomNoise(BizyAirBaseNode):
1032
1040
  }
1033
1041
 
1034
1042
  RETURN_TYPES = ("NOISE",)
1035
- FUNCTION = "get_noise"
1043
+ # FUNCTION = "get_noise"
1036
1044
  CATEGORY = "sampling/custom_sampling/noise"
1037
1045
 
1038
1046
  def get_noise(self, noise_seed):
@@ -1064,7 +1072,7 @@ class CLIPSetLastLayer(BizyAirBaseNode):
1064
1072
  }
1065
1073
 
1066
1074
  RETURN_TYPES = (data_types.CLIP,)
1067
- FUNCTION = "set_last_layer"
1075
+ # FUNCTION = "set_last_layer"
1068
1076
 
1069
1077
  CATEGORY = "conditioning"
1070
1078
 
@@ -1141,6 +1149,7 @@ class SharedLoraLoader(BizyAir_LoraLoader_Legacy):
1141
1149
 
1142
1150
  RETURN_TYPES = (data_types.MODEL, data_types.CLIP)
1143
1151
  RETURN_NAMES = ("MODEL", "CLIP")
1152
+ # 似乎不能用default实现
1144
1153
  FUNCTION = "shared_load_lora"
1145
1154
  CATEGORY = f"{PREFIX}/loaders"
1146
1155
  NODE_DISPLAY_NAME = "Shared Lora Loader"
@@ -1422,3 +1431,14 @@ class StyleModelApply(BizyAirBaseNode):
1422
1431
  # FUNCTION = "apply_stylemodel"
1423
1432
 
1424
1433
  CATEGORY = "conditioning/style_model"
1434
+
1435
+
1436
+ # 仅用于使用meta传参,故没有输入输出
1437
+ class PassParameter(BizyAirBaseNode):
1438
+ @classmethod
1439
+ def INPUT_TYPES(s):
1440
+ return {}
1441
+
1442
+ RETURN_TYPES = ()
1443
+
1444
+ CATEGORY = "☁️BizyAir"
@@ -1,22 +1,20 @@
1
+ import json
1
2
  import os
2
3
 
3
4
  import numpy as np
4
5
  import torch
6
+ from bizyengine.core import BizyAirMiscBaseNode, pop_api_key_and_prompt_id
7
+ from bizyengine.core.common import client
5
8
  from bizyengine.core.common.env_var import BIZYAIR_SERVER_ADDRESS
6
9
 
7
- from .utils import (
8
- decode_and_deserialize,
9
- get_api_key,
10
- send_post_request,
11
- serialize_and_encode,
12
- )
10
+ from .utils import decode_and_deserialize, serialize_and_encode
13
11
 
14
12
  # Sync with theoritical limit from Comfy base
15
13
  # https://github.com/comfyanonymous/ComfyUI/blob/eecd69b53a896343775bcb02a4f8349e7442ffd1/nodes.py#L45
16
14
  MAX_RESOLUTION = 1024
17
15
 
18
16
 
19
- class BasePreprocessor:
17
+ class BasePreprocessor(BizyAirMiscBaseNode):
20
18
  def __init_subclass__(cls, **kwargs):
21
19
  super().__init_subclass__(**kwargs)
22
20
  if not hasattr(cls, "model_name"):
@@ -24,27 +22,29 @@ class BasePreprocessor:
24
22
  cls.API_URL = f"{BIZYAIR_SERVER_ADDRESS}{cls.model_name}"
25
23
  cls.CATEGORY = f"☁️BizyAir/{cls.CATEGORY}"
26
24
 
27
- @staticmethod
28
- def get_headers():
29
- return {
30
- "accept": "application/json",
31
- "content-type": "application/json",
32
- "authorization": f"Bearer {get_api_key()}",
33
- }
34
-
35
25
  RETURN_TYPES = ("IMAGE",)
36
26
  FUNCTION = "execute"
37
27
 
38
28
  def execute(self, **kwargs):
29
+ extra_data = pop_api_key_and_prompt_id(kwargs)
30
+ headers = client.headers(api_key=extra_data["api_key"])
31
+
39
32
  compress = True
40
33
  image: torch.Tensor = kwargs.pop("image")
41
34
  device = image.device
42
35
  kwargs["image"] = serialize_and_encode(image, compress)[0]
43
36
  kwargs["is_compress"] = compress
44
- response: str = send_post_request(
45
- self.API_URL, payload=kwargs, headers=self.get_headers()
37
+ if "prompt_id" in extra_data:
38
+ kwargs["prompt_id"] = extra_data["prompt_id"]
39
+ data = json.dumps(kwargs).encode("utf-8")
40
+
41
+ image_np = client.send_request(
42
+ url=self.API_URL,
43
+ data=data,
44
+ headers=headers,
45
+ callback=None,
46
+ response_handler=decode_and_deserialize,
46
47
  )
47
- image_np = decode_and_deserialize(response)
48
48
  image_torch = torch.from_numpy(image_np).to(device)
49
49
  return (image_torch,)
50
50
 
@@ -8,13 +8,13 @@ import os
8
8
 
9
9
  import numpy as np
10
10
  import requests
11
+ from bizyengine.core import BizyAirMiscBaseNode, pop_api_key_and_prompt_id
12
+ from bizyengine.core.common import client
11
13
  from bizyengine.core.common.env_var import BIZYAIR_SERVER_ADDRESS
12
14
  from bizyengine.core.image_utils import decode_comfy_image, encode_comfy_image
13
15
 
14
- from .utils import get_api_key
15
16
 
16
-
17
- class StableDiffusionXLControlNetUnionPipeline:
17
+ class StableDiffusionXLControlNetUnionPipeline(BizyAirMiscBaseNode):
18
18
  API_URL = f"{BIZYAIR_SERVER_ADDRESS}/supernode/diffusers-v1-stablediffusionxlcontrolnetunionpipeline"
19
19
 
20
20
  @classmethod
@@ -83,14 +83,6 @@ class StableDiffusionXLControlNetUnionPipeline:
83
83
  FUNCTION = "process"
84
84
  CATEGORY = "☁️BizyAir/ControlNet"
85
85
 
86
- @staticmethod
87
- def get_headers():
88
- return {
89
- "accept": "application/json",
90
- "content-type": "application/json",
91
- "authorization": f"Bearer {get_api_key()}",
92
- }
93
-
94
86
  def process(
95
87
  self,
96
88
  openpose_image=None,
@@ -101,6 +93,7 @@ class StableDiffusionXLControlNetUnionPipeline:
101
93
  segment_image=None,
102
94
  **kwargs,
103
95
  ):
96
+ extra_data = pop_api_key_and_prompt_id(kwargs)
104
97
  controlnet_img = {
105
98
  0: openpose_image,
106
99
  1: depth_image,
@@ -143,7 +136,7 @@ class StableDiffusionXLControlNetUnionPipeline:
143
136
  response = requests.post(
144
137
  self.API_URL,
145
138
  json=payload,
146
- headers=self.get_headers(),
139
+ headers=client.headers(api_key=extra_data["api_key"]),
147
140
  )
148
141
 
149
142
  result = response.json()
@@ -6,13 +6,14 @@ from enum import Enum
6
6
  import folder_paths
7
7
  import numpy as np
8
8
  import torch
9
+ from bizyengine.core import BizyAirMiscBaseNode, pop_api_key_and_prompt_id
10
+ from bizyengine.core.common import client
9
11
  from bizyengine.core.common.env_var import BIZYAIR_SERVER_ADDRESS
10
12
  from bizyengine.core.image_utils import decode_base64_to_np, encode_image_to_base64
11
13
  from nodes import LoadImage
12
14
  from PIL import Image, ImageOps, ImageSequence
13
15
 
14
16
  from .route_sam import SAM_COORDINATE
15
- from .utils import get_api_key, send_post_request
16
17
 
17
18
 
18
19
  class INFER_MODE(Enum):
@@ -27,7 +28,7 @@ class EDIT_MODE(Enum):
27
28
  point = 1
28
29
 
29
30
 
30
- class BizyAirSegmentAnythingText:
31
+ class BizyAirSegmentAnythingText(BizyAirMiscBaseNode):
31
32
  API_URL = f"{BIZYAIR_SERVER_ADDRESS}/supernode/sam"
32
33
 
33
34
  @classmethod
@@ -44,7 +45,7 @@ class BizyAirSegmentAnythingText:
44
45
  "FLOAT",
45
46
  {"default": 0.3, "min": 0, "max": 1.0, "step": 0.01},
46
47
  ),
47
- }
48
+ },
48
49
  }
49
50
 
50
51
  RETURN_TYPES = ("IMAGE", "MASK")
@@ -52,8 +53,10 @@ class BizyAirSegmentAnythingText:
52
53
 
53
54
  CATEGORY = "☁️BizyAir/segment-anything"
54
55
 
55
- def text_sam(self, image, prompt, box_threshold, text_threshold):
56
- API_KEY = get_api_key()
56
+ def text_sam(self, image, prompt, box_threshold, text_threshold, **kwargs):
57
+ extra_data = pop_api_key_and_prompt_id(kwargs)
58
+ headers = client.headers(api_key=extra_data["api_key"])
59
+
57
60
  SIZE_LIMIT = 1536
58
61
  device = image.device
59
62
  _, w, h, c = image.shape
@@ -70,19 +73,20 @@ class BizyAirSegmentAnythingText:
70
73
  "text_threshold": text_threshold,
71
74
  },
72
75
  }
73
- auth = f"Bearer {API_KEY}"
74
- headers = {
75
- "accept": "application/json",
76
- "content-type": "application/json",
77
- "authorization": auth,
78
- }
79
76
  image = image.squeeze(0).numpy()
80
77
  image_pil = Image.fromarray((image * 255).astype(np.uint8))
81
78
  input_image = encode_image_to_base64(image_pil, format="webp")
82
79
  payload["image"] = input_image
83
-
84
- ret: str = send_post_request(self.API_URL, payload=payload, headers=headers)
85
- ret = json.loads(ret)
80
+ if "prompt_id" in extra_data:
81
+ payload["prompt_id"] = extra_data["prompt_id"]
82
+ data = json.dumps(payload).encode("utf-8")
83
+
84
+ ret = client.send_request(
85
+ url=self.API_URL,
86
+ data=data,
87
+ headers=headers,
88
+ callback=None,
89
+ )
86
90
 
87
91
  try:
88
92
  if "result" in ret:
@@ -117,7 +121,7 @@ class BizyAirSegmentAnythingText:
117
121
  return (img, img_mask)
118
122
 
119
123
 
120
- class BizyAirSegmentAnythingPointBox:
124
+ class BizyAirSegmentAnythingPointBox(BizyAirMiscBaseNode):
121
125
  API_URL = f"{BIZYAIR_SERVER_ADDRESS}/supernode/sam"
122
126
 
123
127
  @classmethod
@@ -141,8 +145,10 @@ class BizyAirSegmentAnythingPointBox:
141
145
 
142
146
  CATEGORY = "☁️BizyAir/segment-anything"
143
147
 
144
- def apply(self, image, is_point):
145
- API_KEY = get_api_key()
148
+ def apply(self, image, is_point, **kwargs):
149
+ extra_data = pop_api_key_and_prompt_id(kwargs)
150
+ headers = client.headers(api_key=extra_data["api_key"])
151
+
146
152
  SIZE_LIMIT = 1536
147
153
 
148
154
  # 加载原始图像
@@ -201,20 +207,18 @@ class BizyAirSegmentAnythingPointBox:
201
207
  },
202
208
  }
203
209
 
204
- auth = f"Bearer {API_KEY}"
205
- headers = {
206
- "accept": "application/json",
207
- "content-type": "application/json",
208
- "authorization": auth,
209
- }
210
210
  # 处理用于API的图像
211
211
  api_image = image_to_process.squeeze(0).numpy()
212
212
  image_pil = Image.fromarray((api_image * 255).astype(np.uint8))
213
213
  input_image = encode_image_to_base64(image_pil, format="webp")
214
214
  payload["image"] = input_image
215
+ if "prompt_id" in extra_data:
216
+ payload["prompt_id"] = extra_data["prompt_id"]
217
+ data = json.dumps(payload).encode("utf-8")
215
218
 
216
- ret: str = send_post_request(self.API_URL, payload=payload, headers=headers)
217
- ret = json.loads(ret)
219
+ ret = client.send_request(
220
+ url=self.API_URL, data=data, headers=headers, callback=None
221
+ )
218
222
 
219
223
  try:
220
224
  if "result" in ret:
@@ -8,6 +8,8 @@ import folder_paths
8
8
  import node_helpers
9
9
  import numpy as np
10
10
  import torch
11
+ from bizyengine.core import BizyAirMiscBaseNode
12
+ from bizyengine.core.common import client
11
13
  from bizyengine.core.common.env_var import BIZYAIR_SERVER_ADDRESS
12
14
  from bizyengine.core.image_utils import (
13
15
  decode_base64_to_np,
@@ -20,13 +22,12 @@ from PIL import Image, ImageOps, ImageSequence
20
22
 
21
23
  from .utils import (
22
24
  decode_and_deserialize,
23
- get_api_key,
24
- send_post_request,
25
+ pop_api_key_and_prompt_id,
25
26
  serialize_and_encode,
26
27
  )
27
28
 
28
29
 
29
- class RemoveBackground:
30
+ class RemoveBackground(BizyAirMiscBaseNode):
30
31
  API_URL = f"{BIZYAIR_SERVER_ADDRESS}/supernode/removebg"
31
32
 
32
33
  @classmethod
@@ -34,7 +35,7 @@ class RemoveBackground:
34
35
  return {
35
36
  "required": {
36
37
  "image": ("IMAGE",),
37
- }
38
+ },
38
39
  }
39
40
 
40
41
  RETURN_TYPES = ("IMAGE", "MASK")
@@ -42,8 +43,10 @@ class RemoveBackground:
42
43
 
43
44
  CATEGORY = "☁️BizyAir"
44
45
 
45
- def remove_background(self, image):
46
- API_KEY = get_api_key()
46
+ def remove_background(self, image, **kwargs):
47
+ extra_data = pop_api_key_and_prompt_id(kwargs)
48
+ headers = client.headers(api_key=extra_data["api_key"])
49
+
47
50
  device = image.device
48
51
  _, h, w, _ = image.shape
49
52
  assert (
@@ -54,26 +57,27 @@ class RemoveBackground:
54
57
  "is_compress": True,
55
58
  "image": None,
56
59
  }
57
- auth = f"Bearer {API_KEY}"
58
- headers = {
59
- "accept": "application/json",
60
- "content-type": "application/json",
61
- "authorization": auth,
62
- }
63
60
  input_image, compress = serialize_and_encode(image, compress=True)
64
61
  payload["image"] = input_image
65
62
  payload["is_compress"] = compress
66
-
67
- response: str = send_post_request(
68
- self.API_URL, payload=payload, headers=headers
63
+ if "prompt_id" in extra_data:
64
+ payload["prompt_id"] = extra_data["prompt_id"]
65
+ data = json.dumps(payload).encode("utf-8")
66
+
67
+ tensors = client.send_request(
68
+ url=self.API_URL,
69
+ data=data,
70
+ headers=headers,
71
+ callback=None,
72
+ response_handler=decode_and_deserialize,
69
73
  )
70
- tensors = decode_and_deserialize(response)
74
+
71
75
  t_images = tensors["images"].to(device)
72
76
  t_mask = tensors["mask"].to(device)
73
77
  return (t_images, t_mask)
74
78
 
75
79
 
76
- class GenerateLightningImage:
80
+ class GenerateLightningImage(BizyAirMiscBaseNode):
77
81
  API_URL = f"{BIZYAIR_SERVER_ADDRESS}/supernode/realvis4lightning"
78
82
 
79
83
  @classmethod
@@ -98,7 +102,7 @@ class GenerateLightningImage:
98
102
  },
99
103
  ),
100
104
  "batch_size": ("INT", {"default": 1, "min": 1, "max": 4}),
101
- }
105
+ },
102
106
  }
103
107
 
104
108
  RETURN_TYPES = ("IMAGE",)
@@ -106,8 +110,10 @@ class GenerateLightningImage:
106
110
 
107
111
  CATEGORY = "☁️BizyAir"
108
112
 
109
- def generate_image(self, prompt, seed, width, height, cfg, batch_size):
110
- API_KEY = get_api_key()
113
+ def generate_image(self, prompt, seed, width, height, cfg, batch_size, **kwargs):
114
+ extra_data = pop_api_key_and_prompt_id(kwargs)
115
+ headers = client.headers(api_key=extra_data["api_key"])
116
+
111
117
  assert (
112
118
  width <= 1024 and height <= 1024
113
119
  ), f"width and height must be less than 1024, but got {width} and {height}"
@@ -120,17 +126,17 @@ class GenerateLightningImage:
120
126
  "cfg": cfg,
121
127
  "seed": seed,
122
128
  }
123
- auth = f"Bearer {API_KEY}"
124
- headers = {
125
- "accept": "application/json",
126
- "content-type": "application/json",
127
- "authorization": auth,
128
- }
129
-
130
- response: str = send_post_request(
131
- self.API_URL, payload=payload, headers=headers
129
+ if "prompt_id" in extra_data:
130
+ payload["prompt_id"] = extra_data["prompt_id"]
131
+ data = json.dumps(payload).encode("utf-8")
132
+
133
+ tensors_np = client.send_request(
134
+ url=self.API_URL,
135
+ data=data,
136
+ headers=headers,
137
+ callback=None,
138
+ response_handler=decode_and_deserialize,
132
139
  )
133
- tensors_np = decode_and_deserialize(response)
134
140
  tensors = torch.from_numpy(tensors_np)
135
141
 
136
142
  return (tensors,)