bizydraft 0.2.31__py3-none-any.whl → 0.2.78.dev20251117024007__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bizydraft/env.py CHANGED
@@ -9,3 +9,6 @@ logger.info(f"{BIZYDRAFT_DOMAIN=} {BIZYDRAFT_SERVER=}")
9
9
 
10
10
  BIZYAIR_API_KEY = os.getenv("BIZYAIR_API_KEY")
11
11
  logger.info(f"{BIZYAIR_API_KEY=}")
12
+
13
+ COMFYAGENT_NODE_CONFIG = os.getenv("COMFYAGENT_NODE_CONFIG", "")
14
+ logger.info(f"{COMFYAGENT_NODE_CONFIG=}")
bizydraft/hijack_nodes.py CHANGED
@@ -1,7 +1,10 @@
1
1
  import re
2
+ from datetime import datetime
2
3
 
3
4
  from loguru import logger
4
5
 
6
+ from bizydraft.env import COMFYAGENT_NODE_CONFIG
7
+
5
8
  try:
6
9
  from comfy_extras.nodes_video import LoadVideo
7
10
  from nodes import NODE_CLASS_MAPPINGS, LoadImage
@@ -17,56 +20,88 @@ class BizyDraftLoadVideo(LoadVideo):
17
20
  super().__init__(*args, **kwargs)
18
21
 
19
22
  @classmethod
20
- def INPUT_TYPES(cls):
21
- return {
22
- "required": {"file": (["choose your file"], {"video_upload": True})},
23
- }
23
+ def INPUT_TYPES(cls, **kwargs):
24
+ # 调用父类方法,保持兼容性
25
+ return super().INPUT_TYPES(**kwargs)
24
26
 
25
27
  @classmethod
26
28
  def VALIDATE_INPUTS(s, *args, **kwargs):
27
29
  return True
28
30
 
31
+ @classmethod
32
+ def validate_inputs(s, *args, **kwargs):
33
+ # V3 API 使用小写的 validate_inputs
34
+ return True
35
+
29
36
 
30
37
  class BizyDraftLoadImage(LoadImage):
31
38
  def __init__(self, *args, **kwargs):
32
39
  super().__init__(*args, **kwargs)
33
40
 
34
41
  @classmethod
35
- def INPUT_TYPES(s):
36
- return {
37
- "required": {"image": (["choose your file"], {"image_upload": True})},
38
- }
42
+ def INPUT_TYPES(cls, **kwargs):
43
+ # 调用父类方法,保持兼容性
44
+ return super().INPUT_TYPES(**kwargs)
39
45
 
40
46
  @classmethod
41
47
  def VALIDATE_INPUTS(s, *args, **kwargs):
42
48
  return True
43
49
 
50
+ @classmethod
51
+ def validate_inputs(s, *args, **kwargs):
52
+ # V3 API 使用小写的 validate_inputs
53
+ return True
54
+
44
55
 
45
56
  CLASS_PATCHES = {
46
- # "LoadImage": BizyDraftLoadImage,
47
- # "LoadVideo": BizyDraftLoadVideo,
57
+ # "LoadImage": BizyDraftLoadImage,
58
+ # "LoadVideo": BizyDraftLoadVideo,
48
59
  }
49
60
 
