bizyengine 1.2.7__py3-none-any.whl → 1.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -56,7 +56,7 @@ class FluxGuidance(BizyAirBaseNode):
56
56
 
57
57
  CATEGORY = "advanced/conditioning/flux"
58
58
 
59
- def append(self, conditioning: BizyAirNodeIO, guidance):
59
+ def append(self, conditioning: BizyAirNodeIO, guidance, **kwargs):
60
60
  new_conditioning = conditioning.copy(self.assigned_id)
61
61
  new_conditioning.add_node_data(
62
62
  class_type="FluxGuidance",
@@ -31,7 +31,7 @@ class LoadImageURL(BizyAirBaseNode):
31
31
  FUNCTION = "apply"
32
32
  NODE_DISPLAY_NAME = "Load Image (URL)"
33
33
 
34
- def apply(self, url: str):
34
+ def apply(self, url: str, **kwargs):
35
35
  url = url.strip()
36
36
  input_dir = folder_paths.get_input_directory()
37
37
 
@@ -88,6 +88,6 @@ class Image_Encode(BizyAirBaseNode):
88
88
  RETURN_TYPES = ("IMAGE",)
89
89
  NODE_DISPLAY_NAME = "Image Encode"
90
90
 
91
- def apply(self, image, lossless=False):
91
+ def apply(self, image, lossless=False, **kwargs):
92
92
  out = encode_data(image, lossless=lossless)
93
93
  return (out,)
@@ -167,11 +167,7 @@ class NunchakuFluxLoraLoader(BizyAirBaseNode):
167
167
  return True
168
168
 
169
169
  def load_lora(
170
- self,
171
- model,
172
- lora_name,
173
- lora_strength,
174
- model_version_id: str = None,
170
+ self, model, lora_name, lora_strength, model_version_id: str = None, **kwargs
175
171
  ):
176
172
  assigned_id = self.assigned_id
177
173
  new_model: BizyAirNodeIO = model.copy(assigned_id)
@@ -219,6 +219,7 @@ class BizyAirDetailMethodPredict(BizyAirBaseNode):
219
219
  detail_dilate,
220
220
  black_point,
221
221
  white_point,
222
+ **kwargs
222
223
  ):
223
224
 
224
225
  ret_images = []
@@ -178,7 +178,7 @@ class BizyAirDownloadFile(BizyAirBaseNode):
178
178
  OUTPUT_NODE = True
179
179
  OUTPUT_IS_LIST = (False,)
180
180
 
181
- def main(self, url, file_name):
181
+ def main(self, url, file_name, **kwargs):
182
182
  assert url is not None
183
183
  file_name = file_name + ".glb"
184
184
  out_dir = os.path.join(folder_paths.get_output_directory(), "trellis_output")
@@ -124,7 +124,7 @@ class UltimateSDUpscale(BizyAirBaseNode):
124
124
  return prepare_inputs(required, optional)
125
125
 
126
126
  RETURN_TYPES = ("IMAGE",)
127
- FUNCTION = "upscale"
127
+ # FUNCTION = "upscale"
128
128
  CATEGORY = "image/upscaling"
129
129
 
130
130
  def upscale(self, **kwargs):
@@ -63,7 +63,7 @@ class Wan_LoraLoader(BizyAirBaseNode):
63
63
  FUNCTION = "apply_lora"
64
64
  CATEGORY = "Diffusers/WAN Video Generation"
65
65
 
66
- def apply_lora(self, lora_name, lora_weight=0.75):
66
+ def apply_lora(self, lora_name, lora_weight=0.75, **kwargs):
67
67
  return ([(lora_name, lora_weight)],)
68
68
 
69
69
 
@@ -4,5 +4,7 @@ from bizyengine.core.nodes_base import (
4
4
  NODE_CLASS_MAPPINGS,
5
5
  NODE_DISPLAY_NAME_MAPPINGS,
6
6
  BizyAirBaseNode,
7
+ BizyAirMiscBaseNode,
8
+ pop_api_key_and_prompt_id,
7
9
  )
8
10
  from bizyengine.core.nodes_io import BizyAirNodeIO, create_node_data
