bizyengine 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bizyengine/__init__.py +35 -0
- bizyengine/bizy_server/__init__.py +7 -0
- bizyengine/bizy_server/api_client.py +763 -0
- bizyengine/bizy_server/errno.py +122 -0
- bizyengine/bizy_server/error_handler.py +3 -0
- bizyengine/bizy_server/execution.py +55 -0
- bizyengine/bizy_server/resp.py +24 -0
- bizyengine/bizy_server/server.py +898 -0
- bizyengine/bizy_server/utils.py +93 -0
- bizyengine/bizyair_extras/__init__.py +24 -0
- bizyengine/bizyair_extras/nodes_advanced_refluxcontrol.py +62 -0
- bizyengine/bizyair_extras/nodes_cogview4.py +31 -0
- bizyengine/bizyair_extras/nodes_comfyui_detail_daemon.py +180 -0
- bizyengine/bizyair_extras/nodes_comfyui_instantid.py +164 -0
- bizyengine/bizyair_extras/nodes_comfyui_layerstyle_advance.py +141 -0
- bizyengine/bizyair_extras/nodes_comfyui_pulid_flux.py +88 -0
- bizyengine/bizyair_extras/nodes_controlnet.py +50 -0
- bizyengine/bizyair_extras/nodes_custom_sampler.py +130 -0
- bizyengine/bizyair_extras/nodes_dataset.py +99 -0
- bizyengine/bizyair_extras/nodes_differential_diffusion.py +16 -0
- bizyengine/bizyair_extras/nodes_flux.py +69 -0
- bizyengine/bizyair_extras/nodes_image_utils.py +93 -0
- bizyengine/bizyair_extras/nodes_ip2p.py +20 -0
- bizyengine/bizyair_extras/nodes_ipadapter_plus/__init__.py +1 -0
- bizyengine/bizyair_extras/nodes_ipadapter_plus/nodes_ipadapter_plus.py +1598 -0
- bizyengine/bizyair_extras/nodes_janus_pro.py +81 -0
- bizyengine/bizyair_extras/nodes_kolors_mz/__init__.py +86 -0
- bizyengine/bizyair_extras/nodes_model_advanced.py +62 -0
- bizyengine/bizyair_extras/nodes_sd3.py +52 -0
- bizyengine/bizyair_extras/nodes_segment_anything.py +256 -0
- bizyengine/bizyair_extras/nodes_segment_anything_utils.py +134 -0
- bizyengine/bizyair_extras/nodes_testing_utils.py +139 -0
- bizyengine/bizyair_extras/nodes_trellis.py +199 -0
- bizyengine/bizyair_extras/nodes_ultimatesdupscale.py +137 -0
- bizyengine/bizyair_extras/nodes_upscale_model.py +32 -0
- bizyengine/bizyair_extras/nodes_wan_video.py +49 -0
- bizyengine/bizyair_extras/oauth_callback/main.py +118 -0
- bizyengine/core/__init__.py +8 -0
- bizyengine/core/commands/__init__.py +1 -0
- bizyengine/core/commands/base.py +27 -0
- bizyengine/core/commands/invoker.py +4 -0
- bizyengine/core/commands/processors/model_hosting_processor.py +0 -0
- bizyengine/core/commands/processors/prompt_processor.py +123 -0
- bizyengine/core/commands/servers/model_server.py +0 -0
- bizyengine/core/commands/servers/prompt_server.py +234 -0
- bizyengine/core/common/__init__.py +8 -0
- bizyengine/core/common/caching.py +198 -0
- bizyengine/core/common/client.py +262 -0
- bizyengine/core/common/env_var.py +101 -0
- bizyengine/core/common/utils.py +93 -0
- bizyengine/core/configs/conf.py +112 -0
- bizyengine/core/configs/models.json +101 -0
- bizyengine/core/configs/models.yaml +329 -0
- bizyengine/core/data_types.py +20 -0
- bizyengine/core/image_utils.py +288 -0
- bizyengine/core/nodes_base.py +159 -0
- bizyengine/core/nodes_io.py +97 -0
- bizyengine/core/path_utils/__init__.py +9 -0
- bizyengine/core/path_utils/path_manager.py +276 -0
- bizyengine/core/path_utils/utils.py +34 -0
- bizyengine/misc/__init__.py +0 -0
- bizyengine/misc/auth.py +83 -0
- bizyengine/misc/llm.py +431 -0
- bizyengine/misc/mzkolors.py +93 -0
- bizyengine/misc/nodes.py +1208 -0
- bizyengine/misc/nodes_controlnet_aux.py +491 -0
- bizyengine/misc/nodes_controlnet_union_sdxl.py +171 -0
- bizyengine/misc/route_sam.py +60 -0
- bizyengine/misc/segment_anything.py +276 -0
- bizyengine/misc/supernode.py +182 -0
- bizyengine/misc/utils.py +218 -0
- bizyengine/version.txt +1 -0
- bizyengine-0.4.2.dist-info/METADATA +12 -0
- bizyengine-0.4.2.dist-info/RECORD +76 -0
- bizyengine-0.4.2.dist-info/WHEEL +5 -0
- bizyengine-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import logging
|
|
3
|
+
import warnings
|
|
4
|
+
from functools import wraps
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
from .data_types import is_send_request_datatype
|
|
8
|
+
from .nodes_io import BizyAirNodeIO, create_node_data
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
comfy_nodes = importlib.import_module("nodes")
|
|
12
|
+
except ModuleNotFoundError:
|
|
13
|
+
warnings.warn("Importing comfyui.nodes failed!")
|
|
14
|
+
comfy_nodes = type("nodes", (object,), {"NODE_DISPLAY_NAME_MAPPINGS": {}})
|
|
15
|
+
|
|
16
|
+
logging.basicConfig(level=logging.INFO)
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
LOGO = "☁️"
|
|
20
|
+
PREFIX = f"BizyAir"
|
|
21
|
+
NODE_CLASS_MAPPINGS = {}
|
|
22
|
+
NODE_DISPLAY_NAME_MAPPINGS = {}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def to_camel_case(s):
|
|
26
|
+
return "".join(word.capitalize() for word in s.split("_"))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def validate_category(cls, prefix):
|
|
30
|
+
assert cls.CATEGORY.startswith(f"☁️{prefix}")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def register_node(cls, prefix):
|
|
34
|
+
class_name = (
|
|
35
|
+
f"{prefix}_{cls.__name__}"
|
|
36
|
+
if not cls.__name__.startswith(prefix)
|
|
37
|
+
else cls.__name__
|
|
38
|
+
)
|
|
39
|
+
logger.debug(
|
|
40
|
+
f"Class: {cls}, Name: {class_name}, Has DISPLAY_NAME: {hasattr(cls, 'NODE_DISPLAY_NAME')}"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
if hasattr(cls, "NODE_DISPLAY_NAME"):
|
|
44
|
+
display_name = cls.NODE_DISPLAY_NAME
|
|
45
|
+
if not display_name.startswith(f"{LOGO}{prefix}"):
|
|
46
|
+
display_name = f"{LOGO}{prefix} {display_name}"
|
|
47
|
+
else:
|
|
48
|
+
base_name = class_name[len(prefix) + 1 :]
|
|
49
|
+
if base_name in comfy_nodes.NODE_DISPLAY_NAME_MAPPINGS:
|
|
50
|
+
display_name = (
|
|
51
|
+
f"{LOGO}{prefix} {comfy_nodes.NODE_DISPLAY_NAME_MAPPINGS[base_name]}"
|
|
52
|
+
)
|
|
53
|
+
else:
|
|
54
|
+
display_name = f"{LOGO}{prefix} {base_name}"
|
|
55
|
+
logger.warning(
|
|
56
|
+
f"Display name '{display_name}' might differ from the native display name."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
NODE_CLASS_MAPPINGS[class_name] = cls
|
|
60
|
+
NODE_DISPLAY_NAME_MAPPINGS[class_name] = display_name
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def ensure_unique_id(org_func, original_has_unique_id=False):
|
|
64
|
+
@wraps(org_func)
|
|
65
|
+
def new_func(self, **kwargs):
|
|
66
|
+
if original_has_unique_id:
|
|
67
|
+
self._assigned_id = kwargs.get("unique_id", "UNIQUE_ID")
|
|
68
|
+
elif "unique_id" in kwargs:
|
|
69
|
+
self._assigned_id = kwargs.pop("unique_id")
|
|
70
|
+
return org_func(self, **kwargs)
|
|
71
|
+
|
|
72
|
+
return new_func
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def ensure_hidden_unique_id(org_input_types_func):
|
|
76
|
+
original_has_unique_id = False
|
|
77
|
+
|
|
78
|
+
@wraps(org_input_types_func)
|
|
79
|
+
def new_input_types_func():
|
|
80
|
+
nonlocal original_has_unique_id
|
|
81
|
+
|
|
82
|
+
result = org_input_types_func()
|
|
83
|
+
if "hidden" not in result:
|
|
84
|
+
result["hidden"] = {"unique_id": "UNIQUE_ID"}
|
|
85
|
+
elif "unique_id" not in result["hidden"]:
|
|
86
|
+
result["hidden"]["unique_id"] = "UNIQUE_ID"
|
|
87
|
+
else:
|
|
88
|
+
original_has_unique_id = True
|
|
89
|
+
return result
|
|
90
|
+
|
|
91
|
+
# Ensure original_has_unique_id is correctly set before returning
|
|
92
|
+
new_input_types_func()
|
|
93
|
+
return new_input_types_func, original_has_unique_id
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class BizyAirBaseNode:
|
|
97
|
+
FUNCTION = "default_function"
|
|
98
|
+
|
|
99
|
+
def __init_subclass__(cls, **kwargs):
|
|
100
|
+
if not cls.CATEGORY.startswith(f"{LOGO}{PREFIX}"):
|
|
101
|
+
cls.CATEGORY = f"{LOGO}{PREFIX}/{cls.CATEGORY}"
|
|
102
|
+
register_node(cls, PREFIX)
|
|
103
|
+
cls.setup_input_types()
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def setup_input_types(cls):
|
|
107
|
+
# https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
|
|
108
|
+
new_input_types_func, original_has_unique_id = ensure_hidden_unique_id(
|
|
109
|
+
cls.INPUT_TYPES
|
|
110
|
+
)
|
|
111
|
+
cls.INPUT_TYPES = new_input_types_func
|
|
112
|
+
setattr(
|
|
113
|
+
cls,
|
|
114
|
+
cls.FUNCTION,
|
|
115
|
+
ensure_unique_id(getattr(cls, cls.FUNCTION), original_has_unique_id),
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def assigned_id(self):
|
|
120
|
+
assert self._assigned_id is not None and isinstance(self._assigned_id, str)
|
|
121
|
+
return str(self._assigned_id)
|
|
122
|
+
|
|
123
|
+
def default_function(self, **kwargs):
|
|
124
|
+
class_type = self._determine_class_type()
|
|
125
|
+
|
|
126
|
+
node_ios = self._process_non_send_request_types(class_type, kwargs)
|
|
127
|
+
# TODO: add processing for send_request_types
|
|
128
|
+
send_request_datatype_list = self._get_send_request_datatypes()
|
|
129
|
+
if len(send_request_datatype_list) == len(self.RETURN_TYPES):
|
|
130
|
+
return self._process_all_send_request_types(node_ios)
|
|
131
|
+
return node_ios
|
|
132
|
+
|
|
133
|
+
def _get_send_request_datatypes(self):
|
|
134
|
+
return [
|
|
135
|
+
return_type
|
|
136
|
+
for return_type in self.RETURN_TYPES
|
|
137
|
+
if is_send_request_datatype(return_type)
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
def _determine_class_type(self):
|
|
141
|
+
class_type = getattr(self, "CLASS_TYPE_NAME", type(self).__name__)
|
|
142
|
+
if class_type.startswith(f"{PREFIX}_"):
|
|
143
|
+
class_type = class_type[len(PREFIX) + 1 :]
|
|
144
|
+
return class_type
|
|
145
|
+
|
|
146
|
+
def _process_non_send_request_types(self, class_type, kwargs):
|
|
147
|
+
outs = []
|
|
148
|
+
for slot_index, _ in enumerate(self.RETURN_TYPES):
|
|
149
|
+
node = BizyAirNodeIO(node_id=self.assigned_id, nodes={})
|
|
150
|
+
node.add_node_data(
|
|
151
|
+
class_type=class_type, inputs=kwargs, outputs={"slot_index": slot_index}
|
|
152
|
+
)
|
|
153
|
+
outs.append(node)
|
|
154
|
+
return tuple(outs)
|
|
155
|
+
|
|
156
|
+
def _process_all_send_request_types(self, node_ios: List[BizyAirNodeIO]):
|
|
157
|
+
out = node_ios[0].send_request()
|
|
158
|
+
assert len(out) == len(self.RETURN_TYPES)
|
|
159
|
+
return out
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
from .commands import invoker
|
|
5
|
+
from .image_utils import encode_data
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def create_node_data(class_type: str, inputs: dict, outputs: dict):
|
|
9
|
+
assert (
|
|
10
|
+
outputs.get("slot_index", None) is not None
|
|
11
|
+
), "outputs must contain 'slot_index'"
|
|
12
|
+
assert isinstance(outputs["slot_index"], int), "'slot_index' must be an integer"
|
|
13
|
+
assert isinstance(class_type, str)
|
|
14
|
+
|
|
15
|
+
out = {
|
|
16
|
+
"class_type": class_type,
|
|
17
|
+
"inputs": inputs,
|
|
18
|
+
"outputs": outputs,
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
return out
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BizyAirNodeIO:
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
node_id: int = "0", # Unique identifier for the current node
|
|
28
|
+
nodes: Dict[str, Dict[str, any]] = {},
|
|
29
|
+
*args,
|
|
30
|
+
**kwargs,
|
|
31
|
+
):
|
|
32
|
+
self._validate_node_id(node_id=node_id)
|
|
33
|
+
self.node_id = node_id
|
|
34
|
+
self.nodes = nodes
|
|
35
|
+
|
|
36
|
+
def _validate_node_id(self, node_id) -> bool:
|
|
37
|
+
if node_id is None:
|
|
38
|
+
raise ValueError("Node ID cannot be None.")
|
|
39
|
+
if not isinstance(node_id, str):
|
|
40
|
+
raise ValueError("Node ID must be a string.")
|
|
41
|
+
if not node_id.isdigit():
|
|
42
|
+
raise ValueError(
|
|
43
|
+
"Node ID must be a string that can be converted to an integer."
|
|
44
|
+
)
|
|
45
|
+
return True
|
|
46
|
+
|
|
47
|
+
def copy(self, new_node_id: str = None):
|
|
48
|
+
self._validate_node_id(new_node_id)
|
|
49
|
+
if new_node_id in self.nodes:
|
|
50
|
+
raise ValueError(f"Node ID '{new_node_id}' already exists.")
|
|
51
|
+
|
|
52
|
+
return BizyAirNodeIO(
|
|
53
|
+
nodes=self.nodes.copy(),
|
|
54
|
+
node_id=new_node_id,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def add_node_data(
|
|
58
|
+
self,
|
|
59
|
+
class_type: str,
|
|
60
|
+
inputs: Dict[str, Any],
|
|
61
|
+
outputs: Dict[str, Any] = {"slot_index": 0},
|
|
62
|
+
):
|
|
63
|
+
node_data = create_node_data(
|
|
64
|
+
class_type=class_type,
|
|
65
|
+
inputs=inputs,
|
|
66
|
+
outputs=outputs,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
self.update_nodes_from_others(*inputs.values())
|
|
70
|
+
|
|
71
|
+
if self.node_id in self.nodes:
|
|
72
|
+
warnings.warn(
|
|
73
|
+
f"Node ID {self.node_id} already exists. Data will be overwritten.",
|
|
74
|
+
RuntimeWarning,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
self.nodes[self.node_id] = node_data
|
|
78
|
+
|
|
79
|
+
def update_nodes_from_others(self, *others):
|
|
80
|
+
for other in others:
|
|
81
|
+
if isinstance(other, BizyAirNodeIO):
|
|
82
|
+
self.nodes.update(other.nodes)
|
|
83
|
+
|
|
84
|
+
def send_request(
|
|
85
|
+
self, url=None, headers=None, *, progress_callback=None, stream=False
|
|
86
|
+
) -> any:
|
|
87
|
+
out = invoker.prompt_server.execute(
|
|
88
|
+
prompt=self.nodes, last_node_ids=[self.node_id]
|
|
89
|
+
)
|
|
90
|
+
return out
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@encode_data.register(BizyAirNodeIO)
|
|
94
|
+
def _(output: BizyAirNodeIO, **kwargs):
|
|
95
|
+
origin_id = output.node_id
|
|
96
|
+
origin_slot = output.nodes[origin_id]["outputs"]["slot_index"]
|
|
97
|
+
return [origin_id, origin_slot]
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import pprint
|
|
5
|
+
import re
|
|
6
|
+
import warnings
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from functools import lru_cache
|
|
10
|
+
from typing import Any, Collection, Dict, List, Union
|
|
11
|
+
|
|
12
|
+
from bizyengine.core.common import fetch_models_by_type
|
|
13
|
+
from bizyengine.core.common.env_var import BIZYAIR_DEBUG, BIZYAIR_SERVER_ADDRESS
|
|
14
|
+
from bizyengine.core.configs.conf import ModelRule, config_manager
|
|
15
|
+
from bizyengine.core.path_utils.utils import (
|
|
16
|
+
filter_files_extensions,
|
|
17
|
+
get_service_route,
|
|
18
|
+
load_yaml_config,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
supported_pt_extensions: set[str] = {
|
|
22
|
+
".ckpt",
|
|
23
|
+
".pt",
|
|
24
|
+
".bin",
|
|
25
|
+
".pth",
|
|
26
|
+
".safetensors",
|
|
27
|
+
".pkl",
|
|
28
|
+
".sft",
|
|
29
|
+
}
|
|
30
|
+
ScanPathType = list[str]
|
|
31
|
+
folder_names_and_paths: dict[str, ScanPathType] = defaultdict(list)
|
|
32
|
+
filename_path_mapping: dict[str, dict[str, str]] = {}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class RefreshSettings:
|
|
37
|
+
loras: bool = True
|
|
38
|
+
controlnet: bool = True
|
|
39
|
+
|
|
40
|
+
def get(self, folder_name: str, default: bool = True):
|
|
41
|
+
return getattr(self, folder_name, default)
|
|
42
|
+
|
|
43
|
+
def set(self, folder_name: str, value: bool):
|
|
44
|
+
setattr(self, folder_name, value)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
refresh_settings = RefreshSettings()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def enable_refresh_options(folder_names: Union[str, list[str]]):
|
|
51
|
+
if isinstance(folder_names, str):
|
|
52
|
+
folder_names = [folder_names]
|
|
53
|
+
for folder_name in folder_names:
|
|
54
|
+
refresh_settings.set(folder_name, True)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def disable_refresh_options(folder_names: Union[str, list[str]]):
|
|
58
|
+
if isinstance(folder_names, str):
|
|
59
|
+
folder_names = [folder_names]
|
|
60
|
+
for folder_name in folder_names:
|
|
61
|
+
refresh_settings.set(folder_name, False)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _get_config_path():
|
|
65
|
+
src_bizyair_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
66
|
+
configs_path = os.path.join(src_bizyair_path, "configs")
|
|
67
|
+
return configs_path
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
configs_path = _get_config_path()
|
|
71
|
+
|
|
72
|
+
models_config: Dict[str, Dict[str, Any]] = load_yaml_config(
|
|
73
|
+
os.path.join(configs_path, "models.yaml")
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def guess_url_from_node(
|
|
78
|
+
node: Dict[str, Dict[str, Any]], class_type_table: Dict[str, bool]
|
|
79
|
+
) -> Union[List[ModelRule], None]:
|
|
80
|
+
rules: List[ModelRule] = config_manager.get_rules(node["class_type"])
|
|
81
|
+
out = [
|
|
82
|
+
rule
|
|
83
|
+
for rule in rules
|
|
84
|
+
if len(rule.inputs) == 0
|
|
85
|
+
or all(
|
|
86
|
+
any(re.search(p, node["inputs"][key]) is not None for p in patterns)
|
|
87
|
+
for key, patterns in rule.inputs.items()
|
|
88
|
+
)
|
|
89
|
+
]
|
|
90
|
+
return out
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def guess_config(
|
|
94
|
+
*,
|
|
95
|
+
ckpt_name: str = None,
|
|
96
|
+
unet_name: str = None,
|
|
97
|
+
vae_name: str = None,
|
|
98
|
+
clip_name: str = None,
|
|
99
|
+
) -> str:
|
|
100
|
+
warnings.warn("The interface has changed, please do not use it", DeprecationWarning)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_config_file_list(base_path=None) -> list:
|
|
104
|
+
if base_path is None:
|
|
105
|
+
base_path = os.path.dirname(os.path.abspath(__file__))
|
|
106
|
+
path = os.path.join(configs_path)
|
|
107
|
+
extensions = ".yaml"
|
|
108
|
+
config_files = []
|
|
109
|
+
for root, _, files in os.walk(path):
|
|
110
|
+
for file in files:
|
|
111
|
+
if file.endswith(extensions):
|
|
112
|
+
file_path = os.path.join(root, file)
|
|
113
|
+
config_files.append(file_path)
|
|
114
|
+
return config_files
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def cached_filename_list(
|
|
118
|
+
folder_name: str, *, share_id: str = None, verbose=False, refresh=False
|
|
119
|
+
) -> list[str]:
|
|
120
|
+
global filename_path_mapping
|
|
121
|
+
if refresh or folder_name not in filename_path_mapping:
|
|
122
|
+
model_types: Dict[str, str] = models_config["model_types"]
|
|
123
|
+
if share_id:
|
|
124
|
+
url = f"{BIZYAIR_SERVER_ADDRESS}/{share_id}/models/files"
|
|
125
|
+
else:
|
|
126
|
+
url = get_service_route(models_config["model_hub"]["find_model"])
|
|
127
|
+
msg = fetch_models_by_type(
|
|
128
|
+
url=url, method="GET", model_type=model_types[folder_name]
|
|
129
|
+
)
|
|
130
|
+
if verbose:
|
|
131
|
+
pprint.pprint({"cached_filename_list": msg})
|
|
132
|
+
|
|
133
|
+
try:
|
|
134
|
+
if not msg or "data" not in msg or msg["data"] is None:
|
|
135
|
+
return []
|
|
136
|
+
|
|
137
|
+
filename_path_mapping[folder_name] = {
|
|
138
|
+
x["label_path"]: x["real_path"]
|
|
139
|
+
for x in msg["data"]["files"]
|
|
140
|
+
if x["label_path"]
|
|
141
|
+
}
|
|
142
|
+
except Exception as e:
|
|
143
|
+
warnings.warn(f"Failed to get filename list: {e}")
|
|
144
|
+
return []
|
|
145
|
+
finally:
|
|
146
|
+
# TODO fix share_id vaild refresh settings
|
|
147
|
+
if share_id is None:
|
|
148
|
+
disable_refresh_options(folder_name)
|
|
149
|
+
|
|
150
|
+
return list(
|
|
151
|
+
filter_files_extensions(
|
|
152
|
+
filename_path_mapping[folder_name].keys(),
|
|
153
|
+
extensions=supported_pt_extensions,
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def convert_prompt_label_path_to_real_path(prompt: dict[str, dict[str, any]]) -> dict:
|
|
159
|
+
# TODO fix Temporarily write dead
|
|
160
|
+
new_prompt = {}
|
|
161
|
+
for unique_id in prompt:
|
|
162
|
+
new_prompt[unique_id] = copy.copy(prompt[unique_id])
|
|
163
|
+
inputs = copy.copy(prompt[unique_id]["inputs"])
|
|
164
|
+
|
|
165
|
+
for key, folder_name in [
|
|
166
|
+
("lora_name", "loras"),
|
|
167
|
+
("control_net_name", "controlnet"),
|
|
168
|
+
]:
|
|
169
|
+
if key in inputs:
|
|
170
|
+
value = inputs[key]
|
|
171
|
+
new_value = filename_path_mapping.get(folder_name, {}).get(value, None)
|
|
172
|
+
if new_value:
|
|
173
|
+
inputs[key] = new_value
|
|
174
|
+
else:
|
|
175
|
+
file_list = get_filename_list(folder_name)
|
|
176
|
+
if value not in file_list:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
f"{key} '{value}' not found in file list. Available {key} names: {', '.join(file_list)}"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
new_prompt[unique_id]["inputs"] = inputs
|
|
182
|
+
return new_prompt
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def get_share_filename_list(folder_name, share_id, *, verbose=BIZYAIR_DEBUG):
|
|
186
|
+
assert share_id is not None and isinstance(share_id, str)
|
|
187
|
+
# TODO fix share_id vaild refresh settings
|
|
188
|
+
return cached_filename_list(
|
|
189
|
+
folder_name, share_id=share_id, verbose=verbose, refresh=True
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def get_filename_list(folder_name, *, verbose=BIZYAIR_DEBUG):
|
|
194
|
+
|
|
195
|
+
global folder_names_and_paths
|
|
196
|
+
results = folder_names_and_paths.get(folder_name, [])
|
|
197
|
+
# 社区node上线后移除
|
|
198
|
+
# if folder_name in models_config["model_types"]:
|
|
199
|
+
# refresh = refresh_settings.get(folder_name, True)
|
|
200
|
+
# results.extend(
|
|
201
|
+
# cached_filename_list(folder_name, verbose=verbose, refresh=refresh)
|
|
202
|
+
# )
|
|
203
|
+
# if folder_name in folder_names_and_paths:
|
|
204
|
+
# results.extend(folder_names_and_paths[folder_name])
|
|
205
|
+
# if BIZYAIR_DEBUG:
|
|
206
|
+
# try:
|
|
207
|
+
# import folder_paths
|
|
208
|
+
#
|
|
209
|
+
# results.extend(folder_paths.get_filename_list(folder_name))
|
|
210
|
+
# except:
|
|
211
|
+
# pass
|
|
212
|
+
return results
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def filter_files_extensions(
|
|
216
|
+
files: Collection[str], extensions: Collection[str]
|
|
217
|
+
) -> list[str]:
|
|
218
|
+
return sorted(
|
|
219
|
+
list(
|
|
220
|
+
filter(
|
|
221
|
+
lambda a: os.path.splitext(a)[-1].lower() in extensions
|
|
222
|
+
or len(extensions) == 0,
|
|
223
|
+
files,
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def recursive_extract_models(data: Any, prefix_path: str = "") -> List[str]:
|
|
230
|
+
def merge_paths(base_path: str, new_path: Any) -> str:
|
|
231
|
+
if not isinstance(new_path, str):
|
|
232
|
+
return base_path
|
|
233
|
+
return f"{base_path}/{new_path}" if base_path else new_path
|
|
234
|
+
|
|
235
|
+
results: List[str] = []
|
|
236
|
+
if isinstance(data, dict):
|
|
237
|
+
for key, value in data.items():
|
|
238
|
+
new_prefix = merge_paths(prefix_path, key)
|
|
239
|
+
results.extend(recursive_extract_models(value, new_prefix))
|
|
240
|
+
elif isinstance(data, list):
|
|
241
|
+
for item in data:
|
|
242
|
+
new_prefix = merge_paths(prefix_path, item)
|
|
243
|
+
results.extend(recursive_extract_models(item, new_prefix))
|
|
244
|
+
elif isinstance(data, str) and prefix_path.endswith(data):
|
|
245
|
+
return filter_files_extensions([prefix_path], supported_pt_extensions)
|
|
246
|
+
|
|
247
|
+
return results
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def load_json(file_path: str) -> dict:
|
|
251
|
+
with open(file_path, "r") as file:
|
|
252
|
+
data = json.load(file)
|
|
253
|
+
return data
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def init_config():
|
|
257
|
+
global folder_names_and_paths
|
|
258
|
+
for k, filenames in config_manager.model_path_manager.model_paths.items():
|
|
259
|
+
folder_names_and_paths[k].extend(filenames)
|
|
260
|
+
if BIZYAIR_DEBUG:
|
|
261
|
+
pprint.pprint("=" * 20 + "init_config: " + "=" * 20)
|
|
262
|
+
pprint.pprint(folder_names_and_paths)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
init_config()
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
if __name__ == "__main__":
|
|
269
|
+
# print(f"Loaded config from {get_config_file_list()}")
|
|
270
|
+
# configs = [load_yaml_config(x) for x in get_config_file_list()]
|
|
271
|
+
# print(get_filename_list("clip_vision"))
|
|
272
|
+
# print(folder_names_and_paths)
|
|
273
|
+
|
|
274
|
+
api_key = os.getenv("BIZYAIR_API_KEY", "")
|
|
275
|
+
host_ckpts = get_filename_list("loras", verbose=True)
|
|
276
|
+
print(host_ckpts)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from collections.abc import Collection
|
|
3
|
+
from typing import Dict, Union
|
|
4
|
+
|
|
5
|
+
import yaml
|
|
6
|
+
from bizyengine.core.common.env_var import BIZYAIR_SERVER_ADDRESS
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def filter_files_extensions(
|
|
10
|
+
files: Collection[str], extensions: Collection[str]
|
|
11
|
+
) -> list[str]:
|
|
12
|
+
return sorted(
|
|
13
|
+
list(
|
|
14
|
+
filter(
|
|
15
|
+
lambda a: os.path.splitext(a)[-1].lower() in extensions
|
|
16
|
+
or len(extensions) == 0,
|
|
17
|
+
files,
|
|
18
|
+
)
|
|
19
|
+
)
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def load_yaml_config(file_path):
|
|
24
|
+
with open(file_path, "r") as file:
|
|
25
|
+
config = yaml.safe_load(file)
|
|
26
|
+
return config
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_service_route(service_config: Dict[str, str]) -> Union[str, None]:
|
|
30
|
+
if {"route"}.issubset(service_config):
|
|
31
|
+
return str(
|
|
32
|
+
service_config.get("service_address", BIZYAIR_SERVER_ADDRESS)
|
|
33
|
+
) + service_config.get("route")
|
|
34
|
+
return None
|
|
File without changes
|
bizyengine/misc/auth.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import uuid
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import bizyengine.core
|
|
6
|
+
import server
|
|
7
|
+
from aiohttp import web
|
|
8
|
+
from bizyengine.core.common import create_api_key_file, load_api_key, validate_api_key
|
|
9
|
+
|
|
10
|
+
API_KEY = None
|
|
11
|
+
# html_file_path = Path(os.path.dirname(os.path.abspath(__file__))) / "set_api_key.html"
|
|
12
|
+
# with open(html_file_path, "r", encoding="utf-8") as htmlfile:
|
|
13
|
+
# set_api_key_html = htmlfile.read()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
has_key, api_key = load_api_key()
|
|
17
|
+
if has_key:
|
|
18
|
+
API_KEY = api_key
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# @server.PromptServer.instance.routes.get("/bizyair/set-api-key")
|
|
22
|
+
# async def set_api_key_page(request):
|
|
23
|
+
# return web.Response(text=set_api_key_html, content_type="text/html")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@server.PromptServer.instance.routes.post("/bizyair/set_api_key")
|
|
27
|
+
async def set_api_key(request):
|
|
28
|
+
global API_KEY
|
|
29
|
+
try:
|
|
30
|
+
data = await request.post()
|
|
31
|
+
api_key = data.get("api_key")
|
|
32
|
+
if api_key:
|
|
33
|
+
if not validate_api_key(api_key):
|
|
34
|
+
error_msg = "Wrong API key provided, please refer to cloud.siliconflow.cn to get the key"
|
|
35
|
+
print("set_api_key:", error_msg)
|
|
36
|
+
return web.Response(
|
|
37
|
+
text=error_msg,
|
|
38
|
+
status=400,
|
|
39
|
+
)
|
|
40
|
+
create_api_key_file(api_key)
|
|
41
|
+
API_KEY = api_key
|
|
42
|
+
bizyengine.core.set_api_key(API_KEY, override=True)
|
|
43
|
+
print("Set the key sucessfully.")
|
|
44
|
+
return web.Response(text="ok")
|
|
45
|
+
else:
|
|
46
|
+
error_msg = "No API key provided, please refer to cloud.siliconflow.cn to get the key"
|
|
47
|
+
print("set_api_key:", error_msg)
|
|
48
|
+
return web.Response(
|
|
49
|
+
text=error_msg,
|
|
50
|
+
status=400,
|
|
51
|
+
)
|
|
52
|
+
except Exception as e:
|
|
53
|
+
print(f"set api key error: {str(e)}")
|
|
54
|
+
return web.Response(text=str(e), status=500)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@server.PromptServer.instance.routes.get("/bizyair/get_api_key")
|
|
58
|
+
async def get_api_key(request):
|
|
59
|
+
global API_KEY
|
|
60
|
+
try:
|
|
61
|
+
has_key, api_key = load_api_key()
|
|
62
|
+
if has_key:
|
|
63
|
+
API_KEY = api_key
|
|
64
|
+
bizyengine.core.set_api_key(API_KEY)
|
|
65
|
+
return web.Response(text="Key has been loaded from the api_key.ini file")
|
|
66
|
+
else:
|
|
67
|
+
api_key = request.cookies.get("api_key")
|
|
68
|
+
if not api_key:
|
|
69
|
+
print("No api key found in cookies")
|
|
70
|
+
return web.Response(
|
|
71
|
+
text="No api key found in cookies, please refer to cloud.siliconflow.cn to get the key",
|
|
72
|
+
status=404,
|
|
73
|
+
)
|
|
74
|
+
API_KEY = api_key
|
|
75
|
+
bizyengine.core.set_api_key(API_KEY)
|
|
76
|
+
return web.Response(text="Key has been loaded from the cookies")
|
|
77
|
+
|
|
78
|
+
except Exception as e:
|
|
79
|
+
return web.Response(text=str(e), status=500)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
NODE_CLASS_MAPPINGS = {}
|
|
83
|
+
NODE_DISPLAY_NAME_MAPPINGS = {}
|