bizyengine 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bizyengine/__init__.py +35 -0
- bizyengine/bizy_server/__init__.py +7 -0
- bizyengine/bizy_server/api_client.py +763 -0
- bizyengine/bizy_server/errno.py +122 -0
- bizyengine/bizy_server/error_handler.py +3 -0
- bizyengine/bizy_server/execution.py +55 -0
- bizyengine/bizy_server/resp.py +24 -0
- bizyengine/bizy_server/server.py +898 -0
- bizyengine/bizy_server/utils.py +93 -0
- bizyengine/bizyair_extras/__init__.py +24 -0
- bizyengine/bizyair_extras/nodes_advanced_refluxcontrol.py +62 -0
- bizyengine/bizyair_extras/nodes_cogview4.py +31 -0
- bizyengine/bizyair_extras/nodes_comfyui_detail_daemon.py +180 -0
- bizyengine/bizyair_extras/nodes_comfyui_instantid.py +164 -0
- bizyengine/bizyair_extras/nodes_comfyui_layerstyle_advance.py +141 -0
- bizyengine/bizyair_extras/nodes_comfyui_pulid_flux.py +88 -0
- bizyengine/bizyair_extras/nodes_controlnet.py +50 -0
- bizyengine/bizyair_extras/nodes_custom_sampler.py +130 -0
- bizyengine/bizyair_extras/nodes_dataset.py +99 -0
- bizyengine/bizyair_extras/nodes_differential_diffusion.py +16 -0
- bizyengine/bizyair_extras/nodes_flux.py +69 -0
- bizyengine/bizyair_extras/nodes_image_utils.py +93 -0
- bizyengine/bizyair_extras/nodes_ip2p.py +20 -0
- bizyengine/bizyair_extras/nodes_ipadapter_plus/__init__.py +1 -0
- bizyengine/bizyair_extras/nodes_ipadapter_plus/nodes_ipadapter_plus.py +1598 -0
- bizyengine/bizyair_extras/nodes_janus_pro.py +81 -0
- bizyengine/bizyair_extras/nodes_kolors_mz/__init__.py +86 -0
- bizyengine/bizyair_extras/nodes_model_advanced.py +62 -0
- bizyengine/bizyair_extras/nodes_sd3.py +52 -0
- bizyengine/bizyair_extras/nodes_segment_anything.py +256 -0
- bizyengine/bizyair_extras/nodes_segment_anything_utils.py +134 -0
- bizyengine/bizyair_extras/nodes_testing_utils.py +139 -0
- bizyengine/bizyair_extras/nodes_trellis.py +199 -0
- bizyengine/bizyair_extras/nodes_ultimatesdupscale.py +137 -0
- bizyengine/bizyair_extras/nodes_upscale_model.py +32 -0
- bizyengine/bizyair_extras/nodes_wan_video.py +49 -0
- bizyengine/bizyair_extras/oauth_callback/main.py +118 -0
- bizyengine/core/__init__.py +8 -0
- bizyengine/core/commands/__init__.py +1 -0
- bizyengine/core/commands/base.py +27 -0
- bizyengine/core/commands/invoker.py +4 -0
- bizyengine/core/commands/processors/model_hosting_processor.py +0 -0
- bizyengine/core/commands/processors/prompt_processor.py +123 -0
- bizyengine/core/commands/servers/model_server.py +0 -0
- bizyengine/core/commands/servers/prompt_server.py +234 -0
- bizyengine/core/common/__init__.py +8 -0
- bizyengine/core/common/caching.py +198 -0
- bizyengine/core/common/client.py +262 -0
- bizyengine/core/common/env_var.py +101 -0
- bizyengine/core/common/utils.py +93 -0
- bizyengine/core/configs/conf.py +112 -0
- bizyengine/core/configs/models.json +101 -0
- bizyengine/core/configs/models.yaml +329 -0
- bizyengine/core/data_types.py +20 -0
- bizyengine/core/image_utils.py +288 -0
- bizyengine/core/nodes_base.py +159 -0
- bizyengine/core/nodes_io.py +97 -0
- bizyengine/core/path_utils/__init__.py +9 -0
- bizyengine/core/path_utils/path_manager.py +276 -0
- bizyengine/core/path_utils/utils.py +34 -0
- bizyengine/misc/__init__.py +0 -0
- bizyengine/misc/auth.py +83 -0
- bizyengine/misc/llm.py +431 -0
- bizyengine/misc/mzkolors.py +93 -0
- bizyengine/misc/nodes.py +1208 -0
- bizyengine/misc/nodes_controlnet_aux.py +491 -0
- bizyengine/misc/nodes_controlnet_union_sdxl.py +171 -0
- bizyengine/misc/route_sam.py +60 -0
- bizyengine/misc/segment_anything.py +276 -0
- bizyengine/misc/supernode.py +182 -0
- bizyengine/misc/utils.py +218 -0
- bizyengine/version.txt +1 -0
- bizyengine-0.4.2.dist-info/METADATA +12 -0
- bizyengine-0.4.2.dist-info/RECORD +76 -0
- bizyengine-0.4.2.dist-info/WHEEL +5 -0
- bizyengine-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import pprint
|
|
4
|
+
import urllib.error
|
|
5
|
+
import urllib.request
|
|
6
|
+
import warnings
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
from typing import Any, Union
|
|
10
|
+
|
|
11
|
+
import aiohttp
|
|
12
|
+
import comfy
|
|
13
|
+
|
|
14
|
+
from .caching import CacheManager
|
|
15
|
+
|
|
16
|
+
__all__ = ["send_request"]
|
|
17
|
+
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
|
|
20
|
+
from .env_var import BIZYAIR_API_KEY, BIZYAIR_DEBUG, BIZYAIR_SERVER_ADDRESS
|
|
21
|
+
|
|
22
|
+
IS_API_KEY_VALID = None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class APIKeyState:
|
|
27
|
+
current_api_key: str = field(default=None)
|
|
28
|
+
is_valid: bool = field(default=None)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
api_key_state = APIKeyState()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def set_api_key(api_key: str = "YOUR_API_KEY", override: bool = False):
|
|
35
|
+
global BIZYAIR_API_KEY, api_key_state
|
|
36
|
+
if api_key_state.is_valid is not None and not override:
|
|
37
|
+
warnings.warn("API key has already been set and will not be overridden.")
|
|
38
|
+
return
|
|
39
|
+
if validate_api_key(api_key):
|
|
40
|
+
BIZYAIR_API_KEY = api_key
|
|
41
|
+
api_key_state.is_valid = True
|
|
42
|
+
print("\033[92mAPI key is set successfully.\033[0m")
|
|
43
|
+
else:
|
|
44
|
+
api_key_state.is_valid = False
|
|
45
|
+
warnings.warn("Invalid API key provided.")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def validate_api_key(api_key: str = None) -> bool:
|
|
49
|
+
global api_key_state
|
|
50
|
+
if not api_key or not isinstance(api_key, str):
|
|
51
|
+
warnings.warn("API key is not set.")
|
|
52
|
+
return False
|
|
53
|
+
# if api_key_state.current_api_key == api_key and api_key_state.is_valid is not None:
|
|
54
|
+
# return api_key_state.is_valid
|
|
55
|
+
api_key_state.current_api_key = api_key
|
|
56
|
+
url = f"{BIZYAIR_SERVER_ADDRESS}/user/info"
|
|
57
|
+
headers = {"accept": "application/json", "authorization": f"Bearer {api_key}"}
|
|
58
|
+
try:
|
|
59
|
+
response_data = send_request(
|
|
60
|
+
method="GET", url=url, headers=headers, callback=None
|
|
61
|
+
)
|
|
62
|
+
if "message" not in response_data or response_data["message"] != "Ok":
|
|
63
|
+
api_key_state.is_valid = False
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"\033[91mAPI key validation failed. API Key: {api_key}\033[0m"
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
api_key_state.is_valid = True
|
|
69
|
+
except ConnectionError as ce:
|
|
70
|
+
api_key_state.is_valid = False
|
|
71
|
+
raise ValueError(f"\033[91mConnection error: {ce}\033[0m")
|
|
72
|
+
except PermissionError as pe:
|
|
73
|
+
api_key_state.is_valid = False
|
|
74
|
+
raise ValueError(
|
|
75
|
+
f"\033[91mError validating API key: {api_key}, error: {pe}\033[0m"
|
|
76
|
+
)
|
|
77
|
+
except Exception as e:
|
|
78
|
+
api_key_state.is_valid = False
|
|
79
|
+
raise ValueError(f"\033[91mOther error: {e}\033[0m")
|
|
80
|
+
return api_key_state.is_valid
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def get_api_key() -> str:
|
|
84
|
+
global BIZYAIR_API_KEY
|
|
85
|
+
try:
|
|
86
|
+
validate_api_key(BIZYAIR_API_KEY)
|
|
87
|
+
except Exception as e:
|
|
88
|
+
print(str(e))
|
|
89
|
+
raise ValueError(str(e))
|
|
90
|
+
return BIZYAIR_API_KEY
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _headers():
|
|
94
|
+
headers = {
|
|
95
|
+
"accept": "application/json",
|
|
96
|
+
"content-type": "application/json",
|
|
97
|
+
"authorization": f"Bearer {get_api_key()}",
|
|
98
|
+
}
|
|
99
|
+
return headers
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def process_response_data(response_data: dict) -> dict:
|
|
103
|
+
# Check if 'result' key exists, indicating a cloud response
|
|
104
|
+
if "result" in response_data:
|
|
105
|
+
try:
|
|
106
|
+
msg = json.loads(response_data["result"])
|
|
107
|
+
except json.JSONDecodeError:
|
|
108
|
+
raise ValueError(f"Failed to decode JSON from response. {response_data=}")
|
|
109
|
+
else:
|
|
110
|
+
# Handle local response directly
|
|
111
|
+
msg = response_data
|
|
112
|
+
|
|
113
|
+
return msg # Return processed data or modify as needed
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def send_request(
|
|
117
|
+
method: str = "POST",
|
|
118
|
+
url: str = None,
|
|
119
|
+
data: bytes = None,
|
|
120
|
+
verbose=False,
|
|
121
|
+
callback: callable = process_response_data,
|
|
122
|
+
response_handler: callable = json.loads,
|
|
123
|
+
cache_manager: CacheManager = None,
|
|
124
|
+
**kwargs,
|
|
125
|
+
) -> Union[dict, Any]:
|
|
126
|
+
try:
|
|
127
|
+
headers = kwargs.pop("headers") if "headers" in kwargs else _headers()
|
|
128
|
+
headers["User-Agent"] = "BizyAir Client"
|
|
129
|
+
|
|
130
|
+
req = urllib.request.Request(
|
|
131
|
+
url, data=data, headers=headers, method=method, **kwargs
|
|
132
|
+
)
|
|
133
|
+
with urllib.request.urlopen(req) as response:
|
|
134
|
+
response_data = response.read().decode("utf-8")
|
|
135
|
+
except urllib.error.URLError as e:
|
|
136
|
+
error_message = str(e)
|
|
137
|
+
response_body = e.read().decode("utf-8") if hasattr(e, "read") else "N/A"
|
|
138
|
+
if verbose:
|
|
139
|
+
print(f"URLError encountered: {error_message}")
|
|
140
|
+
print(f"Response Body: {response_data}")
|
|
141
|
+
code, message = "N/A", "N/A"
|
|
142
|
+
try:
|
|
143
|
+
response_dict = json.loads(response_body)
|
|
144
|
+
if isinstance(response_dict, dict):
|
|
145
|
+
code = response_dict.get("code", "N/A")
|
|
146
|
+
message = response_dict.get("message", "N/A")
|
|
147
|
+
|
|
148
|
+
except json.JSONDecodeError:
|
|
149
|
+
if verbose:
|
|
150
|
+
print("Failed to decode response body as JSON.")
|
|
151
|
+
|
|
152
|
+
if "Unauthorized" in error_message:
|
|
153
|
+
raise PermissionError(
|
|
154
|
+
"Key is invalid, please refer to https://cloud.siliconflow.cn to get the API key.\n"
|
|
155
|
+
"If you have the key, please click the 'API Key' button at the bottom right to set the key."
|
|
156
|
+
)
|
|
157
|
+
elif code != "N/A" and message != "N/A":
|
|
158
|
+
raise ConnectionError(
|
|
159
|
+
f"Failed to handle your request: {error_message}.\n"
|
|
160
|
+
+ f" Error code: {code}.\n"
|
|
161
|
+
+ f" Error message: {message}\n."
|
|
162
|
+
+ "The cause of this issue may be incorrect parameter status or ongoing background tasks. \n"
|
|
163
|
+
+ "If retrying after waiting for a while still does not resolve the issue, please report it to "
|
|
164
|
+
+ "Bizyair's official support."
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
common_sites = [
|
|
168
|
+
"https://www.baidu.com",
|
|
169
|
+
"https://www.bing.com",
|
|
170
|
+
"https://www.alibaba.com",
|
|
171
|
+
]
|
|
172
|
+
results = {}
|
|
173
|
+
for site in common_sites:
|
|
174
|
+
success = False
|
|
175
|
+
try:
|
|
176
|
+
|
|
177
|
+
class NoRedirectHandler(urllib.request.HTTPRedirectHandler):
|
|
178
|
+
def redirect_request(self, req, fp, code, msg, headers, newurl):
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
opener = urllib.request.build_opener(NoRedirectHandler())
|
|
182
|
+
response = opener.open(site, timeout=5)
|
|
183
|
+
success = 200 <= response.getcode() < 400
|
|
184
|
+
except urllib.error.HTTPError as e:
|
|
185
|
+
success = 200 <= e.code < 400
|
|
186
|
+
except (urllib.error.URLError, TimeoutError):
|
|
187
|
+
pass
|
|
188
|
+
results[site] = "Success" if success else "Failed"
|
|
189
|
+
raise ConnectionError(
|
|
190
|
+
f"Failed to connect to the server: {url}.\n"
|
|
191
|
+
+ "The connection attempts to the public sites return the following results:\n"
|
|
192
|
+
+ "\n".join(
|
|
193
|
+
[f" {site}: {result}" for site, result in results.items()]
|
|
194
|
+
)
|
|
195
|
+
+ "\nPlease check the network connection."
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if response_handler:
|
|
199
|
+
response_data = response_handler(response_data)
|
|
200
|
+
if callback:
|
|
201
|
+
return callback(response_data)
|
|
202
|
+
return response_data
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
async def async_send_request(
|
|
206
|
+
method: str = "POST",
|
|
207
|
+
url: str = None,
|
|
208
|
+
data: bytes = None,
|
|
209
|
+
verbose=False,
|
|
210
|
+
callback: callable = process_response_data,
|
|
211
|
+
**kwargs,
|
|
212
|
+
) -> dict:
|
|
213
|
+
headers = kwargs.pop("headers") if "headers" in kwargs else _headers()
|
|
214
|
+
try:
|
|
215
|
+
async with aiohttp.ClientSession() as session:
|
|
216
|
+
async with session.request(
|
|
217
|
+
method, url, data=data, headers=headers, **kwargs
|
|
218
|
+
) as response:
|
|
219
|
+
response_data = await response.text()
|
|
220
|
+
if response.status != 200:
|
|
221
|
+
error_message = f"HTTP Status {response.status}"
|
|
222
|
+
if verbose:
|
|
223
|
+
print(f"Error encountered: {error_message}")
|
|
224
|
+
if response.status == 401:
|
|
225
|
+
raise PermissionError(
|
|
226
|
+
"Key is invalid, please refer to https://cloud.siliconflow.cn to get the API key.\n"
|
|
227
|
+
"If you have the key, please click the 'BizyAir Key' button at the bottom right to set the key."
|
|
228
|
+
)
|
|
229
|
+
else:
|
|
230
|
+
raise ConnectionError(
|
|
231
|
+
f"Failed to connect to the server: {error_message}.\n"
|
|
232
|
+
+ "Please check your API key and ensure the server is reachable.\n"
|
|
233
|
+
+ "Also, verify your network settings and disable any proxies if necessary.\n"
|
|
234
|
+
+ "After checking, please restart the ComfyUI service."
|
|
235
|
+
)
|
|
236
|
+
if callback:
|
|
237
|
+
return callback(json.loads(response_data))
|
|
238
|
+
return json.loads(response_data)
|
|
239
|
+
except aiohttp.ClientError as e:
|
|
240
|
+
print(f"Error fetching data: {e}")
|
|
241
|
+
return {}
|
|
242
|
+
except Exception as e:
|
|
243
|
+
print(f"Error fetching data: {str(e)}")
|
|
244
|
+
return {}
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def fetch_models_by_type(
|
|
248
|
+
url: str, model_type: str, *, method="GET", verbose=False
|
|
249
|
+
) -> dict:
|
|
250
|
+
if not validate_api_key(BIZYAIR_API_KEY):
|
|
251
|
+
return {}
|
|
252
|
+
|
|
253
|
+
payload = {"type": model_type}
|
|
254
|
+
if BIZYAIR_DEBUG:
|
|
255
|
+
pprint.pprint(payload)
|
|
256
|
+
msg = send_request(
|
|
257
|
+
method=method,
|
|
258
|
+
url=url,
|
|
259
|
+
data=json.dumps(payload).encode("utf-8"),
|
|
260
|
+
verbose=verbose,
|
|
261
|
+
)
|
|
262
|
+
return msg
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import configparser
|
|
2
|
+
import os
|
|
3
|
+
from os import environ
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
BIZYAIR_COMFYUI_PATH = Path(os.environ.get("BIZYAIR_COMFYUI_PATH", "./"))
|
|
7
|
+
print(f"\033[92m[BizyAir]\033[0m BizyAir ComfyUI Plugin: {str(BIZYAIR_COMFYUI_PATH)}")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ServerAddress:
|
|
11
|
+
_instance = None
|
|
12
|
+
_address = None
|
|
13
|
+
|
|
14
|
+
def __new__(cls, *args, **kwargs):
|
|
15
|
+
if cls._instance is None:
|
|
16
|
+
cls._instance = super(ServerAddress, cls).__new__(cls)
|
|
17
|
+
return cls._instance
|
|
18
|
+
|
|
19
|
+
def __init__(self, address):
|
|
20
|
+
self._address = address
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def address(self):
|
|
24
|
+
return self._address
|
|
25
|
+
|
|
26
|
+
@address.setter
|
|
27
|
+
def address(self, new_address):
|
|
28
|
+
self._address = new_address
|
|
29
|
+
|
|
30
|
+
def __str__(self):
|
|
31
|
+
return self._address
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def env(key, type_, default=None):
|
|
35
|
+
if key not in environ:
|
|
36
|
+
return default
|
|
37
|
+
|
|
38
|
+
val = environ[key]
|
|
39
|
+
|
|
40
|
+
if type_ == str:
|
|
41
|
+
return val
|
|
42
|
+
if type_ == bool:
|
|
43
|
+
if val.lower() in ["1", "true", "yes", "y", "ok", "on"]:
|
|
44
|
+
return True
|
|
45
|
+
if val.lower() in ["0", "false", "no", "n", "nok", "off"]:
|
|
46
|
+
return False
|
|
47
|
+
raise ValueError(
|
|
48
|
+
"Invalid environment variable '%s' (expected a boolean): '%s'" % (key, val)
|
|
49
|
+
)
|
|
50
|
+
if type_ == int:
|
|
51
|
+
try:
|
|
52
|
+
return int(val)
|
|
53
|
+
except ValueError:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
"Invalid environment variable '%s' (expected an integer): '%s'"
|
|
56
|
+
% (key, val)
|
|
57
|
+
) from None
|
|
58
|
+
raise ValueError("The requested type '%r' is not supported" % type_)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def load_api_key():
|
|
62
|
+
|
|
63
|
+
file_path = BIZYAIR_COMFYUI_PATH / "api_key.ini"
|
|
64
|
+
|
|
65
|
+
if file_path.is_file() and file_path.exists():
|
|
66
|
+
config = configparser.ConfigParser()
|
|
67
|
+
config.read(file_path)
|
|
68
|
+
api_key: str = config.get("auth", "api_key", fallback="").strip().strip("'\"")
|
|
69
|
+
has_key = api_key.startswith("sk-")
|
|
70
|
+
return has_key, api_key
|
|
71
|
+
else:
|
|
72
|
+
return False, None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def create_api_key_file(api_key):
|
|
76
|
+
config = configparser.ConfigParser()
|
|
77
|
+
config["auth"] = {"api_key": api_key}
|
|
78
|
+
file_path = BIZYAIR_COMFYUI_PATH / "api_key.ini"
|
|
79
|
+
try:
|
|
80
|
+
with open(file_path, "w", encoding="utf-8") as configfile:
|
|
81
|
+
config.write(configfile)
|
|
82
|
+
except Exception as e:
|
|
83
|
+
raise Exception(f"An error occurred when save the key: {e}")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# production:
|
|
87
|
+
# service_address: https://bizyair-api.siliconflow.cn/x/v1
|
|
88
|
+
# uat:
|
|
89
|
+
# service_address: https://uat-bizyair-api.siliconflow.cn/x/v1
|
|
90
|
+
_BIZYAIR_SERVER_ADDRESS = os.getenv(
|
|
91
|
+
"BIZYAIR_SERVER_ADDRESS", "https://bizyair-api.siliconflow.cn/x/v1"
|
|
92
|
+
)
|
|
93
|
+
BIZYAIR_SERVER_ADDRESS = ServerAddress(_BIZYAIR_SERVER_ADDRESS)
|
|
94
|
+
BIZYAIR_API_KEY = env("BIZYAIR_API_KEY", str, load_api_key()[1])
|
|
95
|
+
# Development Settings
|
|
96
|
+
BIZYAIR_DEV_REQUEST_URL = env("BIZYAIR_DEV_REQUEST_URL", str, None)
|
|
97
|
+
BIZYAIR_DEBUG = env("BIZYAIR_DEBUG", bool, False)
|
|
98
|
+
BIZYAIR_DEV_GET_TASK_RESULT_SERVER = env(
|
|
99
|
+
"BIZYAIR_DEV_GET_TASK_RESULT_SERVER", str, None
|
|
100
|
+
)
|
|
101
|
+
BIZYAIR_PRODUCTION_TEST = env("BIZYAIR_PRODUCTION_TEST", str, None)
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, List
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import yaml
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def truncate_long_strings(obj, max_length=50):
|
|
11
|
+
if isinstance(obj, str):
|
|
12
|
+
return obj if len(obj) <= max_length else obj[:max_length] + "..."
|
|
13
|
+
elif isinstance(obj, dict):
|
|
14
|
+
return {k: truncate_long_strings(v, max_length) for k, v in obj.items()}
|
|
15
|
+
elif isinstance(obj, list):
|
|
16
|
+
return [truncate_long_strings(v, max_length) for v in obj]
|
|
17
|
+
elif isinstance(obj, tuple):
|
|
18
|
+
return tuple(truncate_long_strings(v, max_length) for v in obj)
|
|
19
|
+
elif isinstance(obj, torch.Tensor):
|
|
20
|
+
return obj.shape, obj.dtype, obj.device
|
|
21
|
+
else:
|
|
22
|
+
return obj
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def deepcopy_except_tensor(obj, exclude_types=[torch.Tensor]):
|
|
26
|
+
return deepcopy_except_types(obj=obj, exclude_types=exclude_types)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def deepcopy_except_types(obj, exclude_types):
|
|
30
|
+
"""
|
|
31
|
+
Recursively copy an object, excluding specified data types.
|
|
32
|
+
|
|
33
|
+
:param obj: The object to be copied
|
|
34
|
+
:param exclude_types: A list of data types to be excluded from deep copying
|
|
35
|
+
:return: The copied object
|
|
36
|
+
"""
|
|
37
|
+
if any(isinstance(obj, t) for t in exclude_types):
|
|
38
|
+
return obj # Return the object directly without deep copying
|
|
39
|
+
elif isinstance(obj, (list, tuple)):
|
|
40
|
+
return type(obj)(deepcopy_except_types(item, exclude_types) for item in obj)
|
|
41
|
+
elif isinstance(obj, dict):
|
|
42
|
+
return {
|
|
43
|
+
deepcopy_except_types(key, exclude_types): deepcopy_except_types(
|
|
44
|
+
value, exclude_types
|
|
45
|
+
)
|
|
46
|
+
for key, value in obj.items()
|
|
47
|
+
}
|
|
48
|
+
else:
|
|
49
|
+
return copy.deepcopy(obj)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def recursive_extract_models(data: Any, prefix_path: str = "") -> List[str]:
|
|
53
|
+
def merge_paths(base_path: str, new_path: Any) -> str:
|
|
54
|
+
if not isinstance(new_path, str):
|
|
55
|
+
return base_path
|
|
56
|
+
return f"{base_path}/{new_path}" if base_path else new_path
|
|
57
|
+
|
|
58
|
+
results: List[str] = []
|
|
59
|
+
if isinstance(data, dict):
|
|
60
|
+
for key, value in data.items():
|
|
61
|
+
new_prefix = merge_paths(prefix_path, key)
|
|
62
|
+
results.extend(recursive_extract_models(value, new_prefix))
|
|
63
|
+
elif isinstance(data, list):
|
|
64
|
+
for item in data:
|
|
65
|
+
new_prefix = merge_paths(prefix_path, item)
|
|
66
|
+
results.extend(recursive_extract_models(item, new_prefix))
|
|
67
|
+
elif isinstance(data, str) and prefix_path.endswith(data):
|
|
68
|
+
return [prefix_path]
|
|
69
|
+
|
|
70
|
+
return results
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _load_yaml_config(file_path):
|
|
74
|
+
with open(file_path, "r") as file:
|
|
75
|
+
config = yaml.safe_load(file)
|
|
76
|
+
return config
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _load_json_config(file_path: str) -> dict:
|
|
80
|
+
with open(file_path, "r") as file:
|
|
81
|
+
data = json.load(file)
|
|
82
|
+
return data
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def load_config_file(file_path: str) -> dict:
|
|
86
|
+
if not os.path.exists(file_path):
|
|
87
|
+
raise FileNotFoundError(f"The file {file_path} does not exist.")
|
|
88
|
+
if file_path.endswith(".json"):
|
|
89
|
+
return _load_json_config(file_path)
|
|
90
|
+
elif file_path.endswith(".yaml"):
|
|
91
|
+
return _load_yaml_config(file_path)
|
|
92
|
+
else:
|
|
93
|
+
raise ValueError(f"Unsupported file extension: {file_path}")
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import List, Tuple
|
|
5
|
+
|
|
6
|
+
from bizyengine.core.common.utils import load_config_file, recursive_extract_models
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class ModelRule:
|
|
11
|
+
mode_type: str
|
|
12
|
+
base_model: str
|
|
13
|
+
describe: str
|
|
14
|
+
score: int
|
|
15
|
+
route: str
|
|
16
|
+
class_type: str
|
|
17
|
+
inputs: dict
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class TaskApi:
|
|
22
|
+
task_result_endpoint: str
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ModelRuleManager:
|
|
26
|
+
def __init__(self, model_rules: list[dict]):
|
|
27
|
+
self.model_rules = model_rules
|
|
28
|
+
self.validate()
|
|
29
|
+
self.gen_model_rule_index_mapping()
|
|
30
|
+
|
|
31
|
+
def gen_model_rule_index_mapping(self):
|
|
32
|
+
self.model_rule_index_mapping = defaultdict(list)
|
|
33
|
+
|
|
34
|
+
for idx_1, rule in enumerate(self.model_rules):
|
|
35
|
+
for idx_2, node in enumerate(rule["nodes"]):
|
|
36
|
+
self.model_rule_index_mapping[node["class_type"]].append((idx_1, idx_2))
|
|
37
|
+
|
|
38
|
+
def validate(self):
|
|
39
|
+
for rule in self.model_rules:
|
|
40
|
+
self._validate_rule(rule)
|
|
41
|
+
|
|
42
|
+
def _validate_rule(self, rule: dict):
|
|
43
|
+
if "mode_type" not in rule:
|
|
44
|
+
raise ValueError("mode_type is required")
|
|
45
|
+
if "base_model" not in rule:
|
|
46
|
+
raise ValueError("base_model is required")
|
|
47
|
+
if "route" not in rule:
|
|
48
|
+
raise ValueError("route is required")
|
|
49
|
+
if "nodes" not in rule:
|
|
50
|
+
raise ValueError("nodes is required")
|
|
51
|
+
|
|
52
|
+
def find_rule_indexes(self, class_type: str) -> List[Tuple[int, int]]:
|
|
53
|
+
return self.model_rule_index_mapping[class_type]
|
|
54
|
+
|
|
55
|
+
def find_rules(self, class_type: str) -> List[ModelRule]:
|
|
56
|
+
rule_indexes = self.find_rule_indexes(class_type)
|
|
57
|
+
return [
|
|
58
|
+
ModelRule(
|
|
59
|
+
mode_type=self.model_rules[idx_1]["mode_type"],
|
|
60
|
+
base_model=self.model_rules[idx_1]["base_model"],
|
|
61
|
+
describe=self.model_rules[idx_1]["describe"],
|
|
62
|
+
score=self.model_rules[idx_1]["score"],
|
|
63
|
+
route=self.model_rules[idx_1]["route"],
|
|
64
|
+
class_type=class_type,
|
|
65
|
+
inputs=self.model_rules[idx_1]["nodes"][idx_2].get("inputs", {}),
|
|
66
|
+
)
|
|
67
|
+
for idx_1, idx_2 in rule_indexes
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class ModelPathManager:
|
|
72
|
+
def __init__(self, config_path: str):
|
|
73
|
+
model_paths = {}
|
|
74
|
+
for folder_name, v in load_config_file(config_path).items():
|
|
75
|
+
model_paths[folder_name] = recursive_extract_models(v)
|
|
76
|
+
self.model_paths = model_paths
|
|
77
|
+
|
|
78
|
+
def get_filenames(self, folder_name: str) -> List[str]:
|
|
79
|
+
return self.model_paths.get(folder_name, [])
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class ConfigManager:
|
|
83
|
+
def __init__(self, model_path_config: str, model_rule_config: str):
|
|
84
|
+
self.model_path_manager = ModelPathManager(config_path=model_path_config)
|
|
85
|
+
self.model_rule_config = load_config_file(model_rule_config)
|
|
86
|
+
self.model_rules = ModelRuleManager(
|
|
87
|
+
model_rules=self.model_rule_config["model_rules"]
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def get_filenames(self, folder_name: str) -> List[str]:
|
|
91
|
+
return self.model_path_manager.get_filenames(folder_name)
|
|
92
|
+
|
|
93
|
+
def get_rules(self, class_type: str) -> List[ModelRule]:
|
|
94
|
+
if class_type.startswith("BizyAir_"):
|
|
95
|
+
class_type = class_type[8:]
|
|
96
|
+
return self.model_rules.find_rules(class_type)
|
|
97
|
+
|
|
98
|
+
def get_model_version_id_prefix(self):
|
|
99
|
+
return self.model_rule_config["model_version_config"]["model_version_id_prefix"]
|
|
100
|
+
|
|
101
|
+
def get_cache_config(self):
|
|
102
|
+
return self.model_rule_config.get("cache_config", {})
|
|
103
|
+
|
|
104
|
+
def get_task_api(self):
|
|
105
|
+
return TaskApi(**self.model_rule_config["task_api"])
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
model_path_config = os.path.join(os.path.dirname(__file__), "models.json")
|
|
109
|
+
model_rule_config = os.path.join(os.path.dirname(__file__), "models.yaml")
|
|
110
|
+
config_manager = ConfigManager(
|
|
111
|
+
model_path_config=model_path_config, model_rule_config=model_rule_config
|
|
112
|
+
)
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
{
|
|
2
|
+
"checkpoints": [{
|
|
3
|
+
"sdxl": [
|
|
4
|
+
"counterfeitxl_v25.safetensors",
|
|
5
|
+
"dreamshaperXL_lightningDPMSDE.safetensors",
|
|
6
|
+
"dreamshaperXL_v21TurboDPMSDE.safetensors",
|
|
7
|
+
"HelloWorldXL_v70.safetensors",
|
|
8
|
+
"juggernautXL_v9Rdphoto2Lightning.safetensors",
|
|
9
|
+
"Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors",
|
|
10
|
+
"Juggernaut_X_RunDiffusion_Hyper.safetensors",
|
|
11
|
+
"mannEDreams_v004.safetensors",
|
|
12
|
+
"realisticStockPhoto_v20.safetensors",
|
|
13
|
+
"samaritan3dCartoon_v40SDXL.safetensors"
|
|
14
|
+
],
|
|
15
|
+
"sd15": [
|
|
16
|
+
"dreamshaper_8.safetensors"
|
|
17
|
+
]
|
|
18
|
+
},
|
|
19
|
+
"sd3.5_large.safetensors", "sd3.5_large_turbo.safetensors"],
|
|
20
|
+
"clip_vision": [{
|
|
21
|
+
"models": [
|
|
22
|
+
"CLIP-ViT-H-14-laion2B-s32B-b79K.safetensors"
|
|
23
|
+
],
|
|
24
|
+
"kolors": [
|
|
25
|
+
"pytorch_model.bin"
|
|
26
|
+
]
|
|
27
|
+
}, "sigclip_vision_patch14_384.safetensors"],
|
|
28
|
+
"controlnet": [{
|
|
29
|
+
"kolors": [
|
|
30
|
+
"Kolors-ControlNet-Canny.safetensors",
|
|
31
|
+
"Kolors-ControlNet-Depth.safetensors"
|
|
32
|
+
],
|
|
33
|
+
"sdxl": [
|
|
34
|
+
"diffusion_pytorch_model_promax.safetensors"
|
|
35
|
+
],
|
|
36
|
+
"sd15": [
|
|
37
|
+
"control_v11f1e_sd15_tile.pth"
|
|
38
|
+
],
|
|
39
|
+
"instantid": [
|
|
40
|
+
"diffusion_pytorch_model.safetensors"
|
|
41
|
+
]
|
|
42
|
+
}, "sd3.5_large_controlnet_blur.safetensors", "sd3.5_large_controlnet_canny.safetensors", "sd3.5_large_controlnet_depth.safetensors"],
|
|
43
|
+
"ipadapter": {
|
|
44
|
+
"kolors": [
|
|
45
|
+
"ip_adapter_plus_general.bin"
|
|
46
|
+
]
|
|
47
|
+
},
|
|
48
|
+
"loras": {
|
|
49
|
+
"sdxl": [
|
|
50
|
+
"Cute_Animals.safetensors",
|
|
51
|
+
"watercolor_v1_sdxl_lora.safetensors"
|
|
52
|
+
],
|
|
53
|
+
"flux": [
|
|
54
|
+
"meijia_flux_lora_rank16_bf16.safetensors"
|
|
55
|
+
]
|
|
56
|
+
},
|
|
57
|
+
"unet": ["shuttle-3.1-aesthetic.safetensors",{
|
|
58
|
+
"kolors": [
|
|
59
|
+
"Kolors-Inpainting.safetensors",
|
|
60
|
+
"Kolors.safetensors"
|
|
61
|
+
],
|
|
62
|
+
"flux": [
|
|
63
|
+
"flux1-schnell.sft",
|
|
64
|
+
"flux1-dev.sft",
|
|
65
|
+
"pixelwave-flux1-dev.safetensors",
|
|
66
|
+
"flux1-canny-dev.safetensors",
|
|
67
|
+
"flux1-depth-dev.safetensors",
|
|
68
|
+
"flux1-fill-dev.safetensors"
|
|
69
|
+
]
|
|
70
|
+
}],
|
|
71
|
+
"vae": [{
|
|
72
|
+
"sdxl": [
|
|
73
|
+
"sdxl_vae.safetensors"
|
|
74
|
+
],
|
|
75
|
+
"flux": [
|
|
76
|
+
"ae.sft"
|
|
77
|
+
]
|
|
78
|
+
}, "flux.1-canny-vae.safetensors", "flux.1-depth-vae.safetensors", "flux.1-fill-vae.safetensors"],
|
|
79
|
+
"clip": [
|
|
80
|
+
"clip_l.safetensors",
|
|
81
|
+
"clip_g.safetensors",
|
|
82
|
+
"t5xxl_fp16.safetensors",
|
|
83
|
+
"t5xxl_fp8_e4m3fn.safetensors"
|
|
84
|
+
],
|
|
85
|
+
"upscale_models": [
|
|
86
|
+
"4x_NMKD-Siax_200k.pth",
|
|
87
|
+
"RealESRGAN_x4plus.pth",
|
|
88
|
+
"RealESRGAN_x4plus_anime_6B.pth",
|
|
89
|
+
"RealESRGAN_x2plus.pth",
|
|
90
|
+
"RealESRGAN_x4plus.pth"
|
|
91
|
+
],
|
|
92
|
+
"instantid": [
|
|
93
|
+
"ip-adapter.bin"
|
|
94
|
+
],
|
|
95
|
+
"pulid": [
|
|
96
|
+
"pulid_flux.safetensors"
|
|
97
|
+
],
|
|
98
|
+
"style_models": [
|
|
99
|
+
"flux1-redux-dev.safetensors"
|
|
100
|
+
]
|
|
101
|
+
}
|