@@ -3,7 +3,7 @@ from collections import deque
3
3
  from typing import Any, Dict, List
4
4
 
5
5
  from bizyengine.core.commands.base import Processor # type: ignore
6
- from bizyengine.core.common import client, get_api_key
6
+ from bizyengine.core.common import client
7
7
  from bizyengine.core.common.caching import BizyAirTaskCache, CacheConfig
8
8
  from bizyengine.core.common.env_var import (
9
9
  BIZYAIR_DEBUG,
@@ -15,7 +15,6 @@ from bizyengine.core.path_utils import (
15
15
  convert_prompt_label_path_to_real_path,
16
16
  guess_url_from_node,
17
17
  )
18
- from server import PromptServer
19
18
 
20
19
 
21
20
  def is_link(obj):
@@ -34,7 +33,9 @@ from dataclasses import dataclass
34
33
 
35
34
 
36
35
  class SearchServiceRouter(Processor):
37
- def process(self, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str]):
36
+ def process(
37
+ self, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str], **kwargs
38
+ ):
38
39
  if BIZYAIR_DEV_REQUEST_URL:
39
40
  return BIZYAIR_DEV_REQUEST_URL
40
41
 
@@ -83,18 +84,15 @@ class SearchServiceRouter(Processor):
83
84
  return f"{BIZYAIR_SERVER_ADDRESS}{out_route}"
84
85
 
85
86
  def validate_input(
86
- self, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str]
87
+ self, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str], **kwargs
87
88
  ):
88
89
  assert len(last_node_ids) == 1
89
90
  return True
90
91
 
91
92
 
92
93
  class PromptProcessor(Processor):
93
- def _exec_info(self, prompt: Dict[str, Dict[str, Any]]):
94
- exec_info = {
95
- "model_version_ids": [],
96
- "api_key": get_api_key(),
97
- }
94
+ def _exec_info(self, prompt: Dict[str, Dict[str, Any]], api_key: str):
95
+ exec_info = {"model_version_ids": [], "api_key": api_key}
98
96
 
99
97
  model_version_id_prefix = config_manager.get_model_version_id_prefix()
100
98
  for node_id, node_data in prompt.items():
@@ -105,26 +103,31 @@ class PromptProcessor(Processor):
105
103
  return exec_info
106
104
 
107
105
  def process(
108
- self, url: str, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str]
106
+ self,
107
+ url: str,
108
+ prompt: Dict[str, Dict[str, Any]],
109
+ last_node_ids: List[str],
110
+ **kwargs,
109
111
  ):
110
112
  dict = {
111
113
  "prompt": prompt,
112
114
  "last_node_id": last_node_ids[0],
113
- "exec_info": self._exec_info(prompt),
115
+ "exec_info": self._exec_info(prompt, kwargs["api_key"]),
114
116
  }
115
- if (
116
- PromptServer.instance is not None
117
- and PromptServer.instance.last_prompt_id is not None
118
- ):
119
- dict["prompt_id"] = PromptServer.instance.last_prompt_id
120
- print("Processing prompt with ID: " + PromptServer.instance.last_prompt_id)
117
+ if "prompt_id" in kwargs:
118
+ dict["prompt_id"] = kwargs["prompt_id"]
121
119
 
122
120
  return client.send_request(
123
121
  url=url,
124
122
  data=json.dumps(dict).encode("utf-8"),
123
+ headers=client.headers(api_key=kwargs["api_key"]),
125
124
  )
126
125
 
127
126
  def validate_input(
128
- self, url: str, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str]
127
+ self,
128
+ url: str,
129
+ prompt: Dict[str, Dict[str, Any]],
130
+ last_node_ids: List[str],
131
+ **kwargs,
129
132
  ):
130
133
  return True
@@ -9,7 +9,7 @@ from typing import Any, Dict, List
9
9
  import comfy
10
10
  from bizyengine.core.commands.base import Command, Processor # type: ignore
11
11
  from bizyengine.core.common.caching import BizyAirTaskCache, CacheConfig
