gomyck-tools 1.5.3__py3-none-any.whl → 1.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ctools/authcode_validator.py +429 -0
- ctools/cipher/sm_util.py +0 -1
- ctools/database/database.py +27 -9
- ctools/ml/image_process.py +121 -0
- ctools/ml/img_extractor.py +59 -0
- ctools/ml/ppi.py +276 -0
- ctools/stream/ckafka.py +2 -2
- ctools/sys_log.py +151 -77
- ctools/util/jb_cut.py +0 -1
- ctools/web/bottle_web_base.py +1 -0
- ctools/web/bottle_webserver.py +11 -7
- ctools/web/bottle_websocket.py +1 -1
- ctools/web/ctoken.py +1 -1
- {gomyck_tools-1.5.3.dist-info → gomyck_tools-1.5.5.dist-info}/METADATA +15 -12
- {gomyck_tools-1.5.3.dist-info → gomyck_tools-1.5.5.dist-info}/RECORD +18 -14
- {gomyck_tools-1.5.3.dist-info → gomyck_tools-1.5.5.dist-info}/WHEEL +1 -1
- {gomyck_tools-1.5.3.dist-info → gomyck_tools-1.5.5.dist-info}/licenses/LICENSE +0 -0
- {gomyck_tools-1.5.3.dist-info → gomyck_tools-1.5.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
# -*- coding: UTF-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
授权码验证工具类
|
|
5
|
+
仅用于验证授权码的有效性
|
|
6
|
+
"""
|
|
7
|
+
__author__ = 'haoyang'
|
|
8
|
+
__date__ = '2026/1/26'
|
|
9
|
+
|
|
10
|
+
from typing import Optional
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
import json
|
|
13
|
+
import ipaddress
|
|
14
|
+
|
|
15
|
+
from ctools.cipher import sm_util
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AuthCodeValidator:
|
|
19
|
+
"""授权码验证器"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, public_key: str = None):
|
|
22
|
+
"""
|
|
23
|
+
初始化验证器
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
public_key: SM2公钥(用于签名验证)
|
|
27
|
+
"""
|
|
28
|
+
if not public_key:
|
|
29
|
+
raise Exception("未提供公钥,无法初始化验证器。请传入 public_key 参数。")
|
|
30
|
+
self.public_key = public_key
|
|
31
|
+
|
|
32
|
+
def validate(self, authcode_json: str) -> bool:
|
|
33
|
+
"""
|
|
34
|
+
快速验证授权码是否有效
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
authcode_json: 授权码JSON字符串
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
bool: 授权码是否有效
|
|
41
|
+
"""
|
|
42
|
+
try:
|
|
43
|
+
authcode_obj = json.loads(authcode_json)
|
|
44
|
+
except:
|
|
45
|
+
return False
|
|
46
|
+
|
|
47
|
+
# 验证必需字段
|
|
48
|
+
if 'version' not in authcode_obj or 'body' not in authcode_obj or 'signature' not in authcode_obj:
|
|
49
|
+
return False
|
|
50
|
+
|
|
51
|
+
body = authcode_obj.get('body')
|
|
52
|
+
if not isinstance(body, dict):
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
return self._verify_signature(authcode_json)
|
|
56
|
+
|
|
57
|
+
def _verify_signature(self, authcode_json: str) -> bool:
|
|
58
|
+
"""
|
|
59
|
+
验证授权码签名
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
authcode_json: 授权码JSON字符串
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
bool: 签名是否有效
|
|
66
|
+
"""
|
|
67
|
+
if not self.public_key:
|
|
68
|
+
raise Exception("未初始化公钥,无法验证签名。请传入 public_key 参数。")
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
authcode_obj = json.loads(authcode_json)
|
|
72
|
+
body = authcode_obj.get('body', {})
|
|
73
|
+
signature = authcode_obj.get('signature', '')
|
|
74
|
+
version = authcode_obj.get('version', 'v1_1')
|
|
75
|
+
|
|
76
|
+
# 初始化SM2
|
|
77
|
+
sm_util.init(self.public_key, self.public_key)
|
|
78
|
+
|
|
79
|
+
# 构建签名字符串
|
|
80
|
+
final_val = self._build_sign_string(body, version)
|
|
81
|
+
|
|
82
|
+
# 验证签名
|
|
83
|
+
return sm_util.verify_with_sm2(signature, final_val)
|
|
84
|
+
except Exception as e:
|
|
85
|
+
try:
|
|
86
|
+
# 重试:重新初始化SM2并验证
|
|
87
|
+
authcode_obj = json.loads(authcode_json)
|
|
88
|
+
body = authcode_obj.get('body', {})
|
|
89
|
+
signature = authcode_obj.get('signature', '')
|
|
90
|
+
version = authcode_obj.get('version', 'v1_1')
|
|
91
|
+
|
|
92
|
+
sm_util.init(self.public_key, self.public_key)
|
|
93
|
+
final_val = self._build_sign_string(body, version)
|
|
94
|
+
return sm_util.verify_with_sm2(signature, final_val)
|
|
95
|
+
except:
|
|
96
|
+
return False
|
|
97
|
+
|
|
98
|
+
def _build_sign_string(self, body: dict, version: str) -> str:
|
|
99
|
+
"""
|
|
100
|
+
构建签名字符串,与app.py中的签名逻辑一致
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
body: 授权码内容
|
|
104
|
+
version: 版本号
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
str: 签名字符串
|
|
108
|
+
"""
|
|
109
|
+
ordered_dict = sorted(body.items())
|
|
110
|
+
final_val = ""
|
|
111
|
+
|
|
112
|
+
for k, v in ordered_dict:
|
|
113
|
+
if isinstance(v, list):
|
|
114
|
+
value_str = ",".join(v)
|
|
115
|
+
else:
|
|
116
|
+
value_str = str(v)
|
|
117
|
+
|
|
118
|
+
if version == 'v1':
|
|
119
|
+
# v1: 不带换行符
|
|
120
|
+
final_val += k + ":" + value_str
|
|
121
|
+
elif version == 'v1_1':
|
|
122
|
+
# v1_1: 带换行符
|
|
123
|
+
final_val += k + ":" + value_str + '\n'
|
|
124
|
+
else:
|
|
125
|
+
# 默认v1_1
|
|
126
|
+
final_val += k + ":" + value_str + '\n'
|
|
127
|
+
|
|
128
|
+
return final_val
|
|
129
|
+
|
|
130
|
+
def check_expired(self, authcode_json: str) -> bool:
|
|
131
|
+
"""
|
|
132
|
+
检查授权码是否过期
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
authcode_json: 授权码JSON字符串
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
bool: 是否未过期
|
|
139
|
+
"""
|
|
140
|
+
try:
|
|
141
|
+
authcode_obj = json.loads(authcode_json)
|
|
142
|
+
body = authcode_obj.get('body', {})
|
|
143
|
+
|
|
144
|
+
# 检查过期时间
|
|
145
|
+
if 'expired_time' in body:
|
|
146
|
+
expired_time_str = body['expired_time']
|
|
147
|
+
expired_dt = self._parse_datetime(expired_time_str)
|
|
148
|
+
|
|
149
|
+
if not expired_dt:
|
|
150
|
+
return False
|
|
151
|
+
|
|
152
|
+
if datetime.now() > expired_dt:
|
|
153
|
+
return False
|
|
154
|
+
|
|
155
|
+
# 检查生效时间
|
|
156
|
+
if 'effect_time' in body:
|
|
157
|
+
effect_time_str = body['effect_time']
|
|
158
|
+
effect_dt = self._parse_datetime(effect_time_str)
|
|
159
|
+
|
|
160
|
+
if not effect_dt:
|
|
161
|
+
return False
|
|
162
|
+
|
|
163
|
+
if datetime.now() < effect_dt:
|
|
164
|
+
return False
|
|
165
|
+
|
|
166
|
+
return True
|
|
167
|
+
except:
|
|
168
|
+
return False
|
|
169
|
+
|
|
170
|
+
def check_ip(self, authcode_json: str, client_ip: str) -> bool:
|
|
171
|
+
"""
|
|
172
|
+
检查客户端IP是否在授权范围内
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
authcode_json: 授权码JSON字符串
|
|
176
|
+
client_ip: 客户端IP地址
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
bool: IP是否在授权范围内
|
|
180
|
+
"""
|
|
181
|
+
try:
|
|
182
|
+
authcode_obj = json.loads(authcode_json)
|
|
183
|
+
body = authcode_obj.get('body', {})
|
|
184
|
+
|
|
185
|
+
if 'ip_range' not in body:
|
|
186
|
+
return True
|
|
187
|
+
|
|
188
|
+
ip_ranges = body['ip_range']
|
|
189
|
+
if not ip_ranges:
|
|
190
|
+
return False
|
|
191
|
+
|
|
192
|
+
# 转换为列表
|
|
193
|
+
if isinstance(ip_ranges, str):
|
|
194
|
+
ip_list = [ip.strip() for ip in ip_ranges.split(',')]
|
|
195
|
+
elif isinstance(ip_ranges, list):
|
|
196
|
+
ip_list = ip_ranges
|
|
197
|
+
else:
|
|
198
|
+
return False
|
|
199
|
+
|
|
200
|
+
# 支持通配符 "*"
|
|
201
|
+
if '*' in ip_list:
|
|
202
|
+
return True
|
|
203
|
+
|
|
204
|
+
# 检查IP是否在范围内
|
|
205
|
+
try:
|
|
206
|
+
client_ip_obj = ipaddress.ip_address(client_ip)
|
|
207
|
+
except:
|
|
208
|
+
return False
|
|
209
|
+
|
|
210
|
+
for ip_pattern in ip_list:
|
|
211
|
+
try:
|
|
212
|
+
if '/' in ip_pattern:
|
|
213
|
+
# CIDR格式
|
|
214
|
+
ip_network = ipaddress.ip_network(ip_pattern, strict=False)
|
|
215
|
+
if client_ip_obj in ip_network:
|
|
216
|
+
return True
|
|
217
|
+
else:
|
|
218
|
+
# 精确匹配
|
|
219
|
+
if str(client_ip_obj) == ip_pattern:
|
|
220
|
+
return True
|
|
221
|
+
except:
|
|
222
|
+
continue
|
|
223
|
+
|
|
224
|
+
return False
|
|
225
|
+
except:
|
|
226
|
+
return False
|
|
227
|
+
|
|
228
|
+
def check_machine_code(self, authcode_json: str, machine_code: str) -> bool:
|
|
229
|
+
"""
|
|
230
|
+
检查机器码是否被授权
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
authcode_json: 授权码JSON字符串
|
|
234
|
+
machine_code: 机器码
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
bool: 机器码是否被授权
|
|
238
|
+
"""
|
|
239
|
+
try:
|
|
240
|
+
authcode_obj = json.loads(authcode_json)
|
|
241
|
+
body = authcode_obj.get('body', {})
|
|
242
|
+
|
|
243
|
+
if 'machine_codes' not in body:
|
|
244
|
+
return True
|
|
245
|
+
|
|
246
|
+
codes = body['machine_codes']
|
|
247
|
+
if not codes:
|
|
248
|
+
return False
|
|
249
|
+
|
|
250
|
+
# 转换为列表
|
|
251
|
+
if isinstance(codes, str):
|
|
252
|
+
code_list = [c.strip() for c in codes.split(',')]
|
|
253
|
+
elif isinstance(codes, list):
|
|
254
|
+
code_list = codes
|
|
255
|
+
else:
|
|
256
|
+
return False
|
|
257
|
+
|
|
258
|
+
# 支持通配符 "*"
|
|
259
|
+
if '*' in code_list:
|
|
260
|
+
return True
|
|
261
|
+
|
|
262
|
+
return machine_code in code_list
|
|
263
|
+
except:
|
|
264
|
+
return False
|
|
265
|
+
|
|
266
|
+
def check_module(self, authcode_json: str, module_name: str) -> bool:
|
|
267
|
+
"""
|
|
268
|
+
检查模块是否被授权
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
authcode_json: 授权码JSON字符串
|
|
272
|
+
module_name: 模块名称
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
bool: 模块是否被授权
|
|
276
|
+
"""
|
|
277
|
+
try:
|
|
278
|
+
authcode_obj = json.loads(authcode_json)
|
|
279
|
+
body = authcode_obj.get('body', {})
|
|
280
|
+
|
|
281
|
+
if 'modules' not in body:
|
|
282
|
+
return True
|
|
283
|
+
|
|
284
|
+
modules = body['modules']
|
|
285
|
+
if not modules:
|
|
286
|
+
return False
|
|
287
|
+
|
|
288
|
+
# 转换为列表
|
|
289
|
+
if isinstance(modules, str):
|
|
290
|
+
module_list = [m.strip() for m in modules.split(',')]
|
|
291
|
+
elif isinstance(modules, list):
|
|
292
|
+
module_list = modules
|
|
293
|
+
else:
|
|
294
|
+
return False
|
|
295
|
+
|
|
296
|
+
# 支持通配符 "*"
|
|
297
|
+
if '*' in module_list:
|
|
298
|
+
return True
|
|
299
|
+
|
|
300
|
+
return module_name in module_list
|
|
301
|
+
except:
|
|
302
|
+
return False
|
|
303
|
+
|
|
304
|
+
def check_artifact(self, authcode_json: str, artifact_name: str) -> bool:
|
|
305
|
+
"""
|
|
306
|
+
检查制品名称是否匹配
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
authcode_json: 授权码JSON字符串
|
|
310
|
+
artifact_name: 制品名称
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
bool: 制品是否被授权
|
|
314
|
+
"""
|
|
315
|
+
try:
|
|
316
|
+
authcode_obj = json.loads(authcode_json)
|
|
317
|
+
body = authcode_obj.get('body', {})
|
|
318
|
+
|
|
319
|
+
if 'artifact' not in body:
|
|
320
|
+
return True
|
|
321
|
+
|
|
322
|
+
authorized_artifact = body['artifact']
|
|
323
|
+
|
|
324
|
+
# 支持通配符 "*"
|
|
325
|
+
if authorized_artifact == '*':
|
|
326
|
+
return True
|
|
327
|
+
|
|
328
|
+
return artifact_name == authorized_artifact
|
|
329
|
+
except:
|
|
330
|
+
return False
|
|
331
|
+
|
|
332
|
+
def check_version(self, authcode_json: str, version: str) -> bool:
|
|
333
|
+
"""
|
|
334
|
+
检查版本是否被授权
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
authcode_json: 授权码JSON字符串
|
|
338
|
+
version: 版本号
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
bool: 版本是否被授权
|
|
342
|
+
"""
|
|
343
|
+
try:
|
|
344
|
+
authcode_obj = json.loads(authcode_json)
|
|
345
|
+
body = authcode_obj.get('body', {})
|
|
346
|
+
|
|
347
|
+
if 'version' not in body:
|
|
348
|
+
return True
|
|
349
|
+
|
|
350
|
+
authorized_version = body['version']
|
|
351
|
+
|
|
352
|
+
# 支持通配符 "*"
|
|
353
|
+
if authorized_version == '*':
|
|
354
|
+
return True
|
|
355
|
+
|
|
356
|
+
return version == authorized_version
|
|
357
|
+
except:
|
|
358
|
+
return False
|
|
359
|
+
|
|
360
|
+
def validate_all(self, authcode_json: str, client_ip: str = None,
|
|
361
|
+
machine_code: str = None, artifact: str = None,
|
|
362
|
+
version: str = None, module: str = None) -> bool:
|
|
363
|
+
"""
|
|
364
|
+
检查所有条件是否都满足
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
authcode_json: 授权码JSON字符串
|
|
368
|
+
client_ip: 客户端IP(可选)
|
|
369
|
+
machine_code: 机器码(可选)
|
|
370
|
+
artifact: 制品名称(可选)
|
|
371
|
+
version: 版本号(可选)
|
|
372
|
+
module: 模块名称(可选)
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
bool: 所有条件是否都满足
|
|
376
|
+
"""
|
|
377
|
+
# 基本验证
|
|
378
|
+
if not self.validate(authcode_json):
|
|
379
|
+
return False
|
|
380
|
+
|
|
381
|
+
# 检查时间
|
|
382
|
+
if not self.check_expired(authcode_json):
|
|
383
|
+
return False
|
|
384
|
+
|
|
385
|
+
# 检查IP
|
|
386
|
+
if client_ip and not self.check_ip(authcode_json, client_ip):
|
|
387
|
+
return False
|
|
388
|
+
|
|
389
|
+
# 检查机器码
|
|
390
|
+
if machine_code and not self.check_machine_code(authcode_json, machine_code):
|
|
391
|
+
return False
|
|
392
|
+
|
|
393
|
+
# 检查制品
|
|
394
|
+
if artifact and not self.check_artifact(authcode_json, artifact):
|
|
395
|
+
return False
|
|
396
|
+
|
|
397
|
+
# 检查版本
|
|
398
|
+
if version and not self.check_version(authcode_json, version):
|
|
399
|
+
return False
|
|
400
|
+
|
|
401
|
+
# 检查模块
|
|
402
|
+
if module and not self.check_module(authcode_json, module):
|
|
403
|
+
return False
|
|
404
|
+
|
|
405
|
+
return True
|
|
406
|
+
|
|
407
|
+
# ============= 私有方法 =============
|
|
408
|
+
|
|
409
|
+
def _parse_datetime(self, date_str: str) -> Optional[datetime]:
|
|
410
|
+
"""解析日期时间字符串"""
|
|
411
|
+
if not date_str:
|
|
412
|
+
return None
|
|
413
|
+
|
|
414
|
+
# 尝试多种格式
|
|
415
|
+
formats = [
|
|
416
|
+
'%Y-%m-%dT%H:%M', # ISO: 2026-01-26T09:23
|
|
417
|
+
'%Y-%m-%dT%H:%M:%S', # ISO: 2026-01-26T09:23:45
|
|
418
|
+
'%Y-%m-%d %H:%M:%S', # 标准: 2026-01-26 09:23:45
|
|
419
|
+
'%Y-%m-%d %H:%M', # 标准: 2026-01-26 09:23
|
|
420
|
+
'%Y-%m-%d', # 日期: 2026-01-26
|
|
421
|
+
]
|
|
422
|
+
|
|
423
|
+
for fmt in formats:
|
|
424
|
+
try:
|
|
425
|
+
return datetime.strptime(date_str, fmt)
|
|
426
|
+
except:
|
|
427
|
+
continue
|
|
428
|
+
|
|
429
|
+
return None
|
ctools/cipher/sm_util.py
CHANGED
|
@@ -9,7 +9,6 @@ sm2_crypt: sm2.CryptSM2 = None
|
|
|
9
9
|
def init(private_key: str, public_key: str):
|
|
10
10
|
global sm2_crypt
|
|
11
11
|
if sm2_crypt is not None:
|
|
12
|
-
print('sm2 is already init!!!')
|
|
13
12
|
return
|
|
14
13
|
sm2_crypt = sm2.CryptSM2(private_key=private_key, public_key=public_key, asn1=True, mode=1)
|
|
15
14
|
|
ctools/database/database.py
CHANGED
|
@@ -1,11 +1,9 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
import datetime
|
|
3
3
|
import math
|
|
4
|
-
import threading
|
|
5
4
|
|
|
6
|
-
from sqlalchemy import create_engine, BigInteger, Column
|
|
7
|
-
from sqlalchemy.
|
|
8
|
-
from sqlalchemy.orm import sessionmaker, Session
|
|
5
|
+
from sqlalchemy import create_engine, BigInteger, Column
|
|
6
|
+
from sqlalchemy.orm import sessionmaker, Session, declarative_base
|
|
9
7
|
from sqlalchemy.sql import text
|
|
10
8
|
|
|
11
9
|
from ctools import call
|
|
@@ -14,17 +12,37 @@ from ctools.pools.thread_pool import thread_local
|
|
|
14
12
|
from ctools.web.bottle_web_base import PageInfo
|
|
15
13
|
|
|
16
14
|
"""
|
|
15
|
+
from time import sleep
|
|
16
|
+
|
|
17
|
+
from sqlalchemy import text, Column, BigInteger, String
|
|
18
|
+
|
|
19
|
+
from ctools import cid, cjson
|
|
20
|
+
from ctools.database import database
|
|
21
|
+
from ctools.database.database import BaseMixin
|
|
22
|
+
|
|
17
23
|
class XXXX(BaseMixin):
|
|
18
24
|
__tablename__ = 't_xxx_info'
|
|
19
25
|
__table_args__ = {'comment': 'xxx信息表'}
|
|
20
26
|
server_content: Column = Column(String(50), nullable=True, default='', comment='123123')
|
|
21
27
|
server_ip: Column = Column(String(30), index=True)
|
|
22
28
|
user_id: Column = Column(BigInteger)
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
with database.get_session('source') as s:
|
|
26
|
-
|
|
27
|
-
|
|
29
|
+
database.init_db('postgresql://postgres:Hylink2014%40postgres@192.168.3.127:31199/postgres', default_schema='public', db_key='source', pool_size=100, auto_gen_table=True)
|
|
30
|
+
while True:
|
|
31
|
+
with database.get_session('source') as s:
|
|
32
|
+
params = {
|
|
33
|
+
'obj_id': cid.get_snowflake_id(),
|
|
34
|
+
'server_ip': cid.get_random_str(5),
|
|
35
|
+
'user_id': 123,
|
|
36
|
+
'server_content': cid.get_random_str(5),
|
|
37
|
+
}
|
|
38
|
+
s.execute(text('insert into t_xxx_info (obj_id, server_ip, user_id) values (:obj_id, :server_ip, :user_id)'), params)
|
|
39
|
+
s.commit()
|
|
40
|
+
sleep(0.2)
|
|
41
|
+
res = s.query(XXXX.obj_id, XXXX.server_ip, XXXX.user_id).all()
|
|
42
|
+
for item in res:
|
|
43
|
+
print(item._asdict())
|
|
44
|
+
print(len(res))
|
|
45
|
+
sleep(1)
|
|
28
46
|
"""
|
|
29
47
|
|
|
30
48
|
Base = None
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import random
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from PIL import Image, ImageDraw, ImageFont
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def preprocess(img, img_size):
|
|
8
|
+
mean = [0.485, 0.456, 0.406]
|
|
9
|
+
std = [0.229, 0.224, 0.225]
|
|
10
|
+
img = resize(img, img_size)
|
|
11
|
+
img = img[:, :, ::-1].astype('float32') # RGB->BGR
|
|
12
|
+
img = normalize(img, mean, std)
|
|
13
|
+
img = img.transpose((2, 0, 1)) # hwc -> chw
|
|
14
|
+
#show_preprocess(img, mean, std)
|
|
15
|
+
return img[np.newaxis, :]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def resize(img, target_size):
|
|
19
|
+
"""
|
|
20
|
+
img: numpy.ndarray (H,W,3) BGR or RGB
|
|
21
|
+
return: numpy.ndarray (target_size, target_size, 3)
|
|
22
|
+
"""
|
|
23
|
+
img = Image.fromarray(img)
|
|
24
|
+
img = img.resize((target_size, target_size), Image.BILINEAR)
|
|
25
|
+
return np.array(img)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def normalize(img, mean, std):
|
|
29
|
+
img = img / 255.0
|
|
30
|
+
mean = np.array(mean)[np.newaxis, np.newaxis, :]
|
|
31
|
+
std = np.array(std)[np.newaxis, np.newaxis, :]
|
|
32
|
+
img -= mean
|
|
33
|
+
img /= std
|
|
34
|
+
return img
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def show_preprocess(chw_img, mean, std):
|
|
38
|
+
"""
|
|
39
|
+
chw_img: (3, H, W), float32, normalized
|
|
40
|
+
"""
|
|
41
|
+
img = chw_img.copy()
|
|
42
|
+
# 1. CHW -> HWC
|
|
43
|
+
img = img.transpose(1, 2, 0)
|
|
44
|
+
# 2. de-normalize
|
|
45
|
+
img = img * std + mean
|
|
46
|
+
img = img * 255.0
|
|
47
|
+
# 3. clamp + uint8
|
|
48
|
+
img = np.clip(img, 0, 255).astype(np.uint8)
|
|
49
|
+
Image.fromarray(img).show()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def draw_bbox(img, result, threshold=0.5, save_name='res.jpg', scale_factor=None, im_size=320, class_names=None):
|
|
53
|
+
draw = ImageDraw.Draw(img)
|
|
54
|
+
|
|
55
|
+
if scale_factor is not None:
|
|
56
|
+
h_scale, w_scale = scale_factor[0]
|
|
57
|
+
else:
|
|
58
|
+
h_scale = w_scale = 1.
|
|
59
|
+
|
|
60
|
+
# 类别颜色随机但固定
|
|
61
|
+
category_colors = {}
|
|
62
|
+
if class_names is not None:
|
|
63
|
+
for cls in class_names:
|
|
64
|
+
category_colors[cls] = tuple(random.randint(0, 255) for _ in range(3))
|
|
65
|
+
|
|
66
|
+
# 字体
|
|
67
|
+
try:
|
|
68
|
+
font = ImageFont.truetype("arial.ttf", 15)
|
|
69
|
+
except:
|
|
70
|
+
font = ImageFont.load_default()
|
|
71
|
+
|
|
72
|
+
for res in result:
|
|
73
|
+
cat_id, score, bbox = res[0], res[1], res[2:]
|
|
74
|
+
if score < threshold:
|
|
75
|
+
continue
|
|
76
|
+
|
|
77
|
+
# 归一化 bbox -> 模型输入尺寸
|
|
78
|
+
xmin = bbox[0] * im_size
|
|
79
|
+
ymin = bbox[1] * im_size
|
|
80
|
+
xmax = bbox[2] * im_size
|
|
81
|
+
ymax = bbox[3] * im_size
|
|
82
|
+
|
|
83
|
+
# 模型输入尺寸 -> 原图
|
|
84
|
+
xmin = xmin / w_scale
|
|
85
|
+
xmax = xmax / w_scale
|
|
86
|
+
ymin = ymin / h_scale
|
|
87
|
+
ymax = ymax / h_scale
|
|
88
|
+
|
|
89
|
+
# 类别名和颜色
|
|
90
|
+
if class_names is not None:
|
|
91
|
+
class_name = class_names[int(cat_id)]
|
|
92
|
+
color = category_colors[class_name]
|
|
93
|
+
text = f"{class_name}:{score:.2f}"
|
|
94
|
+
|
|
95
|
+
# 获取文字尺寸,兼容所有版本 Pillow
|
|
96
|
+
try:
|
|
97
|
+
text_width, text_height = font.getsize(text) # 旧版 / 大部分版本
|
|
98
|
+
except AttributeError:
|
|
99
|
+
# Pillow 9.2+ 推荐用 getbbox
|
|
100
|
+
bbox_font = font.getbbox(text)
|
|
101
|
+
text_width = bbox_font[2] - bbox_font[0]
|
|
102
|
+
text_height = bbox_font[3] - bbox_font[1]
|
|
103
|
+
|
|
104
|
+
text_origin = (xmin, max(0, ymin - text_height)) # 框上方显示
|
|
105
|
+
draw.text(text_origin, text, fill=color, font=font)
|
|
106
|
+
else:
|
|
107
|
+
color = 'red'
|
|
108
|
+
|
|
109
|
+
# 画矩形框
|
|
110
|
+
draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=2)
|
|
111
|
+
|
|
112
|
+
img.save(save_name)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def image_show(base64_str):
|
|
116
|
+
"""base64字符串转PIL Image并显示"""
|
|
117
|
+
from io import BytesIO
|
|
118
|
+
import base64
|
|
119
|
+
img_data = base64.b64decode(base64_str)
|
|
120
|
+
img = Image.open(BytesIO(img_data))
|
|
121
|
+
img.show()
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
# -*- coding: UTF-8 -*-
|
|
3
|
+
__author__ = 'haoyang'
|
|
4
|
+
__date__ = '2026/1/20 15:02'
|
|
5
|
+
|
|
6
|
+
import base64
|
|
7
|
+
from io import BytesIO
|
|
8
|
+
|
|
9
|
+
class ClassRegionBase64ExtractorPIL:
|
|
10
|
+
def __init__(self, class_names, target_classes=None, threshold=0.5):
|
|
11
|
+
"""
|
|
12
|
+
class_names: 模型类别列表
|
|
13
|
+
target_classes: 只截取的类别名列表,None 表示全部
|
|
14
|
+
threshold: 置信度阈值
|
|
15
|
+
"""
|
|
16
|
+
self.class_names = class_names
|
|
17
|
+
self.target_classes = target_classes
|
|
18
|
+
self.threshold = threshold
|
|
19
|
+
|
|
20
|
+
@staticmethod
|
|
21
|
+
def image_to_base64(img, format='PNG'):
|
|
22
|
+
"""
|
|
23
|
+
PIL Image -> base64 字符串
|
|
24
|
+
"""
|
|
25
|
+
buffer = BytesIO()
|
|
26
|
+
img.save(buffer, format=format)
|
|
27
|
+
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
|
28
|
+
|
|
29
|
+
def extract(self, img, results, scale_factor=None, im_size=320):
|
|
30
|
+
"""
|
|
31
|
+
img: PIL Image
|
|
32
|
+
results: 模型输出 [[cat_id, score, xmin, ymin, xmax, ymax], ...]
|
|
33
|
+
scale_factor: np.array([[h_scale, w_scale]]) 或 None
|
|
34
|
+
im_size: 模型输入尺寸
|
|
35
|
+
return: List[Dict] -> [{"class": class_name, "score": score, "base64": base64_str}, ...]
|
|
36
|
+
"""
|
|
37
|
+
outputs = []
|
|
38
|
+
for res in results:
|
|
39
|
+
cat_id, score, bbox = res[0], res[1], res[2:]
|
|
40
|
+
if score < self.threshold or cat_id > len(self.class_names) - 1:
|
|
41
|
+
continue
|
|
42
|
+
class_name = self.class_names[int(cat_id)]
|
|
43
|
+
if self.target_classes is not None and class_name not in self.target_classes:
|
|
44
|
+
continue
|
|
45
|
+
xmin = bbox[0]
|
|
46
|
+
ymin = bbox[1]
|
|
47
|
+
xmax = bbox[2]
|
|
48
|
+
ymax = bbox[3]
|
|
49
|
+
# 裁剪
|
|
50
|
+
pil_img_threadsafe = img.copy()
|
|
51
|
+
cropped = pil_img_threadsafe.crop((xmin, ymin, xmax, ymax))
|
|
52
|
+
# 转 base64
|
|
53
|
+
b64_str = self.image_to_base64(cropped)
|
|
54
|
+
outputs.append({
|
|
55
|
+
"class": class_name,
|
|
56
|
+
"score": float(score),
|
|
57
|
+
"base64": b64_str
|
|
58
|
+
})
|
|
59
|
+
return outputs
|