bizyengine 1.2.6__py3-none-any.whl → 1.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bizyengine/bizy_server/api_client.py +125 -57
- bizyengine/bizy_server/errno.py +9 -0
- bizyengine/bizy_server/server.py +353 -239
- bizyengine/bizyair_extras/__init__.py +1 -0
- bizyengine/bizyair_extras/nodes_flux.py +1 -1
- bizyengine/bizyair_extras/nodes_image_utils.py +2 -2
- bizyengine/bizyair_extras/nodes_nunchaku.py +1 -5
- bizyengine/bizyair_extras/nodes_segment_anything.py +1 -0
- bizyengine/bizyair_extras/nodes_trellis.py +1 -1
- bizyengine/bizyair_extras/nodes_ultimatesdupscale.py +1 -1
- bizyengine/bizyair_extras/nodes_wan_i2v.py +222 -0
- bizyengine/core/__init__.py +2 -0
- bizyengine/core/commands/processors/prompt_processor.py +21 -18
- bizyengine/core/commands/servers/prompt_server.py +28 -13
- bizyengine/core/common/client.py +14 -2
- bizyengine/core/common/env_var.py +2 -0
- bizyengine/core/nodes_base.py +85 -7
- bizyengine/core/nodes_io.py +2 -2
- bizyengine/misc/llm.py +48 -85
- bizyengine/misc/mzkolors.py +27 -19
- bizyengine/misc/nodes.py +41 -21
- bizyengine/misc/nodes_controlnet_aux.py +18 -18
- bizyengine/misc/nodes_controlnet_union_sdxl.py +5 -12
- bizyengine/misc/segment_anything.py +29 -25
- bizyengine/misc/supernode.py +36 -30
- bizyengine/misc/utils.py +33 -21
- bizyengine/version.txt +1 -1
- bizyengine-1.2.8.dist-info/METADATA +211 -0
- {bizyengine-1.2.6.dist-info → bizyengine-1.2.8.dist-info}/RECORD +31 -30
- {bizyengine-1.2.6.dist-info → bizyengine-1.2.8.dist-info}/WHEEL +1 -1
- bizyengine-1.2.6.dist-info/METADATA +0 -19
- {bizyengine-1.2.6.dist-info → bizyengine-1.2.8.dist-info}/top_level.txt +0 -0
|
@@ -56,7 +56,7 @@ class FluxGuidance(BizyAirBaseNode):
|
|
|
56
56
|
|
|
57
57
|
CATEGORY = "advanced/conditioning/flux"
|
|
58
58
|
|
|
59
|
-
def append(self, conditioning: BizyAirNodeIO, guidance):
|
|
59
|
+
def append(self, conditioning: BizyAirNodeIO, guidance, **kwargs):
|
|
60
60
|
new_conditioning = conditioning.copy(self.assigned_id)
|
|
61
61
|
new_conditioning.add_node_data(
|
|
62
62
|
class_type="FluxGuidance",
|
|
@@ -31,7 +31,7 @@ class LoadImageURL(BizyAirBaseNode):
|
|
|
31
31
|
FUNCTION = "apply"
|
|
32
32
|
NODE_DISPLAY_NAME = "Load Image (URL)"
|
|
33
33
|
|
|
34
|
-
def apply(self, url: str):
|
|
34
|
+
def apply(self, url: str, **kwargs):
|
|
35
35
|
url = url.strip()
|
|
36
36
|
input_dir = folder_paths.get_input_directory()
|
|
37
37
|
|
|
@@ -88,6 +88,6 @@ class Image_Encode(BizyAirBaseNode):
|
|
|
88
88
|
RETURN_TYPES = ("IMAGE",)
|
|
89
89
|
NODE_DISPLAY_NAME = "Image Encode"
|
|
90
90
|
|
|
91
|
-
def apply(self, image, lossless=False):
|
|
91
|
+
def apply(self, image, lossless=False, **kwargs):
|
|
92
92
|
out = encode_data(image, lossless=lossless)
|
|
93
93
|
return (out,)
|
|
@@ -167,11 +167,7 @@ class NunchakuFluxLoraLoader(BizyAirBaseNode):
|
|
|
167
167
|
return True
|
|
168
168
|
|
|
169
169
|
def load_lora(
|
|
170
|
-
self,
|
|
171
|
-
model,
|
|
172
|
-
lora_name,
|
|
173
|
-
lora_strength,
|
|
174
|
-
model_version_id: str = None,
|
|
170
|
+
self, model, lora_name, lora_strength, model_version_id: str = None, **kwargs
|
|
175
171
|
):
|
|
176
172
|
assigned_id = self.assigned_id
|
|
177
173
|
new_model: BizyAirNodeIO = model.copy(assigned_id)
|
|
@@ -178,7 +178,7 @@ class BizyAirDownloadFile(BizyAirBaseNode):
|
|
|
178
178
|
OUTPUT_NODE = True
|
|
179
179
|
OUTPUT_IS_LIST = (False,)
|
|
180
180
|
|
|
181
|
-
def main(self, url, file_name):
|
|
181
|
+
def main(self, url, file_name, **kwargs):
|
|
182
182
|
assert url is not None
|
|
183
183
|
file_name = file_name + ".glb"
|
|
184
184
|
out_dir = os.path.join(folder_paths.get_output_directory(), "trellis_output")
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import time
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from comfy_api_nodes.apinode_utils import (
|
|
9
|
+
download_url_to_video_output,
|
|
10
|
+
tensor_to_base64_string,
|
|
11
|
+
)
|
|
12
|
+
except ModuleNotFoundError as e:
|
|
13
|
+
download_url_to_video_output = None
|
|
14
|
+
tensor_to_base64_string = None
|
|
15
|
+
|
|
16
|
+
ERROR_MSG = f"Error {e} ComfyUI API nodes module not found. Please ensure you have ComfyUI version 0.3.36 or later installed."
|
|
17
|
+
|
|
18
|
+
warnings.warn(ERROR_MSG)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
from ..core.common import client
|
|
22
|
+
from ..core.common.env_var import BIZYAIR_DEBUG
|
|
23
|
+
from ..core.nodes_base import BizyAirBaseNode
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class WanApiNodeBase:
|
|
27
|
+
MODEL_ENDPOINTS = {
|
|
28
|
+
"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers": "https://bizyair-api.siliconflow.cn/x/v1/supernode/faas-wan-i2v-14b-480p-server"
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Wan_LoraLoader(BizyAirBaseNode):
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def INPUT_TYPES(cls):
|
|
36
|
+
return {
|
|
37
|
+
"required": {
|
|
38
|
+
"lora_name": (
|
|
39
|
+
"STRING",
|
|
40
|
+
{
|
|
41
|
+
"multiline": True,
|
|
42
|
+
"default": "https://huggingface.co/Remade-AI/Squish/resolve/main/squish_18.safetensors",
|
|
43
|
+
"tooltip": "LoRA 模型下载地址",
|
|
44
|
+
},
|
|
45
|
+
),
|
|
46
|
+
},
|
|
47
|
+
"optional": {
|
|
48
|
+
"lora_weight": (
|
|
49
|
+
"FLOAT",
|
|
50
|
+
{
|
|
51
|
+
"default": 1,
|
|
52
|
+
"min": 0.0,
|
|
53
|
+
"max": 1.0,
|
|
54
|
+
"step": 0.05,
|
|
55
|
+
"tooltip": "LoRA权重强度",
|
|
56
|
+
},
|
|
57
|
+
)
|
|
58
|
+
},
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
RETURN_TYPES = ("LORA_CONFIG",)
|
|
62
|
+
RETURN_NAMES = ("lora_config",)
|
|
63
|
+
FUNCTION = "apply_lora"
|
|
64
|
+
CATEGORY = "Diffusers/WAN Video Generation"
|
|
65
|
+
|
|
66
|
+
def apply_lora(self, lora_name, lora_weight=0.75, **kwargs):
|
|
67
|
+
return ([(lora_name, lora_weight)],)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class Wan_ImageToVideoPipeline(WanApiNodeBase, BizyAirBaseNode):
|
|
71
|
+
|
|
72
|
+
POLLING_INTERVAL = 10 # sec
|
|
73
|
+
MAX_POLLING_TIME = 60 * 20 # sec
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def INPUT_TYPES(s):
|
|
77
|
+
return {
|
|
78
|
+
"required": {
|
|
79
|
+
"image": (
|
|
80
|
+
"IMAGE",
|
|
81
|
+
{
|
|
82
|
+
"default": None,
|
|
83
|
+
"tooltip": "Optional reference image to guide video generation",
|
|
84
|
+
},
|
|
85
|
+
),
|
|
86
|
+
"model_id": (["Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"],),
|
|
87
|
+
"prompt": (
|
|
88
|
+
"STRING",
|
|
89
|
+
{
|
|
90
|
+
"multiline": True,
|
|
91
|
+
"default": "",
|
|
92
|
+
"tooltip": "Text description of the video",
|
|
93
|
+
},
|
|
94
|
+
),
|
|
95
|
+
},
|
|
96
|
+
"optional": {
|
|
97
|
+
"negative_prompt": (
|
|
98
|
+
"STRING",
|
|
99
|
+
{
|
|
100
|
+
"multiline": True,
|
|
101
|
+
"default": "",
|
|
102
|
+
"tooltip": "Negative text prompt to guide what to avoid in the video",
|
|
103
|
+
},
|
|
104
|
+
),
|
|
105
|
+
"steps": ("INT", {"default": 30, "min": 1, "max": 40}),
|
|
106
|
+
"cfg": (
|
|
107
|
+
"FLOAT",
|
|
108
|
+
{
|
|
109
|
+
"default": 6.0,
|
|
110
|
+
"min": 0.0,
|
|
111
|
+
"max": 100.0,
|
|
112
|
+
"step": 0.1,
|
|
113
|
+
"round": 0.01,
|
|
114
|
+
},
|
|
115
|
+
),
|
|
116
|
+
"seed": (
|
|
117
|
+
"INT",
|
|
118
|
+
{
|
|
119
|
+
"default": 0,
|
|
120
|
+
"min": 0,
|
|
121
|
+
"max": 0xFFFFFFFF,
|
|
122
|
+
"step": 1,
|
|
123
|
+
"display": "number",
|
|
124
|
+
"control_after_generate": True,
|
|
125
|
+
"tooltip": "Seed for video generation (0 for random)",
|
|
126
|
+
},
|
|
127
|
+
),
|
|
128
|
+
"use_teacache": (["enable", "disable"],),
|
|
129
|
+
"lora_config": ("LORA_CONFIG", {}),
|
|
130
|
+
},
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
RETURN_TYPES = ("VIDEO",)
|
|
134
|
+
|
|
135
|
+
FUNCTION = "generate_video"
|
|
136
|
+
CATEGORY = "Diffusers/WAN Video Generation"
|
|
137
|
+
|
|
138
|
+
def _encode_image(self, image_tensor):
|
|
139
|
+
# https://docs.comfy.org/custom-nodes/backend/snippets
|
|
140
|
+
tensor_to_base64_string(image_tensor=image_tensor, mime_type="image/webp")
|
|
141
|
+
base64_str = "data:image/webp;base64," + tensor_to_base64_string(
|
|
142
|
+
image_tensor=image_tensor, mime_type="image/webp"
|
|
143
|
+
)
|
|
144
|
+
return base64_str
|
|
145
|
+
|
|
146
|
+
def _prepare_headers(self):
|
|
147
|
+
headers = client._headers()
|
|
148
|
+
headers["X-Fn-Task-Mode"] = "non-blocking"
|
|
149
|
+
return headers
|
|
150
|
+
|
|
151
|
+
def _send_initial_request(self, endpoint, request_data):
|
|
152
|
+
headers = self._prepare_headers()
|
|
153
|
+
response = client.send_request(
|
|
154
|
+
url=endpoint,
|
|
155
|
+
data=json.dumps({"prompt": request_data}).encode(),
|
|
156
|
+
headers=headers,
|
|
157
|
+
)
|
|
158
|
+
return response["query_url"]
|
|
159
|
+
|
|
160
|
+
def _poll_for_completion(self, query_url):
|
|
161
|
+
start_time = time.time()
|
|
162
|
+
headers = self._prepare_headers()
|
|
163
|
+
|
|
164
|
+
while time.time() - start_time < self.MAX_POLLING_TIME:
|
|
165
|
+
response = requests.get(query_url, headers=headers)
|
|
166
|
+
try:
|
|
167
|
+
response_data = response.json()
|
|
168
|
+
if response_data["data_status"] == "COMPLETED":
|
|
169
|
+
return response_data
|
|
170
|
+
time.sleep(self.POLLING_INTERVAL)
|
|
171
|
+
except (json.JSONDecodeError, KeyError) as e:
|
|
172
|
+
if BIZYAIR_DEBUG:
|
|
173
|
+
print(
|
|
174
|
+
f"Response parsing error: {e} | Raw response: {response.text}"
|
|
175
|
+
)
|
|
176
|
+
time.sleep(self.POLLING_INTERVAL)
|
|
177
|
+
|
|
178
|
+
raise TimeoutError("Task processing timeout")
|
|
179
|
+
|
|
180
|
+
def _process_result(self, result_data):
|
|
181
|
+
video_url = result_data["data"]["payload"]
|
|
182
|
+
return (download_url_to_video_output(video_url),)
|
|
183
|
+
|
|
184
|
+
def generate_video(
|
|
185
|
+
self,
|
|
186
|
+
model_id: str,
|
|
187
|
+
prompt: str,
|
|
188
|
+
negative_prompt: str = "",
|
|
189
|
+
seed: int = 0,
|
|
190
|
+
image=None,
|
|
191
|
+
lora_config=[],
|
|
192
|
+
use_teacache="enable",
|
|
193
|
+
**kwargs,
|
|
194
|
+
):
|
|
195
|
+
if download_url_to_video_output is None or tensor_to_base64_string is None:
|
|
196
|
+
raise ImportError(ERROR_MSG)
|
|
197
|
+
|
|
198
|
+
req_dict = {}
|
|
199
|
+
req_dict["guidance_scale"] = kwargs.pop("cfg", 6.0)
|
|
200
|
+
req_dict["num_inference_steps"] = kwargs.pop("steps", 30)
|
|
201
|
+
req_dict["prompt"] = prompt
|
|
202
|
+
req_dict["negative_prompt"] = negative_prompt
|
|
203
|
+
req_dict["seed"] = seed
|
|
204
|
+
if lora_config:
|
|
205
|
+
if len(lora_config) > 1:
|
|
206
|
+
raise NotImplementedError(f"TODO, tmp only support one lora")
|
|
207
|
+
req_dict["lora_name_list"] = [x[0] for x in lora_config]
|
|
208
|
+
req_dict["lora_weight_list"] = [x[1] for x in lora_config]
|
|
209
|
+
else:
|
|
210
|
+
req_dict["lora_name_list"] = []
|
|
211
|
+
req_dict["lora_weight_list"] = []
|
|
212
|
+
|
|
213
|
+
if use_teacache == "enable":
|
|
214
|
+
req_dict["teacache"] = 0.3
|
|
215
|
+
else:
|
|
216
|
+
req_dict["teacache"] = 0
|
|
217
|
+
|
|
218
|
+
req_dict["image"] = self._encode_image(image_tensor=image)
|
|
219
|
+
endpoint = self.MODEL_ENDPOINTS[model_id]
|
|
220
|
+
query_url = self._send_initial_request(endpoint, request_data=req_dict)
|
|
221
|
+
result = self._poll_for_completion(query_url)
|
|
222
|
+
return self._process_result(result)
|
bizyengine/core/__init__.py
CHANGED
|
@@ -3,7 +3,7 @@ from collections import deque
|
|
|
3
3
|
from typing import Any, Dict, List
|
|
4
4
|
|
|
5
5
|
from bizyengine.core.commands.base import Processor # type: ignore
|
|
6
|
-
from bizyengine.core.common import client
|
|
6
|
+
from bizyengine.core.common import client
|
|
7
7
|
from bizyengine.core.common.caching import BizyAirTaskCache, CacheConfig
|
|
8
8
|
from bizyengine.core.common.env_var import (
|
|
9
9
|
BIZYAIR_DEBUG,
|
|
@@ -15,7 +15,6 @@ from bizyengine.core.path_utils import (
|
|
|
15
15
|
convert_prompt_label_path_to_real_path,
|
|
16
16
|
guess_url_from_node,
|
|
17
17
|
)
|
|
18
|
-
from server import PromptServer
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
def is_link(obj):
|
|
@@ -34,7 +33,9 @@ from dataclasses import dataclass
|
|
|
34
33
|
|
|
35
34
|
|
|
36
35
|
class SearchServiceRouter(Processor):
|
|
37
|
-
def process(
|
|
36
|
+
def process(
|
|
37
|
+
self, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str], **kwargs
|
|
38
|
+
):
|
|
38
39
|
if BIZYAIR_DEV_REQUEST_URL:
|
|
39
40
|
return BIZYAIR_DEV_REQUEST_URL
|
|
40
41
|
|
|
@@ -83,18 +84,15 @@ class SearchServiceRouter(Processor):
|
|
|
83
84
|
return f"{BIZYAIR_SERVER_ADDRESS}{out_route}"
|
|
84
85
|
|
|
85
86
|
def validate_input(
|
|
86
|
-
self, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str]
|
|
87
|
+
self, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str], **kwargs
|
|
87
88
|
):
|
|
88
89
|
assert len(last_node_ids) == 1
|
|
89
90
|
return True
|
|
90
91
|
|
|
91
92
|
|
|
92
93
|
class PromptProcessor(Processor):
|
|
93
|
-
def _exec_info(self, prompt: Dict[str, Dict[str, Any]]):
|
|
94
|
-
exec_info = {
|
|
95
|
-
"model_version_ids": [],
|
|
96
|
-
"api_key": get_api_key(),
|
|
97
|
-
}
|
|
94
|
+
def _exec_info(self, prompt: Dict[str, Dict[str, Any]], api_key: str):
|
|
95
|
+
exec_info = {"model_version_ids": [], "api_key": api_key}
|
|
98
96
|
|
|
99
97
|
model_version_id_prefix = config_manager.get_model_version_id_prefix()
|
|
100
98
|
for node_id, node_data in prompt.items():
|
|
@@ -105,26 +103,31 @@ class PromptProcessor(Processor):
|
|
|
105
103
|
return exec_info
|
|
106
104
|
|
|
107
105
|
def process(
|
|
108
|
-
self,
|
|
106
|
+
self,
|
|
107
|
+
url: str,
|
|
108
|
+
prompt: Dict[str, Dict[str, Any]],
|
|
109
|
+
last_node_ids: List[str],
|
|
110
|
+
**kwargs,
|
|
109
111
|
):
|
|
110
112
|
dict = {
|
|
111
113
|
"prompt": prompt,
|
|
112
114
|
"last_node_id": last_node_ids[0],
|
|
113
|
-
"exec_info": self._exec_info(prompt),
|
|
115
|
+
"exec_info": self._exec_info(prompt, kwargs["api_key"]),
|
|
114
116
|
}
|
|
115
|
-
if
|
|
116
|
-
|
|
117
|
-
and PromptServer.instance.last_prompt_id is not None
|
|
118
|
-
):
|
|
119
|
-
dict["prompt_id"] = PromptServer.instance.last_prompt_id
|
|
120
|
-
print("Processing prompt with ID: " + PromptServer.instance.last_prompt_id)
|
|
117
|
+
if "prompt_id" in kwargs:
|
|
118
|
+
dict["prompt_id"] = kwargs["prompt_id"]
|
|
121
119
|
|
|
122
120
|
return client.send_request(
|
|
123
121
|
url=url,
|
|
124
122
|
data=json.dumps(dict).encode("utf-8"),
|
|
123
|
+
headers=client.headers(api_key=kwargs["api_key"]),
|
|
125
124
|
)
|
|
126
125
|
|
|
127
126
|
def validate_input(
|
|
128
|
-
self,
|
|
127
|
+
self,
|
|
128
|
+
url: str,
|
|
129
|
+
prompt: Dict[str, Dict[str, Any]],
|
|
130
|
+
last_node_ids: List[str],
|
|
131
|
+
**kwargs,
|
|
129
132
|
):
|
|
130
133
|
return True
|
|
@@ -9,7 +9,7 @@ from typing import Any, Dict, List
|
|
|
9
9
|
import comfy
|
|
10
10
|
from bizyengine.core.commands.base import Command, Processor # type: ignore
|
|
11
11
|
from bizyengine.core.common.caching import BizyAirTaskCache, CacheConfig
|
|
12
|
-
from bizyengine.core.common.client import send_request
|
|
12
|
+
from bizyengine.core.common.client import headers, send_request
|
|
13
13
|
from bizyengine.core.common.env_var import (
|
|
14
14
|
BIZYAIR_DEBUG,
|
|
15
15
|
BIZYAIR_DEV_GET_TASK_RESULT_SERVER,
|
|
@@ -20,7 +20,7 @@ from bizyengine.core.configs.conf import config_manager
|
|
|
20
20
|
from bizyengine.core.image_utils import decode_data, encode_data
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
def get_task_result(task_id: str, offset: int = 0) -> dict:
|
|
23
|
+
def get_task_result(task_id: str, offset: int = 0, api_key: str = None) -> dict:
|
|
24
24
|
"""
|
|
25
25
|
Get the result of a task.
|
|
26
26
|
"""
|
|
@@ -34,8 +34,12 @@ def get_task_result(task_id: str, offset: int = 0) -> dict:
|
|
|
34
34
|
|
|
35
35
|
if BIZYAIR_DEBUG:
|
|
36
36
|
print(f"Debug: get task result url: {url}")
|
|
37
|
+
_headers = headers(api_key=api_key)
|
|
37
38
|
response_json = send_request(
|
|
38
|
-
method="GET",
|
|
39
|
+
method="GET",
|
|
40
|
+
url=url,
|
|
41
|
+
data=json.dumps({"offset": offset}).encode("utf-8"),
|
|
42
|
+
headers=_headers,
|
|
39
43
|
)
|
|
40
44
|
out = response_json
|
|
41
45
|
events = out.get("data", {}).get("events", [])
|
|
@@ -47,7 +51,9 @@ def get_task_result(task_id: str, offset: int = 0) -> dict:
|
|
|
47
51
|
and event["data"].startswith("https://")
|
|
48
52
|
):
|
|
49
53
|
# event["data"] = requests.get(event["data"]).json()
|
|
50
|
-
event["data"] = send_request(
|
|
54
|
+
event["data"] = send_request(
|
|
55
|
+
method="GET", url=event["data"], headers=_headers
|
|
56
|
+
)
|
|
51
57
|
new_events.append(event)
|
|
52
58
|
out["data"]["events"] = new_events
|
|
53
59
|
return out
|
|
@@ -59,6 +65,7 @@ class BizyAirTask:
|
|
|
59
65
|
task_id: str
|
|
60
66
|
data_pool: list[dict] = field(default_factory=list)
|
|
61
67
|
data_status: str = None
|
|
68
|
+
api_key: str = None
|
|
62
69
|
|
|
63
70
|
@staticmethod
|
|
64
71
|
def check_inputs(inputs: dict) -> bool:
|
|
@@ -69,12 +76,19 @@ class BizyAirTask:
|
|
|
69
76
|
)
|
|
70
77
|
|
|
71
78
|
@classmethod
|
|
72
|
-
def from_data(
|
|
79
|
+
def from_data(
|
|
80
|
+
cls, inputs: dict, check_inputs: bool = True, **kwargs
|
|
81
|
+
) -> "BizyAirTask":
|
|
73
82
|
if check_inputs and not cls.check_inputs(inputs):
|
|
74
83
|
raise ValueError(f"Invalid inputs: {inputs}")
|
|
75
84
|
data = inputs.get("data", {})
|
|
76
85
|
task_id = data.get("task_id", "")
|
|
77
|
-
return cls(
|
|
86
|
+
return cls(
|
|
87
|
+
task_id=task_id,
|
|
88
|
+
data_pool=[],
|
|
89
|
+
data_status="started",
|
|
90
|
+
api_key=kwargs["api_key"],
|
|
91
|
+
)
|
|
78
92
|
|
|
79
93
|
def is_finished(self) -> bool:
|
|
80
94
|
if not self.data_pool:
|
|
@@ -85,7 +99,7 @@ class BizyAirTask:
|
|
|
85
99
|
|
|
86
100
|
def send_request(self, offset: int = 0) -> dict:
|
|
87
101
|
if offset >= len(self.data_pool):
|
|
88
|
-
return get_task_result(self.task_id, offset)
|
|
102
|
+
return get_task_result(self.task_id, offset, self.api_key)
|
|
89
103
|
else:
|
|
90
104
|
return self.data_pool[offset]
|
|
91
105
|
|
|
@@ -163,12 +177,12 @@ class PromptServer(Command):
|
|
|
163
177
|
and "task_id" in result.get("data", {})
|
|
164
178
|
)
|
|
165
179
|
|
|
166
|
-
def _get_result(self, result: Dict[str, Any], *, cache_key: str = None):
|
|
180
|
+
def _get_result(self, result: Dict[str, Any], *, cache_key: str = None, **kwargs):
|
|
167
181
|
try:
|
|
168
182
|
response_data = result["data"]
|
|
169
183
|
if BizyAirTask.check_inputs(result):
|
|
170
184
|
self.cache_manager.set(cache_key, result)
|
|
171
|
-
bz_task = BizyAirTask.from_data(result, check_inputs=False)
|
|
185
|
+
bz_task = BizyAirTask.from_data(result, check_inputs=False, **kwargs)
|
|
172
186
|
bz_task.do_task_until_completed(timeout=60 * 60) # 60 minutes
|
|
173
187
|
last_data = bz_task.get_last_data()
|
|
174
188
|
response_data = last_data.get("data")
|
|
@@ -187,7 +201,6 @@ class PromptServer(Command):
|
|
|
187
201
|
*args,
|
|
188
202
|
**kwargs,
|
|
189
203
|
):
|
|
190
|
-
|
|
191
204
|
prompt = encode_data(prompt)
|
|
192
205
|
|
|
193
206
|
if BIZYAIR_DEBUG:
|
|
@@ -197,7 +210,7 @@ class PromptServer(Command):
|
|
|
197
210
|
}
|
|
198
211
|
pprint.pprint(debug_info, indent=4)
|
|
199
212
|
|
|
200
|
-
url = self.router(prompt=prompt, last_node_ids=last_node_ids)
|
|
213
|
+
url = self.router(prompt=prompt, last_node_ids=last_node_ids, **kwargs)
|
|
201
214
|
|
|
202
215
|
if BIZYAIR_DEBUG:
|
|
203
216
|
print(f"Generated URL: {url}")
|
|
@@ -218,8 +231,10 @@ class PromptServer(Command):
|
|
|
218
231
|
print(f"Cache hit for sh256-{sh256}")
|
|
219
232
|
out = cached_output
|
|
220
233
|
else:
|
|
221
|
-
result = self.processor(
|
|
222
|
-
|
|
234
|
+
result = self.processor(
|
|
235
|
+
url, prompt=prompt, last_node_ids=last_node_ids, **kwargs
|
|
236
|
+
)
|
|
237
|
+
out = self._get_result(result, cache_key=sh256, **kwargs)
|
|
223
238
|
|
|
224
239
|
if BIZYAIR_DEBUG:
|
|
225
240
|
pprint.pprint({"out": truncate_long_strings(out, 50)}, indent=4)
|
bizyengine/core/common/client.py
CHANGED
|
@@ -22,6 +22,7 @@ from .env_var import (
|
|
|
22
22
|
BIZYAIR_API_KEY,
|
|
23
23
|
BIZYAIR_DEBUG,
|
|
24
24
|
BIZYAIR_SERVER_ADDRESS,
|
|
25
|
+
BIZYAIR_SERVER_MODE,
|
|
25
26
|
create_api_key_file,
|
|
26
27
|
)
|
|
27
28
|
|
|
@@ -41,6 +42,8 @@ api_key_state = APIKeyState()
|
|
|
41
42
|
|
|
42
43
|
|
|
43
44
|
def set_api_key(api_key: str = "YOUR_API_KEY", override: bool = False) -> bool:
|
|
45
|
+
if BIZYAIR_SERVER_MODE:
|
|
46
|
+
return
|
|
44
47
|
logging.debug("client.py set_api_key called")
|
|
45
48
|
global api_key_state
|
|
46
49
|
if api_key_state.is_valid and not override:
|
|
@@ -58,6 +61,9 @@ def set_api_key(api_key: str = "YOUR_API_KEY", override: bool = False) -> bool:
|
|
|
58
61
|
|
|
59
62
|
|
|
60
63
|
def validate_api_key(api_key: str = None) -> bool:
|
|
64
|
+
if BIZYAIR_SERVER_MODE:
|
|
65
|
+
return False
|
|
66
|
+
|
|
61
67
|
logging.debug("validating api key...")
|
|
62
68
|
if not api_key or not isinstance(api_key, str):
|
|
63
69
|
warnings.warn("invalid api_key")
|
|
@@ -92,6 +98,8 @@ def validate_api_key(api_key: str = None) -> bool:
|
|
|
92
98
|
|
|
93
99
|
|
|
94
100
|
def get_api_key() -> str:
|
|
101
|
+
if BIZYAIR_SERVER_MODE:
|
|
102
|
+
return None
|
|
95
103
|
logging.debug("client.py get_api_key called")
|
|
96
104
|
global api_key_state
|
|
97
105
|
try:
|
|
@@ -106,11 +114,15 @@ def get_api_key() -> str:
|
|
|
106
114
|
return api_key_state.current_api_key
|
|
107
115
|
|
|
108
116
|
|
|
109
|
-
def
|
|
117
|
+
def headers(api_key: str = None):
|
|
118
|
+
return _headers(api_key=api_key)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _headers(api_key: str = None):
|
|
110
122
|
headers = {
|
|
111
123
|
"accept": "application/json",
|
|
112
124
|
"content-type": "application/json",
|
|
113
|
-
"authorization": f"Bearer {get_api_key()}",
|
|
125
|
+
"authorization": f"Bearer {api_key if api_key else get_api_key()}",
|
|
114
126
|
}
|
|
115
127
|
return headers
|
|
116
128
|
|