12
- from bizyengine.core.common.client import send_request
12
+ from bizyengine.core.common.client import headers, send_request
13
13
  from bizyengine.core.common.env_var import (
14
14
  BIZYAIR_DEBUG,
15
15
  BIZYAIR_DEV_GET_TASK_RESULT_SERVER,
@@ -20,7 +20,7 @@ from bizyengine.core.configs.conf import config_manager
20
20
  from bizyengine.core.image_utils import decode_data, encode_data
21
21
 
22
22
 
23
- def get_task_result(task_id: str, offset: int = 0) -> dict:
23
+ def get_task_result(task_id: str, offset: int = 0, api_key: str = None) -> dict:
24
24
  """
25
25
  Get the result of a task.
26
26
  """
@@ -34,8 +34,12 @@ def get_task_result(task_id: str, offset: int = 0) -> dict:
34
34
 
35
35
  if BIZYAIR_DEBUG:
36
36
  print(f"Debug: get task result url: {url}")
37
+ _headers = headers(api_key=api_key)
37
38
  response_json = send_request(
38
- method="GET", url=url, data=json.dumps({"offset": offset}).encode("utf-8")
39
+ method="GET",
40
+ url=url,
41
+ data=json.dumps({"offset": offset}).encode("utf-8"),
42
+ headers=_headers,
39
43
  )
40
44
  out = response_json
41
45
  events = out.get("data", {}).get("events", [])
@@ -47,7 +51,9 @@ def get_task_result(task_id: str, offset: int = 0) -> dict:
47
51
  and event["data"].startswith("https://")
48
52
  ):
49
53
  # event["data"] = requests.get(event["data"]).json()
50
- event["data"] = send_request(method="GET", url=event["data"])
54
+ event["data"] = send_request(
55
+ method="GET", url=event["data"], headers=_headers
56
+ )
51
57
  new_events.append(event)
52
58
  out["data"]["events"] = new_events
53
59
  return out
@@ -59,6 +65,7 @@ class BizyAirTask:
59
65
  task_id: str
60
66
  data_pool: list[dict] = field(default_factory=list)
61
67
  data_status: str = None
68
+ api_key: str = None
62
69
 
63
70
  @staticmethod
64
71
  def check_inputs(inputs: dict) -> bool:
@@ -69,12 +76,19 @@ class BizyAirTask:
69
76
  )
70
77
 
71
78
  @classmethod
72
- def from_data(cls, inputs: dict, check_inputs: bool = True) -> "BizyAirTask":
79
+ def from_data(
80
+ cls, inputs: dict, check_inputs: bool = True, **kwargs
81
+ ) -> "BizyAirTask":
73
82
  if check_inputs and not cls.check_inputs(inputs):
74
83
  raise ValueError(f"Invalid inputs: {inputs}")
75
84
  data = inputs.get("data", {})
76
85
  task_id = data.get("task_id", "")
77
- return cls(task_id=task_id, data_pool=[], data_status="started")
86
+ return cls(
87
+ task_id=task_id,
88
+ data_pool=[],
89
+ data_status="started",
90
+ api_key=kwargs["api_key"],
91
+ )
78
92
 
79
93
  def is_finished(self) -> bool:
80
94
  if not self.data_pool:
@@ -85,7 +99,7 @@ class BizyAirTask:
85
99
 
86
100
  def send_request(self, offset: int = 0) -> dict:
87
101
  if offset >= len(self.data_pool):
88
- return get_task_result(self.task_id, offset)
102
+ return get_task_result(self.task_id, offset, self.api_key)
89
103
  else:
90
104
  return self.data_pool[offset]
91
105
 
@@ -163,12 +177,12 @@ class PromptServer(Command):
163
177
  and "task_id" in result.get("data", {})
164
178
  )
165
179
 
166
- def _get_result(self, result: Dict[str, Any], *, cache_key: str = None):
180
+ def _get_result(self, result: Dict[str, Any], *, cache_key: str = None, **kwargs):
167
181
  try:
168
182
  response_data = result["data"]
169
183
  if BizyAirTask.check_inputs(result):
170
184
  self.cache_manager.set(cache_key, result)
171
- bz_task = BizyAirTask.from_data(result, check_inputs=False)
185
+ bz_task = BizyAirTask.from_data(result, check_inputs=False, **kwargs)
172
186
  bz_task.do_task_until_completed(timeout=60 * 60) # 60 minutes
