bizyengine 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. bizyengine/__init__.py +35 -0
  2. bizyengine/bizy_server/__init__.py +7 -0
  3. bizyengine/bizy_server/api_client.py +763 -0
  4. bizyengine/bizy_server/errno.py +122 -0
  5. bizyengine/bizy_server/error_handler.py +3 -0
  6. bizyengine/bizy_server/execution.py +55 -0
  7. bizyengine/bizy_server/resp.py +24 -0
  8. bizyengine/bizy_server/server.py +898 -0
  9. bizyengine/bizy_server/utils.py +93 -0
  10. bizyengine/bizyair_extras/__init__.py +24 -0
  11. bizyengine/bizyair_extras/nodes_advanced_refluxcontrol.py +62 -0
  12. bizyengine/bizyair_extras/nodes_cogview4.py +31 -0
  13. bizyengine/bizyair_extras/nodes_comfyui_detail_daemon.py +180 -0
  14. bizyengine/bizyair_extras/nodes_comfyui_instantid.py +164 -0
  15. bizyengine/bizyair_extras/nodes_comfyui_layerstyle_advance.py +141 -0
  16. bizyengine/bizyair_extras/nodes_comfyui_pulid_flux.py +88 -0
  17. bizyengine/bizyair_extras/nodes_controlnet.py +50 -0
  18. bizyengine/bizyair_extras/nodes_custom_sampler.py +130 -0
  19. bizyengine/bizyair_extras/nodes_dataset.py +99 -0
  20. bizyengine/bizyair_extras/nodes_differential_diffusion.py +16 -0
  21. bizyengine/bizyair_extras/nodes_flux.py +69 -0
  22. bizyengine/bizyair_extras/nodes_image_utils.py +93 -0
  23. bizyengine/bizyair_extras/nodes_ip2p.py +20 -0
  24. bizyengine/bizyair_extras/nodes_ipadapter_plus/__init__.py +1 -0
  25. bizyengine/bizyair_extras/nodes_ipadapter_plus/nodes_ipadapter_plus.py +1598 -0
  26. bizyengine/bizyair_extras/nodes_janus_pro.py +81 -0
  27. bizyengine/bizyair_extras/nodes_kolors_mz/__init__.py +86 -0
  28. bizyengine/bizyair_extras/nodes_model_advanced.py +62 -0
  29. bizyengine/bizyair_extras/nodes_sd3.py +52 -0
  30. bizyengine/bizyair_extras/nodes_segment_anything.py +256 -0
  31. bizyengine/bizyair_extras/nodes_segment_anything_utils.py +134 -0
  32. bizyengine/bizyair_extras/nodes_testing_utils.py +139 -0
  33. bizyengine/bizyair_extras/nodes_trellis.py +199 -0
  34. bizyengine/bizyair_extras/nodes_ultimatesdupscale.py +137 -0
  35. bizyengine/bizyair_extras/nodes_upscale_model.py +32 -0
  36. bizyengine/bizyair_extras/nodes_wan_video.py +49 -0
  37. bizyengine/bizyair_extras/oauth_callback/main.py +118 -0
  38. bizyengine/core/__init__.py +8 -0
  39. bizyengine/core/commands/__init__.py +1 -0
  40. bizyengine/core/commands/base.py +27 -0
  41. bizyengine/core/commands/invoker.py +4 -0
  42. bizyengine/core/commands/processors/model_hosting_processor.py +0 -0
  43. bizyengine/core/commands/processors/prompt_processor.py +123 -0
  44. bizyengine/core/commands/servers/model_server.py +0 -0
  45. bizyengine/core/commands/servers/prompt_server.py +234 -0
  46. bizyengine/core/common/__init__.py +8 -0
  47. bizyengine/core/common/caching.py +198 -0
  48. bizyengine/core/common/client.py +262 -0
  49. bizyengine/core/common/env_var.py +101 -0
  50. bizyengine/core/common/utils.py +93 -0
  51. bizyengine/core/configs/conf.py +112 -0
  52. bizyengine/core/configs/models.json +101 -0
  53. bizyengine/core/configs/models.yaml +329 -0
  54. bizyengine/core/data_types.py +20 -0
  55. bizyengine/core/image_utils.py +288 -0
  56. bizyengine/core/nodes_base.py +159 -0
  57. bizyengine/core/nodes_io.py +97 -0
  58. bizyengine/core/path_utils/__init__.py +9 -0
  59. bizyengine/core/path_utils/path_manager.py +276 -0
  60. bizyengine/core/path_utils/utils.py +34 -0
  61. bizyengine/misc/__init__.py +0 -0
  62. bizyengine/misc/auth.py +83 -0
  63. bizyengine/misc/llm.py +431 -0
  64. bizyengine/misc/mzkolors.py +93 -0
  65. bizyengine/misc/nodes.py +1208 -0
  66. bizyengine/misc/nodes_controlnet_aux.py +491 -0
  67. bizyengine/misc/nodes_controlnet_union_sdxl.py +171 -0
  68. bizyengine/misc/route_sam.py +60 -0
  69. bizyengine/misc/segment_anything.py +276 -0
  70. bizyengine/misc/supernode.py +182 -0
  71. bizyengine/misc/utils.py +218 -0
  72. bizyengine/version.txt +1 -0
  73. bizyengine-0.4.2.dist-info/METADATA +12 -0
  74. bizyengine-0.4.2.dist-info/RECORD +76 -0
  75. bizyengine-0.4.2.dist-info/WHEEL +5 -0
  76. bizyengine-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,262 @@
