bizydraft 0.1.29__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bizydraft might be problematic. Click here for more details.

@@ -0,0 +1,35 @@
1
+ import os
2
+
3
+ from loguru import logger
4
+
5
+ BIZYDRAFT_BLACKLIST_NODES = os.getenv(
6
+ "BIZYDRAFT_BLACKLIST_NODES", "blacklist_nodes.json"
7
+ )
8
+ logger.info(f"Using blacklist nodes file: {BIZYDRAFT_BLACKLIST_NODES=}")
9
+
10
+ if os.path.exists(BIZYDRAFT_BLACKLIST_NODES):
11
+ import json
12
+
13
+ with open(BIZYDRAFT_BLACKLIST_NODES, "r") as f:
14
+ BLACKLIST_NODE_CLASS = json.load(f)
15
+ else:
16
+ logger.error(f"Blacklisted nodes file {BIZYDRAFT_BLACKLIST_NODES} does not exist.")
17
+ BLACKLIST_NODE_CLASS = []
18
+
19
+
20
+ def remove_blacklisted_nodes():
21
+ try:
22
+ import nodes
23
+ except ImportError:
24
+ logger.error(
25
+ "Failed to import NODE_CLASS_MAPPINGS, ensure PYTHONPATH is set correctly. (export PYTHONPATH=$PYTHONPATH:/path/to/ComfyUI)"
26
+ )
27
+ return
28
+
29
+ for node_name in BLACKLIST_NODE_CLASS:
30
+ if node_name in nodes.NODE_CLASS_MAPPINGS:
31
+ del nodes.NODE_CLASS_MAPPINGS[node_name]
32
+ logger.info(f"Removed blacklisted node: {node_name}")
33
+ else:
34
+ pass
35
+ # logger.warning(f"Node {node_name} not found in NODE_CLASS_MAPPINGS")
bizydraft/env.py ADDED
@@ -0,0 +1,11 @@
1
+ import os
2
+
3
+ from loguru import logger
4
+
5
+ BIZYDRAFT_DOMAIN = os.getenv("BIZYDRAFT_DOMAIN", "https://api.bizyair.cn")
6
+ BIZYDRAFT_SERVER = f"{BIZYDRAFT_DOMAIN}/x/v1"
7
+
8
+ logger.info(f"{BIZYDRAFT_DOMAIN=} {BIZYDRAFT_SERVER=}")
9
+
10
+ BIZYAIR_API_KEY = os.getenv("BIZYAIR_API_KEY")
11
+ logger.info(f"{BIZYAIR_API_KEY=}")
bizydraft/hijack_nodes.py CHANGED
@@ -1,6 +1,9 @@
1
+ import re
2
+
1
3
  from loguru import logger
2
4
 
3
5
  try:
6
+ from comfy_extras.nodes_video import LoadVideo
4
7
  from nodes import NODE_CLASS_MAPPINGS, LoadImage
5
8
  except ImportError:
6
9
  logger.error(
@@ -9,6 +12,21 @@ except ImportError:
9
12
  exit(1)
10
13
 
11
14
 
15
+ class BizyDraftLoadVideo(LoadVideo):
16
+ def __init__(self, *args, **kwargs):
17
+ super().__init__(*args, **kwargs)
18
+
19
+ @classmethod
20
+ def INPUT_TYPES(cls):
21
+ return {
22
+ "required": {"file": (["choose your file"], {"video_upload": True})},
23
+ }
24
+
25
+ @classmethod
26
+ def VALIDATE_INPUTS(s, *args, **kwargs):
27
+ return True
28
+
29
+
12
30
  class BizyDraftLoadImage(LoadImage):
13
31
  def __init__(self, *args, **kwargs):
14
32
  super().__init__(*args, **kwargs)
@@ -16,17 +34,59 @@ class BizyDraftLoadImage(LoadImage):
16
34
  @classmethod
17
35
  def INPUT_TYPES(s):
18
36
  return {
19
- "required": {"image": ([], {"image_upload": True})},
37
+ "required": {"image": (["choose your file"], {"image_upload": True})},
20
38
  }
21
39
 
22
40
  @classmethod
23
- def VALIDATE_INPUTS(s, image, *args, **kwargs):
41
+ def VALIDATE_INPUTS(s, *args, **kwargs):
24
42
  return True
25
43
 
26
44
 
45
+ CLASS_PATCHES = {
46
+ # "LoadImage": BizyDraftLoadImage,
47
+ # "LoadVideo": BizyDraftLoadVideo,
48
+ }
49
+
50
+ DATA_LOAD_CLASSES = [
51
+ "LoadImage",
52
+ "LoadVideo",
53
+ "LoadImageMask",
54
+ "LoadAudio",
55
+ "Load3D",
56
+ ]
57
+
58
+
27
59
  def hijack_nodes():
28
- if "LoadImage" in NODE_CLASS_MAPPINGS:
29
- del NODE_CLASS_MAPPINGS["LoadImage"]
30
- NODE_CLASS_MAPPINGS["LoadImage"] = BizyDraftLoadImage
60
+ def _hijack_node(node_name, new_class):
61
+ if node_name in NODE_CLASS_MAPPINGS:
62
+ logger.warning(
63
+ f"Node {node_name} already exists, replacing with {new_class.__name__}"
64
+ )
65
+ NODE_CLASS_MAPPINGS[node_name] = new_class
66
+
67
+ # 特例情况,用手写的 class 替换
68
+ for node_name, new_class in CLASS_PATCHES.items():
69
+ _hijack_node(node_name, new_class)
70
+
71
+ # 通用情况,正则匹配后,打通用patch、替换
72
+ for node_name, base_class in NODE_CLASS_MAPPINGS.items():
73
+ regex = r"^(?!BizyAir_)\w+.*Loader.*"
74
+ match = re.match(regex, node_name, re.IGNORECASE)
75
+ if (match and (node_name not in CLASS_PATCHES)) or (
76
+ node_name in DATA_LOAD_CLASSES
77
+ ):
78
+ logger.debug(f"Creating patched class for {node_name}")
79
+ patched_class = create_patched_class(base_class)
80
+ NODE_CLASS_MAPPINGS[node_name] = patched_class
81
+
82
+
83
+ def create_patched_class(base_class, validate_inputs_func=None):
84
+ class PatchedClass(base_class):
85
+ pass
86
+
87
+ if validate_inputs_func:
88
+ PatchedClass.VALIDATE_INPUTS = classmethod(validate_inputs_func)
89
+ else:
90
+ PatchedClass.VALIDATE_INPUTS = classmethod(lambda cls, *a, **k: True)
31
91
 
32
- logger.info("[BizyDraft] Hijacked LoadImage node to BizyDraftLoadImage.")
92
+ return PatchedClass
@@ -1,15 +1,11 @@
1
- import math
2
- import os
3
- import asyncio
4
- import mimetypes
5
- import uuid
6
-
7
- from aiohttp import web, ClientSession, ClientTimeout
1
+ from aiohttp import web
8
2
  from loguru import logger
9
3
 
4
+ from bizydraft.oss_utils import upload_image
5
+ from bizydraft.patch_handlers import post_prompt, view_image
6
+
10
7
  try:
11
8
  from server import PromptServer
12
- import execution
13
9
 
14
10
  comfy_server = PromptServer.instance
15
11
  except ImportError:
@@ -19,181 +15,88 @@ except ImportError:
19
15
  exit(1)
20
16
 
21
17
 
22
- BIZYDRAFT_MAX_FILE_SIZE = int(
23
- os.getenv("BIZYDRAFT_MAX_FILE_SIZE", 100 * 1024 * 1024)
24
- ) # 100MB
25
- BIZYDRAFT_REQUEST_TIMEOUT = int(
26
- os.getenv("BIZYDRAFT_REQUEST_TIMEOUT", 20 * 60)
27
- ) # 20分钟
28
- BIZYDRAFT_CHUNK_SIZE = int(os.getenv("BIZYDRAFT_CHUNK_SIZE", 1024 * 16)) # 16KB
29
-
30
-
31
- async def view_image(request, old_handler):
32
- logger.debug(f"Received request for /view with query: {request.rel_url.query}")
33
- if "filename" not in request.rel_url.query:
34
- logger.warning("'filename' not provided in query string, returning 404")
35
- return web.Response(status=404, text="'filename' not provided in query string")
36
-
37
- filename = request.rel_url.query["filename"]
38
- subfolder = request.rel_url.query.get("subfolder", "")
39
-
40
- if not filename.startswith(("http://", "https://")) and not subfolder.startswith(
41
- ("http://", "https://")
42
- ):
43
- logger.warning(f"Invalid filename format: {filename}, only URLs are supported")
44
- return web.Response(
45
- status=400, text="Invalid filename format(only url supported)"
46
- )
47
-
48
- try:
49
- filename = (
50
- f"{subfolder}/{filename}"
51
- if not filename.startswith(("http://", "https://"))
52
- else filename
53
- ) # preview 3d request: https://host:port/api/view?filename=filename.glb&type=output&subfolder=https://bizyair-dev.oss-cn-shanghai.aliyuncs.com/outputs&rand=0.5763957215362988
54
-
55
- content_type, _ = mimetypes.guess_type(filename)
56
- if content_type and any(x in content_type for x in ("image", "video")):
57
- return web.HTTPFound(filename)
58
-
59
- timeout = ClientTimeout(total=BIZYDRAFT_REQUEST_TIMEOUT)
60
- async with ClientSession(timeout=timeout) as session:
61
- async with session.get(filename) as resp:
62
- resp.raise_for_status()
63
- content_length = int(resp.headers.get("Content-Length", 0))
64
- if content_length > BIZYDRAFT_MAX_FILE_SIZE:
65
- logger.warning(
66
- f"File size {human_readable_size(content_length)} exceeds limit {human_readable_size(BIZYDRAFT_MAX_FILE_SIZE)}"
67
- )
68
- return web.Response(
69
- status=413,
70
- text=f"File size exceeds limit ({human_readable_size(BIZYDRAFT_MAX_FILE_SIZE)})",
71
- )
72
-
73
- headers = {
74
- "Content-Disposition": f'attachment; filename="{uuid.uuid4()}"',
75
- "Content-Type": "application/octet-stream",
76
- }
77
-
78
- proxy_response = web.StreamResponse(headers=headers)
79
- await proxy_response.prepare(request)
80
-
81
- total_bytes = 0
82
- async for chunk in resp.content.iter_chunked(BIZYDRAFT_CHUNK_SIZE):
83
- total_bytes += len(chunk)
84
- if total_bytes > BIZYDRAFT_MAX_FILE_SIZE:
85
- await proxy_response.write(b"")
86
- return web.Response(
87
- status=413,
88
- text=f"File size exceeds limit during streaming ({human_readable_size(BIZYDRAFT_MAX_FILE_SIZE)})",
89
- )
90
- await proxy_response.write(chunk)
91
-
92
- return proxy_response
93
-
94
- except asyncio.TimeoutError:
95
- return web.Response(
96
- status=504,
97
- text=f"Request timed out (max {BIZYDRAFT_REQUEST_TIMEOUT//60} minutes)",
98
- )
99
- except Exception as e:
100
- return web.Response(
101
- status=502, text=f"Failed to fetch remote resource: {str(e)}"
102
- )
103
-
104
-
105
- def human_readable_size(size_bytes):
106
- if size_bytes == 0:
107
- return "0B"
108
- size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
109
- i = int(math.floor(math.log(size_bytes, 1024)))
110
- p = math.pow(1024, i)
111
- s = round(size_bytes / p, 2)
112
- return f"{s} {size_name[i]}"
113
-
114
-
115
- async def post_prompt(request):
116
- json_data = await request.json()
117
- logger.debug(f"Received POST request to /prompt with data")
118
- json_data = comfy_server.trigger_on_prompt(json_data)
119
-
120
- if "prompt" in json_data:
121
- prompt = json_data["prompt"]
122
- valid = execution.validate_prompt(prompt)
123
- if valid[0]:
124
- response = {
125
- "prompt_id": None,
126
- "number": None,
127
- "node_errors": valid[3],
128
- }
129
- return web.json_response(response)
130
- else:
131
- return web.json_response(
132
- {"error": valid[1], "node_errors": valid[3]}, status=400
133
- )
134
- else:
135
- error = {
136
- "type": "no_prompt",
137
- "message": "No prompt provided",
138
- "details": "No prompt provided",
139
- "extra_info": {},
18
+ def hijack_routes_pre_add_routes():
19
+ app = comfy_server.app
20
+
21
+ async def custom_business_middleware(app, handler):
22
+ routes_patch = {
23
+ ("/view", "GET"): view_image,
24
+ ("/prompt", "POST"): post_prompt,
25
+ ("/upload/image", "POST"): upload_image,
26
+ # /api alias
27
+ ("/api/view", "GET"): view_image,
28
+ ("/api/prompt", "POST"): post_prompt,
29
+ ("/api/upload/image", "POST"): upload_image,
140
30
  }
141
- return web.json_response({"error": error, "node_errors": {}}, status=400)
142
-
143
-
144
- def hijack_routes():
145
- routes = comfy_server.routes
146
- for idx, route in enumerate(routes._items):
147
- if route.path == "/view" and route.method == "GET":
148
- old_handler = route.handler
149
-
150
- async def new_handler(request):
151
- return await view_image(request, old_handler)
152
-
153
- routes._items[idx] = web.get("/view", new_handler)
154
- routes._items[idx].kwargs.clear()
155
- logger.info("Hijacked /view route to handle image, video and 3D streaming")
156
- break
157
- for idx, route in enumerate(routes._items):
158
- if route.path == "/prompt" and route.method == "POST":
159
- routes._items[idx] = web.post("/prompt", post_prompt)
160
- logger.info(
161
- "Hijacked /prompt route to handle prompt validation but not execution"
162
- )
163
- break
164
-
165
- routes = comfy_server.routes
166
- white_list = [
167
- # 劫持改造过的
168
- "/prompt",
169
- "/view",
170
- # 原生的
171
- "/",
172
- "/ws",
173
- "/extensions",
174
- "/object_info",
175
- "/object_info/{node_class}",
176
- ]
177
-
178
- async def null_handler(request):
179
- return web.Response(
180
- status=403,
181
- text="Access Forbidden: You do not have permission to access this resource.",
182
- )
183
-
184
- for idx, route in enumerate(routes._items):
185
- if (route.path not in white_list) and ("/bizyair" not in route.path):
186
- if route.method == "GET":
187
- logger.info(f"hijiack to null: {route.path}, {route.method}")
188
- routes._items[idx] = web.get(route.path, null_handler)
189
- routes._items[idx].kwargs.clear()
190
- elif route.method == "POST":
191
- logger.info(f"hijiack to null: {route.path}, {route.method}")
192
- routes._items[idx] = web.post(route.path, null_handler)
193
- routes._items[idx].kwargs.clear()
194
- else:
195
- logger.warning(
196
- f"neither GET or POST, passed: {route.path}, {route.method}"
31
+
32
+ async def middleware_handler(request):
33
+ if ((request.path, request.method) in routes_patch) or (
34
+ (
35
+ "/api" + request.path,
36
+ request.method,
37
+ )
38
+ in routes_patch
39
+ ):
40
+ logger.debug(
41
+ f"Custom handler for {request.path} with method {request.method}"
42
+ )
43
+ new_handler = routes_patch[(request.path, request.method)]
44
+ return await new_handler(request)
45
+
46
+ return await handler(request)
47
+
48
+ return middleware_handler
49
+
50
+ async def access_control_middleware(app, handler):
51
+ base_white_list = [
52
+ "/prompt",
53
+ "/view",
54
+ "/upload/image",
55
+ "/",
56
+ "/ws",
57
+ "/extensions",
58
+ "/object_info",
59
+ "/object_info/{node_class}",
60
+ "/assets",
61
+ "/users",
62
+ "/settings",
63
+ "/i18n",
64
+ "/userdata",
65
+ ]
66
+
67
+ white_list = [
68
+ *base_white_list,
69
+ *(f"/api{path}" for path in base_white_list if path not in ("/", "/ws")),
70
+ ]
71
+
72
+ async def middleware_handler(request):
73
+ is_allowed = any(
74
+ request.path == path
75
+ or (
76
+ request.path.startswith(path.replace("{node_class}", ""))
77
+ and path != "/"
78
+ and path != "/api/"
197
79
  )
198
- else:
199
- logger.warning(f"passed directly: {route.path}, {route.method}")
80
+ or ".css" in request.path
81
+ for path in white_list
82
+ )
83
+ # Access control check for /assets/GraphView-Y3xu9HJK.js.map: allowed
84
+ logger.debug(
85
+ f"Access control check for {request.path}: {'allowed' if is_allowed else 'blocked'}"
86
+ )
87
+ if not is_allowed and "/bizyair" not in request.path:
88
+ logger.info(f"Blocked access to: {request.path}")
89
+ return web.Response(status=403, text="Access Forbidden")
90
+
91
+ return await handler(request)
92
+
93
+ return middleware_handler
94
+
95
+ app.middlewares.extend([custom_business_middleware])
96
+
97
+ logger.info("Optimized middleware setup complete.")
98
+
99
+
100
+ def hijack_routes_post_add_routes():
101
+ # do someting after all routes are set up
102
+ logger.info("Post-add routes hijack complete, all routes are set up.")
bizydraft/oss_utils.py ADDED
@@ -0,0 +1,227 @@
1
+ import base64
2
+ import json
3
+ import os
4
+ from http.cookies import SimpleCookie
5
+ from time import time
6
+ from typing import Any, Dict
7
+
8
+ import aiohttp
9
+ import oss2
10
+ from aiohttp import web
11
+ from cryptography.hazmat.backends import default_backend
12
+ from cryptography.hazmat.primitives import serialization
13
+ from cryptography.hazmat.primitives.asymmetric import padding
14
+ from loguru import logger
15
+ from werkzeug.utils import secure_filename
16
+
17
+ from bizydraft.env import BIZYAIR_API_KEY, BIZYDRAFT_SERVER
18
+
19
+ private_key_pem = os.getenv(
20
+ "RSA_PRIVATE_KEY",
21
+ """-----BEGIN RSA PRIVATE KEY-----
22
+ MIIEpAIBAAKCAQEAuROqSPqhJlpv5R1wDl2sGuyA59Hf1y+VLR0w3cCyM6/WEQ4b
23
+ +TBFfM5HeCLc2YVDybc0ZJxsEqCXKpTweMlQg063ECK4961icF3xL8DRfXkwpUFJ
24
+ CfG24tLdXwWK3CJDb4RqGSyZm2F0mE/kqMpidsoJrXy24B4iSJrk5DGRSL1dChiL
25
+ vuvNNWPtdDHylormBxz2f8ePvvO8v/qsN+Xpxt7YirqWe5P2VavqMv66H7tItcZj
26
+ LMIFF2kV8rYF94tk6/jL/Hb7gG7ujG2p5ikG+sNhrzn0TsWdh97S6F9kTC5D1IkM
27
+ TXEhedXN1CQ4Z35TvIHxU1DBiax8t8mq/lF3rwIDAQABAoIBAQCvR8SaYWOF41jd
28
+ 8MdTk7uPtDVRWB9auSHbHC5PllQvR3TBqk8r7V+iF+rwCHSJPgE5ZV0lfE+ORLFm
29
+ DrDAdEjgUwhlK71qNLdqHE50H3VIFCLSH8aAuH+wymwFtkYQvhKH5yxksyy3T9EQ
30
+ /3lbsnEWd7o6qEa6c0+c27WzuI4UCEdQpeSG+5UYHykC/Rdfc25wXTjeK8QSUcw4
31
+ Xlbt1O7omKAdrbSwbTValfqoUpKlAZ55nvJGqHnBWE5cvx9UHPooGWMUpq8004xb
32
+ sU42q2mDSEkRNE+irvc1FInxJ+gDk51Qem1r4Uy4pUnzyngXBFrp2XQazE/aVZSr
33
+ JG9fxfmBAoGBAN66SwUJg5LsRBFlPZTXzTLTzwXqm8e9ipKfe9dkuXX5Mx9mEbTd
34
+ mjZL1pHX0+YZAQu2V6dekvABFwEOnlvm0l0TopR1yyzA7PZK5ZUF0Tb9binLobO1
35
+ 8G01Cp2jmrlarRGbwRdr9YXQ4ZKbvKUMevzYMIvPUFIkKQxHY/+x2IkRAoGBANS5
36
+ gDHwJ/voZTqqcJpn916MwhjsQvOmlNdDzqKe5FYd/DKb1X+tDAXK/sAFMOMj5Row
37
+ qCWu5G1T4f7DRY/BDXEU4u6YqcdokXTeZ45Z+fAZotcSit50T9gGoCTx8MMdeTUb
38
+ y4uY6cvCnd6x5PYOoBRL9QQX/ML7LX0S1Q2xL/S/AoGAfOQ/nuJ32hIMFSkNAAKG
39
+ eOLWan3kvnslUhSF8AD2EhYbuZaVhTLh/2JFPmCk3JjWwkeMHTjl8hjaWmhlGilz
40
+ emfBObhXpo/EEFNtK0QozcoMVPlvggMaf1JH0p9j6l3TQFVzT/vkoBXB92DGxlIa
41
+ QN/FURB9/KF0NwNtKnsCbdECgYARgUZUVa/koeYaosXrXtzTUf/y7xY/WJjs8e6C
42
+ IVMm5wbG3139SK8xltfJ02OHfX+v3QspNrAjcwCo50bFIpzJjm9yNOvbtfYqSNb6
43
+ ttrDcEifLC5zSdz8KOdqwuIOHFHKFgR081th4hz9o2P0/5UatnluIc8x+Ftw7GjN
44
+ 3KPWnwKBgQCrt3Zs5eqDvFRmuB6d1uMFhAPqjrxnvdl3xhONnIopM4A62FLW4AoI
45
+ jpIg9K5YWK3nrROMWINH286CewjHXu2fhkhk1VPKo6Mz8bTqUoFZkI8cap/wfyqv
46
+ BMb5TNmgx+tp12pH2VNc/kC5c+GKi8VnNYx8K6gRzpZIIDfSUR10RQ==
47
+ -----END RSA PRIVATE KEY-----""",
48
+ )
49
+
50
+
51
+ class TokenExpiredError(Exception):
52
+ """Exception raised when the token has expired."""
53
+
54
+ pass
55
+
56
+
57
+ def decrypt(encrypted_message):
58
+ try:
59
+ if not encrypted_message or not isinstance(encrypted_message, str):
60
+ raise ValueError("无效的加密消息")
61
+
62
+ private_key = serialization.load_pem_private_key(
63
+ private_key_pem.encode(), password=None, backend=default_backend()
64
+ )
65
+
66
+ encrypted_bytes = base64.b64decode(encrypted_message)
67
+ decrypted_bytes = private_key.decrypt(encrypted_bytes, padding.PKCS1v15())
68
+ decrypted_str = decrypted_bytes.decode("utf-8")
69
+
70
+ parsed_data = json.loads(decrypted_str)
71
+
72
+ now = int(time() * 1000) # Convert to milliseconds to match JavaScript
73
+ if now - parsed_data["timestamp"] > parsed_data["expiresIn"]:
74
+ raise TokenExpiredError("Token已过期")
75
+
76
+ return parsed_data["data"]
77
+
78
+ except Exception as error:
79
+ logger.error(
80
+ "解密失败:",
81
+ {
82
+ "message": str(error),
83
+ "input": encrypted_message[:100] + "..." if encrypted_message else None,
84
+ },
85
+ )
86
+ return None
87
+
88
+
89
+ async def get_upload_token(
90
+ filename: str,
91
+ api_key: str,
92
+ ) -> Dict[str, Any]:
93
+ url = f"{BIZYDRAFT_SERVER}/upload/token?file_name={filename}&file_type=inputs"
94
+
95
+ headers = {
96
+ "Content-Type": "application/json",
97
+ "Authorization": f"Bearer {api_key}",
98
+ }
99
+
100
+ async with aiohttp.ClientSession() as session:
101
+ async with session.get(url, headers=headers) as response:
102
+ if response.status == 200:
103
+ return await response.json()
104
+ else:
105
+ response.raise_for_status()
106
+
107
+
108
+ async def upload_filefield_to_oss(file_field, token_data):
109
+ file_info = token_data["data"]["file"]
110
+ storage_info = token_data["data"]["storage"]
111
+
112
+ auth = oss2.StsAuth(
113
+ file_info["access_key_id"],
114
+ file_info["access_key_secret"],
115
+ file_info["security_token"],
116
+ )
117
+ bucket = oss2.Bucket(
118
+ auth, f"http://{storage_info['endpoint']}", storage_info["bucket"]
119
+ )
120
+
121
+ try:
122
+ result = bucket.put_object(
123
+ file_info["object_key"], # OSS存储路径
124
+ file_field.file, # 直接使用文件流对象
125
+ headers={
126
+ "Content-Type": file_field.content_type, # 保留原始MIME类型
127
+ "Content-Disposition": f"attachment; filename={secure_filename(file_field.filename)}",
128
+ },
129
+ )
130
+
131
+ if result.status == 200:
132
+ return {
133
+ "status": result.status,
134
+ "url": f"https://{storage_info['bucket']}.{storage_info['endpoint']}/{file_info['object_key']}",
135
+ }
136
+ else:
137
+ return {
138
+ "status": result.status,
139
+ "reason": f"OSS返回状态码: {result.status}",
140
+ }
141
+
142
+ except Exception as e:
143
+ return {"status": 500, "reason": str(e)}
144
+
145
+
146
+ async def commit_file(object_key: str, filename: str, api_key: str):
147
+ url = f"{BIZYDRAFT_SERVER}/input_resource/commit"
148
+ headers = {
149
+ "Content-Type": "application/json",
150
+ "Authorization": f"Bearer {api_key}",
151
+ }
152
+ payload = {
153
+ "object_key": object_key,
154
+ "name": filename,
155
+ }
156
+
157
+ async with aiohttp.ClientSession() as session:
158
+ async with session.post(url, headers=headers, json=payload) as response:
159
+ if response.status == 200:
160
+ return await response.json()
161
+ response.raise_for_status()
162
+
163
+
164
+ async def upload_to_oss(post, api_key: str):
165
+ from bizydraft.oss_utils import get_upload_token
166
+
167
+ image = post.get("image")
168
+ overwrite = post.get("overwrite")
169
+ image_upload_type = post.get("type")
170
+ subfolder = post.get("subfolder", "")
171
+ logger.debug(f"{image=}, {overwrite=}, {image_upload_type=}, {subfolder=}")
172
+
173
+ if not (image and image.file):
174
+ return web.Response(status=400)
175
+
176
+ filename = image.filename
177
+ if not filename:
178
+ return web.Response(status=400)
179
+
180
+ oss_token = await get_upload_token(filename, api_key)
181
+ result = await upload_filefield_to_oss(image, oss_token)
182
+ if result["status"] != 200:
183
+ return web.Response(status=result["status"], text=result.get("reason", ""))
184
+ logger.debug(f"upload file: {result['url']}")
185
+ try:
186
+ object_key = oss_token["data"]["file"]["object_key"]
187
+ await commit_file(object_key, filename, api_key)
188
+ logger.debug(f"sucess: commit {filename=}")
189
+ except Exception as e:
190
+ logger.error(f"Commit file failed: {e}")
191
+ return web.Response(status=500, text=str(e))
192
+ return web.json_response(
193
+ {"name": result["url"], "subfolder": subfolder, "type": image_upload_type}
194
+ )
195
+
196
+
197
+ def get_api_key(request):
198
+ if BIZYAIR_API_KEY:
199
+ return BIZYAIR_API_KEY
200
+
201
+ cookies = request.headers.get("Cookie")
202
+ if not cookies:
203
+ return None
204
+
205
+ try:
206
+ cookie = SimpleCookie()
207
+ cookie.load(cookies)
208
+
209
+ bizy_token = cookie.get("bizy_token").value if "bizy_token" in cookie else None
210
+
211
+ decrypted_token = decrypt(bizy_token)
212
+ api_key = decrypted_token if decrypted_token else None
213
+
214
+ except Exception as e:
215
+ logger.error(f"error happens when get_api_key from cookies: {e}")
216
+ return None
217
+
218
+ return api_key
219
+
220
+
221
+ async def upload_image(request):
222
+ logger.debug(f"Received request to upload image: {request.path}")
223
+ api_key = get_api_key(request)
224
+ if not api_key:
225
+ return web.Response(status=403, text="No validated key found")
226
+ post = await request.post()
227
+ return await upload_to_oss(post, api_key)