61
+
62
+ def get_data_load_classes_from_url(config_url):
63
+ import requests
64
+
65
+ # 获取当前时间,精确到分钟
66
+ current_time = datetime.now().strftime("%Y%m%d%H%M")
67
+
68
+ try:
69
+ config_url = config_url + "?t=" + current_time
70
+ response = requests.get(config_url)
71
+ response.raise_for_status()
72
+ data = response.json()
73
+ keys_list = []
74
+ if "weight_load_nodes" in data:
75
+ keys_list.extend(list(data["weight_load_nodes"].keys()))
76
+ if "media_load_nodes" in data:
77
+ keys_list.extend(list(data["media_load_nodes"].keys()))
78
+
79
+ return keys_list
80
+ except Exception as e:
81
+ logger.error(
82
+ f"Failed to fetch or comfyagent node config from {config_url}: {e}"
83
+ )
84
+ return []
85
+
86
+
50
87
  DATA_LOAD_CLASSES = [
51
88
  "LoadImage",
52
89
  "LoadVideo",
53
90
  "LoadImageMask",
54
91
  "LoadAudio",
55
92
  "Load3D",
56
- "LayerMask: YoloV8Detect",
57
- "Lora Loader Stack (rgthree)",
58
- "easy loraNames",
59
- "easy loraStack",
60
- "Load Lora",
61
- "Intrinsic_lora_sampling",
62
- "ADE_LoadAnimateDiffModel",
63
- "ADE_AnimateDiffLoRALoader",
64
- "easy ultralyticsDetectorPipe",
65
- "UltralyticsDetectorProvider",
66
- "SAMLoader",
67
- "easy samLoaderPipe",
93
+ "VHS_LoadAudioUpload",
94
+ "VHS_LoadVideo",
68
95
  ]
69
96
 
97
+ if COMFYAGENT_NODE_CONFIG.startswith("http"):
98
+ fetched_classes = get_data_load_classes_from_url(COMFYAGENT_NODE_CONFIG)
99
+ if fetched_classes:
100
+ DATA_LOAD_CLASSES.extend(fetched_classes)
101
+ logger.info(f"Fetched additional data load classes: {fetched_classes}")
102
+ else:
103
+ logger.warning("No additional data load classes fetched from the URL.")
104
+
70
105
 
71
106
  def hijack_nodes():
72
107
  def _hijack_node(node_name, new_class):
@@ -82,6 +117,7 @@ def hijack_nodes():
82
117
 
83
118
  # 通用情况,正则匹配后,打通用patch、替换
84
119
  for node_name, base_class in NODE_CLASS_MAPPINGS.items():
120
+
85
121
  regex = r"^(?!BizyAir_)\w+.*Loader.*"
86
122
  match = re.match(regex, node_name, re.IGNORECASE)
87
123
  if (match and (node_name not in CLASS_PATCHES)) or (
@@ -94,7 +130,10 @@ def hijack_nodes():
94
130
 
95
131
  def create_patched_class(base_class, validate_inputs_func=None):
96
132
  class PatchedClass(base_class):
97
- pass
133
+ @classmethod
134
+ def validate_inputs(cls, *args, **kwargs):
135
+ # V3 API
136
+ return True
98
137
 
99
138
  if validate_inputs_func:
100
139
  PatchedClass.VALIDATE_INPUTS = classmethod(validate_inputs_func)
@@ -1,8 +1,8 @@
1
1
  from aiohttp import web
2
2
  from loguru import logger
3
3
 
4
- from bizydraft.oss_utils import upload_image
5
- from bizydraft.patch_handlers import post_prompt, view_image
4
+ from bizydraft.oss_utils import upload_image, upload_mask
5
+ from bizydraft.patch_handlers import post_prompt, view_image, view_video
6
6
 
7
7
  try:
8
8
  from server import PromptServer
@@ -23,10 +23,14 @@ def hijack_routes_pre_add_routes():
23
23
  ("/view", "GET"): view_image,
24
24
  ("/prompt", "POST"): post_prompt,
25
25
  ("/upload/image", "POST"): upload_image,
26
+ ("/upload/mask", "POST"): upload_mask,
26
27
  # /api alias
27
28
  ("/api/view", "GET"): view_image,
28
29
  ("/api/prompt", "POST"): post_prompt,
29
30
  ("/api/upload/image", "POST"): upload_image,
31
+ ("/api/upload/mask", "POST"): upload_mask,
32
+ # VHS plugin support
33
+ ("/api/vhs/viewvideo", "GET"): view_video,
30
34
  }
31
35
 
32
36
  async def middleware_handler(request):
@@ -52,6 +56,8 @@ def hijack_routes_pre_add_routes():
52
56
  "/prompt",
53
57
  "/view",
54
58
  "/upload/image",
59
+ "/upload/mask",
60
+ "/vhs/viewvideo",
55
61
  "/",
