bizyengine 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. bizyengine/__init__.py +35 -0
  2. bizyengine/bizy_server/__init__.py +7 -0
  3. bizyengine/bizy_server/api_client.py +763 -0
  4. bizyengine/bizy_server/errno.py +122 -0
  5. bizyengine/bizy_server/error_handler.py +3 -0
  6. bizyengine/bizy_server/execution.py +55 -0
  7. bizyengine/bizy_server/resp.py +24 -0
  8. bizyengine/bizy_server/server.py +898 -0
  9. bizyengine/bizy_server/utils.py +93 -0
  10. bizyengine/bizyair_extras/__init__.py +24 -0
  11. bizyengine/bizyair_extras/nodes_advanced_refluxcontrol.py +62 -0
  12. bizyengine/bizyair_extras/nodes_cogview4.py +31 -0
  13. bizyengine/bizyair_extras/nodes_comfyui_detail_daemon.py +180 -0
  14. bizyengine/bizyair_extras/nodes_comfyui_instantid.py +164 -0
  15. bizyengine/bizyair_extras/nodes_comfyui_layerstyle_advance.py +141 -0
  16. bizyengine/bizyair_extras/nodes_comfyui_pulid_flux.py +88 -0
  17. bizyengine/bizyair_extras/nodes_controlnet.py +50 -0
  18. bizyengine/bizyair_extras/nodes_custom_sampler.py +130 -0
  19. bizyengine/bizyair_extras/nodes_dataset.py +99 -0
  20. bizyengine/bizyair_extras/nodes_differential_diffusion.py +16 -0
  21. bizyengine/bizyair_extras/nodes_flux.py +69 -0
  22. bizyengine/bizyair_extras/nodes_image_utils.py +93 -0
  23. bizyengine/bizyair_extras/nodes_ip2p.py +20 -0
  24. bizyengine/bizyair_extras/nodes_ipadapter_plus/__init__.py +1 -0
  25. bizyengine/bizyair_extras/nodes_ipadapter_plus/nodes_ipadapter_plus.py +1598 -0
  26. bizyengine/bizyair_extras/nodes_janus_pro.py +81 -0
  27. bizyengine/bizyair_extras/nodes_kolors_mz/__init__.py +86 -0
  28. bizyengine/bizyair_extras/nodes_model_advanced.py +62 -0
  29. bizyengine/bizyair_extras/nodes_sd3.py +52 -0
  30. bizyengine/bizyair_extras/nodes_segment_anything.py +256 -0
  31. bizyengine/bizyair_extras/nodes_segment_anything_utils.py +134 -0
  32. bizyengine/bizyair_extras/nodes_testing_utils.py +139 -0
  33. bizyengine/bizyair_extras/nodes_trellis.py +199 -0
  34. bizyengine/bizyair_extras/nodes_ultimatesdupscale.py +137 -0
  35. bizyengine/bizyair_extras/nodes_upscale_model.py +32 -0
  36. bizyengine/bizyair_extras/nodes_wan_video.py +49 -0
  37. bizyengine/bizyair_extras/oauth_callback/main.py +118 -0
  38. bizyengine/core/__init__.py +8 -0
  39. bizyengine/core/commands/__init__.py +1 -0
  40. bizyengine/core/commands/base.py +27 -0
  41. bizyengine/core/commands/invoker.py +4 -0
  42. bizyengine/core/commands/processors/model_hosting_processor.py +0 -0
  43. bizyengine/core/commands/processors/prompt_processor.py +123 -0
  44. bizyengine/core/commands/servers/model_server.py +0 -0
  45. bizyengine/core/commands/servers/prompt_server.py +234 -0
  46. bizyengine/core/common/__init__.py +8 -0
  47. bizyengine/core/common/caching.py +198 -0
  48. bizyengine/core/common/client.py +262 -0
  49. bizyengine/core/common/env_var.py +101 -0
  50. bizyengine/core/common/utils.py +93 -0
  51. bizyengine/core/configs/conf.py +112 -0
  52. bizyengine/core/configs/models.json +101 -0
  53. bizyengine/core/configs/models.yaml +329 -0
  54. bizyengine/core/data_types.py +20 -0
  55. bizyengine/core/image_utils.py +288 -0
  56. bizyengine/core/nodes_base.py +159 -0
  57. bizyengine/core/nodes_io.py +97 -0
  58. bizyengine/core/path_utils/__init__.py +9 -0
  59. bizyengine/core/path_utils/path_manager.py +276 -0
  60. bizyengine/core/path_utils/utils.py +34 -0
  61. bizyengine/misc/__init__.py +0 -0
  62. bizyengine/misc/auth.py +83 -0
  63. bizyengine/misc/llm.py +431 -0
  64. bizyengine/misc/mzkolors.py +93 -0
  65. bizyengine/misc/nodes.py +1208 -0
  66. bizyengine/misc/nodes_controlnet_aux.py +491 -0
  67. bizyengine/misc/nodes_controlnet_union_sdxl.py +171 -0
  68. bizyengine/misc/route_sam.py +60 -0
  69. bizyengine/misc/segment_anything.py +276 -0
  70. bizyengine/misc/supernode.py +182 -0
  71. bizyengine/misc/utils.py +218 -0
  72. bizyengine/version.txt +1 -0
  73. bizyengine-0.4.2.dist-info/METADATA +12 -0
  74. bizyengine-0.4.2.dist-info/RECORD +76 -0
  75. bizyengine-0.4.2.dist-info/WHEEL +5 -0
  76. bizyengine-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,159 @@
