ima-python-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ima_python_sdk-0.1.0.dist-info/METADATA +93 -0
- ima_python_sdk-0.1.0.dist-info/RECORD +15 -0
- ima_python_sdk-0.1.0.dist-info/WHEEL +5 -0
- ima_python_sdk-0.1.0.dist-info/entry_points.txt +2 -0
- ima_python_sdk-0.1.0.dist-info/licenses/LICENSE +21 -0
- ima_python_sdk-0.1.0.dist-info/top_level.txt +1 -0
- ima_sdk/__init__.py +90 -0
- ima_sdk/cli.py +447 -0
- ima_sdk/client.py +268 -0
- ima_sdk/cos_uploader.py +158 -0
- ima_sdk/file_checker.py +226 -0
- ima_sdk/knowledge_base.py +591 -0
- ima_sdk/logger.py +89 -0
- ima_sdk/notes.py +343 -0
- ima_sdk/types.py +424 -0
ima_sdk/client.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
"""IMA OpenAPI 核心客户端。
|
|
2
|
+
|
|
3
|
+
提供 HTTP 请求封装、认证管理和统一错误处理。
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
import shlex
|
|
11
|
+
import ssl
|
|
12
|
+
import sys
|
|
13
|
+
import time
|
|
14
|
+
import urllib.error
|
|
15
|
+
import urllib.request
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any, Callable, Dict, Generator, List, Optional
|
|
18
|
+
|
|
19
|
+
from . import logger
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ImaApiError(Exception):
|
|
23
|
+
"""IMA API 调用异常。"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, retcode: int, errmsg: str, data: Any = None):
|
|
26
|
+
self.retcode = retcode
|
|
27
|
+
self.errmsg = errmsg
|
|
28
|
+
self.data = data
|
|
29
|
+
super().__init__(f"[{retcode}] {errmsg}")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ImaClient:
|
|
33
|
+
"""IMA OpenAPI 核心客户端。
|
|
34
|
+
|
|
35
|
+
凭证来源优先级:
|
|
36
|
+
1. 构造参数
|
|
37
|
+
2. 环境变量 IMA_OPENAPI_CLIENTID / IMA_OPENAPI_APIKEY
|
|
38
|
+
3. 配置文件 ~/.config/ima/client_id / ~/.config/ima/api_key
|
|
39
|
+
|
|
40
|
+
用法::
|
|
41
|
+
|
|
42
|
+
client = ImaClient()
|
|
43
|
+
data = client.post("openapi/wiki/v1/get_knowledge_base", {"ids": ["xxx"]})
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
BASE_URL = "https://ima.qq.com"
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
client_id: Optional[str] = None,
|
|
51
|
+
api_key: Optional[str] = None,
|
|
52
|
+
base_url: Optional[str] = None,
|
|
53
|
+
*,
|
|
54
|
+
curl_mode: bool = False,
|
|
55
|
+
):
|
|
56
|
+
self.base_url = (base_url or self.BASE_URL).rstrip("/")
|
|
57
|
+
self.curl_mode = curl_mode
|
|
58
|
+
self._client_id = client_id or self._resolve_credential("client_id")
|
|
59
|
+
self._api_key = api_key or self._resolve_credential("api_key")
|
|
60
|
+
|
|
61
|
+
if not self._client_id or not self._api_key:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
"缺少 IMA 凭证。请通过以下方式之一提供:\n"
|
|
64
|
+
" 1. ImaClient(client_id=..., api_key=...)\n"
|
|
65
|
+
" 2. 环境变量 IMA_OPENAPI_CLIENTID / IMA_OPENAPI_APIKEY\n"
|
|
66
|
+
" 3. 配置文件 ~/.config/ima/client_id / ~/.config/ima/api_key"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def _resolve_credential(name: str) -> Optional[str]:
|
|
71
|
+
"""按优先级解析凭证:环境变量 > 配置文件。"""
|
|
72
|
+
# 环境变量
|
|
73
|
+
env_key = f"IMA_OPENAPI_{name.upper()}"
|
|
74
|
+
# client_id -> IMA_OPENAPI_CLIENTID, api_key -> IMA_OPENAPI_APIKEY
|
|
75
|
+
env_key = env_key.replace("CLIENT_ID", "CLIENTID").replace("API_KEY", "APIKEY")
|
|
76
|
+
value = os.environ.get(env_key)
|
|
77
|
+
if value:
|
|
78
|
+
return value.strip()
|
|
79
|
+
|
|
80
|
+
# 配置文件
|
|
81
|
+
config_path = Path.home() / ".config" / "ima" / name
|
|
82
|
+
if config_path.is_file():
|
|
83
|
+
return config_path.read_text(encoding="utf-8").strip()
|
|
84
|
+
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
def _build_headers(self) -> Dict[str, str]:
|
|
88
|
+
return {
|
|
89
|
+
"ima-openapi-clientid": self._client_id,
|
|
90
|
+
"ima-openapi-apikey": self._api_key,
|
|
91
|
+
"Content-Type": "application/json",
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
def _print_curl(self, method: str, url: str, headers: Dict[str, str], payload: bytes) -> None:
|
|
95
|
+
"""输出等价的 curl 命令到 stderr。"""
|
|
96
|
+
parts = ["curl", "-X", method, shlex.quote(url)]
|
|
97
|
+
for k, v in headers.items():
|
|
98
|
+
parts.append(f"-H {shlex.quote(f'{k}: {v}')}")
|
|
99
|
+
if payload:
|
|
100
|
+
parts.append(f"-d {shlex.quote(payload.decode('utf-8'))}")
|
|
101
|
+
print(" \\\n ".join(parts), file=sys.stderr)
|
|
102
|
+
|
|
103
|
+
# 连接级别重试配置(应对 SSL 握手瞬断)
|
|
104
|
+
_CONNECT_RETRIES = 3
|
|
105
|
+
_CONNECT_RETRY_DELAY = 1 # 秒
|
|
106
|
+
|
|
107
|
+
def post(self, path: str, body: Optional[dict] = None) -> dict:
|
|
108
|
+
"""发送 POST 请求并返回响应 JSON。
|
|
109
|
+
|
|
110
|
+
API 直接返回业务数据 JSON,通过 HTTP 状态码判断成败。
|
|
111
|
+
对 SSL 握手瞬断自动重试。
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
path: API 路径,例如 ``openapi/wiki/v1/get_knowledge_base``
|
|
115
|
+
body: 请求体字典
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
API 返回的 JSON 字典
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
ImaApiError: HTTP 错误或响应体包含错误信息
|
|
122
|
+
urllib.error.URLError: 网络异常
|
|
123
|
+
"""
|
|
124
|
+
url = f"{self.base_url}/{path.lstrip('/')}"
|
|
125
|
+
req_body = body or {}
|
|
126
|
+
payload = json.dumps(req_body, ensure_ascii=False).encode("utf-8")
|
|
127
|
+
logger.log_request("POST", url, req_body)
|
|
128
|
+
if self.curl_mode:
|
|
129
|
+
self._print_curl("POST", url, self._build_headers(), payload)
|
|
130
|
+
|
|
131
|
+
raw: str = ""
|
|
132
|
+
status: int = 0
|
|
133
|
+
for attempt in range(self._CONNECT_RETRIES):
|
|
134
|
+
req = urllib.request.Request(
|
|
135
|
+
url,
|
|
136
|
+
data=payload,
|
|
137
|
+
headers=self._build_headers(),
|
|
138
|
+
method="POST",
|
|
139
|
+
)
|
|
140
|
+
try:
|
|
141
|
+
with urllib.request.urlopen(req) as resp:
|
|
142
|
+
raw = resp.read().decode("utf-8")
|
|
143
|
+
status = resp.status
|
|
144
|
+
break
|
|
145
|
+
except urllib.error.HTTPError as e:
|
|
146
|
+
raw = e.read().decode("utf-8", errors="replace")
|
|
147
|
+
logger.log_error(f"HTTP {e.code}: {raw}")
|
|
148
|
+
raise ImaApiError(
|
|
149
|
+
retcode=e.code,
|
|
150
|
+
errmsg=raw,
|
|
151
|
+
) from e
|
|
152
|
+
except urllib.error.URLError as e:
|
|
153
|
+
if isinstance(e.reason, ssl.SSLEOFError) and attempt < self._CONNECT_RETRIES - 1:
|
|
154
|
+
logger.log(f" SSL 连接中断,{self._CONNECT_RETRY_DELAY}s 后重试 ({attempt + 1}/{self._CONNECT_RETRIES})")
|
|
155
|
+
time.sleep(self._CONNECT_RETRY_DELAY)
|
|
156
|
+
continue
|
|
157
|
+
raise
|
|
158
|
+
|
|
159
|
+
result = json.loads(raw)
|
|
160
|
+
|
|
161
|
+
# 兼容三种响应格式:
|
|
162
|
+
# 1. 有 retcode 包装: {"retcode": 0, "errmsg": "ok", "data": {...}}
|
|
163
|
+
# 2. 有 code 包装: {"code": 0, "msg": "success", "data": {...}}
|
|
164
|
+
# 3. 直接返回数据: {"addable_knowledge_base_list": [...], "is_end": true}
|
|
165
|
+
if "retcode" in result:
|
|
166
|
+
retcode = result["retcode"]
|
|
167
|
+
if retcode != 0:
|
|
168
|
+
logger.log_error(f"API 错误 [{retcode}] {result.get('errmsg', '')}: {raw}")
|
|
169
|
+
raise ImaApiError(
|
|
170
|
+
retcode=retcode,
|
|
171
|
+
errmsg=result.get("errmsg", "未知错误"),
|
|
172
|
+
data=result.get("data"),
|
|
173
|
+
)
|
|
174
|
+
logger.log_response(status, f"retcode=0, data keys={list(result.get('data', {}).keys())}")
|
|
175
|
+
return result.get("data", {})
|
|
176
|
+
|
|
177
|
+
if "code" in result:
|
|
178
|
+
code = result["code"]
|
|
179
|
+
if code != 0:
|
|
180
|
+
logger.log_error(f"API 错误 [{code}] {result.get('msg', '')}: {raw}")
|
|
181
|
+
raise ImaApiError(
|
|
182
|
+
retcode=code,
|
|
183
|
+
errmsg=result.get("msg", "未知错误"),
|
|
184
|
+
data=result.get("data"),
|
|
185
|
+
)
|
|
186
|
+
logger.log_response(status, f"code=0, data keys={list(result.get('data', {}).keys())}")
|
|
187
|
+
return result.get("data", {})
|
|
188
|
+
|
|
189
|
+
logger.log_response(status, f"keys={list(result.keys())}")
|
|
190
|
+
return result
|
|
191
|
+
|
|
192
|
+
def post_raw(self, path: str, body: Optional[dict] = None) -> dict:
|
|
193
|
+
"""发送 POST 请求并返回完整响应(不捕获 HTTP 错误)。"""
|
|
194
|
+
return self.post(path, body)
|
|
195
|
+
|
|
196
|
+
def paginate(
|
|
197
|
+
self,
|
|
198
|
+
path: str,
|
|
199
|
+
body: dict,
|
|
200
|
+
items_key: str,
|
|
201
|
+
*,
|
|
202
|
+
cursor_field: str = "cursor",
|
|
203
|
+
next_cursor_field: str = "next_cursor",
|
|
204
|
+
is_end_field: str = "is_end",
|
|
205
|
+
limit: int = 50,
|
|
206
|
+
limit_field: str = "limit",
|
|
207
|
+
initial_cursor: str = "",
|
|
208
|
+
) -> Generator[dict, None, None]:
|
|
209
|
+
"""游标翻页迭代器,逐页 yield data 字典。
|
|
210
|
+
|
|
211
|
+
自动处理游标翻页逻辑,每次 yield 一页的 data 字典。
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
path: API 路径
|
|
215
|
+
body: 请求体基础参数(不包含 cursor/limit)
|
|
216
|
+
items_key: data 中包含列表数据的字段名
|
|
217
|
+
cursor_field: 请求体中游标参数名
|
|
218
|
+
next_cursor_field: 返回中下一页游标字段名
|
|
219
|
+
is_end_field: 返回中是否结束字段名
|
|
220
|
+
limit: 每页数量
|
|
221
|
+
limit_field: 请求体中 limit 参数名
|
|
222
|
+
initial_cursor: 初始游标值
|
|
223
|
+
"""
|
|
224
|
+
cursor = initial_cursor
|
|
225
|
+
page = 0
|
|
226
|
+
while True:
|
|
227
|
+
page += 1
|
|
228
|
+
req_body = {**body, cursor_field: cursor, limit_field: limit}
|
|
229
|
+
logger.log_step(f"翻页请求 (第 {page} 页)", f"cursor={cursor!r}")
|
|
230
|
+
data = self.post(path, req_body)
|
|
231
|
+
items_count = len(data.get(items_key, []))
|
|
232
|
+
logger.log(f" 本页获取 {items_count} 条 {items_key}")
|
|
233
|
+
yield data
|
|
234
|
+
|
|
235
|
+
if data.get(is_end_field, True):
|
|
236
|
+
logger.log(f" 已到最后一页 (共 {page} 页)")
|
|
237
|
+
break
|
|
238
|
+
next_cursor = data.get(next_cursor_field, "")
|
|
239
|
+
if not next_cursor:
|
|
240
|
+
logger.log(f" 无下一页游标,结束翻页 (共 {page} 页)")
|
|
241
|
+
break
|
|
242
|
+
cursor = next_cursor
|
|
243
|
+
|
|
244
|
+
def paginate_items(
|
|
245
|
+
self,
|
|
246
|
+
path: str,
|
|
247
|
+
body: dict,
|
|
248
|
+
items_key: str,
|
|
249
|
+
item_factory: Callable[[dict], Any],
|
|
250
|
+
**kwargs: Any,
|
|
251
|
+
) -> List[Any]:
|
|
252
|
+
"""翻页并收集所有条目到列表。
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
path: API 路径
|
|
256
|
+
body: 请求体基础参数
|
|
257
|
+
items_key: data 中包含列表数据的字段名
|
|
258
|
+
item_factory: 将单个条目字典转换为对象的工厂函数
|
|
259
|
+
**kwargs: 传递给 paginate 的额外参数
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
所有页面合并后的条目列表
|
|
263
|
+
"""
|
|
264
|
+
results: List[Any] = []
|
|
265
|
+
for data in self.paginate(path, body, items_key, **kwargs):
|
|
266
|
+
for item in data.get(items_key, []):
|
|
267
|
+
results.append(item_factory(item))
|
|
268
|
+
return results
|
ima_sdk/cos_uploader.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""腾讯云 COS 文件上传模块。
|
|
2
|
+
|
|
3
|
+
纯 Python 实现 COS PUT Object 操作,使用 HMAC-SHA1 签名。
|
|
4
|
+
参考: https://cloud.tencent.com/document/product/436/7778
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import hashlib
|
|
10
|
+
import hmac
|
|
11
|
+
import math
|
|
12
|
+
import time
|
|
13
|
+
import urllib.request
|
|
14
|
+
from typing import Optional
|
|
15
|
+
|
|
16
|
+
from .types import CosCredential
|
|
17
|
+
from . import logger
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _hmac_sha1(key: str, data: str) -> str:
|
|
21
|
+
"""HMAC-SHA1 签名,返回十六进制字符串。"""
|
|
22
|
+
return hmac.new(
|
|
23
|
+
key.encode("utf-8"), data.encode("utf-8"), hashlib.sha1
|
|
24
|
+
).hexdigest()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _sha1(data: str) -> str:
|
|
28
|
+
"""SHA1 哈希,返回十六进制字符串。"""
|
|
29
|
+
return hashlib.sha1(data.encode("utf-8")).hexdigest()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _build_authorization(
|
|
33
|
+
secret_id: str,
|
|
34
|
+
secret_key: str,
|
|
35
|
+
method: str,
|
|
36
|
+
pathname: str,
|
|
37
|
+
headers: dict[str, str],
|
|
38
|
+
start_time: int,
|
|
39
|
+
expired_time: int,
|
|
40
|
+
) -> str:
|
|
41
|
+
"""构建 COS Authorization header。
|
|
42
|
+
|
|
43
|
+
算法步骤:
|
|
44
|
+
1. SignKey = HMAC-SHA1(SecretKey, KeyTime)
|
|
45
|
+
2. HttpString = method\\npathname\\nparams\\nheaders\\n
|
|
46
|
+
3. StringToSign = sha1\\nKeyTime\\nSHA1(HttpString)\\n
|
|
47
|
+
4. Signature = HMAC-SHA1(SignKey, StringToSign)
|
|
48
|
+
"""
|
|
49
|
+
key_time = f"{start_time};{expired_time}"
|
|
50
|
+
|
|
51
|
+
# 1. SignKey
|
|
52
|
+
sign_key = _hmac_sha1(secret_key, key_time)
|
|
53
|
+
|
|
54
|
+
# 2. HttpString — 对 PUT 请求,无 query params
|
|
55
|
+
header_keys = sorted(headers.keys())
|
|
56
|
+
http_headers = "&".join(
|
|
57
|
+
f"{k.lower()}={urllib.request.quote(headers[k], safe='')}"
|
|
58
|
+
for k in header_keys
|
|
59
|
+
)
|
|
60
|
+
http_string = f"{method.lower()}\n{pathname}\n\n{http_headers}\n"
|
|
61
|
+
|
|
62
|
+
# 3. StringToSign
|
|
63
|
+
string_to_sign = f"sha1\n{key_time}\n{_sha1(http_string)}\n"
|
|
64
|
+
|
|
65
|
+
# 4. Signature
|
|
66
|
+
signature = _hmac_sha1(sign_key, string_to_sign)
|
|
67
|
+
|
|
68
|
+
# 5. 组装 Authorization
|
|
69
|
+
header_list = ";".join(k.lower() for k in header_keys)
|
|
70
|
+
return (
|
|
71
|
+
f"q-sign-algorithm=sha1"
|
|
72
|
+
f"&q-ak={secret_id}"
|
|
73
|
+
f"&q-sign-time={key_time}"
|
|
74
|
+
f"&q-key-time={key_time}"
|
|
75
|
+
f"&q-header-list={header_list}"
|
|
76
|
+
f"&q-url-param-list="
|
|
77
|
+
f"&q-signature={signature}"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def cos_upload(
|
|
82
|
+
file_path: str,
|
|
83
|
+
credential: CosCredential,
|
|
84
|
+
content_type: str = "application/octet-stream",
|
|
85
|
+
start_time: Optional[int] = None,
|
|
86
|
+
expired_time: Optional[int] = None,
|
|
87
|
+
) -> None:
|
|
88
|
+
"""上传文件到腾讯云 COS。
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
file_path: 本地文件路径
|
|
92
|
+
credential: COS 上传凭证(由 create_media 返回)
|
|
93
|
+
content_type: 文件 MIME 类型
|
|
94
|
+
start_time: 签名开始时间(秒级时间戳),默认当前时间
|
|
95
|
+
expired_time: 签名过期时间,默认当前时间 + 1小时
|
|
96
|
+
|
|
97
|
+
Raises:
|
|
98
|
+
RuntimeError: 上传失败
|
|
99
|
+
FileNotFoundError: 文件不存在
|
|
100
|
+
"""
|
|
101
|
+
with open(file_path, "rb") as f:
|
|
102
|
+
file_content = f.read()
|
|
103
|
+
|
|
104
|
+
logger.log_step("COS 上传", f"file={file_path} size={len(file_content)} bytes")
|
|
105
|
+
|
|
106
|
+
now = math.floor(time.time())
|
|
107
|
+
if start_time is None:
|
|
108
|
+
start_time = credential.start_time or now
|
|
109
|
+
if expired_time is None:
|
|
110
|
+
expired_time = credential.expired_time or (now + 3600)
|
|
111
|
+
|
|
112
|
+
hostname = f"{credential.bucket_name}.cos.{credential.region}.myqcloud.com"
|
|
113
|
+
pathname = f"/{credential.cos_key}"
|
|
114
|
+
|
|
115
|
+
# 参与签名的 headers
|
|
116
|
+
sign_headers = {
|
|
117
|
+
"content-length": str(len(file_content)),
|
|
118
|
+
"host": hostname,
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
authorization = _build_authorization(
|
|
122
|
+
secret_id=credential.secret_id,
|
|
123
|
+
secret_key=credential.secret_key,
|
|
124
|
+
method="PUT",
|
|
125
|
+
pathname=pathname,
|
|
126
|
+
headers=sign_headers,
|
|
127
|
+
start_time=start_time,
|
|
128
|
+
expired_time=expired_time,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
url = f"https://{hostname}{pathname}"
|
|
132
|
+
logger.log(f" PUT {url}")
|
|
133
|
+
logger.log(f" content_type={content_type}")
|
|
134
|
+
req = urllib.request.Request(
|
|
135
|
+
url,
|
|
136
|
+
data=file_content,
|
|
137
|
+
headers={
|
|
138
|
+
"Content-Type": content_type,
|
|
139
|
+
"Content-Length": str(len(file_content)),
|
|
140
|
+
"Authorization": authorization,
|
|
141
|
+
"x-cos-security-token": credential.token,
|
|
142
|
+
},
|
|
143
|
+
method="PUT",
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
try:
|
|
147
|
+
with urllib.request.urlopen(req) as resp:
|
|
148
|
+
if resp.status < 200 or resp.status >= 300:
|
|
149
|
+
body = resp.read().decode("utf-8", errors="replace")
|
|
150
|
+
logger.log_error(f"COS 上传失败 (HTTP {resp.status}): {body}")
|
|
151
|
+
raise RuntimeError(
|
|
152
|
+
f"COS 上传失败 (HTTP {resp.status}): {body}"
|
|
153
|
+
)
|
|
154
|
+
logger.log(f" 上传成功 (HTTP {resp.status})")
|
|
155
|
+
except urllib.error.HTTPError as e:
|
|
156
|
+
body = e.read().decode("utf-8", errors="replace")
|
|
157
|
+
logger.log_error(f"COS 上传失败 (HTTP {e.code}): {body}")
|
|
158
|
+
raise RuntimeError(f"COS 上传失败 (HTTP {e.code}): {body}") from e
|
ima_sdk/file_checker.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
"""文件预检模块。
|
|
2
|
+
|
|
3
|
+
在上传文件到知识库前,校验文件类型和大小是否符合要求。
|
|
4
|
+
完整移植自 preflight-check.cjs 的逻辑。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
from .types import FileCheckResult, MediaType
|
|
13
|
+
from . import logger
|
|
14
|
+
|
|
15
|
+
# ── 扩展名 → media_type + content_type 映射 ──────────────────────────────────
|
|
16
|
+
|
|
17
|
+
EXT_MAP: dict[str, dict] = {
|
|
18
|
+
"pdf": {"media_type": MediaType.PDF, "content_type": "application/pdf"},
|
|
19
|
+
"doc": {"media_type": MediaType.WORD, "content_type": "application/msword"},
|
|
20
|
+
"docx": {
|
|
21
|
+
"media_type": MediaType.WORD,
|
|
22
|
+
"content_type": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
23
|
+
},
|
|
24
|
+
"ppt": {
|
|
25
|
+
"media_type": MediaType.PPT,
|
|
26
|
+
"content_type": "application/vnd.ms-powerpoint",
|
|
27
|
+
},
|
|
28
|
+
"pptx": {
|
|
29
|
+
"media_type": MediaType.PPT,
|
|
30
|
+
"content_type": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
|
31
|
+
},
|
|
32
|
+
"xls": {
|
|
33
|
+
"media_type": MediaType.EXCEL,
|
|
34
|
+
"content_type": "application/vnd.ms-excel",
|
|
35
|
+
},
|
|
36
|
+
"xlsx": {
|
|
37
|
+
"media_type": MediaType.EXCEL,
|
|
38
|
+
"content_type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
|
39
|
+
},
|
|
40
|
+
"csv": {"media_type": MediaType.EXCEL, "content_type": "text/csv"},
|
|
41
|
+
"md": {"media_type": MediaType.MARKDOWN, "content_type": "text/markdown"},
|
|
42
|
+
"markdown": {"media_type": MediaType.MARKDOWN, "content_type": "text/markdown"},
|
|
43
|
+
"png": {"media_type": MediaType.IMAGE, "content_type": "image/png"},
|
|
44
|
+
"jpg": {"media_type": MediaType.IMAGE, "content_type": "image/jpeg"},
|
|
45
|
+
"jpeg": {"media_type": MediaType.IMAGE, "content_type": "image/jpeg"},
|
|
46
|
+
"webp": {"media_type": MediaType.IMAGE, "content_type": "image/webp"},
|
|
47
|
+
"txt": {"media_type": MediaType.TXT, "content_type": "text/plain"},
|
|
48
|
+
"xmind": {"media_type": MediaType.XMIND, "content_type": "application/x-xmind"},
|
|
49
|
+
"mp3": {"media_type": MediaType.AUDIO, "content_type": "audio/mpeg"},
|
|
50
|
+
"m4a": {"media_type": MediaType.AUDIO, "content_type": "audio/x-m4a"},
|
|
51
|
+
"wav": {"media_type": MediaType.AUDIO, "content_type": "audio/wav"},
|
|
52
|
+
"aac": {"media_type": MediaType.AUDIO, "content_type": "audio/aac"},
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
# ── content_type → media_type 反向映射 ────────────────────────────────────────
|
|
56
|
+
|
|
57
|
+
CONTENT_TYPE_MAP: dict[str, int] = {}
|
|
58
|
+
for _info in EXT_MAP.values():
|
|
59
|
+
ct = _info["content_type"]
|
|
60
|
+
if ct not in CONTENT_TYPE_MAP:
|
|
61
|
+
CONTENT_TYPE_MAP[ct] = _info["media_type"]
|
|
62
|
+
|
|
63
|
+
# 额外 content_type 别名
|
|
64
|
+
CONTENT_TYPE_MAP.update(
|
|
65
|
+
{
|
|
66
|
+
"text/x-markdown": MediaType.MARKDOWN,
|
|
67
|
+
"application/md": MediaType.MARKDOWN,
|
|
68
|
+
"application/markdown": MediaType.MARKDOWN,
|
|
69
|
+
"application/vnd.xmind.workbook": MediaType.XMIND,
|
|
70
|
+
"application/zip": MediaType.XMIND, # xmind 可能为 zip
|
|
71
|
+
}
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# ── 文件大小限制(字节)────────────────────────────────────────────────────────
|
|
75
|
+
|
|
76
|
+
MB = 1024 * 1024
|
|
77
|
+
|
|
78
|
+
SIZE_LIMITS: dict[int, int] = {
|
|
79
|
+
MediaType.EXCEL: 10 * MB,
|
|
80
|
+
MediaType.MARKDOWN: 10 * MB,
|
|
81
|
+
MediaType.TXT: 10 * MB,
|
|
82
|
+
MediaType.XMIND: 10 * MB,
|
|
83
|
+
MediaType.IMAGE: 30 * MB,
|
|
84
|
+
}
|
|
85
|
+
DEFAULT_SIZE_LIMIT = 200 * MB
|
|
86
|
+
|
|
87
|
+
# ── 不支持的视频类型 ─────────────────────────────────────────────────────────
|
|
88
|
+
|
|
89
|
+
UNSUPPORTED_VIDEO_EXTS = {
|
|
90
|
+
"mp4", "avi", "mov", "mkv", "wmv", "flv", "webm", "m4v", "rmvb", "rm", "3gp"
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
UNSUPPORTED_VIDEO_CONTENT_TYPES = {
|
|
94
|
+
"video/mp4",
|
|
95
|
+
"video/x-msvideo",
|
|
96
|
+
"video/quicktime",
|
|
97
|
+
"video/x-matroska",
|
|
98
|
+
"video/x-ms-wmv",
|
|
99
|
+
"video/x-flv",
|
|
100
|
+
"video/webm",
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _format_size(size_bytes: int) -> str:
|
|
105
|
+
"""格式化文件大小。"""
|
|
106
|
+
if size_bytes < MB:
|
|
107
|
+
return f"{size_bytes / 1024:.1f} KB"
|
|
108
|
+
return f"{size_bytes / MB:.1f} MB"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def check_file(
|
|
112
|
+
file_path: str, content_type: Optional[str] = None
|
|
113
|
+
) -> FileCheckResult:
|
|
114
|
+
"""对上传文件进行预检。
|
|
115
|
+
|
|
116
|
+
检查文件是否存在、是否为支持的类型、是否超过大小限制。
|
|
117
|
+
|
|
118
|
+
解析优先级:
|
|
119
|
+
1. 若 content_type 可识别 → 使用 content_type
|
|
120
|
+
2. 若 content_type 不可识别 → 回退到扩展名
|
|
121
|
+
3. 若无 content_type → 使用扩展名
|
|
122
|
+
4. 两者都无法识别 → 检查失败
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
file_path: 文件绝对路径
|
|
126
|
+
content_type: 可选的 MIME 类型
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
FileCheckResult 对象,passed=True 表示通过预检
|
|
130
|
+
"""
|
|
131
|
+
abs_path = os.path.abspath(file_path)
|
|
132
|
+
file_name = os.path.basename(abs_path)
|
|
133
|
+
logger.log_step("文件预检", f"{file_name}")
|
|
134
|
+
|
|
135
|
+
# 提取扩展名
|
|
136
|
+
_, ext_with_dot = os.path.splitext(file_name)
|
|
137
|
+
ext = ext_with_dot.lstrip(".").lower() if ext_with_dot else ""
|
|
138
|
+
|
|
139
|
+
base_result = FileCheckResult(
|
|
140
|
+
file_path=abs_path, file_name=file_name, file_ext=ext
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# 1. 检查文件存在
|
|
144
|
+
if not os.path.isfile(abs_path):
|
|
145
|
+
base_result.reason = f"文件不存在: {abs_path}"
|
|
146
|
+
logger.log_error(base_result.reason)
|
|
147
|
+
return base_result
|
|
148
|
+
|
|
149
|
+
# 2. 检查不支持的视频类型
|
|
150
|
+
if ext in UNSUPPORTED_VIDEO_EXTS:
|
|
151
|
+
base_result.reason = (
|
|
152
|
+
f"不支持视频文件 (.{ext}),视频仅支持在 IMA 桌面端添加。"
|
|
153
|
+
)
|
|
154
|
+
logger.log_error(base_result.reason)
|
|
155
|
+
return base_result
|
|
156
|
+
|
|
157
|
+
if content_type and content_type in UNSUPPORTED_VIDEO_CONTENT_TYPES:
|
|
158
|
+
base_result.reason = (
|
|
159
|
+
f"不支持视频文件 ({content_type}),视频仅支持在 IMA 桌面端添加。"
|
|
160
|
+
)
|
|
161
|
+
logger.log_error(base_result.reason)
|
|
162
|
+
return base_result
|
|
163
|
+
|
|
164
|
+
# 3. 解析 media_type 和 content_type
|
|
165
|
+
media_type: Optional[int] = None
|
|
166
|
+
resolved_ct: Optional[str] = None
|
|
167
|
+
|
|
168
|
+
ct_media_type = CONTENT_TYPE_MAP.get(content_type) if content_type else None
|
|
169
|
+
ext_mapping = EXT_MAP.get(ext) if ext else None
|
|
170
|
+
|
|
171
|
+
if ct_media_type is not None:
|
|
172
|
+
# content_type 可识别
|
|
173
|
+
media_type = ct_media_type
|
|
174
|
+
resolved_ct = content_type
|
|
175
|
+
elif content_type:
|
|
176
|
+
# content_type 提供了但不可识别,回退到扩展名
|
|
177
|
+
if ext_mapping:
|
|
178
|
+
media_type = ext_mapping["media_type"]
|
|
179
|
+
resolved_ct = ext_mapping["content_type"]
|
|
180
|
+
else:
|
|
181
|
+
ext_hint = f"和文件扩展名 .{ext}" if ext else ""
|
|
182
|
+
base_result.reason = (
|
|
183
|
+
f"无法识别的 content_type: {content_type}{ext_hint},"
|
|
184
|
+
"该文件类型不受支持。"
|
|
185
|
+
)
|
|
186
|
+
return base_result
|
|
187
|
+
else:
|
|
188
|
+
# 无 content_type,使用扩展名
|
|
189
|
+
if ext_mapping:
|
|
190
|
+
media_type = ext_mapping["media_type"]
|
|
191
|
+
resolved_ct = ext_mapping["content_type"]
|
|
192
|
+
elif ext:
|
|
193
|
+
base_result.reason = f"无法识别的文件扩展名 .{ext},该文件类型不受支持。"
|
|
194
|
+
return base_result
|
|
195
|
+
else:
|
|
196
|
+
base_result.reason = (
|
|
197
|
+
"文件没有扩展名且未提供 content_type,无法判断文件类型。"
|
|
198
|
+
)
|
|
199
|
+
return base_result
|
|
200
|
+
|
|
201
|
+
# 4. 检查文件大小
|
|
202
|
+
file_size = os.path.getsize(abs_path)
|
|
203
|
+
size_limit = SIZE_LIMITS.get(media_type, DEFAULT_SIZE_LIMIT)
|
|
204
|
+
|
|
205
|
+
if file_size > size_limit:
|
|
206
|
+
base_result.file_size = file_size
|
|
207
|
+
base_result.media_type = media_type
|
|
208
|
+
base_result.content_type = resolved_ct
|
|
209
|
+
base_result.reason = (
|
|
210
|
+
f"文件大小 {_format_size(file_size)} "
|
|
211
|
+
f"超过此类型的限制 {_format_size(size_limit)}。"
|
|
212
|
+
)
|
|
213
|
+
return base_result
|
|
214
|
+
|
|
215
|
+
# 5. 全部通过
|
|
216
|
+
logger.log(f" 预检通过: ext={ext} media_type={media_type} "
|
|
217
|
+
f"content_type={resolved_ct} size={_format_size(file_size)}")
|
|
218
|
+
return FileCheckResult(
|
|
219
|
+
passed=True,
|
|
220
|
+
file_path=abs_path,
|
|
221
|
+
file_name=file_name,
|
|
222
|
+
file_ext=ext,
|
|
223
|
+
file_size=file_size,
|
|
224
|
+
media_type=media_type,
|
|
225
|
+
content_type=resolved_ct,
|
|
226
|
+
)
|