1
+ import json
2
+ import os
3
+ import pprint
4
+ import urllib.error
5
+ import urllib.request
6
+ import warnings
7
+ from abc import ABC, abstractmethod
8
+ from collections import defaultdict
9
+ from typing import Any, Union
10
+
11
+ import aiohttp
12
+ import comfy
13
+
14
+ from .caching import CacheManager
15
+
16
+ __all__ = ["send_request"]
17
+
18
+ from dataclasses import dataclass, field
19
+
20
+ from .env_var import BIZYAIR_API_KEY, BIZYAIR_DEBUG, BIZYAIR_SERVER_ADDRESS
21
+
22
+ IS_API_KEY_VALID = None
23
+
24
+
25
+ @dataclass
26
+ class APIKeyState:
27
+ current_api_key: str = field(default=None)
28
+ is_valid: bool = field(default=None)
29
+
30
+
31
+ api_key_state = APIKeyState()
32
+
33
+
34
+ def set_api_key(api_key: str = "YOUR_API_KEY", override: bool = False):
35
+ global BIZYAIR_API_KEY, api_key_state
36
+ if api_key_state.is_valid is not None and not override:
37
+ warnings.warn("API key has already been set and will not be overridden.")
38
+ return
39
+ if validate_api_key(api_key):
40
+ BIZYAIR_API_KEY = api_key
41
+ api_key_state.is_valid = True
42
+ print("\033[92mAPI key is set successfully.\033[0m")
43
+ else:
44
+ api_key_state.is_valid = False
45
+ warnings.warn("Invalid API key provided.")
46
+
47
+
48
+ def validate_api_key(api_key: str = None) -> bool:
49
+ global api_key_state
50
+ if not api_key or not isinstance(api_key, str):
51
+ warnings.warn("API key is not set.")
52
+ return False
53
+ # if api_key_state.current_api_key == api_key and api_key_state.is_valid is not None:
54
+ # return api_key_state.is_valid
55
+ api_key_state.current_api_key = api_key
56
+ url = f"{BIZYAIR_SERVER_ADDRESS}/user/info"
57
+ headers = {"accept": "application/json", "authorization": f"Bearer {api_key}"}
58
+ try:
59
+ response_data = send_request(
60
+ method="GET", url=url, headers=headers, callback=None
61
+ )
62
+ if "message" not in response_data or response_data["message"] != "Ok":
63
+ api_key_state.is_valid = False
64
+ raise ValueError(
65
+ f"\033[91mAPI key validation failed. API Key: {api_key}\033[0m"
66
+ )
67
+ else:
68
+ api_key_state.is_valid = True
69
+ except ConnectionError as ce:
70
+ api_key_state.is_valid = False
71
+ raise ValueError(f"\033[91mConnection error: {ce}\033[0m")
72
+ except PermissionError as pe:
73
+ api_key_state.is_valid = False
74
+ raise ValueError(
75
+ f"\033[91mError validating API key: {api_key}, error: {pe}\033[0m"
76
+ )
77
+ except Exception as e:
78
+ api_key_state.is_valid = False
79
+ raise ValueError(f"\033[91mOther error: {e}\033[0m")
80
+ return api_key_state.is_valid
81
+
82
+
83
+ def get_api_key() -> str:
84
+ global BIZYAIR_API_KEY
85
+ try:
86
+ validate_api_key(BIZYAIR_API_KEY)
87
+ except Exception as e:
88
+ print(str(e))
89
+ raise ValueError(str(e))
90
+ return BIZYAIR_API_KEY
91
+
92
+
93
+ def _headers():
94
+ headers = {
95
+ "accept": "application/json",
96
+ "content-type": "application/json",
97
+ "authorization": f"Bearer {get_api_key()}",
98
+ }
99
+ return headers
100
+
101
+
102
+ def process_response_data(response_data: dict) -> dict:
103
+ # Check if 'result' key exists, indicating a cloud response
104
+ if "result" in response_data:
105
+ try:
106
+ msg = json.loads(response_data["result"])
107
+ except json.JSONDecodeError:
108
+ raise ValueError(f"Failed to decode JSON from response. {response_data=}")
109
+ else:
110
+ # Handle local response directly
111
+ msg = response_data
112
+
113
+ return msg # Return processed data or modify as needed
114
+
115
+
116
+ def send_request(
117
+ method: str = "POST",
118
+ url: str = None,
119
+ data: bytes = None,
120
+ verbose=False,
121
+ callback: callable = process_response_data,
122
+ response_handler: callable = json.loads,
123
+ cache_manager: CacheManager = None,
124
+ **kwargs,
125
+ ) -> Union[dict, Any]:
126
+ try:
127
+ headers = kwargs.pop("headers") if "headers" in kwargs else _headers()
128
+ headers["User-Agent"] = "BizyAir Client"
129
+
130
+ req = urllib.request.Request(
131
+ url, data=data, headers=headers, method=method, **kwargs
132
+ )
133
+ with urllib.request.urlopen(req) as response:
134
+ response_data = response.read().decode("utf-8")
135
+ except urllib.error.URLError as e:
136
+ error_message = str(e)
137
+ response_body = e.read().decode("utf-8") if hasattr(e, "read") else "N/A"
138
+ if verbose:
139
+ print(f"URLError encountered: {error_message}")
140
+ print(f"Response Body: {response_data}")
141
+ code, message = "N/A", "N/A"
142
+ try:
143
+ response_dict = json.loads(response_body)
144
+ if isinstance(response_dict, dict):
145
+ code = response_dict.get("code", "N/A")
146
+ message = response_dict.get("message", "N/A")
147
+
148
+ except json.JSONDecodeError:
149
+ if verbose:
150
+ print("Failed to decode response body as JSON.")
151
+
152
+ if "Unauthorized" in error_message:
153
+ raise PermissionError(
154
+ "Key is invalid, please refer to https://cloud.siliconflow.cn to get the API key.\n"
155
+ "If you have the key, please click the 'API Key' button at the bottom right to set the key."
156
+ )
157
+ elif code != "N/A" and message != "N/A":
158
+ raise ConnectionError(
159
+ f"Failed to handle your request: {error_message}.\n"
160
+ + f" Error code: {code}.\n"
161
+ + f" Error message: {message}\n."
162
+ + "The cause of this issue may be incorrect parameter status or ongoing background tasks. \n"
163
+ + "If retrying after waiting for a while still does not resolve the issue, please report it to "
164
+ + "Bizyair's official support."
165
+ )
166
+ else:
167
+ common_sites = [
168
+ "https://www.baidu.com",
169
+ "https://www.bing.com",
170
+ "https://www.alibaba.com",
171
+ ]
172
+ results = {}
173
+ for site in common_sites:
174
+ success = False
175
+ try:
176
+
177
+ class NoRedirectHandler(urllib.request.HTTPRedirectHandler):
178
+ def redirect_request(self, req, fp, code, msg, headers, newurl):
179
+ return None
180
+
181
+ opener = urllib.request.build_opener(NoRedirectHandler())
182
+ response = opener.open(site, timeout=5)
183
+ success = 200 <= response.getcode() < 400
184
+ except urllib.error.HTTPError as e:
185
+ success = 200 <= e.code < 400
186
+ except (urllib.error.URLError, TimeoutError):
187
+ pass
188
+ results[site] = "Success" if success else "Failed"
189
+ raise ConnectionError(
190
+ f"Failed to connect to the server: {url}.\n"
191
+ + "The connection attempts to the public sites return the following results:\n"
192
+ + "\n".join(
193
+ [f" {site}: {result}" for site, result in results.items()]
194
+ )
195
+ + "\nPlease check the network connection."
196
+ )
197
+
198
+ if response_handler:
199
+ response_data = response_handler(response_data)
200
+ if callback:
201
+ return callback(response_data)
202
+ return response_data
203
+
204
+
205
+ async def async_send_request(
206
+ method: str = "POST",
207
+ url: str = None,
208
+ data: bytes = None,
209
+ verbose=False,
210
+ callback: callable = process_response_data,
211
+ **kwargs,
212
+ ) -> dict:
213
+ headers = kwargs.pop("headers") if "headers" in kwargs else _headers()
214
+ try:
215
+ async with aiohttp.ClientSession() as session:
216
+ async with session.request(
217
+ method, url, data=data, headers=headers, **kwargs
218
+ ) as response:
219
+ response_data = await response.text()
220
+ if response.status != 200:
221
+ error_message = f"HTTP Status {response.status}"
222
+ if verbose:
223
+ print(f"Error encountered: {error_message}")
224
+ if response.status == 401:
225
+ raise PermissionError(
226
+ "Key is invalid, please refer to https://cloud.siliconflow.cn to get the API key.\n"
227
+ "If you have the key, please click the 'BizyAir Key' button at the bottom right to set the key."
228
+ )
229
+ else:
230
+ raise ConnectionError(
231
+ f"Failed to connect to the server: {error_message}.\n"
232
+ + "Please check your API key and ensure the server is reachable.\n"
233
+ + "Also, verify your network settings and disable any proxies if necessary.\n"
234
+ + "After checking, please restart the ComfyUI service."
235
+ )
236
+ if callback:
237
+ return callback(json.loads(response_data))
238
+ return json.loads(response_data)
239
+ except aiohttp.ClientError as e:
240
+ print(f"Error fetching data: {e}")
241
+ return {}
242
+ except Exception as e:
243
+ print(f"Error fetching data: {str(e)}")
244
+ return {}
245
+
246
+
247
+ def fetch_models_by_type(
248
+ url: str, model_type: str, *, method="GET", verbose=False
249
+ ) -> dict:
250
+ if not validate_api_key(BIZYAIR_API_KEY):
251
+ return {}
252
+
253
+ payload = {"type": model_type}
254
+ if BIZYAIR_DEBUG:
255
+ pprint.pprint(payload)
256
+ msg = send_request(
257
+ method=method,
258
+ url=url,
259
+ data=json.dumps(payload).encode("utf-8"),
260
+ verbose=verbose,
261
+ )
262
+ return msg
@@ -0,0 +1,101 @@
1
+ import configparser
2
+ import os
3
+ from os import environ
4
+ from pathlib import Path
5
+
6
+ BIZYAIR_COMFYUI_PATH = Path(os.environ.get("BIZYAIR_COMFYUI_PATH", "./"))
7
+ print(f"\033[92m[BizyAir]\033[0m BizyAir ComfyUI Plugin: {str(BIZYAIR_COMFYUI_PATH)}")
8
+
9
+
10
+ class ServerAddress:
11
+ _instance = None
12
+ _address = None
13
+
14
+ def __new__(cls, *args, **kwargs):
15
+ if cls._instance is None:
16
+ cls._instance = super(ServerAddress, cls).__new__(cls)
17
+ return cls._instance
18
+
19
+ def __init__(self, address):
20
+ self._address = address
21
+
22
+ @property
23
+ def address(self):
24
+ return self._address
25
+
26
+ @address.setter
27
+ def address(self, new_address):
28
+ self._address = new_address
29
+
30
+ def __str__(self):
31
+ return self._address
32
+
33
+
34
+ def env(key, type_, default=None):
35
+ if key not in environ:
36
+ return default
37
+
38
+ val = environ[key]
39
+
40
+ if type_ == str:
41
+ return val
42
+ if type_ == bool:
43
+ if val.lower() in ["1", "true", "yes", "y", "ok", "on"]:
44
+ return True
45
+ if val.lower() in ["0", "false", "no", "n", "nok", "off"]:
46
+ return False
47
+ raise ValueError(
48
+ "Invalid environment variable '%s' (expected a boolean): '%s'" % (key, val)
49
+ )
50
+ if type_ == int:
51
+ try:
52
+ return int(val)
53
+ except ValueError:
54
+ raise ValueError(
55
+ "Invalid environment variable '%s' (expected an integer): '%s'"
56
+ % (key, val)
57
+ ) from None
58
+ raise ValueError("The requested type '%r' is not supported" % type_)
59
+
60
+
61
+ def load_api_key():
62
+
63
+ file_path = BIZYAIR_COMFYUI_PATH / "api_key.ini"
64
+
65
+ if file_path.is_file() and file_path.exists():
66
+ config = configparser.ConfigParser()
67
+ config.read(file_path)
68
+ api_key: str = config.get("auth", "api_key", fallback="").strip().strip("'\"")
69
+ has_key = api_key.startswith("sk-")
70
+ return has_key, api_key
71
+ else:
72
+ return False, None
73
+
74
+
75
+ def create_api_key_file(api_key):
76
+ config = configparser.ConfigParser()
77
+ config["auth"] = {"api_key": api_key}
78
+ file_path = BIZYAIR_COMFYUI_PATH / "api_key.ini"
79
+ try:
80
+ with open(file_path, "w", encoding="utf-8") as configfile:
81
+ config.write(configfile)
82
+ except Exception as e:
83
+ raise Exception(f"An error occurred when save the key: {e}")
84
+
85
+
86
+ # production:
87
+ # service_address: https://bizyair-api.siliconflow.cn/x/v1
88
+ # uat:
89
+ # service_address: https://uat-bizyair-api.siliconflow.cn/x/v1
90
+ _BIZYAIR_SERVER_ADDRESS = os.getenv(
91
+ "BIZYAIR_SERVER_ADDRESS", "https://bizyair-api.siliconflow.cn/x/v1"
92
+ )
93
+ BIZYAIR_SERVER_ADDRESS = ServerAddress(_BIZYAIR_SERVER_ADDRESS)
94
+ BIZYAIR_API_KEY = env("BIZYAIR_API_KEY", str, load_api_key()[1])
95
+ # Development Settings
96
+ BIZYAIR_DEV_REQUEST_URL = env("BIZYAIR_DEV_REQUEST_URL", str, None)
97
+ BIZYAIR_DEBUG = env("BIZYAIR_DEBUG", bool, False)
98
+ BIZYAIR_DEV_GET_TASK_RESULT_SERVER = env(
99
+ "BIZYAIR_DEV_GET_TASK_RESULT_SERVER", str, None
100
+ )
101
+ BIZYAIR_PRODUCTION_TEST = env("BIZYAIR_PRODUCTION_TEST", str, None)
@@ -0,0 +1,93 @@
1
+ import copy
2
+ import json
3
+ import os
4
+ from typing import Any, List
5
+
6
+ import torch
7
+ import yaml
8
+
9
+
10
+ def truncate_long_strings(obj, max_length=50):
11
+ if isinstance(obj, str):
12
+ return obj if len(obj) <= max_length else obj[:max_length] + "..."
13
+ elif isinstance(obj, dict):
14
+ return {k: truncate_long_strings(v, max_length) for k, v in obj.items()}
15
+ elif isinstance(obj, list):
16
+ return [truncate_long_strings(v, max_length) for v in obj]
17
+ elif isinstance(obj, tuple):
18
+ return tuple(truncate_long_strings(v, max_length) for v in obj)
19
+ elif isinstance(obj, torch.Tensor):
20
+ return obj.shape, obj.dtype, obj.device
21
+ else:
22
+ return obj
23
+
24
+
25
+ def deepcopy_except_tensor(obj, exclude_types=[torch.Tensor]):
26
+ return deepcopy_except_types(obj=obj, exclude_types=exclude_types)
27
+
28
+
29
+ def deepcopy_except_types(obj, exclude_types):
30
+ """
31
+ Recursively copy an object, excluding specified data types.
32
+
33
+ :param obj: The object to be copied
34
+ :param exclude_types: A list of data types to be excluded from deep copying
35
+ :return: The copied object
36
+ """
37
+ if any(isinstance(obj, t) for t in exclude_types):
38
+ return obj # Return the object directly without deep copying
39
+ elif isinstance(obj, (list, tuple)):
40
+ return type(obj)(deepcopy_except_types(item, exclude_types) for item in obj)
41
+ elif isinstance(obj, dict):
42
+ return {
43
+ deepcopy_except_types(key, exclude_types): deepcopy_except_types(
44
+ value, exclude_types
45
+ )
46
+ for key, value in obj.items()
47
+ }
48
+ else:
49
+ return copy.deepcopy(obj)
50
+
51
+
52
+ def recursive_extract_models(data: Any, prefix_path: str = "") -> List[str]:
53
+ def merge_paths(base_path: str, new_path: Any) -> str:
54
+ if not isinstance(new_path, str):
55
+ return base_path
56
+ return f"{base_path}/{new_path}" if base_path else new_path
57
+
58
+ results: List[str] = []
59
+ if isinstance(data, dict):
60
+ for key, value in data.items():
61
+ new_prefix = merge_paths(prefix_path, key)
62
+ results.extend(recursive_extract_models(value, new_prefix))
63
+ elif isinstance(data, list):
64
+ for item in data:
65
+ new_prefix = merge_paths(prefix_path, item)
66
+ results.extend(recursive_extract_models(item, new_prefix))
67
+ elif isinstance(data, str) and prefix_path.endswith(data):
68
+ return [prefix_path]
69
+
70
+ return results
71
+
72
+
73
+ def _load_yaml_config(file_path):
74
+ with open(file_path, "r") as file:
75
+ config = yaml.safe_load(file)
76
+ return config
77
+
78
+
79
+ def _load_json_config(file_path: str) -> dict:
80
+ with open(file_path, "r") as file:
81
+ data = json.load(file)
82
+ return data
83
+
84
+
85
+ def load_config_file(file_path: str) -> dict:
86
+ if not os.path.exists(file_path):
87
+ raise FileNotFoundError(f"The file {file_path} does not exist.")
88
+ if file_path.endswith(".json"):
89
+ return _load_json_config(file_path)
90
+ elif file_path.endswith(".yaml"):
91
+ return _load_yaml_config(file_path)
92
+ else:
93
+ raise ValueError(f"Unsupported file extension: {file_path}")
@@ -0,0 +1,112 @@
1
+ import os
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass
4
+ from typing import List, Tuple
5
+
6
+ from bizyengine.core.common.utils import load_config_file, recursive_extract_models
7
+
8
+
9
+ @dataclass
10
+ class ModelRule:
11
+ mode_type: str
12
+ base_model: str
13
+ describe: str
14
+ score: int
15
+ route: str
16
+ class_type: str
17
+ inputs: dict
18
+
19
+
20
+ @dataclass
21
+ class TaskApi:
22
+ task_result_endpoint: str
23
+
24
+
25
+ class ModelRuleManager:
26
+ def __init__(self, model_rules: list[dict]):
27
+ self.model_rules = model_rules
28
+ self.validate()
29
+ self.gen_model_rule_index_mapping()
30
+
31
+ def gen_model_rule_index_mapping(self):
32
+ self.model_rule_index_mapping = defaultdict(list)
33
+
34
+ for idx_1, rule in enumerate(self.model_rules):
35
+ for idx_2, node in enumerate(rule["nodes"]):
36
+ self.model_rule_index_mapping[node["class_type"]].append((idx_1, idx_2))
37
+
38
+ def validate(self):
39
+ for rule in self.model_rules:
40
+ self._validate_rule(rule)
41
+
42
+ def _validate_rule(self, rule: dict):
43
+ if "mode_type" not in rule:
44
+ raise ValueError("mode_type is required")
45
+ if "base_model" not in rule:
46
+ raise ValueError("base_model is required")
47
+ if "route" not in rule:
48
+ raise ValueError("route is required")
49
+ if "nodes" not in rule:
50
+ raise ValueError("nodes is required")
51
+
52
+ def find_rule_indexes(self, class_type: str) -> List[Tuple[int, int]]:
53
+ return self.model_rule_index_mapping[class_type]
54
+
55
+ def find_rules(self, class_type: str) -> List[ModelRule]:
56
+ rule_indexes = self.find_rule_indexes(class_type)
57
+ return [
58
+ ModelRule(
59
+ mode_type=self.model_rules[idx_1]["mode_type"],
60
+ base_model=self.model_rules[idx_1]["base_model"],
61
+ describe=self.model_rules[idx_1]["describe"],
62
+ score=self.model_rules[idx_1]["score"],
63
+ route=self.model_rules[idx_1]["route"],
64
+ class_type=class_type,
65
+ inputs=self.model_rules[idx_1]["nodes"][idx_2].get("inputs", {}),
66
+ )
67
+ for idx_1, idx_2 in rule_indexes
68
+ ]
69
+
70
+
71
+ class ModelPathManager:
72
+ def __init__(self, config_path: str):
73
+ model_paths = {}
74
+ for folder_name, v in load_config_file(config_path).items():
75
+ model_paths[folder_name] = recursive_extract_models(v)
76
+ self.model_paths = model_paths
77
+
78
+ def get_filenames(self, folder_name: str) -> List[str]:
79
+ return self.model_paths.get(folder_name, [])
80
+
81
+
82
+ class ConfigManager:
83
+ def __init__(self, model_path_config: str, model_rule_config: str):
84
+ self.model_path_manager = ModelPathManager(config_path=model_path_config)
85
+ self.model_rule_config = load_config_file(model_rule_config)
86
+ self.model_rules = ModelRuleManager(
87
+ model_rules=self.model_rule_config["model_rules"]
88
+ )
89
+
90
+ def get_filenames(self, folder_name: str) -> List[str]:
91
+ return self.model_path_manager.get_filenames(folder_name)
92
+
93
+ def get_rules(self, class_type: str) -> List[ModelRule]:
94
+ if class_type.startswith("BizyAir_"):
95
+ class_type = class_type[8:]
96
+ return self.model_rules.find_rules(class_type)
97
+
98
+ def get_model_version_id_prefix(self):
99
+ return self.model_rule_config["model_version_config"]["model_version_id_prefix"]
100
+
101
+ def get_cache_config(self):
102
+ return self.model_rule_config.get("cache_config", {})
103
+
104
+ def get_task_api(self):
105
+ return TaskApi(**self.model_rule_config["task_api"])
106
+
107
+
108
+ model_path_config = os.path.join(os.path.dirname(__file__), "models.json")
109
+ model_rule_config = os.path.join(os.path.dirname(__file__), "models.yaml")
110
+ config_manager = ConfigManager(
111
+ model_path_config=model_path_config, model_rule_config=model_rule_config
112
+ )
@@ -0,0 +1,101 @@
1
+ {
2
+ "checkpoints": [{
3
+ "sdxl": [
4
+ "counterfeitxl_v25.safetensors",
5
+ "dreamshaperXL_lightningDPMSDE.safetensors",
6
+ "dreamshaperXL_v21TurboDPMSDE.safetensors",
7
+ "HelloWorldXL_v70.safetensors",
8
+ "juggernautXL_v9Rdphoto2Lightning.safetensors",
9
+ "Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors",
10
+ "Juggernaut_X_RunDiffusion_Hyper.safetensors",
11
+ "mannEDreams_v004.safetensors",
12
+ "realisticStockPhoto_v20.safetensors",
13
+ "samaritan3dCartoon_v40SDXL.safetensors"
14
+ ],
15
+ "sd15": [
16
+ "dreamshaper_8.safetensors"
17
+ ]
18
+ },
19
+ "sd3.5_large.safetensors", "sd3.5_large_turbo.safetensors"],
20
+ "clip_vision": [{
21
+ "models": [
22
+ "CLIP-ViT-H-14-laion2B-s32B-b79K.safetensors"
23
+ ],
24
+ "kolors": [
25
+ "pytorch_model.bin"
26
+ ]
27
+ }, "sigclip_vision_patch14_384.safetensors"],
28
+ "controlnet": [{
29
+ "kolors": [
30
+ "Kolors-ControlNet-Canny.safetensors",
31
+ "Kolors-ControlNet-Depth.safetensors"
32
+ ],
33
+ "sdxl": [
34
+ "diffusion_pytorch_model_promax.safetensors"
35
+ ],
36
+ "sd15": [
37
+ "control_v11f1e_sd15_tile.pth"
38
+ ],
39
+ "instantid": [
40
+ "diffusion_pytorch_model.safetensors"
41
+ ]
42
+ }, "sd3.5_large_controlnet_blur.safetensors", "sd3.5_large_controlnet_canny.safetensors", "sd3.5_large_controlnet_depth.safetensors"],
43
+ "ipadapter": {
44
+ "kolors": [
45
+ "ip_adapter_plus_general.bin"
46
+ ]
47
+ },
48
+ "loras": {
49
+ "sdxl": [
50
+ "Cute_Animals.safetensors",
51
+ "watercolor_v1_sdxl_lora.safetensors"
52
+ ],
53
+ "flux": [
54
+ "meijia_flux_lora_rank16_bf16.safetensors"
55
+ ]
56
+ },
57
+ "unet": ["shuttle-3.1-aesthetic.safetensors",{
58
+ "kolors": [
59
+ "Kolors-Inpainting.safetensors",
60
+ "Kolors.safetensors"
61
+ ],
62
+ "flux": [
63
+ "flux1-schnell.sft",
64
+ "flux1-dev.sft",
65
+ "pixelwave-flux1-dev.safetensors",
66
+ "flux1-canny-dev.safetensors",
67
+ "flux1-depth-dev.safetensors",
68
+ "flux1-fill-dev.safetensors"
69
+ ]
70
+ }],
71
+ "vae": [{
72
+ "sdxl": [
73
+ "sdxl_vae.safetensors"
74
+ ],
75
+ "flux": [
76
+ "ae.sft"
77
+ ]
78
+ }, "flux.1-canny-vae.safetensors", "flux.1-depth-vae.safetensors", "flux.1-fill-vae.safetensors"],
79
+ "clip": [
80
+ "clip_l.safetensors",
81
+ "clip_g.safetensors",
82
+ "t5xxl_fp16.safetensors",
83
+ "t5xxl_fp8_e4m3fn.safetensors"
84
+ ],
85
+ "upscale_models": [
86
+ "4x_NMKD-Siax_200k.pth",
87
+ "RealESRGAN_x4plus.pth",
88
+ "RealESRGAN_x4plus_anime_6B.pth",
89
+ "RealESRGAN_x2plus.pth",
90
+ "RealESRGAN_x4plus.pth"
91
+ ],
92
+ "instantid": [
93
+ "ip-adapter.bin"
94
+ ],
95
+ "pulid": [
96
+ "pulid_flux.safetensors"
97
+ ],
98
+ "style_models": [
99
+ "flux1-redux-dev.safetensors"
100
+ ]
101
+ }