bizyengine 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bizyengine/__init__.py +35 -0
- bizyengine/bizy_server/__init__.py +7 -0
- bizyengine/bizy_server/api_client.py +763 -0
- bizyengine/bizy_server/errno.py +122 -0
- bizyengine/bizy_server/error_handler.py +3 -0
- bizyengine/bizy_server/execution.py +55 -0
- bizyengine/bizy_server/resp.py +24 -0
- bizyengine/bizy_server/server.py +898 -0
- bizyengine/bizy_server/utils.py +93 -0
- bizyengine/bizyair_extras/__init__.py +24 -0
- bizyengine/bizyair_extras/nodes_advanced_refluxcontrol.py +62 -0
- bizyengine/bizyair_extras/nodes_cogview4.py +31 -0
- bizyengine/bizyair_extras/nodes_comfyui_detail_daemon.py +180 -0
- bizyengine/bizyair_extras/nodes_comfyui_instantid.py +164 -0
- bizyengine/bizyair_extras/nodes_comfyui_layerstyle_advance.py +141 -0
- bizyengine/bizyair_extras/nodes_comfyui_pulid_flux.py +88 -0
- bizyengine/bizyair_extras/nodes_controlnet.py +50 -0
- bizyengine/bizyair_extras/nodes_custom_sampler.py +130 -0
- bizyengine/bizyair_extras/nodes_dataset.py +99 -0
- bizyengine/bizyair_extras/nodes_differential_diffusion.py +16 -0
- bizyengine/bizyair_extras/nodes_flux.py +69 -0
- bizyengine/bizyair_extras/nodes_image_utils.py +93 -0
- bizyengine/bizyair_extras/nodes_ip2p.py +20 -0
- bizyengine/bizyair_extras/nodes_ipadapter_plus/__init__.py +1 -0
- bizyengine/bizyair_extras/nodes_ipadapter_plus/nodes_ipadapter_plus.py +1598 -0
- bizyengine/bizyair_extras/nodes_janus_pro.py +81 -0
- bizyengine/bizyair_extras/nodes_kolors_mz/__init__.py +86 -0
- bizyengine/bizyair_extras/nodes_model_advanced.py +62 -0
- bizyengine/bizyair_extras/nodes_sd3.py +52 -0
- bizyengine/bizyair_extras/nodes_segment_anything.py +256 -0
- bizyengine/bizyair_extras/nodes_segment_anything_utils.py +134 -0
- bizyengine/bizyair_extras/nodes_testing_utils.py +139 -0
- bizyengine/bizyair_extras/nodes_trellis.py +199 -0
- bizyengine/bizyair_extras/nodes_ultimatesdupscale.py +137 -0
- bizyengine/bizyair_extras/nodes_upscale_model.py +32 -0
- bizyengine/bizyair_extras/nodes_wan_video.py +49 -0
- bizyengine/bizyair_extras/oauth_callback/main.py +118 -0
- bizyengine/core/__init__.py +8 -0
- bizyengine/core/commands/__init__.py +1 -0
- bizyengine/core/commands/base.py +27 -0
- bizyengine/core/commands/invoker.py +4 -0
- bizyengine/core/commands/processors/model_hosting_processor.py +0 -0
- bizyengine/core/commands/processors/prompt_processor.py +123 -0
- bizyengine/core/commands/servers/model_server.py +0 -0
- bizyengine/core/commands/servers/prompt_server.py +234 -0
- bizyengine/core/common/__init__.py +8 -0
- bizyengine/core/common/caching.py +198 -0
- bizyengine/core/common/client.py +262 -0
- bizyengine/core/common/env_var.py +101 -0
- bizyengine/core/common/utils.py +93 -0
- bizyengine/core/configs/conf.py +112 -0
- bizyengine/core/configs/models.json +101 -0
- bizyengine/core/configs/models.yaml +329 -0
- bizyengine/core/data_types.py +20 -0
- bizyengine/core/image_utils.py +288 -0
- bizyengine/core/nodes_base.py +159 -0
- bizyengine/core/nodes_io.py +97 -0
- bizyengine/core/path_utils/__init__.py +9 -0
- bizyengine/core/path_utils/path_manager.py +276 -0
- bizyengine/core/path_utils/utils.py +34 -0
- bizyengine/misc/__init__.py +0 -0
- bizyengine/misc/auth.py +83 -0
- bizyengine/misc/llm.py +431 -0
- bizyengine/misc/mzkolors.py +93 -0
- bizyengine/misc/nodes.py +1208 -0
- bizyengine/misc/nodes_controlnet_aux.py +491 -0
- bizyengine/misc/nodes_controlnet_union_sdxl.py +171 -0
- bizyengine/misc/route_sam.py +60 -0
- bizyengine/misc/segment_anything.py +276 -0
- bizyengine/misc/supernode.py +182 -0
- bizyengine/misc/utils.py +218 -0
- bizyengine/version.txt +1 -0
- bizyengine-0.4.2.dist-info/METADATA +12 -0
- bizyengine-0.4.2.dist-info/RECORD +76 -0
- bizyengine-0.4.2.dist-info/WHEEL +5 -0
- bizyengine-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Command(ABC):
|
|
6
|
+
@abstractmethod
|
|
7
|
+
def execute(self):
|
|
8
|
+
raise NotImplementedError("Subclasses should implement this!")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Processor(ABC):
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def process(
|
|
14
|
+
self,
|
|
15
|
+
url: str,
|
|
16
|
+
prompt,
|
|
17
|
+
) -> Any:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def validate_input(self, input: Any) -> bool:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
|
25
|
+
if not self.validate_input(*args, **kwds):
|
|
26
|
+
raise ValueError(f"Invalid input {args=} {kwds=}")
|
|
27
|
+
return self.process(*args, **kwds)
|
|
File without changes
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections import deque
|
|
3
|
+
from typing import Any, Dict, List
|
|
4
|
+
|
|
5
|
+
from bizyengine.core.commands.base import Processor # type: ignore
|
|
6
|
+
from bizyengine.core.common import client, get_api_key
|
|
7
|
+
from bizyengine.core.common.caching import BizyAirTaskCache, CacheConfig
|
|
8
|
+
from bizyengine.core.common.env_var import (
|
|
9
|
+
BIZYAIR_DEBUG,
|
|
10
|
+
BIZYAIR_DEV_REQUEST_URL,
|
|
11
|
+
BIZYAIR_SERVER_ADDRESS,
|
|
12
|
+
)
|
|
13
|
+
from bizyengine.core.configs.conf import ModelRule, config_manager
|
|
14
|
+
from bizyengine.core.path_utils import (
|
|
15
|
+
convert_prompt_label_path_to_real_path,
|
|
16
|
+
guess_url_from_node,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def is_link(obj):
|
|
21
|
+
if not isinstance(obj, list):
|
|
22
|
+
return False
|
|
23
|
+
if len(obj) != 2:
|
|
24
|
+
return False
|
|
25
|
+
if not isinstance(obj[0], str):
|
|
26
|
+
return False
|
|
27
|
+
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
|
|
28
|
+
return False
|
|
29
|
+
return True
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
from dataclasses import dataclass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SearchServiceRouter(Processor):
|
|
36
|
+
def process(self, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str]):
|
|
37
|
+
if BIZYAIR_DEV_REQUEST_URL:
|
|
38
|
+
return BIZYAIR_DEV_REQUEST_URL
|
|
39
|
+
|
|
40
|
+
# TODO Improve distribution logic
|
|
41
|
+
queue = deque(last_node_ids)
|
|
42
|
+
visited = {key: True for key in last_node_ids}
|
|
43
|
+
results: List[ModelRule] = []
|
|
44
|
+
class_type_table = {
|
|
45
|
+
node_data["class_type"]: True for node_data in prompt.values()
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
while queue:
|
|
49
|
+
vertex = queue.popleft()
|
|
50
|
+
if BIZYAIR_DEBUG:
|
|
51
|
+
print(vertex, end="->")
|
|
52
|
+
|
|
53
|
+
rules = guess_url_from_node(prompt[vertex], class_type_table)
|
|
54
|
+
if rules:
|
|
55
|
+
results.extend(rules)
|
|
56
|
+
for _, in_data in prompt[vertex].get("inputs", {}).items():
|
|
57
|
+
if is_link(in_data):
|
|
58
|
+
neighbor = in_data[0]
|
|
59
|
+
if neighbor not in visited:
|
|
60
|
+
visited[neighbor] = True
|
|
61
|
+
queue.append(neighbor)
|
|
62
|
+
|
|
63
|
+
base_model, out_route, out_score = None, None, 0
|
|
64
|
+
for rule in results[::-1]:
|
|
65
|
+
# TODO add to config models.yaml
|
|
66
|
+
if rule.mode_type in {"unet", "vae", "checkpoint", "upscale_models"}:
|
|
67
|
+
base_model = rule.base_model
|
|
68
|
+
out_route = rule.route
|
|
69
|
+
out_score = rule.score
|
|
70
|
+
break
|
|
71
|
+
|
|
72
|
+
for rule in results:
|
|
73
|
+
if base_model is None:
|
|
74
|
+
if rule.score > out_score:
|
|
75
|
+
out_route, out_score = rule.route, rule.score
|
|
76
|
+
if rule.base_model == base_model:
|
|
77
|
+
if rule.score > out_score:
|
|
78
|
+
out_route, out_score = rule.route, rule.score
|
|
79
|
+
assert (
|
|
80
|
+
out_route is not None
|
|
81
|
+
), "Failed to find out_route, please check your prompt"
|
|
82
|
+
return f"{BIZYAIR_SERVER_ADDRESS}{out_route}"
|
|
83
|
+
|
|
84
|
+
def validate_input(
|
|
85
|
+
self, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str]
|
|
86
|
+
):
|
|
87
|
+
assert len(last_node_ids) == 1
|
|
88
|
+
return True
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class PromptProcessor(Processor):
|
|
92
|
+
def _exec_info(self, prompt: Dict[str, Dict[str, Any]]):
|
|
93
|
+
exec_info = {
|
|
94
|
+
"model_version_ids": [],
|
|
95
|
+
"api_key": get_api_key(),
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
model_version_id_prefix = config_manager.get_model_version_id_prefix()
|
|
99
|
+
for node_id, node_data in prompt.items():
|
|
100
|
+
for k, v in node_data.get("inputs", {}).items():
|
|
101
|
+
if isinstance(v, str) and v.startswith(model_version_id_prefix):
|
|
102
|
+
model_version_id = int(v[len(model_version_id_prefix) :])
|
|
103
|
+
exec_info["model_version_ids"].append(model_version_id)
|
|
104
|
+
return exec_info
|
|
105
|
+
|
|
106
|
+
def process(
|
|
107
|
+
self, url: str, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str]
|
|
108
|
+
):
|
|
109
|
+
return client.send_request(
|
|
110
|
+
url=url,
|
|
111
|
+
data=json.dumps(
|
|
112
|
+
{
|
|
113
|
+
"prompt": prompt,
|
|
114
|
+
"last_node_id": last_node_ids[0],
|
|
115
|
+
"exec_info": self._exec_info(prompt),
|
|
116
|
+
}
|
|
117
|
+
).encode("utf-8"),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def validate_input(
|
|
121
|
+
self, url: str, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str]
|
|
122
|
+
):
|
|
123
|
+
return True
|
|
File without changes
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import json
|
|
3
|
+
import pprint
|
|
4
|
+
import time
|
|
5
|
+
import traceback
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Dict, List
|
|
8
|
+
|
|
9
|
+
import comfy
|
|
10
|
+
from bizyengine.core.commands.base import Command, Processor # type: ignore
|
|
11
|
+
from bizyengine.core.common.caching import BizyAirTaskCache, CacheConfig
|
|
12
|
+
from bizyengine.core.common.client import send_request
|
|
13
|
+
from bizyengine.core.common.env_var import (
|
|
14
|
+
BIZYAIR_DEBUG,
|
|
15
|
+
BIZYAIR_DEV_GET_TASK_RESULT_SERVER,
|
|
16
|
+
BIZYAIR_SERVER_ADDRESS,
|
|
17
|
+
)
|
|
18
|
+
from bizyengine.core.common.utils import truncate_long_strings
|
|
19
|
+
from bizyengine.core.configs.conf import config_manager
|
|
20
|
+
from bizyengine.core.image_utils import decode_data, encode_data
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_task_result(task_id: str, offset: int = 0) -> dict:
|
|
24
|
+
"""
|
|
25
|
+
Get the result of a task.
|
|
26
|
+
"""
|
|
27
|
+
import requests
|
|
28
|
+
|
|
29
|
+
task_api = config_manager.get_task_api()
|
|
30
|
+
if BIZYAIR_DEV_GET_TASK_RESULT_SERVER:
|
|
31
|
+
url = f"{BIZYAIR_DEV_GET_TASK_RESULT_SERVER}{task_api.task_result_endpoint}/{task_id}"
|
|
32
|
+
else:
|
|
33
|
+
url = f"{BIZYAIR_SERVER_ADDRESS}{task_api.task_result_endpoint}/{task_id}"
|
|
34
|
+
|
|
35
|
+
if BIZYAIR_DEBUG:
|
|
36
|
+
print(f"Debug: get task result url: {url}")
|
|
37
|
+
response_json = send_request(
|
|
38
|
+
method="GET", url=url, data=json.dumps({"offset": offset}).encode("utf-8")
|
|
39
|
+
)
|
|
40
|
+
out = response_json
|
|
41
|
+
events = out.get("data", {}).get("events", [])
|
|
42
|
+
new_events = []
|
|
43
|
+
for event in events:
|
|
44
|
+
if (
|
|
45
|
+
"data" in event
|
|
46
|
+
and isinstance(event["data"], str)
|
|
47
|
+
and event["data"].startswith("https://")
|
|
48
|
+
):
|
|
49
|
+
# event["data"] = requests.get(event["data"]).json()
|
|
50
|
+
event["data"] = send_request(method="GET", url=event["data"])
|
|
51
|
+
new_events.append(event)
|
|
52
|
+
out["data"]["events"] = new_events
|
|
53
|
+
return out
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class BizyAirTask:
|
|
58
|
+
TASK_DATA_STATUS = ["PENDING", "PROCESSING", "COMPLETED"]
|
|
59
|
+
task_id: str
|
|
60
|
+
data_pool: list[dict] = field(default_factory=list)
|
|
61
|
+
data_status: str = None
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def check_inputs(inputs: dict) -> bool:
|
|
65
|
+
return (
|
|
66
|
+
inputs.get("code") == 20000
|
|
67
|
+
and inputs.get("status", False)
|
|
68
|
+
and "task_id" in inputs.get("data", {})
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def from_data(cls, inputs: dict, check_inputs: bool = True) -> "BizyAirTask":
|
|
73
|
+
if check_inputs and not cls.check_inputs(inputs):
|
|
74
|
+
raise ValueError(f"Invalid inputs: {inputs}")
|
|
75
|
+
data = inputs.get("data", {})
|
|
76
|
+
task_id = data.get("task_id", "")
|
|
77
|
+
return cls(task_id=task_id, data_pool=[], data_status="started")
|
|
78
|
+
|
|
79
|
+
def is_finished(self) -> bool:
|
|
80
|
+
if not self.data_pool:
|
|
81
|
+
return False
|
|
82
|
+
if self.data_pool[-1].get("data_status") == self.TASK_DATA_STATUS[-1]:
|
|
83
|
+
return True
|
|
84
|
+
return False
|
|
85
|
+
|
|
86
|
+
def send_request(self, offset: int = 0) -> dict:
|
|
87
|
+
if offset >= len(self.data_pool):
|
|
88
|
+
return get_task_result(self.task_id, offset)
|
|
89
|
+
else:
|
|
90
|
+
return self.data_pool[offset]
|
|
91
|
+
|
|
92
|
+
def get_data(self, offset: int = 0) -> dict:
|
|
93
|
+
if offset >= len(self.data_pool):
|
|
94
|
+
return {}
|
|
95
|
+
return self.data_pool[offset]
|
|
96
|
+
|
|
97
|
+
@staticmethod
|
|
98
|
+
def _fetch_remote_data(url: str) -> dict:
|
|
99
|
+
import requests
|
|
100
|
+
|
|
101
|
+
return requests.get(url).json()
|
|
102
|
+
|
|
103
|
+
def get_last_data(self) -> dict:
|
|
104
|
+
return self.get_data(len(self.data_pool) - 1)
|
|
105
|
+
|
|
106
|
+
def do_task_until_completed(
|
|
107
|
+
self, *, timeout: int = 600, poll_interval: float = 1
|
|
108
|
+
) -> list[dict]:
|
|
109
|
+
offset = 0
|
|
110
|
+
start_time = time.time()
|
|
111
|
+
pbar = None
|
|
112
|
+
while not self.is_finished():
|
|
113
|
+
try:
|
|
114
|
+
data = self.send_request(offset)
|
|
115
|
+
data_lst = self._extract_data_list(data)
|
|
116
|
+
self.data_pool.extend(data_lst)
|
|
117
|
+
offset += len(data_lst)
|
|
118
|
+
for data in data_lst:
|
|
119
|
+
message = data.get("data", {}).get("message", {})
|
|
120
|
+
if (
|
|
121
|
+
isinstance(message, dict)
|
|
122
|
+
and message.get("event", None) == "progress"
|
|
123
|
+
):
|
|
124
|
+
value = message["data"]["value"]
|
|
125
|
+
total = message["data"]["max"]
|
|
126
|
+
if pbar is None:
|
|
127
|
+
pbar = comfy.utils.ProgressBar(total)
|
|
128
|
+
pbar.update_absolute(value + 1, total, None)
|
|
129
|
+
except Exception as e:
|
|
130
|
+
print(f"Exception: {e}")
|
|
131
|
+
|
|
132
|
+
if time.time() - start_time > timeout:
|
|
133
|
+
raise TimeoutError(f"Timeout waiting for task {self.task_id} to finish")
|
|
134
|
+
|
|
135
|
+
time.sleep(poll_interval)
|
|
136
|
+
|
|
137
|
+
return self.data_pool
|
|
138
|
+
|
|
139
|
+
def _extract_data_list(self, data):
|
|
140
|
+
data_lst = data.get("data", {}).get("events", [])
|
|
141
|
+
if not data_lst:
|
|
142
|
+
raise ValueError(f"No data found in task {self.task_id}")
|
|
143
|
+
return data_lst
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class PromptServer(Command):
|
|
147
|
+
cache_manager: BizyAirTaskCache = BizyAirTaskCache(
|
|
148
|
+
config=CacheConfig.from_config(config_manager.get_cache_config())
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
def __init__(self, router: Processor, processor: Processor):
|
|
152
|
+
self.router = router
|
|
153
|
+
self.processor = processor
|
|
154
|
+
|
|
155
|
+
def get_task_id(self, result: Dict[str, Any]) -> str:
|
|
156
|
+
return result.get("data", {}).get("task_id", "")
|
|
157
|
+
|
|
158
|
+
def is_async_task(self, result: Dict[str, Any]) -> str:
|
|
159
|
+
"""Determine if the result indicates an asynchronous task."""
|
|
160
|
+
return (
|
|
161
|
+
result.get("code") == 20000
|
|
162
|
+
and result.get("status", False)
|
|
163
|
+
and "task_id" in result.get("data", {})
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def _get_result(self, result: Dict[str, Any], *, cache_key: str = None):
|
|
167
|
+
try:
|
|
168
|
+
response_data = result["data"]
|
|
169
|
+
if BizyAirTask.check_inputs(result):
|
|
170
|
+
self.cache_manager.set(cache_key, result)
|
|
171
|
+
bz_task = BizyAirTask.from_data(result, check_inputs=False)
|
|
172
|
+
bz_task.do_task_until_completed(timeout=10 * 60) # 10 minutes
|
|
173
|
+
last_data = bz_task.get_last_data()
|
|
174
|
+
response_data = last_data.get("data")
|
|
175
|
+
out = response_data["payload"]
|
|
176
|
+
assert out is not None, "Output payload should not be None"
|
|
177
|
+
self.cache_manager.set(cache_key, out, overwrite=True)
|
|
178
|
+
return out
|
|
179
|
+
except Exception as e:
|
|
180
|
+
self.cache_manager.delete(cache_key)
|
|
181
|
+
raise RuntimeError(f"Exception: {e}, response_data: {response_data}") from e
|
|
182
|
+
|
|
183
|
+
def execute(
|
|
184
|
+
self,
|
|
185
|
+
prompt: Dict[str, Dict[str, Any]],
|
|
186
|
+
last_node_ids: List[str],
|
|
187
|
+
*args,
|
|
188
|
+
**kwargs,
|
|
189
|
+
):
|
|
190
|
+
|
|
191
|
+
prompt = encode_data(prompt)
|
|
192
|
+
|
|
193
|
+
if BIZYAIR_DEBUG:
|
|
194
|
+
debug_info = {
|
|
195
|
+
"prompt": truncate_long_strings(prompt, 50),
|
|
196
|
+
"last_node_ids": last_node_ids,
|
|
197
|
+
}
|
|
198
|
+
pprint.pprint(debug_info, indent=4)
|
|
199
|
+
|
|
200
|
+
url = self.router(prompt=prompt, last_node_ids=last_node_ids)
|
|
201
|
+
|
|
202
|
+
if BIZYAIR_DEBUG:
|
|
203
|
+
print(f"Generated URL: {url}")
|
|
204
|
+
|
|
205
|
+
start_time = time.time()
|
|
206
|
+
sh256 = hashlib.sha256(
|
|
207
|
+
json.dumps({"url": url, "prompt": prompt}).encode("utf-8")
|
|
208
|
+
).hexdigest()
|
|
209
|
+
end_time = time.time()
|
|
210
|
+
if BIZYAIR_DEBUG:
|
|
211
|
+
print(
|
|
212
|
+
f"Time taken to generate sh256-{sh256}: {end_time - start_time} seconds"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
cached_output = self.cache_manager.get(sh256)
|
|
216
|
+
if cached_output:
|
|
217
|
+
if BIZYAIR_DEBUG:
|
|
218
|
+
print(f"Cache hit for sh256-{sh256}")
|
|
219
|
+
out = cached_output
|
|
220
|
+
else:
|
|
221
|
+
result = self.processor(url, prompt=prompt, last_node_ids=last_node_ids)
|
|
222
|
+
out = self._get_result(result, cache_key=sh256)
|
|
223
|
+
|
|
224
|
+
if BIZYAIR_DEBUG:
|
|
225
|
+
pprint.pprint({"out": truncate_long_strings(out, 50)}, indent=4)
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
real_out = decode_data(out)
|
|
229
|
+
return [x[0] for x in real_out]
|
|
230
|
+
except Exception as e:
|
|
231
|
+
print("Exception occurred while decoding data")
|
|
232
|
+
self.cache_manager.delete(sh256)
|
|
233
|
+
traceback.print_exc()
|
|
234
|
+
raise RuntimeError(f"Exception: {e=}") from e
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
import glob
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from collections import OrderedDict
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Any, Dict
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class CacheConfig:
|
|
13
|
+
max_size: int = 100
|
|
14
|
+
expiration: int = 300 # 300 seconds
|
|
15
|
+
cache_dir: str = "./cache"
|
|
16
|
+
file_prefix: str = "bizyair_cache_"
|
|
17
|
+
file_suffix: str = ".json"
|
|
18
|
+
use_cache: bool = True
|
|
19
|
+
|
|
20
|
+
@classmethod
|
|
21
|
+
def from_config(cls, config: Dict[str, Any]):
|
|
22
|
+
return cls(
|
|
23
|
+
max_size=config.get("max_size", 100),
|
|
24
|
+
expiration=config.get("expiration", 300),
|
|
25
|
+
cache_dir=config.get("cache_dir", "./cache"),
|
|
26
|
+
file_prefix=config.get("file_prefix", "bizyair_cache_"),
|
|
27
|
+
file_suffix=config.get("file_suffix", ".json"),
|
|
28
|
+
use_cache=config.get("use_cache", True),
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CacheManager(ABC):
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def get(self, key):
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def set(self, key, value):
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def clear(self):
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def disable(self):
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class BizyAirTaskCache(CacheManager):
|
|
51
|
+
def __init__(self, config: CacheConfig):
|
|
52
|
+
self.config = config
|
|
53
|
+
self.cache = OrderedDict()
|
|
54
|
+
self.cache_dir = config.cache_dir
|
|
55
|
+
self.ensure_directory_exists()
|
|
56
|
+
self.cache = self.load_cache() if config.use_cache else self.cache
|
|
57
|
+
|
|
58
|
+
def ensure_directory_exists(self):
|
|
59
|
+
if not os.path.exists(self.cache_dir):
|
|
60
|
+
os.makedirs(self.cache_dir)
|
|
61
|
+
|
|
62
|
+
def load_cache(self):
|
|
63
|
+
cache_v_files = glob.glob(
|
|
64
|
+
os.path.join(
|
|
65
|
+
self.cache_dir, f"{self.config.file_prefix}*{self.config.file_suffix}"
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
output = OrderedDict()
|
|
69
|
+
cache_datas = []
|
|
70
|
+
for cache_file in cache_v_files:
|
|
71
|
+
try:
|
|
72
|
+
file_name = os.path.basename(cache_file)[
|
|
73
|
+
len(self.config.file_prefix) : -len(self.config.file_suffix)
|
|
74
|
+
]
|
|
75
|
+
cache_key = file_name.split("-")[0]
|
|
76
|
+
cache_timestamp = file_name.split("-")[1]
|
|
77
|
+
if int(time.time()) - int(cache_timestamp) > self.config.expiration:
|
|
78
|
+
self.delete_file(cache_file)
|
|
79
|
+
continue
|
|
80
|
+
cache_datas.append(
|
|
81
|
+
{
|
|
82
|
+
"key": cache_key,
|
|
83
|
+
"timestamp": int(cache_timestamp),
|
|
84
|
+
"file_path": cache_file,
|
|
85
|
+
}
|
|
86
|
+
)
|
|
87
|
+
except Exception as e:
|
|
88
|
+
print(
|
|
89
|
+
f"Warning: Error loading cache file {cache_file}: because {e}, will delete it"
|
|
90
|
+
)
|
|
91
|
+
cache_datas = sorted(cache_datas, key=lambda x: x["timestamp"])
|
|
92
|
+
for cache_data in cache_datas:
|
|
93
|
+
output[cache_data["key"]] = (
|
|
94
|
+
cache_data["file_path"],
|
|
95
|
+
cache_data["timestamp"],
|
|
96
|
+
)
|
|
97
|
+
return output
|
|
98
|
+
|
|
99
|
+
def delete(self, key):
|
|
100
|
+
if key in self.cache:
|
|
101
|
+
self.delete_file(self.cache[key][0])
|
|
102
|
+
del self.cache[key]
|
|
103
|
+
|
|
104
|
+
def get(self, key):
|
|
105
|
+
if key not in self.cache:
|
|
106
|
+
return None
|
|
107
|
+
|
|
108
|
+
file_path, timestamp = self.cache[key]
|
|
109
|
+
if time.time() - timestamp >= self.config.expiration:
|
|
110
|
+
self._remove_expired_entry(file_path, key)
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
cache_data = self._read_file(file_path)
|
|
114
|
+
if cache_data["cache_key"] == key:
|
|
115
|
+
return cache_data["result"]
|
|
116
|
+
else:
|
|
117
|
+
self._remove_expired_entry(file_path, key)
|
|
118
|
+
return None
|
|
119
|
+
|
|
120
|
+
def _read_file(self, file_path):
|
|
121
|
+
try:
|
|
122
|
+
with open(file_path, "r") as f:
|
|
123
|
+
cache_data = json.load(f)
|
|
124
|
+
return cache_data
|
|
125
|
+
except Exception as e:
|
|
126
|
+
print(f"Error reading file {file_path}: {e}")
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
def _remove_expired_entry(self, file_path, key):
|
|
130
|
+
self.delete_file(file_path)
|
|
131
|
+
del self.cache[key]
|
|
132
|
+
|
|
133
|
+
def set(self, key, value, *, overwrite=False):
|
|
134
|
+
if not overwrite and key in self.cache:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
f"Key '{key}' already exists in cache. Use overwrite=True to replace it."
|
|
137
|
+
)
|
|
138
|
+
assert isinstance(key, str), "Key must be a string"
|
|
139
|
+
|
|
140
|
+
if len(self.cache) >= self.config.max_size:
|
|
141
|
+
self._evict_oldest()
|
|
142
|
+
|
|
143
|
+
timestamp = int(time.time())
|
|
144
|
+
file_path = os.path.join(
|
|
145
|
+
self.cache_dir, f"{self.config.file_prefix}{key}-{timestamp}.json"
|
|
146
|
+
)
|
|
147
|
+
self.write_file(key, value, file_path, timestamp)
|
|
148
|
+
|
|
149
|
+
def _evict_oldest(self):
|
|
150
|
+
oldest_key, (oldest_file_path, _) = self.cache.popitem(last=False)
|
|
151
|
+
self.delete_file(oldest_file_path)
|
|
152
|
+
|
|
153
|
+
def write_file(self, key: str, value: Any, file_path: str, timestamp: int):
|
|
154
|
+
try:
|
|
155
|
+
with open(file_path, "w") as f:
|
|
156
|
+
json.dump(
|
|
157
|
+
{"result": value, "cache_key": key, "timestamp": timestamp}, f
|
|
158
|
+
)
|
|
159
|
+
self.cache[key] = (file_path, timestamp)
|
|
160
|
+
except Exception as e:
|
|
161
|
+
print(f"Error writing file for key '{key}': {e}")
|
|
162
|
+
|
|
163
|
+
def delete_file(self, file_path):
|
|
164
|
+
if os.path.exists(file_path):
|
|
165
|
+
try:
|
|
166
|
+
os.remove(file_path)
|
|
167
|
+
except Exception as e:
|
|
168
|
+
print(f"Error deleting file '{file_path}': {e}")
|
|
169
|
+
|
|
170
|
+
def clear(self):
|
|
171
|
+
for file_path, _ in self.cache.values():
|
|
172
|
+
self.delete_file(file_path)
|
|
173
|
+
self.cache.clear()
|
|
174
|
+
|
|
175
|
+
def disable(self):
|
|
176
|
+
self.clear()
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
# Example usage
|
|
180
|
+
if __name__ == "__main__":
|
|
181
|
+
cache_config = CacheConfig(max_size=12, expiration=10, cache_dir="./cache")
|
|
182
|
+
cache = BizyAirTaskCache(cache_config)
|
|
183
|
+
|
|
184
|
+
# Set some cache values
|
|
185
|
+
cache.set("key1", "This is the value for key1")
|
|
186
|
+
cache.set("key2", "This is the value for key2")
|
|
187
|
+
|
|
188
|
+
# Retrieve values from cache
|
|
189
|
+
print(cache.get("key1")) # Output: This is the value for key1
|
|
190
|
+
print(cache.get("key2")) # Output: This is the value for key2
|
|
191
|
+
|
|
192
|
+
# Wait for expiration
|
|
193
|
+
time.sleep(9)
|
|
194
|
+
print(cache.get("key1")) # Output: None (expired)
|
|
195
|
+
|
|
196
|
+
# Clear cache
|
|
197
|
+
cache.clear()
|
|
198
|
+
print(cache.get("key2")) # Output: None (cache cleared)
|