173
187
  last_data = bz_task.get_last_data()
174
188
  response_data = last_data.get("data")
@@ -187,7 +201,6 @@ class PromptServer(Command):
187
201
  *args,
188
202
  **kwargs,
189
203
  ):
190
-
191
204
  prompt = encode_data(prompt)
192
205
 
193
206
  if BIZYAIR_DEBUG:
@@ -197,7 +210,7 @@ class PromptServer(Command):
197
210
  }
198
211
  pprint.pprint(debug_info, indent=4)
199
212
 
200
- url = self.router(prompt=prompt, last_node_ids=last_node_ids)
213
+ url = self.router(prompt=prompt, last_node_ids=last_node_ids, **kwargs)
201
214
 
202
215
  if BIZYAIR_DEBUG:
203
216
  print(f"Generated URL: {url}")
@@ -218,8 +231,10 @@ class PromptServer(Command):
218
231
  print(f"Cache hit for sh256-{sh256}")
219
232
  out = cached_output
220
233
  else:
221
- result = self.processor(url, prompt=prompt, last_node_ids=last_node_ids)
222
- out = self._get_result(result, cache_key=sh256)
234
+ result = self.processor(
235
+ url, prompt=prompt, last_node_ids=last_node_ids, **kwargs
236
+ )
237
+ out = self._get_result(result, cache_key=sh256, **kwargs)
223
238
 
224
239
  if BIZYAIR_DEBUG:
225
240
  pprint.pprint({"out": truncate_long_strings(out, 50)}, indent=4)
@@ -22,6 +22,7 @@ from .env_var import (
22
22
  BIZYAIR_API_KEY,
23
23
  BIZYAIR_DEBUG,
24
24
  BIZYAIR_SERVER_ADDRESS,
25
+ BIZYAIR_SERVER_MODE,
25
26
  create_api_key_file,
26
27
  )
27
28
 
@@ -41,6 +42,8 @@ api_key_state = APIKeyState()
41
42
 
42
43
 
43
44
  def set_api_key(api_key: str = "YOUR_API_KEY", override: bool = False) -> bool:
45
+ if BIZYAIR_SERVER_MODE:
46
+ return
44
47
  logging.debug("client.py set_api_key called")
45
48
  global api_key_state
46
49
  if api_key_state.is_valid and not override:
@@ -58,6 +61,9 @@ def set_api_key(api_key: str = "YOUR_API_KEY", override: bool = False) -> bool:
58
61
 
59
62
 
60
63
  def validate_api_key(api_key: str = None) -> bool:
64
+ if BIZYAIR_SERVER_MODE:
65
+ return False
66
+
61
67
  logging.debug("validating api key...")
62
68
  if not api_key or not isinstance(api_key, str):
63
69
  warnings.warn("invalid api_key")
@@ -92,6 +98,8 @@ def validate_api_key(api_key: str = None) -> bool:
92
98
 
93
99
 
94
100
  def get_api_key() -> str:
101
+ if BIZYAIR_SERVER_MODE:
102
+ return None
95
103
  logging.debug("client.py get_api_key called")
96
104
  global api_key_state
97
105
  try:
@@ -106,11 +114,15 @@ def get_api_key() -> str:
106
114
  return api_key_state.current_api_key
107
115
 
108
116
 
109
- def _headers():
117
+ def headers(api_key: str = None):
118
+ return _headers(api_key=api_key)
119
+
120
+
121
+ def _headers(api_key: str = None):
110
122
  headers = {
111
123
  "accept": "application/json",
112
124
  "content-type": "application/json",
113
- "authorization": f"Bearer {get_api_key()}",
125
+ "authorization": f"Bearer {api_key if api_key else get_api_key()}",
114
126
  }
115
127
  return headers
116
128
 
@@ -105,3 +105,5 @@ BIZYAIR_DEV_GET_TASK_RESULT_SERVER = env(
105
105
  "BIZYAIR_DEV_GET_TASK_RESULT_SERVER", str, None
106
106
  )
107
107
  BIZYAIR_PRODUCTION_TEST = env("BIZYAIR_PRODUCTION_TEST", str, None)