56
62
  "/ws",
57
63
  "/extensions",
bizydraft/oss_utils.py CHANGED
@@ -1,7 +1,10 @@
1
1
  import base64
2
2
  import json
3
3
  import os
4
+ import re
5
+ import uuid
4
6
  from http.cookies import SimpleCookie
7
+ from pathlib import Path
5
8
  from time import time
6
9
  from typing import Any, Dict
7
10
 
@@ -16,6 +19,8 @@ from werkzeug.utils import secure_filename
16
19
 
17
20
  from bizydraft.env import BIZYAIR_API_KEY, BIZYDRAFT_SERVER
18
21
 
22
+ CLIPSPACE_TO_OSS_MAPPING = {}
23
+
19
24
  private_key_pem = os.getenv(
20
25
  "RSA_PRIVATE_KEY",
21
26
  """-----BEGIN RSA PRIVATE KEY-----
@@ -59,6 +64,9 @@ def decrypt(encrypted_message):
59
64
  if not encrypted_message or not isinstance(encrypted_message, str):
60
65
  raise ValueError("无效的加密消息")
61
66
 
67
+ if "v4.public" in encrypted_message:
68
+ return encrypted_message
69
+
62
70
  private_key = serialization.load_pem_private_key(
63
71
  private_key_pem.encode(), password=None, backend=default_backend()
64
72
  )
@@ -90,7 +98,13 @@ async def get_upload_token(
90
98
  filename: str,
91
99
  api_key: str,
92
100
  ) -> Dict[str, Any]:
93
- url = f"{BIZYDRAFT_SERVER}/upload/token?file_name={filename}&file_type=inputs"
101
+ from urllib.parse import quote
102
+
103
+ # 对文件名进行URL编码,避免特殊字符导致问题
104
+ encoded_filename = quote(filename, safe="")
105
+ url = (
106
+ f"{BIZYDRAFT_SERVER}/upload/token?file_name={encoded_filename}&file_type=inputs"
107
+ )
94
108
 
95
109
  headers = {
96
110
  "Content-Type": "application/json",
@@ -173,10 +187,15 @@ async def upload_to_oss(post, api_key: str):
173
187
  if not (image and image.file):
174
188
  return web.Response(status=400)
175
189
 
190
+ original_frontend_filename = image.filename # 保存前端发送的原始文件名
176
191
  filename = image.filename
177
192
  if not filename:
178
193
  return web.Response(status=400)
179
194
 
195
+ should_clean, filename = clean_filename(filename)
196
+ if should_clean:
197
+ filename = f"{uuid.uuid4()}.{filename}"
198
+
180
199
  oss_token = await get_upload_token(filename, api_key)
181
200
  result = await upload_filefield_to_oss(image, oss_token)
182
201
  if result["status"] != 200:
@@ -189,8 +208,23 @@ async def upload_to_oss(post, api_key: str):
189
208
  except Exception as e:
190
209
  logger.error(f"Commit file failed: {e}")
191
210
  return web.Response(status=500, text=str(e))
211
+
212
+ # 将 OSS URL 拆分成 filename 和 subfolder,以便前端正确构建 /api/view 请求
213
+ # 例如: https://bizyair-prod.oss-cn-shanghai.aliyuncs.com/inputs/20250930/file.png
214
+ oss_url = result["url"]
215
+ oss_filename = oss_url.split("/")[-1] # 获取最后一部分作为文件名
216
+ oss_subfolder = "/".join(
217
+ oss_url.split("/")[:-1]
218
+ ) # 获取除文件名外的部分作为 subfolder
219
+
220
+ if original_frontend_filename:
221
+ CLIPSPACE_TO_OSS_MAPPING[original_frontend_filename] = oss_url
222
+ logger.info(
223
+ f"[OSS_MAPPING] Cached mapping: {original_frontend_filename} -> {oss_url}"
224
+ )
225
+
192
226
  return web.json_response(
193
- {"name": result["url"], "subfolder": subfolder, "type": image_upload_type}
227
+ {"name": oss_filename, "subfolder": oss_subfolder, "type": image_upload_type}
194
228
  )
195
229
 
196
230
 
@@ -225,3 +259,190 @@ async def upload_image(request):
225
259
  return web.Response(status=403, text="No validated key found")
226
260
  post = await request.post()
227
261
  return await upload_to_oss(post, api_key)
262
+
263
+
264
+ async def upload_mask(request):
265
+ """
266
+ 处理 mask editor 上传,将带 alpha 通道的图片上传到 OSS
267
+ """
268
+ import io
269
+ import json
270
+ import tempfile
271
+
272
+ from PIL import Image
273
+ from PIL.PngImagePlugin import PngInfo
274
+
275
+ api_key = get_api_key(request)
276
+ if not api_key:
277
+ logger.error("[UPLOAD_MASK] No API key found")
278
+ return web.Response(status=403, text="No validated key found")
279
+
280
+ post = await request.post()
281
+
282
+ # 获取上传的 mask 图片
283
+ mask_image = post.get("image")
284
+ if not (mask_image and mask_image.file):
285
+ logger.error("[UPLOAD_MASK] No image provided in request")
286
+ return web.Response(status=400, text="No image provided")
287
+
288
+ # 保存前端发送的原始文件名,用于后续缓存映射
289
+ original_frontend_filename = mask_image.filename
290
+
291
+ # 获取原始图片引用
292
+ original_ref_str = post.get("original_ref")
293
+
294
+ if not original_ref_str:
295
+ # 如果没有 original_ref,直接上传 mask
296
+ return await upload_to_oss(post, api_key)
297
+
298
+ try:
299
+ from urllib.parse import unquote
300
+
301
+ original_ref = json.loads(original_ref_str)
302
+ original_filename = original_ref.get("filename")
303
+ original_subfolder = original_ref.get("subfolder", "")
304
+
305
+ if not original_filename:
306
+ logger.error("[UPLOAD_MASK] No filename in original_ref")
307
+ return web.Response(status=400, text="No filename in original_ref")
308
+
309
+ # 构建完整的 OSS URL(类似 view_image 的逻辑)
310
+ http_prefix_options = ("http:", "https:")
311
+
312
+ if "http" in original_subfolder:
313
+ # subfolder 中包含 URL 基础路径
314
+ original_subfolder = original_subfolder[original_subfolder.find("http") :]
315
+ original_subfolder = unquote(original_subfolder).replace(
316
+ "https:/", "https://"
317
+ )
318
+ original_url = f"{original_subfolder}/{original_filename}"
319
+ elif original_filename.startswith(http_prefix_options):
320
+ # filename 本身就是完整 URL
321
+ original_url = original_filename
322
+ elif (
323
+ original_subfolder == "clipspace"
324
+ and original_filename in CLIPSPACE_TO_OSS_MAPPING
325
+ ):
326
+ # 检查缓存:如果是 clipspace 文件且在缓存中,使用缓存的 OSS URL
327
+ original_url = CLIPSPACE_TO_OSS_MAPPING[original_filename]
328
+ else:
329
+ # 不是 OSS URL 格式且不在缓存中,直接上传 mask 图片
330
+ return await upload_to_oss(post, api_key)
331
+
332
+ async with aiohttp.ClientSession() as session:
333
+ async with session.get(original_url) as resp:
334
+ if resp.status != 200:
335
+ logger.error(
336
+ f"[UPLOAD_MASK] Failed to download original image: {resp.status}"
337
+ )
338
+ return web.Response(
339
+ status=502,
340
+ text=f"Failed to download original image: {resp.status}",
341
+ )
342
+ original_image_data = await resp.read()
343
+
344
+ # 处理图片:应用 alpha 通道
345
+ with Image.open(io.BytesIO(original_image_data)) as original_pil:
346
+ # 保存元数据
347
+ metadata = PngInfo()
348
+ if hasattr(original_pil, "text"):
349
+ for key in original_pil.text:
350
+ metadata.add_text(key, original_pil.text[key])
351
+
352
+ # 转换为 RGBA
353
+ original_pil = original_pil.convert("RGBA")
354
+
355
+ # 读取上传的 mask
356
+ mask_pil = Image.open(mask_image.file).convert("RGBA")
357
+
358
+ # alpha copy - 从 mask 提取 alpha 通道并应用到原图
359
+ new_alpha = mask_pil.getchannel("A")
360
+ original_pil.putalpha(new_alpha)
361
+
362
+ # 保存到临时文件
363
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
364
+ tmp_filepath = tmp_file.name
365
+ original_pil.save(tmp_filepath, compress_level=4, pnginfo=metadata)
366
+
367
+ # 准备上传到 OSS
368
+ filename = f"clipspace-mask-{uuid.uuid4().hex[:8]}.png"
369
+ # subfolder = post.get("subfolder", "clipspace")
370
+ image_upload_type = post.get("type", "input")
371
+
372
+ try:
373
+ # 获取上传 token
374
+ oss_token = await get_upload_token(filename, api_key)
375
+
376
+ # 读取临时文件并上传
377
+ with open(tmp_filepath, "rb") as f:
378
+ # 创建一个类似 FileField 的对象
379
+ class FileFieldLike:
380
+ def __init__(self, file_obj, filename, content_type):
381
+ self.file = file_obj
382
+ self.filename = filename
383
+ self.content_type = content_type
384
+
385
+ file_field = FileFieldLike(f, filename, "image/png")
386
+ result = await upload_filefield_to_oss(file_field, oss_token)
387
+
388
+ if result["status"] != 200:
389
+ logger.error(f"[UPLOAD_MASK] Upload failed: {result.get('reason', '')}")
390
+ return web.Response(
391
+ status=result["status"], text=result.get("reason", "")
392
+ )
393
+
394
+ # Commit file
395
+ object_key = oss_token["data"]["file"]["object_key"]
396
+ await commit_file(object_key, filename, api_key)
397
+
398
+ # 将 OSS URL 拆分成 filename 和 subfolder,以便前端正确构建 /api/view 请求
399
+ oss_url = result["url"]
400
+ oss_filename = oss_url.split("/")[-1]
401
+ oss_subfolder = "/".join(oss_url.split("/")[:-1])
402
+
403
+ if original_frontend_filename:
404
+ CLIPSPACE_TO_OSS_MAPPING[original_frontend_filename] = oss_url
405
+
406
+ response_data = {
407
+ "name": oss_filename,
408
+ "subfolder": oss_subfolder,
409
+ "type": image_upload_type,
410
+ }
411
+ return web.json_response(response_data)
412
+
413
+ finally:
414
+ # 清理临时文件
415
+ if os.path.exists(tmp_filepath):
416
+ os.remove(tmp_filepath)
417
+
418
+ except Exception as e:
419
+ logger.error(f"[UPLOAD_MASK] ERROR processing mask upload: {e}", exc_info=True)
420
+ return web.Response(status=500, text=f"Error processing mask: {str(e)}")
421
+
422
+
423
+ def _should_clean(name: str) -> bool:
424
+ """True -> 包含非白名单字符;False -> 正常
425
+
426
+ 使用白名单机制:只允许安全字符(中英文、数字、下划线、连字符、点、空格)
427
+ 如果文件名包含白名单之外的字符,则需要清理
428
+ """
429
+ if not name:
430
+ return False
431
+
432
+ # 分离文件名和扩展名
433
+ if "." not in name:
434
+ return False
435
+
436
+ # 白名单:允许中英文、数字、下划线、连字符、点、空格、圆括号
437
+ safe_pattern = r"^[\w\u4e00-\u9fa5\s\-().]+$"
438
+
439
+ return not bool(re.match(safe_pattern, name))
440
+
441
+
442
+ def clean_filename(bad: str) -> (bool, str):
443
+ """对乱码串提取最后扩展名;正常串直接返回原值"""
444
+ if not _should_clean(bad):
445
+ return False, bad
446
+ # 提取最后扩展名(含点)
447
+ ext = re.search(r"(\.[\w]+)$", bad)
448
+ return True, ext.group(1) if ext else bad