1
+ import importlib
2
+ import logging
3
+ import warnings
4
+ from functools import wraps
5
+ from typing import List
6
+
7
+ from .data_types import is_send_request_datatype
8
+ from .nodes_io import BizyAirNodeIO, create_node_data
9
+
10
+ try:
11
+ comfy_nodes = importlib.import_module("nodes")
12
+ except ModuleNotFoundError:
13
+ warnings.warn("Importing comfyui.nodes failed!")
14
+ comfy_nodes = type("nodes", (object,), {"NODE_DISPLAY_NAME_MAPPINGS": {}})
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ LOGO = "☁️"
20
+ PREFIX = f"BizyAir"
21
+ NODE_CLASS_MAPPINGS = {}
22
+ NODE_DISPLAY_NAME_MAPPINGS = {}
23
+
24
+
25
+ def to_camel_case(s):
26
+ return "".join(word.capitalize() for word in s.split("_"))
27
+
28
+
29
+ def validate_category(cls, prefix):
30
+ assert cls.CATEGORY.startswith(f"☁️{prefix}")
31
+
32
+
33
+ def register_node(cls, prefix):
34
+ class_name = (
35
+ f"{prefix}_{cls.__name__}"
36
+ if not cls.__name__.startswith(prefix)
37
+ else cls.__name__
38
+ )
39
+ logger.debug(
40
+ f"Class: {cls}, Name: {class_name}, Has DISPLAY_NAME: {hasattr(cls, 'NODE_DISPLAY_NAME')}"
41
+ )
42
+
43
+ if hasattr(cls, "NODE_DISPLAY_NAME"):
44
+ display_name = cls.NODE_DISPLAY_NAME
45
+ if not display_name.startswith(f"{LOGO}{prefix}"):
46
+ display_name = f"{LOGO}{prefix} {display_name}"
47
+ else:
48
+ base_name = class_name[len(prefix) + 1 :]
49
+ if base_name in comfy_nodes.NODE_DISPLAY_NAME_MAPPINGS:
50
+ display_name = (
51
+ f"{LOGO}{prefix} {comfy_nodes.NODE_DISPLAY_NAME_MAPPINGS[base_name]}"
52
+ )
53
+ else:
54
+ display_name = f"{LOGO}{prefix} {base_name}"
55
+ logger.warning(
56
+ f"Display name '{display_name}' might differ from the native display name."
57
+ )
58
+
59
+ NODE_CLASS_MAPPINGS[class_name] = cls
60
+ NODE_DISPLAY_NAME_MAPPINGS[class_name] = display_name
61
+
62
+
63
+ def ensure_unique_id(org_func, original_has_unique_id=False):
64
+ @wraps(org_func)
65
+ def new_func(self, **kwargs):
66
+ if original_has_unique_id:
67
+ self._assigned_id = kwargs.get("unique_id", "UNIQUE_ID")
68
+ elif "unique_id" in kwargs:
69
+ self._assigned_id = kwargs.pop("unique_id")
70
+ return org_func(self, **kwargs)
71
+
72
+ return new_func
73
+
74
+
75
+ def ensure_hidden_unique_id(org_input_types_func):
76
+ original_has_unique_id = False
77
+
78
+ @wraps(org_input_types_func)
79
+ def new_input_types_func():
80
+ nonlocal original_has_unique_id
81
+
82
+ result = org_input_types_func()
83
+ if "hidden" not in result:
84
+ result["hidden"] = {"unique_id": "UNIQUE_ID"}
85
+ elif "unique_id" not in result["hidden"]:
86
+ result["hidden"]["unique_id"] = "UNIQUE_ID"
87
+ else:
88
+ original_has_unique_id = True
89
+ return result
90
+
91
+ # Ensure original_has_unique_id is correctly set before returning
92
+ new_input_types_func()
93
+ return new_input_types_func, original_has_unique_id
94
+
95
+
96
+ class BizyAirBaseNode:
97
+ FUNCTION = "default_function"
98
+
99
+ def __init_subclass__(cls, **kwargs):
100
+ if not cls.CATEGORY.startswith(f"{LOGO}{PREFIX}"):
101
+ cls.CATEGORY = f"{LOGO}{PREFIX}/{cls.CATEGORY}"
102
+ register_node(cls, PREFIX)
103
+ cls.setup_input_types()
104
+
105
+ @classmethod
106
+ def setup_input_types(cls):
107
+ # https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
108
+ new_input_types_func, original_has_unique_id = ensure_hidden_unique_id(
109
+ cls.INPUT_TYPES
110
+ )
111
+ cls.INPUT_TYPES = new_input_types_func
112
+ setattr(
113
+ cls,
114
+ cls.FUNCTION,
115
+ ensure_unique_id(getattr(cls, cls.FUNCTION), original_has_unique_id),
116
+ )
117
+
118
+ @property
119
+ def assigned_id(self):
120
+ assert self._assigned_id is not None and isinstance(self._assigned_id, str)
121
+ return str(self._assigned_id)
122
+
123
+ def default_function(self, **kwargs):
124
+ class_type = self._determine_class_type()
125
+
126
+ node_ios = self._process_non_send_request_types(class_type, kwargs)
127
+ # TODO: add processing for send_request_types
128
+ send_request_datatype_list = self._get_send_request_datatypes()
129
+ if len(send_request_datatype_list) == len(self.RETURN_TYPES):
130
+ return self._process_all_send_request_types(node_ios)
131
+ return node_ios
132
+
133
+ def _get_send_request_datatypes(self):
134
+ return [
135
+ return_type
136
+ for return_type in self.RETURN_TYPES
137
+ if is_send_request_datatype(return_type)
138
+ ]
139
+
140
+ def _determine_class_type(self):
141
+ class_type = getattr(self, "CLASS_TYPE_NAME", type(self).__name__)
142
+ if class_type.startswith(f"{PREFIX}_"):
143
+ class_type = class_type[len(PREFIX) + 1 :]
144
+ return class_type
145
+
146
+ def _process_non_send_request_types(self, class_type, kwargs):
147
+ outs = []
148
+ for slot_index, _ in enumerate(self.RETURN_TYPES):
149
+ node = BizyAirNodeIO(node_id=self.assigned_id, nodes={})
150
+ node.add_node_data(
151
+ class_type=class_type, inputs=kwargs, outputs={"slot_index": slot_index}
152
+ )
153
+ outs.append(node)
154
+ return tuple(outs)
155
+
156
+ def _process_all_send_request_types(self, node_ios: List[BizyAirNodeIO]):
157
+ out = node_ios[0].send_request()
158
+ assert len(out) == len(self.RETURN_TYPES)
159
+ return out
@@ -0,0 +1,97 @@
1
+ import warnings
2
+ from typing import Any, Dict
3
+
4
+ from .commands import invoker
5
+ from .image_utils import encode_data
6
+
7
+
8
+ def create_node_data(class_type: str, inputs: dict, outputs: dict):
9
+ assert (
10
+ outputs.get("slot_index", None) is not None
11
+ ), "outputs must contain 'slot_index'"
12
+ assert isinstance(outputs["slot_index"], int), "'slot_index' must be an integer"
13
+ assert isinstance(class_type, str)
14
+
15
+ out = {
16
+ "class_type": class_type,
17
+ "inputs": inputs,
18
+ "outputs": outputs,
19
+ }
20
+
21
+ return out
22
+
23
+
24
+ class BizyAirNodeIO:
25
+ def __init__(
26
+ self,
27
+ node_id: int = "0", # Unique identifier for the current node
28
+ nodes: Dict[str, Dict[str, any]] = {},
29
+ *args,
30
+ **kwargs,
31
+ ):
32
+ self._validate_node_id(node_id=node_id)
33
+ self.node_id = node_id
34
+ self.nodes = nodes
35
+
36
+ def _validate_node_id(self, node_id) -> bool:
37
+ if node_id is None:
38
+ raise ValueError("Node ID cannot be None.")
39
+ if not isinstance(node_id, str):
40
+ raise ValueError("Node ID must be a string.")
41
+ if not node_id.isdigit():
42
+ raise ValueError(
43
+ "Node ID must be a string that can be converted to an integer."
44
+ )
45
+ return True
46
+
47
+ def copy(self, new_node_id: str = None):
48
+ self._validate_node_id(new_node_id)
49
+ if new_node_id in self.nodes:
50
+ raise ValueError(f"Node ID '{new_node_id}' already exists.")
51
+
52
+ return BizyAirNodeIO(
53
+ nodes=self.nodes.copy(),
54
+ node_id=new_node_id,
55
+ )
56
+
57
+ def add_node_data(
58
+ self,
59
+ class_type: str,
60
+ inputs: Dict[str, Any],
61
+ outputs: Dict[str, Any] = {"slot_index": 0},
62
+ ):
63
+ node_data = create_node_data(
64
+ class_type=class_type,
65
+ inputs=inputs,
66
+ outputs=outputs,
67
+ )
68
+
69
+ self.update_nodes_from_others(*inputs.values())
70
+
71
+ if self.node_id in self.nodes:
72
+ warnings.warn(
73
+ f"Node ID {self.node_id} already exists. Data will be overwritten.",
74
+ RuntimeWarning,
75
+ )
76
+
77
+ self.nodes[self.node_id] = node_data
78
+
79
+ def update_nodes_from_others(self, *others):
80
+ for other in others:
81
+ if isinstance(other, BizyAirNodeIO):
82
+ self.nodes.update(other.nodes)
83
+
84
+ def send_request(
85
+ self, url=None, headers=None, *, progress_callback=None, stream=False
86
+ ) -> any:
87
+ out = invoker.prompt_server.execute(
88
+ prompt=self.nodes, last_node_ids=[self.node_id]
89
+ )
90
+ return out
91
+
92
+
93
+ @encode_data.register(BizyAirNodeIO)
94
+ def _(output: BizyAirNodeIO, **kwargs):
95
+ origin_id = output.node_id
96
+ origin_slot = output.nodes[origin_id]["outputs"]["slot_index"]
97
+ return [origin_id, origin_slot]
@@ -0,0 +1,9 @@
1
+ from .path_manager import (
2
+ convert_prompt_label_path_to_real_path,
3
+ disable_refresh_options,
4
+ enable_refresh_options,
5
+ get_filename_list,
6
+ guess_config,
7
+ guess_url_from_node,
8
+ )
9
+ from .utils import filter_files_extensions
@@ -0,0 +1,276 @@
1
+ import copy
2
+ import json
3
+ import os
4
+ import pprint
5
+ import re
6
+ import warnings
7
+ from collections import defaultdict
8
+ from dataclasses import dataclass
9
+ from functools import lru_cache
10
+ from typing import Any, Collection, Dict, List, Union
11
+
12
+ from bizyengine.core.common import fetch_models_by_type
13
+ from bizyengine.core.common.env_var import BIZYAIR_DEBUG, BIZYAIR_SERVER_ADDRESS
14
+ from bizyengine.core.configs.conf import ModelRule, config_manager
15
+ from bizyengine.core.path_utils.utils import (
16
+ filter_files_extensions,
17
+ get_service_route,
18
+ load_yaml_config,
19
+ )
20
+
21
+ supported_pt_extensions: set[str] = {
22
+ ".ckpt",
23
+ ".pt",
24
+ ".bin",
25
+ ".pth",
26
+ ".safetensors",
27
+ ".pkl",
28
+ ".sft",
29
+ }
30
+ ScanPathType = list[str]
31
+ folder_names_and_paths: dict[str, ScanPathType] = defaultdict(list)
32
+ filename_path_mapping: dict[str, dict[str, str]] = {}
33
+
34
+
35
+ @dataclass
36
+ class RefreshSettings:
37
+ loras: bool = True
38
+ controlnet: bool = True
39
+
40
+ def get(self, folder_name: str, default: bool = True):
41
+ return getattr(self, folder_name, default)
42
+
43
+ def set(self, folder_name: str, value: bool):
44
+ setattr(self, folder_name, value)
45
+
46
+
47
+ refresh_settings = RefreshSettings()
48
+
49
+
50
+ def enable_refresh_options(folder_names: Union[str, list[str]]):
51
+ if isinstance(folder_names, str):
52
+ folder_names = [folder_names]
53
+ for folder_name in folder_names:
54
+ refresh_settings.set(folder_name, True)
55
+
56
+
57
+ def disable_refresh_options(folder_names: Union[str, list[str]]):
58
+ if isinstance(folder_names, str):
59
+ folder_names = [folder_names]
60
+ for folder_name in folder_names:
61
+ refresh_settings.set(folder_name, False)
62
+
63
+
64
+ def _get_config_path():
65
+ src_bizyair_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
66
+ configs_path = os.path.join(src_bizyair_path, "configs")
67
+ return configs_path
68
+
69
+
70
+ configs_path = _get_config_path()
71
+
72
+ models_config: Dict[str, Dict[str, Any]] = load_yaml_config(
73
+ os.path.join(configs_path, "models.yaml")
74
+ )
75
+
76
+
77
+ def guess_url_from_node(
78
+ node: Dict[str, Dict[str, Any]], class_type_table: Dict[str, bool]
79
+ ) -> Union[List[ModelRule], None]:
80
+ rules: List[ModelRule] = config_manager.get_rules(node["class_type"])
81
+ out = [
82
+ rule
83
+ for rule in rules
84
+ if len(rule.inputs) == 0
85
+ or all(
86
+ any(re.search(p, node["inputs"][key]) is not None for p in patterns)
87
+ for key, patterns in rule.inputs.items()
88
+ )
89
+ ]
90
+ return out
91
+
92
+
93
+ def guess_config(
94
+ *,
95
+ ckpt_name: str = None,
96
+ unet_name: str = None,
97
+ vae_name: str = None,
98
+ clip_name: str = None,
99
+ ) -> str:
100
+ warnings.warn("The interface has changed, please do not use it", DeprecationWarning)
101
+
102
+
103
+ def get_config_file_list(base_path=None) -> list:
104
+ if base_path is None:
105
+ base_path = os.path.dirname(os.path.abspath(__file__))
106
+ path = os.path.join(configs_path)
107
+ extensions = ".yaml"
108
+ config_files = []
109
+ for root, _, files in os.walk(path):
110
+ for file in files:
111
+ if file.endswith(extensions):
112
+ file_path = os.path.join(root, file)
113
+ config_files.append(file_path)
114
+ return config_files
115
+
116
+
117
+ def cached_filename_list(
118
+ folder_name: str, *, share_id: str = None, verbose=False, refresh=False
119
+ ) -> list[str]:
120
+ global filename_path_mapping
121
+ if refresh or folder_name not in filename_path_mapping:
122
+ model_types: Dict[str, str] = models_config["model_types"]
123
+ if share_id:
124
+ url = f"{BIZYAIR_SERVER_ADDRESS}/{share_id}/models/files"
125
+ else:
126
+ url = get_service_route(models_config["model_hub"]["find_model"])
127
+ msg = fetch_models_by_type(
128
+ url=url, method="GET", model_type=model_types[folder_name]
129
+ )
130
+ if verbose:
131
+ pprint.pprint({"cached_filename_list": msg})
132
+
133
+ try:
134
+ if not msg or "data" not in msg or msg["data"] is None:
135
+ return []
136
+
137
+ filename_path_mapping[folder_name] = {
138
+ x["label_path"]: x["real_path"]
139
+ for x in msg["data"]["files"]
140
+ if x["label_path"]
141
+ }
142
+ except Exception as e:
143
+ warnings.warn(f"Failed to get filename list: {e}")
144
+ return []
145
+ finally:
146
+ # TODO fix share_id vaild refresh settings
147
+ if share_id is None:
148
+ disable_refresh_options(folder_name)
149
+
150
+ return list(
151
+ filter_files_extensions(
152
+ filename_path_mapping[folder_name].keys(),
153
+ extensions=supported_pt_extensions,
154
+ )
155
+ )
156
+
157
+
158
+ def convert_prompt_label_path_to_real_path(prompt: dict[str, dict[str, any]]) -> dict:
159
+ # TODO fix Temporarily write dead
160
+ new_prompt = {}
161
+ for unique_id in prompt:
162
+ new_prompt[unique_id] = copy.copy(prompt[unique_id])
163
+ inputs = copy.copy(prompt[unique_id]["inputs"])
164
+
165
+ for key, folder_name in [
166
+ ("lora_name", "loras"),
167
+ ("control_net_name", "controlnet"),
168
+ ]:
169
+ if key in inputs:
170
+ value = inputs[key]
171
+ new_value = filename_path_mapping.get(folder_name, {}).get(value, None)
172
+ if new_value:
173
+ inputs[key] = new_value
174
+ else:
175
+ file_list = get_filename_list(folder_name)
176
+ if value not in file_list:
177
+ raise ValueError(
178
+ f"{key} '{value}' not found in file list. Available {key} names: {', '.join(file_list)}"
179
+ )
180
+
181
+ new_prompt[unique_id]["inputs"] = inputs
182
+ return new_prompt
183
+
184
+
185
+ def get_share_filename_list(folder_name, share_id, *, verbose=BIZYAIR_DEBUG):
186
+ assert share_id is not None and isinstance(share_id, str)
187
+ # TODO fix share_id vaild refresh settings
188
+ return cached_filename_list(
189
+ folder_name, share_id=share_id, verbose=verbose, refresh=True
190
+ )
191
+
192
+
193
+ def get_filename_list(folder_name, *, verbose=BIZYAIR_DEBUG):
194
+
195
+ global folder_names_and_paths
196
+ results = folder_names_and_paths.get(folder_name, [])
197
+ # 社区node上线后移除
198
+ # if folder_name in models_config["model_types"]:
199
+ # refresh = refresh_settings.get(folder_name, True)
200
+ # results.extend(
201
+ # cached_filename_list(folder_name, verbose=verbose, refresh=refresh)
202
+ # )
203
+ # if folder_name in folder_names_and_paths:
204
+ # results.extend(folder_names_and_paths[folder_name])
205
+ # if BIZYAIR_DEBUG:
206
+ # try:
207
+ # import folder_paths
208
+ #
209
+ # results.extend(folder_paths.get_filename_list(folder_name))
210
+ # except:
211
+ # pass
212
+ return results
213
+
214
+
215
+ def filter_files_extensions(
216
+ files: Collection[str], extensions: Collection[str]
217
+ ) -> list[str]:
218
+ return sorted(
219
+ list(
220
+ filter(
221
+ lambda a: os.path.splitext(a)[-1].lower() in extensions
222
+ or len(extensions) == 0,
223
+ files,
224
+ )
225
+ )
226
+ )
227
+
228
+
229
+ def recursive_extract_models(data: Any, prefix_path: str = "") -> List[str]:
230
+ def merge_paths(base_path: str, new_path: Any) -> str:
231
+ if not isinstance(new_path, str):
232
+ return base_path
233
+ return f"{base_path}/{new_path}" if base_path else new_path
234
+
235
+ results: List[str] = []
236
+ if isinstance(data, dict):
237
+ for key, value in data.items():
238
+ new_prefix = merge_paths(prefix_path, key)
239
+ results.extend(recursive_extract_models(value, new_prefix))
240
+ elif isinstance(data, list):
241
+ for item in data:
242
+ new_prefix = merge_paths(prefix_path, item)
243
+ results.extend(recursive_extract_models(item, new_prefix))
244
+ elif isinstance(data, str) and prefix_path.endswith(data):
245
+ return filter_files_extensions([prefix_path], supported_pt_extensions)
246
+
247
+ return results
248
+
249
+
250
+ def load_json(file_path: str) -> dict:
251
+ with open(file_path, "r") as file:
252
+ data = json.load(file)
253
+ return data
254
+
255
+
256
+ def init_config():
257
+ global folder_names_and_paths
258
+ for k, filenames in config_manager.model_path_manager.model_paths.items():
259
+ folder_names_and_paths[k].extend(filenames)
260
+ if BIZYAIR_DEBUG:
261
+ pprint.pprint("=" * 20 + "init_config: " + "=" * 20)
262
+ pprint.pprint(folder_names_and_paths)
263
+
264
+
265
+ init_config()
266
+
267
+
268
+ if __name__ == "__main__":
269
+ # print(f"Loaded config from {get_config_file_list()}")
270
+ # configs = [load_yaml_config(x) for x in get_config_file_list()]
271
+ # print(get_filename_list("clip_vision"))
272
+ # print(folder_names_and_paths)
273
+
274
+ api_key = os.getenv("BIZYAIR_API_KEY", "")
275
+ host_ckpts = get_filename_list("loras", verbose=True)
276
+ print(host_ckpts)
@@ -0,0 +1,34 @@
1
+ import os
2
+ from collections.abc import Collection
3
+ from typing import Dict, Union
4
+
5
+ import yaml
6
+ from bizyengine.core.common.env_var import BIZYAIR_SERVER_ADDRESS
7
+
8
+
9
+ def filter_files_extensions(
10
+ files: Collection[str], extensions: Collection[str]
11
+ ) -> list[str]:
12
+ return sorted(
13
+ list(
14
+ filter(
15
+ lambda a: os.path.splitext(a)[-1].lower() in extensions
16
+ or len(extensions) == 0,
17
+ files,
18
+ )
19
+ )
20
+ )
21
+
22
+
23
+ def load_yaml_config(file_path):
24
+ with open(file_path, "r") as file:
25
+ config = yaml.safe_load(file)
26
+ return config
27
+
28
+
29
+ def get_service_route(service_config: Dict[str, str]) -> Union[str, None]:
30
+ if {"route"}.issubset(service_config):
31
+ return str(
32
+ service_config.get("service_address", BIZYAIR_SERVER_ADDRESS)
33
+ ) + service_config.get("route")
34
+ return None
File without changes
@@ -0,0 +1,83 @@
1
+ import os
2
+ import uuid
3
+ from pathlib import Path
4
+
5
+ import bizyengine.core
6
+ import server
7
+ from aiohttp import web
8
+ from bizyengine.core.common import create_api_key_file, load_api_key, validate_api_key
9
+
10
+ API_KEY = None
11
+ # html_file_path = Path(os.path.dirname(os.path.abspath(__file__))) / "set_api_key.html"
12
+ # with open(html_file_path, "r", encoding="utf-8") as htmlfile:
13
+ # set_api_key_html = htmlfile.read()
14
+
15
+
16
+ has_key, api_key = load_api_key()
17
+ if has_key:
18
+ API_KEY = api_key
19
+
20
+
21
+ # @server.PromptServer.instance.routes.get("/bizyair/set-api-key")
22
+ # async def set_api_key_page(request):
23
+ # return web.Response(text=set_api_key_html, content_type="text/html")
24
+
25
+
26
+ @server.PromptServer.instance.routes.post("/bizyair/set_api_key")
27
+ async def set_api_key(request):
28
+ global API_KEY
29
+ try:
30
+ data = await request.post()
31
+ api_key = data.get("api_key")
32
+ if api_key:
33
+ if not validate_api_key(api_key):
34
+ error_msg = "Wrong API key provided, please refer to cloud.siliconflow.cn to get the key"
35
+ print("set_api_key:", error_msg)
36
+ return web.Response(
37
+ text=error_msg,
38
+ status=400,
39
+ )
40
+ create_api_key_file(api_key)
41
+ API_KEY = api_key
42
+ bizyengine.core.set_api_key(API_KEY, override=True)
43
+ print("Set the key sucessfully.")
44
+ return web.Response(text="ok")
45
+ else:
46
+ error_msg = "No API key provided, please refer to cloud.siliconflow.cn to get the key"
47
+ print("set_api_key:", error_msg)
48
+ return web.Response(
49
+ text=error_msg,
50
+ status=400,
51
+ )
52
+ except Exception as e:
53
+ print(f"set api key error: {str(e)}")
54
+ return web.Response(text=str(e), status=500)
55
+
56
+
57
+ @server.PromptServer.instance.routes.get("/bizyair/get_api_key")
58
+ async def get_api_key(request):
59
+ global API_KEY
60
+ try:
61
+ has_key, api_key = load_api_key()
62
+ if has_key:
63
+ API_KEY = api_key
64
+ bizyengine.core.set_api_key(API_KEY)
65
+ return web.Response(text="Key has been loaded from the api_key.ini file")
66
+ else:
67
+ api_key = request.cookies.get("api_key")
68
+ if not api_key:
69
+ print("No api key found in cookies")
70
+ return web.Response(
71
+ text="No api key found in cookies, please refer to cloud.siliconflow.cn to get the key",
72
+ status=404,
73
+ )
74
+ API_KEY = api_key
75
+ bizyengine.core.set_api_key(API_KEY)
76
+ return web.Response(text="Key has been loaded from the cookies")
77
+
78
+ except Exception as e:
79
+ return web.Response(text=str(e), status=500)
80
+
81
+
82
+ NODE_CLASS_MAPPINGS = {}
83
+ NODE_DISPLAY_NAME_MAPPINGS = {}