108
+ # Server Mode
109
+ BIZYAIR_SERVER_MODE = env("BIZYAIR_SERVER_MODE", bool, False)
@@ -4,8 +4,11 @@ import warnings
4
4
  from functools import wraps
5
5
  from typing import List
6
6
 
7
+ from bizyengine.core.common.env_var import BIZYAIR_DEBUG, BIZYAIR_SERVER_MODE
7
8
  from bizyengine.core.configs.conf import config_manager
9
+ from server import PromptServer
8
10
 
11
+ from .common.client import get_api_key
9
12
  from .data_types import is_send_request_datatype
10
13
  from .nodes_io import BizyAirNodeIO, create_node_data
11
14
 
@@ -23,6 +26,37 @@ PREFIX = f"BizyAir"
23
26
  NODE_CLASS_MAPPINGS = {}
24
27
  NODE_DISPLAY_NAME_MAPPINGS = {}
25
28
 
29
+ BIZYAIR_PROMPT_KEY = "bizyair_prompt"
30
+ BIZYAIR_PARAM_MAGIC_NODE_ID = "bizyair_magic_node"
31
+
32
+
33
+ def pop_api_key_and_prompt_id(kwargs):
34
+ extra_data = {}
35
+ prompt = None
36
+ if BIZYAIR_PROMPT_KEY in kwargs:
37
+ prompt = kwargs.pop(BIZYAIR_PROMPT_KEY)
38
+ if BIZYAIR_SERVER_MODE:
39
+ if BIZYAIR_PARAM_MAGIC_NODE_ID in prompt:
40
+ extra_data["api_key"] = prompt[BIZYAIR_PARAM_MAGIC_NODE_ID]["_meta"][
41
+ "api_key"
42
+ ]
43
+ extra_data["prompt_id"] = prompt[BIZYAIR_PARAM_MAGIC_NODE_ID]["_meta"][
44
+ "prompt_id"
45
+ ]
46
+ logging.debug(
47
+ "Using server mode passed in prompt_id: " + extra_data["prompt_id"]
48
+ )
49
+ else:
50
+ extra_data["api_key"] = get_api_key()
51
+ if (
52
+ PromptServer.instance is not None
53
+ and PromptServer.instance.last_prompt_id is not None
54
+ ):
55
+ extra_data["prompt_id"] = PromptServer.instance.last_prompt_id
56
+ print("Processing prompt with ID: " + PromptServer.instance.last_prompt_id)
57
+
58
+ return extra_data
59
+
26
60
 
27
61
  def process_kwargs(kwargs):
