bizydraft 0.2.49__py3-none-any.whl → 0.2.87__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bizydraft might be problematic. Click here for more details.
- bizydraft/env.py +3 -0
- bizydraft/hijack_nodes.py +60 -42
- bizydraft/hijack_routes.py +36 -3
- bizydraft/oss_utils.py +231 -2
- bizydraft/patch_handlers.py +197 -8
- bizydraft/static/js/aiAppHandler.js +460 -425
- bizydraft/static/js/clipspaceToOss.js +386 -0
- bizydraft/static/js/disableComfyWebSocket.js +64 -0
- bizydraft/static/js/freezeModeHandler.js +425 -404
- bizydraft/static/js/handleStyle.js +128 -36
- bizydraft/static/js/hookLoad/configLoader.js +74 -0
- bizydraft/static/js/hookLoad/media.js +684 -0
- bizydraft/static/js/hookLoad/model.js +322 -0
- bizydraft/static/js/hookLoadMedia.js +196 -0
- bizydraft/static/js/hookLoadModel.js +207 -256
- bizydraft/static/js/main.js +2 -0
- bizydraft/static/js/nodeFocusHandler.js +118 -106
- bizydraft/static/js/nodeParamsFilter.js +91 -89
- bizydraft/static/js/postEvent.js +1207 -967
- bizydraft/static/js/socket.js +55 -50
- bizydraft/static/js/tool.js +71 -63
- bizydraft/static/js/uploadFile.js +49 -41
- bizydraft/static/js/workflow_io.js +193 -0
- {bizydraft-0.2.49.dist-info → bizydraft-0.2.87.dist-info}/METADATA +1 -1
- bizydraft-0.2.87.dist-info/RECORD +34 -0
- bizydraft/static/js/hookLoadImage.js +0 -177
- bizydraft-0.2.49.dist-info/RECORD +0 -28
- {bizydraft-0.2.49.dist-info → bizydraft-0.2.87.dist-info}/WHEEL +0 -0
- {bizydraft-0.2.49.dist-info → bizydraft-0.2.87.dist-info}/top_level.txt +0 -0
bizydraft/env.py
CHANGED
bizydraft/hijack_nodes.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
import re
|
|
2
|
+
from datetime import datetime
|
|
2
3
|
|
|
3
4
|
from loguru import logger
|
|
4
5
|
|
|
6
|
+
from bizydraft.env import COMFYAGENT_NODE_CONFIG
|
|
7
|
+
|
|
5
8
|
try:
|
|
6
9
|
from comfy_extras.nodes_video import LoadVideo
|
|
7
10
|
from nodes import NODE_CLASS_MAPPINGS, LoadImage
|
|
@@ -17,36 +20,70 @@ class BizyDraftLoadVideo(LoadVideo):
|
|
|
17
20
|
super().__init__(*args, **kwargs)
|
|
18
21
|
|
|
19
22
|
@classmethod
|
|
20
|
-
def INPUT_TYPES(cls):
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
}
|
|
23
|
+
def INPUT_TYPES(cls, **kwargs):
|
|
24
|
+
# 调用父类方法,保持兼容性
|
|
25
|
+
return super().INPUT_TYPES(**kwargs)
|
|
24
26
|
|
|
25
27
|
@classmethod
|
|
26
28
|
def VALIDATE_INPUTS(s, *args, **kwargs):
|
|
27
29
|
return True
|
|
28
30
|
|
|
31
|
+
@classmethod
|
|
32
|
+
def validate_inputs(s, *args, **kwargs):
|
|
33
|
+
# V3 API 使用小写的 validate_inputs
|
|
34
|
+
return True
|
|
35
|
+
|
|
29
36
|
|
|
30
37
|
class BizyDraftLoadImage(LoadImage):
|
|
31
38
|
def __init__(self, *args, **kwargs):
|
|
32
39
|
super().__init__(*args, **kwargs)
|
|
33
40
|
|
|
34
41
|
@classmethod
|
|
35
|
-
def INPUT_TYPES(
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
}
|
|
42
|
+
def INPUT_TYPES(cls, **kwargs):
|
|
43
|
+
# 调用父类方法,保持兼容性
|
|
44
|
+
return super().INPUT_TYPES(**kwargs)
|
|
39
45
|
|
|
40
46
|
@classmethod
|
|
41
47
|
def VALIDATE_INPUTS(s, *args, **kwargs):
|
|
42
48
|
return True
|
|
43
49
|
|
|
50
|
+
@classmethod
|
|
51
|
+
def validate_inputs(s, *args, **kwargs):
|
|
52
|
+
# V3 API 使用小写的 validate_inputs
|
|
53
|
+
return True
|
|
54
|
+
|
|
44
55
|
|
|
45
56
|
CLASS_PATCHES = {
|
|
46
|
-
#
|
|
47
|
-
#
|
|
57
|
+
# "LoadImage": BizyDraftLoadImage,
|
|
58
|
+
# "LoadVideo": BizyDraftLoadVideo,
|
|
48
59
|
}
|
|
49
60
|
|
|
61
|
+
|
|
62
|
+
def get_data_load_classes_from_url(config_url):
|
|
63
|
+
import requests
|
|
64
|
+
|
|
65
|
+
# 获取当前时间,精确到分钟
|
|
66
|
+
current_time = datetime.now().strftime("%Y%m%d%H%M")
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
config_url = config_url + "?t=" + current_time
|
|
70
|
+
response = requests.get(config_url)
|
|
71
|
+
response.raise_for_status()
|
|
72
|
+
data = response.json()
|
|
73
|
+
keys_list = []
|
|
74
|
+
if "weight_load_nodes" in data:
|
|
75
|
+
keys_list.extend(list(data["weight_load_nodes"].keys()))
|
|
76
|
+
if "media_load_nodes" in data:
|
|
77
|
+
keys_list.extend(list(data["media_load_nodes"].keys()))
|
|
78
|
+
|
|
79
|
+
return keys_list
|
|
80
|
+
except Exception as e:
|
|
81
|
+
logger.error(
|
|
82
|
+
f"Failed to fetch or comfyagent node config from {config_url}: {e}"
|
|
83
|
+
)
|
|
84
|
+
return []
|
|
85
|
+
|
|
86
|
+
|
|
50
87
|
DATA_LOAD_CLASSES = [
|
|
51
88
|
"LoadImage",
|
|
52
89
|
"LoadVideo",
|
|
@@ -55,39 +92,16 @@ DATA_LOAD_CLASSES = [
|
|
|
55
92
|
"Load3D",
|
|
56
93
|
"VHS_LoadAudioUpload",
|
|
57
94
|
"VHS_LoadVideo",
|
|
58
|
-
"LayerMask: YoloV8Detect",
|
|
59
|
-
"Lora Loader Stack (rgthree)",
|
|
60
|
-
"easy loraNames",
|
|
61
|
-
"easy loraStack",
|
|
62
|
-
"Load Lora",
|
|
63
|
-
"Intrinsic_lora_sampling",
|
|
64
|
-
"ADE_LoadAnimateDiffModel",
|
|
65
|
-
"ADE_AnimateDiffLoRALoader",
|
|
66
|
-
"easy ultralyticsDetectorPipe",
|
|
67
|
-
"UltralyticsDetectorProvider",
|
|
68
|
-
"ONNXDetectorProvider",
|
|
69
|
-
"SAMLoader",
|
|
70
|
-
"easy samLoaderPipe",
|
|
71
|
-
"WanVideoModelLoader",
|
|
72
|
-
"LoadWanVideoT5TextEncoder",
|
|
73
|
-
"WanVideoLoraSelect",
|
|
74
|
-
"LoadFramePackModel",
|
|
75
|
-
"ReActorLoadFaceModel",
|
|
76
|
-
"ReActorMaskHelper",
|
|
77
|
-
"LoadAndApplyICLightUnet",
|
|
78
|
-
"SeedVR2",
|
|
79
|
-
"LoadLaMaModel",
|
|
80
|
-
"Upscale Model Loader",
|
|
81
|
-
"CR Upscale Image",
|
|
82
|
-
"SUPIR_Upscale",
|
|
83
|
-
"CR Multi Upscale Stack",
|
|
84
|
-
"QuadrupleCLIPLoader",
|
|
85
|
-
"LoadWanVideoClipTextEncoder",
|
|
86
|
-
"SUPIR_model_loader_v2_clip",
|
|
87
|
-
"LayerMask: LoadSAM2Model",
|
|
88
|
-
"LayerMask: SegmentAnythingUltra V2",
|
|
89
95
|
]
|
|
90
96
|
|
|
97
|
+
if COMFYAGENT_NODE_CONFIG.startswith("http"):
|
|
98
|
+
fetched_classes = get_data_load_classes_from_url(COMFYAGENT_NODE_CONFIG)
|
|
99
|
+
if fetched_classes:
|
|
100
|
+
DATA_LOAD_CLASSES.extend(fetched_classes)
|
|
101
|
+
logger.info(f"Fetched additional data load classes: {fetched_classes}")
|
|
102
|
+
else:
|
|
103
|
+
logger.warning("No additional data load classes fetched from the URL.")
|
|
104
|
+
|
|
91
105
|
|
|
92
106
|
def hijack_nodes():
|
|
93
107
|
def _hijack_node(node_name, new_class):
|
|
@@ -103,6 +117,7 @@ def hijack_nodes():
|
|
|
103
117
|
|
|
104
118
|
# 通用情况,正则匹配后,打通用patch、替换
|
|
105
119
|
for node_name, base_class in NODE_CLASS_MAPPINGS.items():
|
|
120
|
+
|
|
106
121
|
regex = r"^(?!BizyAir_)\w+.*Loader.*"
|
|
107
122
|
match = re.match(regex, node_name, re.IGNORECASE)
|
|
108
123
|
if (match and (node_name not in CLASS_PATCHES)) or (
|
|
@@ -115,7 +130,10 @@ def hijack_nodes():
|
|
|
115
130
|
|
|
116
131
|
def create_patched_class(base_class, validate_inputs_func=None):
|
|
117
132
|
class PatchedClass(base_class):
|
|
118
|
-
|
|
133
|
+
@classmethod
|
|
134
|
+
def validate_inputs(cls, *args, **kwargs):
|
|
135
|
+
# V3 API
|
|
136
|
+
return True
|
|
119
137
|
|
|
120
138
|
if validate_inputs_func:
|
|
121
139
|
PatchedClass.VALIDATE_INPUTS = classmethod(validate_inputs_func)
|
bizydraft/hijack_routes.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
1
3
|
from aiohttp import web
|
|
2
4
|
from loguru import logger
|
|
3
5
|
|
|
4
|
-
from bizydraft.oss_utils import upload_image
|
|
5
|
-
from bizydraft.patch_handlers import post_prompt, view_image
|
|
6
|
+
from bizydraft.oss_utils import upload_image, upload_mask
|
|
7
|
+
from bizydraft.patch_handlers import post_prompt, view_image, view_video
|
|
6
8
|
|
|
7
9
|
try:
|
|
8
10
|
from server import PromptServer
|
|
@@ -23,10 +25,14 @@ def hijack_routes_pre_add_routes():
|
|
|
23
25
|
("/view", "GET"): view_image,
|
|
24
26
|
("/prompt", "POST"): post_prompt,
|
|
25
27
|
("/upload/image", "POST"): upload_image,
|
|
28
|
+
("/upload/mask", "POST"): upload_mask,
|
|
26
29
|
# /api alias
|
|
27
30
|
("/api/view", "GET"): view_image,
|
|
28
31
|
("/api/prompt", "POST"): post_prompt,
|
|
29
32
|
("/api/upload/image", "POST"): upload_image,
|
|
33
|
+
("/api/upload/mask", "POST"): upload_mask,
|
|
34
|
+
# VHS plugin support
|
|
35
|
+
("/api/vhs/viewvideo", "GET"): view_video,
|
|
30
36
|
}
|
|
31
37
|
|
|
32
38
|
async def middleware_handler(request):
|
|
@@ -47,11 +53,38 @@ def hijack_routes_pre_add_routes():
|
|
|
47
53
|
|
|
48
54
|
return middleware_handler
|
|
49
55
|
|
|
56
|
+
# 覆盖 /settings 响应,修改设置 NodeIdBadgeMode 为 ShowAll显示id,将 LinkRenderMode 为(Spline=2),显示连线
|
|
57
|
+
async def settings_response_middleware(app, handler):
|
|
58
|
+
async def middleware_handler(request):
|
|
59
|
+
if request.method == "GET":
|
|
60
|
+
p = request.path
|
|
61
|
+
if p == "/settings" or p == "/api/settings":
|
|
62
|
+
resp = await handler(request)
|
|
63
|
+
if getattr(resp, "content_type", None) == "application/json":
|
|
64
|
+
body_bytes = getattr(resp, "body", b"")
|
|
65
|
+
charset = getattr(resp, "charset", None) or "utf-8"
|
|
66
|
+
payload = json.loads(body_bytes.decode(charset))
|
|
67
|
+
if payload.get("Comfy.NodeBadge.NodeIdBadgeMode") != "ShowAll":
|
|
68
|
+
payload["Comfy.NodeBadge.NodeIdBadgeMode"] = "ShowAll"
|
|
69
|
+
if payload.get("Comfy.LinkRenderMode") != 2:
|
|
70
|
+
payload["Comfy.LinkRenderMode"] = 2
|
|
71
|
+
if payload.get("Comfy.VueNodes.Enabled") != False:
|
|
72
|
+
payload["Comfy.VueNodes.Enabled"] = False
|
|
73
|
+
return web.json_response(
|
|
74
|
+
payload, status=getattr(resp, "status", 200)
|
|
75
|
+
)
|
|
76
|
+
return resp
|
|
77
|
+
return await handler(request)
|
|
78
|
+
|
|
79
|
+
return middleware_handler
|
|
80
|
+
|
|
50
81
|
async def access_control_middleware(app, handler):
|
|
51
82
|
base_white_list = [
|
|
52
83
|
"/prompt",
|
|
53
84
|
"/view",
|
|
54
85
|
"/upload/image",
|
|
86
|
+
"/upload/mask",
|
|
87
|
+
"/vhs/viewvideo",
|
|
55
88
|
"/",
|
|
56
89
|
"/ws",
|
|
57
90
|
"/extensions",
|
|
@@ -92,7 +125,7 @@ def hijack_routes_pre_add_routes():
|
|
|
92
125
|
|
|
93
126
|
return middleware_handler
|
|
94
127
|
|
|
95
|
-
app.middlewares.extend([custom_business_middleware])
|
|
128
|
+
app.middlewares.extend([custom_business_middleware, settings_response_middleware])
|
|
96
129
|
|
|
97
130
|
logger.info("Optimized middleware setup complete.")
|
|
98
131
|
|
bizydraft/oss_utils.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import json
|
|
3
3
|
import os
|
|
4
|
+
import re
|
|
5
|
+
import uuid
|
|
4
6
|
from http.cookies import SimpleCookie
|
|
7
|
+
from pathlib import Path
|
|
5
8
|
from time import time
|
|
6
9
|
from typing import Any, Dict
|
|
7
10
|
|
|
@@ -16,6 +19,8 @@ from werkzeug.utils import secure_filename
|
|
|
16
19
|
|
|
17
20
|
from bizydraft.env import BIZYAIR_API_KEY, BIZYDRAFT_SERVER
|
|
18
21
|
|
|
22
|
+
CLIPSPACE_TO_OSS_MAPPING = {}
|
|
23
|
+
|
|
19
24
|
private_key_pem = os.getenv(
|
|
20
25
|
"RSA_PRIVATE_KEY",
|
|
21
26
|
"""-----BEGIN RSA PRIVATE KEY-----
|
|
@@ -59,6 +64,9 @@ def decrypt(encrypted_message):
|
|
|
59
64
|
if not encrypted_message or not isinstance(encrypted_message, str):
|
|
60
65
|
raise ValueError("无效的加密消息")
|
|
61
66
|
|
|
67
|
+
if "v4.public" in encrypted_message:
|
|
68
|
+
return encrypted_message
|
|
69
|
+
|
|
62
70
|
private_key = serialization.load_pem_private_key(
|
|
63
71
|
private_key_pem.encode(), password=None, backend=default_backend()
|
|
64
72
|
)
|
|
@@ -90,7 +98,13 @@ async def get_upload_token(
|
|
|
90
98
|
filename: str,
|
|
91
99
|
api_key: str,
|
|
92
100
|
) -> Dict[str, Any]:
|
|
93
|
-
|
|
101
|
+
from urllib.parse import quote
|
|
102
|
+
|
|
103
|
+
# 对文件名进行URL编码,避免特殊字符导致问题
|
|
104
|
+
encoded_filename = quote(filename, safe="")
|
|
105
|
+
url = (
|
|
106
|
+
f"{BIZYDRAFT_SERVER}/upload/token?file_name={encoded_filename}&file_type=inputs"
|
|
107
|
+
)
|
|
94
108
|
|
|
95
109
|
headers = {
|
|
96
110
|
"Content-Type": "application/json",
|
|
@@ -173,10 +187,15 @@ async def upload_to_oss(post, api_key: str):
|
|
|
173
187
|
if not (image and image.file):
|
|
174
188
|
return web.Response(status=400)
|
|
175
189
|
|
|
190
|
+
original_frontend_filename = image.filename # 保存前端发送的原始文件名
|
|
176
191
|
filename = image.filename
|
|
177
192
|
if not filename:
|
|
178
193
|
return web.Response(status=400)
|
|
179
194
|
|
|
195
|
+
should_clean, filename = clean_filename(filename)
|
|
196
|
+
if should_clean:
|
|
197
|
+
filename = f"{uuid.uuid4()}.{filename}"
|
|
198
|
+
|
|
180
199
|
oss_token = await get_upload_token(filename, api_key)
|
|
181
200
|
result = await upload_filefield_to_oss(image, oss_token)
|
|
182
201
|
if result["status"] != 200:
|
|
@@ -189,8 +208,23 @@ async def upload_to_oss(post, api_key: str):
|
|
|
189
208
|
except Exception as e:
|
|
190
209
|
logger.error(f"Commit file failed: {e}")
|
|
191
210
|
return web.Response(status=500, text=str(e))
|
|
211
|
+
|
|
212
|
+
# 将 OSS URL 拆分成 filename 和 subfolder,以便前端正确构建 /api/view 请求
|
|
213
|
+
# 例如: https://bizyair-prod.oss-cn-shanghai.aliyuncs.com/inputs/20250930/file.png
|
|
214
|
+
oss_url = result["url"]
|
|
215
|
+
oss_filename = oss_url.split("/")[-1] # 获取最后一部分作为文件名
|
|
216
|
+
oss_subfolder = "/".join(
|
|
217
|
+
oss_url.split("/")[:-1]
|
|
218
|
+
) # 获取除文件名外的部分作为 subfolder
|
|
219
|
+
|
|
220
|
+
if original_frontend_filename:
|
|
221
|
+
CLIPSPACE_TO_OSS_MAPPING[original_frontend_filename] = oss_url
|
|
222
|
+
logger.info(
|
|
223
|
+
f"[OSS_MAPPING] Cached mapping: {original_frontend_filename} -> {oss_url}"
|
|
224
|
+
)
|
|
225
|
+
|
|
192
226
|
return web.json_response(
|
|
193
|
-
{"name":
|
|
227
|
+
{"name": oss_filename, "subfolder": oss_subfolder, "type": image_upload_type}
|
|
194
228
|
)
|
|
195
229
|
|
|
196
230
|
|
|
@@ -225,3 +259,198 @@ async def upload_image(request):
|
|
|
225
259
|
return web.Response(status=403, text="No validated key found")
|
|
226
260
|
post = await request.post()
|
|
227
261
|
return await upload_to_oss(post, api_key)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
async def upload_mask(request):
|
|
265
|
+
"""
|
|
266
|
+
处理 mask editor 上传,将带 alpha 通道的图片上传到 OSS
|
|
267
|
+
"""
|
|
268
|
+
import io
|
|
269
|
+
import json
|
|
270
|
+
import tempfile
|
|
271
|
+
|
|
272
|
+
from PIL import Image
|
|
273
|
+
from PIL.PngImagePlugin import PngInfo
|
|
274
|
+
|
|
275
|
+
api_key = get_api_key(request)
|
|
276
|
+
if not api_key:
|
|
277
|
+
logger.error("[UPLOAD_MASK] No API key found")
|
|
278
|
+
return web.Response(status=403, text="No validated key found")
|
|
279
|
+
|
|
280
|
+
post = await request.post()
|
|
281
|
+
|
|
282
|
+
# 获取上传的 mask 图片
|
|
283
|
+
mask_image = post.get("image")
|
|
284
|
+
if not (mask_image and mask_image.file):
|
|
285
|
+
logger.error("[UPLOAD_MASK] No image provided in request")
|
|
286
|
+
return web.Response(status=400, text="No image provided")
|
|
287
|
+
|
|
288
|
+
# 保存前端发送的原始文件名,用于后续缓存映射
|
|
289
|
+
original_frontend_filename = mask_image.filename
|
|
290
|
+
|
|
291
|
+
# 获取原始图片引用
|
|
292
|
+
original_ref_str = post.get("original_ref")
|
|
293
|
+
|
|
294
|
+
if not original_ref_str:
|
|
295
|
+
# 如果没有 original_ref,直接上传 mask
|
|
296
|
+
return await upload_to_oss(post, api_key)
|
|
297
|
+
|
|
298
|
+
try:
|
|
299
|
+
from urllib.parse import unquote
|
|
300
|
+
|
|
301
|
+
original_ref = json.loads(original_ref_str)
|
|
302
|
+
original_filename = original_ref.get("filename")
|
|
303
|
+
original_subfolder = original_ref.get("subfolder", "")
|
|
304
|
+
|
|
305
|
+
if not original_filename:
|
|
306
|
+
logger.error("[UPLOAD_MASK] No filename in original_ref")
|
|
307
|
+
return web.Response(status=400, text="No filename in original_ref")
|
|
308
|
+
|
|
309
|
+
# 构建完整的 OSS URL(类似 view_image 的逻辑)
|
|
310
|
+
http_prefix_options = ("http:", "https:")
|
|
311
|
+
|
|
312
|
+
if "http" in original_subfolder:
|
|
313
|
+
# subfolder 中包含 URL 基础路径
|
|
314
|
+
original_subfolder = original_subfolder[original_subfolder.find("http") :]
|
|
315
|
+
original_subfolder = unquote(original_subfolder)
|
|
316
|
+
if "https:/" in original_subfolder and not original_subfolder.startswith(
|
|
317
|
+
"https://"
|
|
318
|
+
):
|
|
319
|
+
original_subfolder = original_subfolder.replace(
|
|
320
|
+
"https:/", "https://", 1
|
|
321
|
+
)
|
|
322
|
+
if "http:/" in original_subfolder and not original_subfolder.startswith(
|
|
323
|
+
"http://"
|
|
324
|
+
):
|
|
325
|
+
original_subfolder = original_subfolder.replace("http:/", "http://", 1)
|
|
326
|
+
original_url = f"{original_subfolder}/{original_filename}"
|
|
327
|
+
elif original_filename.startswith(http_prefix_options):
|
|
328
|
+
# filename 本身就是完整 URL
|
|
329
|
+
original_url = original_filename
|
|
330
|
+
elif (
|
|
331
|
+
original_subfolder == "clipspace"
|
|
332
|
+
and original_filename in CLIPSPACE_TO_OSS_MAPPING
|
|
333
|
+
):
|
|
334
|
+
# 检查缓存:如果是 clipspace 文件且在缓存中,使用缓存的 OSS URL
|
|
335
|
+
original_url = CLIPSPACE_TO_OSS_MAPPING[original_filename]
|
|
336
|
+
else:
|
|
337
|
+
# 不是 OSS URL 格式且不在缓存中,直接上传 mask 图片
|
|
338
|
+
return await upload_to_oss(post, api_key)
|
|
339
|
+
|
|
340
|
+
async with aiohttp.ClientSession() as session:
|
|
341
|
+
async with session.get(original_url) as resp:
|
|
342
|
+
if resp.status != 200:
|
|
343
|
+
logger.error(
|
|
344
|
+
f"[UPLOAD_MASK] Failed to download original image: {resp.status}"
|
|
345
|
+
)
|
|
346
|
+
return web.Response(
|
|
347
|
+
status=502,
|
|
348
|
+
text=f"Failed to download original image: {resp.status}",
|
|
349
|
+
)
|
|
350
|
+
original_image_data = await resp.read()
|
|
351
|
+
|
|
352
|
+
# 处理图片:应用 alpha 通道
|
|
353
|
+
with Image.open(io.BytesIO(original_image_data)) as original_pil:
|
|
354
|
+
# 保存元数据
|
|
355
|
+
metadata = PngInfo()
|
|
356
|
+
if hasattr(original_pil, "text"):
|
|
357
|
+
for key in original_pil.text:
|
|
358
|
+
metadata.add_text(key, original_pil.text[key])
|
|
359
|
+
|
|
360
|
+
# 转换为 RGBA
|
|
361
|
+
original_pil = original_pil.convert("RGBA")
|
|
362
|
+
|
|
363
|
+
# 读取上传的 mask
|
|
364
|
+
mask_pil = Image.open(mask_image.file).convert("RGBA")
|
|
365
|
+
|
|
366
|
+
# alpha copy - 从 mask 提取 alpha 通道并应用到原图
|
|
367
|
+
new_alpha = mask_pil.getchannel("A")
|
|
368
|
+
original_pil.putalpha(new_alpha)
|
|
369
|
+
|
|
370
|
+
# 保存到临时文件
|
|
371
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
|
|
372
|
+
tmp_filepath = tmp_file.name
|
|
373
|
+
original_pil.save(tmp_filepath, compress_level=4, pnginfo=metadata)
|
|
374
|
+
|
|
375
|
+
# 准备上传到 OSS
|
|
376
|
+
filename = f"clipspace-mask-{uuid.uuid4().hex[:8]}.png"
|
|
377
|
+
# subfolder = post.get("subfolder", "clipspace")
|
|
378
|
+
image_upload_type = post.get("type", "input")
|
|
379
|
+
|
|
380
|
+
try:
|
|
381
|
+
# 获取上传 token
|
|
382
|
+
oss_token = await get_upload_token(filename, api_key)
|
|
383
|
+
|
|
384
|
+
# 读取临时文件并上传
|
|
385
|
+
with open(tmp_filepath, "rb") as f:
|
|
386
|
+
# 创建一个类似 FileField 的对象
|
|
387
|
+
class FileFieldLike:
|
|
388
|
+
def __init__(self, file_obj, filename, content_type):
|
|
389
|
+
self.file = file_obj
|
|
390
|
+
self.filename = filename
|
|
391
|
+
self.content_type = content_type
|
|
392
|
+
|
|
393
|
+
file_field = FileFieldLike(f, filename, "image/png")
|
|
394
|
+
result = await upload_filefield_to_oss(file_field, oss_token)
|
|
395
|
+
|
|
396
|
+
if result["status"] != 200:
|
|
397
|
+
logger.error(f"[UPLOAD_MASK] Upload failed: {result.get('reason', '')}")
|
|
398
|
+
return web.Response(
|
|
399
|
+
status=result["status"], text=result.get("reason", "")
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# Commit file
|
|
403
|
+
object_key = oss_token["data"]["file"]["object_key"]
|
|
404
|
+
await commit_file(object_key, filename, api_key)
|
|
405
|
+
|
|
406
|
+
# 将 OSS URL 拆分成 filename 和 subfolder,以便前端正确构建 /api/view 请求
|
|
407
|
+
oss_url = result["url"]
|
|
408
|
+
oss_filename = oss_url.split("/")[-1]
|
|
409
|
+
oss_subfolder = "/".join(oss_url.split("/")[:-1])
|
|
410
|
+
|
|
411
|
+
if original_frontend_filename:
|
|
412
|
+
CLIPSPACE_TO_OSS_MAPPING[original_frontend_filename] = oss_url
|
|
413
|
+
|
|
414
|
+
response_data = {
|
|
415
|
+
"name": oss_filename,
|
|
416
|
+
"subfolder": oss_subfolder,
|
|
417
|
+
"type": image_upload_type,
|
|
418
|
+
}
|
|
419
|
+
return web.json_response(response_data)
|
|
420
|
+
|
|
421
|
+
finally:
|
|
422
|
+
# 清理临时文件
|
|
423
|
+
if os.path.exists(tmp_filepath):
|
|
424
|
+
os.remove(tmp_filepath)
|
|
425
|
+
|
|
426
|
+
except Exception as e:
|
|
427
|
+
logger.error(f"[UPLOAD_MASK] ERROR processing mask upload: {e}", exc_info=True)
|
|
428
|
+
return web.Response(status=500, text=f"Error processing mask: {str(e)}")
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def _should_clean(name: str) -> bool:
|
|
432
|
+
"""True -> 包含非白名单字符;False -> 正常
|
|
433
|
+
|
|
434
|
+
使用白名单机制:只允许安全字符(中英文、数字、下划线、连字符、点、空格)
|
|
435
|
+
如果文件名包含白名单之外的字符,则需要清理
|
|
436
|
+
"""
|
|
437
|
+
if not name:
|
|
438
|
+
return False
|
|
439
|
+
|
|
440
|
+
# 分离文件名和扩展名
|
|
441
|
+
if "." not in name:
|
|
442
|
+
return False
|
|
443
|
+
|
|
444
|
+
# 白名单:允许中英文、数字、下划线、连字符、点、空格、圆括号
|
|
445
|
+
safe_pattern = r"^[\w\u4e00-\u9fa5\s\-().]+$"
|
|
446
|
+
|
|
447
|
+
return not bool(re.match(safe_pattern, name))
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def clean_filename(bad: str) -> (bool, str):
|
|
451
|
+
"""对乱码串提取最后扩展名;正常串直接返回原值"""
|
|
452
|
+
if not _should_clean(bad):
|
|
453
|
+
return False, bad
|
|
454
|
+
# 提取最后扩展名(含点)
|
|
455
|
+
ext = re.search(r"(\.[\w]+)$", bad)
|
|
456
|
+
return True, ext.group(1) if ext else bad
|