28
62
  possibleWidgetNames = [
@@ -91,7 +125,7 @@ def ensure_unique_id(org_func, original_has_unique_id=False):
91
125
  return new_func
92
126
 
93
127
 
94
- def ensure_hidden_unique_id(org_input_types_func):
128
+ def ensure_hidden_unique_id_and_prompt(org_input_types_func):
95
129
  original_has_unique_id = False
96
130
 
97
131
  @wraps(org_input_types_func)
@@ -105,6 +139,9 @@ def ensure_hidden_unique_id(org_input_types_func):
105
139
  result["hidden"]["unique_id"] = "UNIQUE_ID"
106
140
  else:
107
141
  original_has_unique_id = True
142
+ # Also set prompt, but to avoid naming conflict prefix with 'bizyair'
143
+ if BIZYAIR_PROMPT_KEY not in result["hidden"]:
144
+ result["hidden"][BIZYAIR_PROMPT_KEY] = "PROMPT"
108
145
  return result
109
146
 
110
147
  # Ensure original_has_unique_id is correctly set before returning
@@ -121,11 +158,19 @@ class BizyAirBaseNode:
121
158
  register_node(cls, PREFIX)
122
159
  cls.setup_input_types()
123
160
 
161
+ # 验证FUNCTION接受**kwargs
162
+ if BIZYAIR_DEBUG:
163
+ import inspect
164
+
165
+ sig = inspect.signature(getattr(cls, cls.FUNCTION))
166
+ params = sig.parameters.values()
167
+ assert any([True for p in params if p.kind == p.VAR_KEYWORD])
168
+
124
169
  @classmethod
125
170
  def setup_input_types(cls):
126
- # https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
127
- new_input_types_func, original_has_unique_id = ensure_hidden_unique_id(
128
- cls.INPUT_TYPES
171
+ # https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-and-flexible-inputs
172
+ new_input_types_func, original_has_unique_id = (
173
+ ensure_hidden_unique_id_and_prompt(cls.INPUT_TYPES)
129
174
  )
130
175
  cls.INPUT_TYPES = new_input_types_func
131
176
  setattr(
@@ -140,13 +185,14 @@ class BizyAirBaseNode:
140
185
  return str(self._assigned_id)
141
186
 
142
187
  def default_function(self, **kwargs):
188
+ extra_data = pop_api_key_and_prompt_id(kwargs)
143
189
  class_type = self._determine_class_type()
144
190
  kwargs = process_kwargs(kwargs)
145
191
  node_ios = self._process_non_send_request_types(class_type, kwargs)
146
192
  # TODO: add processing for send_request_types
147
193
  send_request_datatype_list = self._get_send_request_datatypes()
148
194
  if len(send_request_datatype_list) == len(self.RETURN_TYPES):
149
- return self._process_all_send_request_types(node_ios)
195
+ return self._process_all_send_request_types(node_ios, **extra_data)
150
196
  return node_ios
151
197
 
152
198
  def _get_send_request_datatypes(self):
@@ -172,7 +218,39 @@ class BizyAirBaseNode:
172
218
  outs.append(node)
173
219
  return tuple(outs)
174
220
 
175
- def _process_all_send_request_types(self, node_ios: List[BizyAirNodeIO]):
176
- out = node_ios[0].send_request()
221
+ def _process_all_send_request_types(self, node_ios: List[BizyAirNodeIO], **kwargs):
222
+ out = node_ios[0].send_request(**kwargs)
177
223
  assert len(out) == len(self.RETURN_TYPES)
178
224
  return out
225
+
226
+
227
+ class BizyAirMiscBaseNode:
228
+ # 作为Misc节点基类来保证hidden prompt, unique id存在
229
+ def __init_subclass__(cls, **kwargs):
230
+ if hasattr(cls, "CATEGORY"):
231
+ if not cls.CATEGORY.startswith(f"{LOGO}{PREFIX}"):
232
+ cls.CATEGORY = f"{LOGO}{PREFIX}/{cls.CATEGORY}"
233
+ cls.setup_input_types()
234
+
235
+ # 验证FUNCTION接受**kwargs
236
+ if BIZYAIR_DEBUG:
237
+ import inspect
238
+
239
+ sig = inspect.signature(getattr(cls, cls.FUNCTION))
240
+ params = sig.parameters.values()
241
+ assert any([True for p in params if p.kind == p.VAR_KEYWORD])
242
+
243
+ @classmethod
244
+ def setup_input_types(cls):
245
+ if not hasattr(cls, "INPUT_TYPES"):
246
+ cls.INPUT_TYPES = lambda: {}
247
+ # https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-and-flexible-inputs
248
+ new_input_types_func, original_has_unique_id = (
249
+ ensure_hidden_unique_id_and_prompt(cls.INPUT_TYPES)
250
+ )
251
+ cls.INPUT_TYPES = new_input_types_func
252
+ setattr(
253
+ cls,
254
+ cls.FUNCTION,
255
+ ensure_unique_id(getattr(cls, cls.FUNCTION), original_has_unique_id),
256
+ )
@@ -82,10 +82,10 @@ class BizyAirNodeIO:
82
82
  self.nodes.update(other.nodes)
83
83
 
84
84
  def send_request(
85
- self, url=None, headers=None, *, progress_callback=None, stream=False
85
+ self, url=None, headers=None, *, progress_callback=None, stream=False, **kwargs
86
86
  ) -> any:
87
87
  out = invoker.prompt_server.execute(
88
- prompt=self.nodes, last_node_ids=[self.node_id]
88
+ prompt=self.nodes, last_node_ids=[self.node_id], **kwargs
89
89
  )
90
90
  return out
91
91