kotonebot 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. kotonebot/__init__.py +39 -39
  2. kotonebot/backend/bot.py +312 -312
  3. kotonebot/backend/color.py +525 -525
  4. kotonebot/backend/context/__init__.py +3 -3
  5. kotonebot/backend/context/task_action.py +183 -183
  6. kotonebot/backend/core.py +129 -129
  7. kotonebot/backend/debug/entry.py +89 -89
  8. kotonebot/backend/debug/mock.py +78 -78
  9. kotonebot/backend/debug/server.py +222 -222
  10. kotonebot/backend/debug/vars.py +351 -351
  11. kotonebot/backend/dispatch.py +227 -227
  12. kotonebot/backend/flow_controller.py +196 -196
  13. kotonebot/backend/ocr.py +535 -529
  14. kotonebot/backend/preprocessor.py +103 -103
  15. kotonebot/client/__init__.py +9 -9
  16. kotonebot/client/device.py +528 -503
  17. kotonebot/client/fast_screenshot.py +377 -377
  18. kotonebot/client/host/__init__.py +43 -12
  19. kotonebot/client/host/adb_common.py +107 -103
  20. kotonebot/client/host/custom.py +118 -114
  21. kotonebot/client/host/leidian_host.py +196 -201
  22. kotonebot/client/host/mumu12_host.py +353 -358
  23. kotonebot/client/host/protocol.py +214 -213
  24. kotonebot/client/host/windows_common.py +58 -58
  25. kotonebot/client/implements/__init__.py +71 -15
  26. kotonebot/client/implements/adb.py +89 -85
  27. kotonebot/client/implements/adb_raw.py +162 -158
  28. kotonebot/client/implements/nemu_ipc/__init__.py +11 -7
  29. kotonebot/client/implements/nemu_ipc/external_renderer_ipc.py +284 -284
  30. kotonebot/client/implements/nemu_ipc/nemu_ipc.py +327 -327
  31. kotonebot/client/implements/remote_windows.py +188 -188
  32. kotonebot/client/implements/uiautomator2.py +85 -81
  33. kotonebot/client/implements/windows.py +176 -172
  34. kotonebot/client/protocol.py +69 -69
  35. kotonebot/client/registration.py +24 -24
  36. kotonebot/config/base_config.py +96 -96
  37. kotonebot/config/manager.py +36 -36
  38. kotonebot/errors.py +76 -71
  39. kotonebot/interop/win/__init__.py +10 -3
  40. kotonebot/interop/win/_mouse.py +311 -0
  41. kotonebot/interop/win/message_box.py +313 -313
  42. kotonebot/interop/win/reg.py +37 -37
  43. kotonebot/interop/win/shortcut.py +43 -43
  44. kotonebot/interop/win/task_dialog.py +513 -513
  45. kotonebot/logging/__init__.py +2 -2
  46. kotonebot/logging/log.py +17 -17
  47. kotonebot/primitives/__init__.py +17 -17
  48. kotonebot/primitives/geometry.py +862 -290
  49. kotonebot/primitives/visual.py +63 -63
  50. kotonebot/tools/mirror.py +354 -354
  51. kotonebot/ui/file_host/sensio.py +36 -36
  52. kotonebot/ui/file_host/tmp_send.py +54 -54
  53. kotonebot/ui/pushkit/__init__.py +3 -3
  54. kotonebot/ui/pushkit/image_host.py +88 -87
  55. kotonebot/ui/pushkit/protocol.py +13 -13
  56. kotonebot/ui/pushkit/wxpusher.py +54 -53
  57. kotonebot/ui/user.py +148 -148
  58. kotonebot/util.py +436 -436
  59. {kotonebot-0.4.0.dist-info → kotonebot-0.5.0.dist-info}/METADATA +82 -81
  60. kotonebot-0.5.0.dist-info/RECORD +71 -0
  61. {kotonebot-0.4.0.dist-info → kotonebot-0.5.0.dist-info}/licenses/LICENSE +673 -673
  62. kotonebot-0.4.0.dist-info/RECORD +0 -70
  63. {kotonebot-0.4.0.dist-info → kotonebot-0.5.0.dist-info}/WHEEL +0 -0
  64. {kotonebot-0.4.0.dist-info → kotonebot-0.5.0.dist-info}/top_level.txt +0 -0
kotonebot/backend/ocr.py CHANGED
@@ -1,529 +1,535 @@
1
- import re
2
- import time
3
- import logging
4
- import unicodedata
5
- from functools import lru_cache
6
- from dataclasses import dataclass
7
- import warnings
8
- from typing_extensions import Self, deprecated
9
- from typing import Callable, NamedTuple
10
-
11
- import cv2
12
- import numpy as np
13
- from cv2.typing import MatLike
14
- from thefuzz import fuzz as _fuzz
15
- from rapidocr_onnxruntime import RapidOCR
16
-
17
-
18
- from ..util import lf_path
19
- from ..primitives import Rect, Point
20
- from .core import HintBox, Image, unify_image
21
- from .debug import result as debug_result, debug
22
-
23
- logger = logging.getLogger(__name__)
24
- StringMatchFunction = Callable[[str], bool]
25
- REGEX_NUMBERS = re.compile(r'\d+')
26
-
27
- global_character_mapping: dict[str, str] = {
28
- 'ó': '6',
29
- 'ą': 'a',
30
- }
31
- """
32
- 全局字符映射表。某些字符可能在某些情况下被错误地识别,此时可以在这里添加映射。
33
- """
34
-
35
- def sanitize_text(text: str) -> str:
36
- """
37
- 对识别结果进行清理。此函数将被所有 OCR 引擎调用。
38
-
39
- 默认使用 `global_character_mapping` 中的映射数据进行清理。
40
- 可以重写此函数以实现自定义的清理逻辑。
41
- """
42
- for k, v in global_character_mapping.items():
43
- text = text.replace(k, v)
44
- return text
45
-
46
- @dataclass
47
- class OcrResult:
48
- text: str
49
- rect: Rect
50
- confidence: float
51
- original_rect: Rect
52
- """
53
- 识别结果在原图中的区域坐标。
54
-
55
- 如果识别时没有设置 `rect` 或 `hint` 参数,则此属性值与 `rect` 相同。
56
- """
57
-
58
- def __repr__(self) -> str:
59
- return f'OcrResult(text="{self.text}", rect={self.rect}, confidence={self.confidence})'
60
-
61
- def replace(self, old: str, new: str, count: int = -1) -> Self:
62
- """
63
- 替换识别结果中的文本。
64
- """
65
- self.text = self.text.replace(old, new, count)
66
- return self
67
-
68
- def regex(self, pattern: re.Pattern | str) -> list[str]:
69
- """
70
- 提取识别结果中符合正则表达式的文本。
71
- """
72
- if isinstance(pattern, str):
73
- pattern = re.compile(pattern)
74
- return pattern.findall(self.text)
75
-
76
- def numbers(self) -> list[int]:
77
- """
78
- 提取识别结果中的数字。
79
- """
80
- return [int(x) for x in REGEX_NUMBERS.findall(self.text)]
81
-
82
- class OcrResultList(list[OcrResult]):
83
- def squash(self, remove_newlines: bool = True) -> OcrResult:
84
- """
85
- 将所有识别结果合并为一个大结果。
86
- """
87
- if not self:
88
- return OcrResult('', Rect(0, 0, 0, 0), 0, Rect(0, 0, 0, 0))
89
- text = [r.text for r in self]
90
- confidence = sum(r.confidence for r in self) / len(self)
91
- points = []
92
- for r in self:
93
- points.append(Point(r.rect.x1, r.rect.y1))
94
- points.append(Point(r.rect.x1 + r.rect.w, r.rect.y1))
95
- points.append(Point(r.rect.x1, r.rect.y1 + r.rect.h))
96
- points.append(Point(r.rect.x1 + r.rect.w, r.rect.y1 + r.rect.h))
97
- rect = Rect(xywh=bounding_box(points))
98
- text = '\n'.join(text)
99
- if remove_newlines:
100
- text = text.replace('\n', '')
101
- return OcrResult(
102
- text=text,
103
- rect=rect,
104
- confidence=confidence,
105
- original_rect=rect,
106
- )
107
-
108
- def first(self) -> OcrResult | None:
109
- """
110
- 返回第一个识别结果。
111
- """
112
- return self[0] if self else None
113
-
114
- def where(self, pattern: StringMatchFunction) -> 'OcrResultList':
115
- """
116
- 返回符合条件的识别结果。
117
- """
118
- return OcrResultList([x for x in self if pattern(x.text)])
119
-
120
- class TextNotFoundError(Exception):
121
- def __init__(self, pattern: str | re.Pattern | StringMatchFunction, image: 'MatLike'):
122
- self.pattern = pattern
123
- self.image = image
124
- super().__init__(f"Expected text not found: {pattern}")
125
-
126
- class TextComparator:
127
- def __init__(self, name: str, text: str, func: Callable[[str], bool]):
128
- self.name = name
129
- self.text = text
130
- self.func = func
131
-
132
- def __call__(self, text: str) -> bool:
133
- return self.func(text)
134
-
135
- def __repr__(self) -> str:
136
- return f'{self.name}("{self.text}")'
137
-
138
- @deprecated("即将移除")
139
- @lru_cache(maxsize=1000)
140
- def fuzz(text: str) -> TextComparator:
141
- """返回 fuzzy 算法的字符串匹配函数。"""
142
- func = lambda s: _fuzz.ratio(s, text) > 90
143
- return TextComparator("fuzzy", text, func)
144
-
145
- @lru_cache(maxsize=1000)
146
- def regex(regex: str) -> TextComparator:
147
- """返回正则表达式字符串匹配函数。"""
148
- func = lambda s: re.match(regex, s) is not None
149
- return TextComparator("regex", regex, func)
150
-
151
- @lru_cache(maxsize=1000)
152
- def contains(text: str, *, ignore_case: bool = False) -> TextComparator:
153
- """返回包含指定文本的函数。"""
154
- if ignore_case:
155
- func = lambda s: text.lower() in s.lower()
156
- else:
157
- func = lambda s: text in s
158
- return TextComparator("contains", text, func)
159
-
160
- @lru_cache(maxsize=1000)
161
- def equals(
162
- text: str,
163
- *,
164
- remove_space: bool = False,
165
- ignore_case: bool = True,
166
- ) -> TextComparator:
167
- """
168
- 返回等于指定文本的函数。
169
-
170
- :param text: 要比较的文本。
171
- :param remove_space: 是否忽略空格。默认为 False。
172
- :param ignore_case: 是否忽略大小写。默认为 True。
173
- """
174
- def compare(s: str) -> bool:
175
- nonlocal text
176
-
177
- if ignore_case:
178
- text = text.lower()
179
- s = s.lower()
180
- if remove_space:
181
- text = text.replace(' ', '').replace(' ', '')
182
- s = s.replace(' ', '').replace(' ', '')
183
-
184
- return text == s
185
- return TextComparator("equals", text, compare)
186
-
187
- def grayscaled(img: 'MatLike | str | Image') -> MatLike:
188
- img = unify_image(img)
189
- return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
190
-
191
- def _is_match(text: str, pattern: re.Pattern | str | StringMatchFunction | TextComparator) -> bool:
192
- if isinstance(pattern, re.Pattern):
193
- return pattern.match(text) is not None
194
- elif callable(pattern):
195
- return pattern(text)
196
- else:
197
- return text == pattern
198
-
199
- # https://stackoverflow.com/questions/46335488/how-to-efficiently-find-the-bounding-box-of-a-collection-of-points
200
- def _bounding_box(points):
201
- x_coordinates, y_coordinates = zip(*points)
202
-
203
- return [(min(x_coordinates), min(y_coordinates)), (max(x_coordinates), max(y_coordinates))]
204
-
205
- def bounding_box(points: list[tuple[int, int]]) -> tuple[int, int, int, int]:
206
- """
207
- 计算点集的外接矩形。
208
-
209
- :param points: 点集。以左上角为原点,向下向右为正方向。
210
- :return: 外接矩形的左上角坐标和宽高
211
- """
212
- topleft, bottomright = _bounding_box(points)
213
- return (topleft[0], topleft[1], bottomright[0] - topleft[0], bottomright[1] - topleft[1])
214
-
215
- def pad_to(img: MatLike, target_size: int, rgb: tuple[int, int, int] = (255, 255, 255)) -> tuple[MatLike, tuple[int, int]]:
216
- """
217
- 将图像居中填充到指定大小。缺少部分使用指定颜色填充。
218
-
219
- :return: 填充后的图像和填充的偏移量 (x, y)
220
- """
221
- h, w = img.shape[:2]
222
-
223
- # 计算需要填充的宽高
224
- pad_h = max(0, target_size - h)
225
- pad_w = max(0, target_size - w)
226
-
227
- # 如果不需要填充则直接返回
228
- if pad_h == 0 and pad_w == 0:
229
- return img, (0, 0)
230
-
231
- # 创建目标画布并填充
232
- if len(img.shape) == 2:
233
- # 灰度图像
234
- ret = np.full((h + pad_h, w + pad_w), rgb[0], dtype=np.uint8)
235
- else:
236
- # RGB图像
237
- ret = np.full((h + pad_h, w + pad_w, 3), rgb, dtype=np.uint8)
238
-
239
- # 将原图像居中放置
240
- if len(img.shape) == 2:
241
- ret[
242
- pad_h // 2:pad_h // 2 + h,
243
- pad_w // 2:pad_w // 2 + w] = img
244
- else:
245
- ret[
246
- pad_h // 2:pad_h // 2 + h,
247
- pad_w // 2:pad_w // 2 + w, :] = img
248
- return ret, (pad_w // 2, pad_h // 2)
249
-
250
- def _draw_result(image: 'MatLike', result: list[OcrResult]) -> 'MatLike':
251
- import numpy as np
252
- from PIL import Image, ImageDraw, ImageFont
253
-
254
- # 转换为PIL图像
255
- result_image = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2RGB)
256
- pil_image = Image.fromarray(result_image)
257
- draw = ImageDraw.Draw(pil_image, 'RGBA')
258
-
259
- # 加载字体
260
- try:
261
- font = ImageFont.truetype(lf_path('res/fonts/SourceHanSansHW-Regular.otf'), 16)
262
- except:
263
- font = ImageFont.load_default()
264
-
265
- for r in result:
266
- # 画矩形框
267
- draw.rectangle(
268
- [r.rect.x1, r.rect.y1, r.rect.x1 + r.rect.w, r.rect.y1 + r.rect.h],
269
- outline=(255, 0, 0),
270
- width=2
271
- )
272
-
273
- # 获取文本大小
274
- text = r.text + f" ({r.confidence:.2f})" # 添加置信度显示
275
- text_bbox = draw.textbbox((0, 0), text, font=font)
276
- text_width = text_bbox[2] - text_bbox[0]
277
- text_height = text_bbox[3] - text_bbox[1]
278
-
279
- # 计算文本位置
280
- text_x = r.rect.x1
281
- text_y = r.rect.y1 - text_height - 5 if r.rect.y1 > text_height + 5 else r.rect.y1 + r.rect.h + 5
282
-
283
- # 添加padding
284
- padding = 4
285
- bg_rect = [
286
- text_x - padding,
287
- text_y - padding,
288
- text_x + text_width + padding,
289
- text_y + text_height + padding
290
- ]
291
-
292
- # 画半透明背景
293
- draw.rectangle(
294
- bg_rect,
295
- fill=(0, 0, 0, 128)
296
- )
297
-
298
- # 画文字
299
- draw.text(
300
- (text_x, text_y),
301
- text,
302
- font=font,
303
- fill=(255, 255, 255)
304
- )
305
-
306
- # 转回OpenCV格式
307
- result_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
308
- return result_image
309
-
310
- class Ocr:
311
- def __init__(self, engine: RapidOCR):
312
- self.__engine = engine
313
-
314
- # TODO: 考虑缓存 OCR 结果,避免重复调用。
315
- def ocr(
316
- self,
317
- img: 'MatLike',
318
- *,
319
- rect: Rect | None = None,
320
- pad: bool = True,
321
- ) -> OcrResultList:
322
- """
323
- OCR 一个 cv2 的图像。注意识别结果中的**全角字符会被转换为半角字符**。
324
-
325
-
326
- :param rect: 如果指定,则只识别指定矩形区域。
327
- :param pad:
328
- 是否将过小的图像(尺寸 < 631x631)的图像填充到 631x631。
329
- 默认为 True。
330
-
331
- 对于 PaddleOCR 模型,图片尺寸太小会降低准确率。
332
- 将图片周围填充放大,有助于提高准确率,降低耗时。
333
- :return: 所有识别结果
334
- """
335
- if rect is not None:
336
- x, y, w, h = rect.xywh
337
- img = img[y:y+h, x:x+w]
338
- original_img = img
339
- if pad:
340
- # TODO: 详细研究哪个尺寸最佳,以及背景颜色、图片位置是否对准确率与耗时有影响
341
- # https://blog.csdn.net/YY007H/article/details/124973777
342
- original_img = img.copy()
343
- img, pos_in_padded_img = pad_to(img, 631)
344
- else:
345
- pos_in_padded_img = (0, 0)
346
- img_content = img
347
- result, elapse = self.__engine(img_content)
348
- if result is None:
349
- return OcrResultList()
350
- ret = []
351
- for r in result:
352
- text = sanitize_text(r[1])
353
- # r[0] = [左上, 右上, 右下, 左下]
354
- # 这里有个坑,返回的点不一定是矩形,只能保证是四边形
355
- # 所以这里需要计算出四个点的外接矩形
356
- result_rect = tuple(int(x) for x in bounding_box(r[0])) # type: ignore
357
- # result_rect (x, y, w, h)
358
- if rect is not None:
359
- original_rect = (
360
- result_rect[0] + rect.x1 - pos_in_padded_img[0],
361
- result_rect[1] + rect.y1 - pos_in_padded_img[1],
362
- result_rect[2],
363
- result_rect[3]
364
- )
365
- else:
366
- original_rect = result_rect
367
- if not len(original_rect) == 4:
368
- raise ValueError(f'Invalid original_rect: {original_rect}')
369
- if not len(result_rect) == 4:
370
- raise ValueError(f'Invalid result_rect: {result_rect}')
371
- confidence = float(r[2])
372
- ret.append(OcrResult(
373
- text=text,
374
- rect=Rect(xywh=result_rect),
375
- original_rect=Rect(xywh=original_rect),
376
- confidence=confidence
377
- ))
378
- ret = OcrResultList(ret)
379
- if debug.enabled:
380
- result_image = _draw_result(img, ret)
381
- elapse = elapse or [0, 0, 0]
382
- debug_result(
383
- 'ocr',
384
- [result_image, original_img],
385
- f"pad={pad}\n" + \
386
- f"rect={rect}\n" + \
387
- f"elapsed: det={elapse[0]:.3f}s cls={elapse[1]:.3f}s rec={elapse[2]:.3f}s\n" + \
388
- f"result: \n" + \
389
- "<table class='result-table'><tr><th>Text</th><th>Confidence</th></tr>" + \
390
- "\n".join([f"<tr><td>{r.text}</td><td>{r.confidence:.3f}</td></tr>" for r in ret]) + \
391
- "</table>"
392
- )
393
- return ret
394
-
395
- def find(
396
- self,
397
- img: 'MatLike',
398
- text: str | re.Pattern | StringMatchFunction,
399
- *,
400
- hint: HintBox | None = None,
401
- rect: Rect | None = None,
402
- pad: bool = True,
403
- ) -> OcrResult | None:
404
- """
405
- 识别图像中的文本,并寻找满足指定要求的文本。
406
-
407
- :param hint: 如果指定,则首先只识别 HintBox 范围内的文本,若未命中,再全局寻找。
408
- :param rect: 如果指定,则只识别指定矩形区域。此参数优先级低于 `hint`。
409
- :param pad: `ocr` 的 `pad` 参数。
410
- :return: 找到的文本,如果未找到则返回 None
411
- """
412
- if hint is not None:
413
- warnings.warn("使用 `rect` 参数代替")
414
- if ret := self.find(img, text, rect=Rect(xywh=hint.rect)):
415
- logger.debug(f"find: {text} SUCCESS [hint={hint}]")
416
- return ret
417
- logger.debug(f"find: {text} FAILED [hint={hint}]")
418
-
419
- start_time = time.time()
420
- results = self.ocr(img, rect=rect, pad=pad)
421
- end_time = time.time()
422
- target = None
423
- for result in results:
424
- if _is_match(result.text, text):
425
- target = result
426
- break
427
- logger.debug(
428
- f"find: {text} {'SUCCESS' if target else 'FAILED'} " + \
429
- f"[elapsed={end_time - start_time:.3f}s] [rect={rect}]"
430
- )
431
- return target
432
-
433
- def find_all(
434
- self,
435
- img: 'MatLike',
436
- texts: list[str | re.Pattern | StringMatchFunction],
437
- *,
438
- hint: HintBox | None = None,
439
- rect: Rect | None = None,
440
- pad: bool = True,
441
- ) -> list[OcrResult | None]:
442
- """
443
- 识别图像中的文本,并寻找多个满足指定要求的文本。
444
-
445
- :return:
446
- 所有找到的文本,结果顺序与输入顺序相同。
447
- 若某个文本未找到,则该位置为 None
448
- """
449
- # HintBox 处理
450
- if hint is not None:
451
- warnings.warn("使用 `rect` 参数代替")
452
- result = self.find_all(img, texts, rect=Rect(xywh=hint.rect), pad=pad)
453
- if all(result):
454
- return result
455
-
456
- ret: list[OcrResult | None] = []
457
- ocr_results = self.ocr(img, rect=rect, pad=pad)
458
- logger.debug(f"ocr_results: {ocr_results}")
459
- for text in texts:
460
- for result in ocr_results:
461
- if _is_match(result.text, text):
462
- ret.append(result)
463
- break
464
- else:
465
- ret.append(None)
466
- return ret
467
-
468
- def expect(
469
- self,
470
- img: 'MatLike',
471
- text: str | re.Pattern | StringMatchFunction,
472
- *,
473
- hint: HintBox | None = None,
474
- rect: Rect | None = None,
475
- pad: bool = True,
476
- ) -> OcrResult:
477
- """
478
- 识别图像中的文本,并寻找满足指定要求的文本。如果未找到则抛出异常。
479
-
480
- :param hint: 如果指定,则首先只识别 HintBox 范围内的文本,若未命中,再全局寻找。
481
- :param rect: 如果指定,则只识别指定矩形区域。此参数优先级高于 `hint`。
482
- :param pad: 见 `ocr` 的 `pad` 参数。
483
- :return: 找到的文本
484
- """
485
- ret = self.find(img, text, hint=hint, rect=rect, pad=pad)
486
- if ret is None:
487
- raise TextNotFoundError(text, img)
488
- return ret
489
-
490
- # TODO: 这个路径需要能够独立设置
491
- _engine_jp: RapidOCR | None = None
492
- _engine_en: RapidOCR | None = RapidOCR(
493
- rec_model_path=lf_path('models/en_PP-OCRv3_rec_infer.onnx'),
494
- use_det=True,
495
- use_cls=False,
496
- use_rec=True,
497
- )
498
-
499
- def jp() -> Ocr:
500
- """
501
- 日语 OCR 引擎。
502
- """
503
- global _engine_jp
504
- if _engine_jp is None:
505
- _engine_jp = RapidOCR(
506
- rec_model_path=lf_path('models/japan_PP-OCRv3_rec_infer.onnx'),
507
- use_det=True,
508
- use_cls=False,
509
- use_rec=True,
510
- )
511
- return Ocr(_engine_jp)
512
-
513
- def en() -> Ocr:
514
- """
515
- 英语 OCR 引擎。
516
- """
517
- global _engine_en
518
- if _engine_en is None:
519
- _engine_en = RapidOCR(
520
- rec_model_path=lf_path('models/en_PP-OCRv3_rec_infer.onnx'),
521
- use_det=True,
522
- use_cls=False,
523
- use_rec=True,
524
- )
525
- return Ocr(_engine_en)
526
-
527
-
528
- if __name__ == '__main__':
529
- pass
1
+ import re
2
+ import time
3
+ import logging
4
+ import unicodedata
5
+ from functools import lru_cache
6
+ from dataclasses import dataclass
7
+ import warnings
8
+ from typing_extensions import Self, deprecated
9
+ from typing import Callable, NamedTuple
10
+
11
+ import cv2
12
+ import numpy as np
13
+ from cv2.typing import MatLike
14
+ from thefuzz import fuzz as _fuzz
15
+ from rapidocr_onnxruntime import RapidOCR
16
+
17
+
18
+ from ..util import lf_path
19
+ from ..primitives import Rect, Point
20
+ from .core import HintBox, Image, unify_image
21
+ from .debug import result as debug_result, debug
22
+
23
+ logger = logging.getLogger(__name__)
24
+ StringMatchFunction = Callable[[str], bool]
25
+ REGEX_NUMBERS = re.compile(r'\d+')
26
+
27
+ global_character_mapping: dict[str, str] = {
28
+ 'ó': '6',
29
+ 'ą': 'a',
30
+ }
31
+ """
32
+ 全局字符映射表。某些字符可能在某些情况下被错误地识别,此时可以在这里添加映射。
33
+ """
34
+
35
+ def sanitize_text(text: str) -> str:
36
+ """
37
+ 对识别结果进行清理。此函数将被所有 OCR 引擎调用。
38
+
39
+ 默认行为为先将文本 `Unicode 规范化`_,然后使用 `global_character_mapping` 中的映射数据进行清理。
40
+ 可以重写此函数以实现自定义的清理逻辑。
41
+
42
+ .. note::
43
+ Unicode 规范化最常见的一个行为是将全角字符转换为半角字符。
44
+
45
+ .. _Unicode 规范化: https://docs.python.org/zh-cn/3.14/library/unicodedata.html#unicodedata.normalize
46
+ """
47
+ text = unicodedata.normalize('NFKC', text)
48
+ for k, v in global_character_mapping.items():
49
+ text = text.replace(k, v)
50
+ return text
51
+
52
+ @dataclass
53
+ class OcrResult:
54
+ text: str
55
+ rect: Rect
56
+ confidence: float
57
+ original_rect: Rect
58
+ """
59
+ 识别结果在原图中的区域坐标。
60
+
61
+ 如果识别时没有设置 `rect` `hint` 参数,则此属性值与 `rect` 相同。
62
+ """
63
+
64
+ def __repr__(self) -> str:
65
+ return f'OcrResult(text="{self.text}", rect={self.rect}, confidence={self.confidence})'
66
+
67
+ def replace(self, old: str, new: str, count: int = -1) -> Self:
68
+ """
69
+ 替换识别结果中的文本。
70
+ """
71
+ self.text = self.text.replace(old, new, count)
72
+ return self
73
+
74
+ def regex(self, pattern: re.Pattern | str) -> list[str]:
75
+ """
76
+ 提取识别结果中符合正则表达式的文本。
77
+ """
78
+ if isinstance(pattern, str):
79
+ pattern = re.compile(pattern)
80
+ return pattern.findall(self.text)
81
+
82
+ def numbers(self) -> list[int]:
83
+ """
84
+ 提取识别结果中的数字。
85
+ """
86
+ return [int(x) for x in REGEX_NUMBERS.findall(self.text)]
87
+
88
+ class OcrResultList(list[OcrResult]):
89
+ def squash(self, remove_newlines: bool = True) -> OcrResult:
90
+ """
91
+ 将所有识别结果合并为一个大结果。
92
+ """
93
+ if not self:
94
+ return OcrResult('', Rect(0, 0, 0, 0), 0, Rect(0, 0, 0, 0))
95
+ text = [r.text for r in self]
96
+ confidence = sum(r.confidence for r in self) / len(self)
97
+ points = []
98
+ for r in self:
99
+ points.append(Point(r.rect.x1, r.rect.y1))
100
+ points.append(Point(r.rect.x1 + r.rect.w, r.rect.y1))
101
+ points.append(Point(r.rect.x1, r.rect.y1 + r.rect.h))
102
+ points.append(Point(r.rect.x1 + r.rect.w, r.rect.y1 + r.rect.h))
103
+ rect = Rect(xywh=bounding_box(points))
104
+ text = '\n'.join(text)
105
+ if remove_newlines:
106
+ text = text.replace('\n', '')
107
+ return OcrResult(
108
+ text=text,
109
+ rect=rect,
110
+ confidence=confidence,
111
+ original_rect=rect,
112
+ )
113
+
114
+ def first(self) -> OcrResult | None:
115
+ """
116
+ 返回第一个识别结果。
117
+ """
118
+ return self[0] if self else None
119
+
120
+ def where(self, pattern: StringMatchFunction) -> 'OcrResultList':
121
+ """
122
+ 返回符合条件的识别结果。
123
+ """
124
+ return OcrResultList([x for x in self if pattern(x.text)])
125
+
126
+ class TextNotFoundError(Exception):
127
+ def __init__(self, pattern: str | re.Pattern | StringMatchFunction, image: 'MatLike'):
128
+ self.pattern = pattern
129
+ self.image = image
130
+ super().__init__(f"Expected text not found: {pattern}")
131
+
132
+ class TextComparator:
133
+ def __init__(self, name: str, text: str, func: Callable[[str], bool]):
134
+ self.name = name
135
+ self.text = text
136
+ self.func = func
137
+
138
+ def __call__(self, text: str) -> bool:
139
+ return self.func(text)
140
+
141
+ def __repr__(self) -> str:
142
+ return f'{self.name}("{self.text}")'
143
+
144
+ @deprecated("即将移除")
145
+ @lru_cache(maxsize=1000)
146
+ def fuzz(text: str) -> TextComparator:
147
+ """返回 fuzzy 算法的字符串匹配函数。"""
148
+ func = lambda s: _fuzz.ratio(s, text) > 90
149
+ return TextComparator("fuzzy", text, func)
150
+
151
+ @lru_cache(maxsize=1000)
152
+ def regex(regex: str) -> TextComparator:
153
+ """返回正则表达式字符串匹配函数。"""
154
+ func = lambda s: re.match(regex, s) is not None
155
+ return TextComparator("regex", regex, func)
156
+
157
+ @lru_cache(maxsize=1000)
158
+ def contains(text: str, *, ignore_case: bool = False) -> TextComparator:
159
+ """返回包含指定文本的函数。"""
160
+ if ignore_case:
161
+ func = lambda s: text.lower() in s.lower()
162
+ else:
163
+ func = lambda s: text in s
164
+ return TextComparator("contains", text, func)
165
+
166
+ @lru_cache(maxsize=1000)
167
+ def equals(
168
+ text: str,
169
+ *,
170
+ remove_space: bool = False,
171
+ ignore_case: bool = True,
172
+ ) -> TextComparator:
173
+ """
174
+ 返回等于指定文本的函数。
175
+
176
+ :param text: 要比较的文本。
177
+ :param remove_space: 是否忽略空格。默认为 False。
178
+ :param ignore_case: 是否忽略大小写。默认为 True。
179
+ """
180
+ def compare(s: str) -> bool:
181
+ nonlocal text
182
+
183
+ if ignore_case:
184
+ text = text.lower()
185
+ s = s.lower()
186
+ if remove_space:
187
+ text = text.replace(' ', '').replace(' ', '')
188
+ s = s.replace(' ', '').replace(' ', '')
189
+
190
+ return text == s
191
+ return TextComparator("equals", text, compare)
192
+
193
+ def grayscaled(img: 'MatLike | str | Image') -> MatLike:
194
+ img = unify_image(img)
195
+ return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
196
+
197
+ def _is_match(text: str, pattern: re.Pattern | str | StringMatchFunction | TextComparator) -> bool:
198
+ if isinstance(pattern, re.Pattern):
199
+ return pattern.match(text) is not None
200
+ elif callable(pattern):
201
+ return pattern(text)
202
+ else:
203
+ return text == pattern
204
+
205
+ # https://stackoverflow.com/questions/46335488/how-to-efficiently-find-the-bounding-box-of-a-collection-of-points
206
+ def _bounding_box(points):
207
+ x_coordinates, y_coordinates = zip(*points)
208
+
209
+ return [(min(x_coordinates), min(y_coordinates)), (max(x_coordinates), max(y_coordinates))]
210
+
211
+ def bounding_box(points: list[tuple[int, int]]) -> tuple[int, int, int, int]:
212
+ """
213
+ 计算点集的外接矩形。
214
+
215
+ :param points: 点集。以左上角为原点,向下向右为正方向。
216
+ :return: 外接矩形的左上角坐标和宽高
217
+ """
218
+ topleft, bottomright = _bounding_box(points)
219
+ return (topleft[0], topleft[1], bottomright[0] - topleft[0], bottomright[1] - topleft[1])
220
+
221
+ def pad_to(img: MatLike, target_size: int, rgb: tuple[int, int, int] = (255, 255, 255)) -> tuple[MatLike, tuple[int, int]]:
222
+ """
223
+ 将图像居中填充到指定大小。缺少部分使用指定颜色填充。
224
+
225
+ :return: 填充后的图像和填充的偏移量 (x, y)
226
+ """
227
+ h, w = img.shape[:2]
228
+
229
+ # 计算需要填充的宽高
230
+ pad_h = max(0, target_size - h)
231
+ pad_w = max(0, target_size - w)
232
+
233
+ # 如果不需要填充则直接返回
234
+ if pad_h == 0 and pad_w == 0:
235
+ return img, (0, 0)
236
+
237
+ # 创建目标画布并填充
238
+ if len(img.shape) == 2:
239
+ # 灰度图像
240
+ ret = np.full((h + pad_h, w + pad_w), rgb[0], dtype=np.uint8)
241
+ else:
242
+ # RGB图像
243
+ ret = np.full((h + pad_h, w + pad_w, 3), rgb, dtype=np.uint8)
244
+
245
+ # 将原图像居中放置
246
+ if len(img.shape) == 2:
247
+ ret[
248
+ pad_h // 2:pad_h // 2 + h,
249
+ pad_w // 2:pad_w // 2 + w] = img
250
+ else:
251
+ ret[
252
+ pad_h // 2:pad_h // 2 + h,
253
+ pad_w // 2:pad_w // 2 + w, :] = img
254
+ return ret, (pad_w // 2, pad_h // 2)
255
+
256
+ def _draw_result(image: 'MatLike', result: list[OcrResult]) -> 'MatLike':
257
+ import numpy as np
258
+ from PIL import Image, ImageDraw, ImageFont
259
+
260
+ # 转换为PIL图像
261
+ result_image = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2RGB)
262
+ pil_image = Image.fromarray(result_image)
263
+ draw = ImageDraw.Draw(pil_image, 'RGBA')
264
+
265
+ # 加载字体
266
+ try:
267
+ font = ImageFont.truetype(lf_path('res/fonts/SourceHanSansHW-Regular.otf'), 16)
268
+ except:
269
+ font = ImageFont.load_default()
270
+
271
+ for r in result:
272
+ # 画矩形框
273
+ draw.rectangle(
274
+ [r.rect.x1, r.rect.y1, r.rect.x1 + r.rect.w, r.rect.y1 + r.rect.h],
275
+ outline=(255, 0, 0),
276
+ width=2
277
+ )
278
+
279
+ # 获取文本大小
280
+ text = r.text + f" ({r.confidence:.2f})" # 添加置信度显示
281
+ text_bbox = draw.textbbox((0, 0), text, font=font)
282
+ text_width = text_bbox[2] - text_bbox[0]
283
+ text_height = text_bbox[3] - text_bbox[1]
284
+
285
+ # 计算文本位置
286
+ text_x = r.rect.x1
287
+ text_y = r.rect.y1 - text_height - 5 if r.rect.y1 > text_height + 5 else r.rect.y1 + r.rect.h + 5
288
+
289
+ # 添加padding
290
+ padding = 4
291
+ bg_rect = [
292
+ text_x - padding,
293
+ text_y - padding,
294
+ text_x + text_width + padding,
295
+ text_y + text_height + padding
296
+ ]
297
+
298
+ # 画半透明背景
299
+ draw.rectangle(
300
+ bg_rect,
301
+ fill=(0, 0, 0, 128)
302
+ )
303
+
304
+ # 画文字
305
+ draw.text(
306
+ (text_x, text_y),
307
+ text,
308
+ font=font,
309
+ fill=(255, 255, 255)
310
+ )
311
+
312
+ # 转回OpenCV格式
313
+ result_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
314
+ return result_image
315
+
316
+ class Ocr:
317
+ def __init__(self, engine: RapidOCR):
318
+ self.__engine = engine
319
+
320
+ # TODO: 考虑缓存 OCR 结果,避免重复调用。
321
+ def ocr(
322
+ self,
323
+ img: 'MatLike',
324
+ *,
325
+ rect: Rect | None = None,
326
+ pad: bool = True,
327
+ ) -> OcrResultList:
328
+ """
329
+ OCR 一个 cv2 的图像。注意识别结果中的**全角字符会被转换为半角字符**。
330
+
331
+
332
+ :param rect: 如果指定,则只识别指定矩形区域。
333
+ :param pad:
334
+ 是否将过小的图像(尺寸 < 631x631)的图像填充到 631x631。
335
+ 默认为 True。
336
+
337
+ 对于 PaddleOCR 模型,图片尺寸太小会降低准确率。
338
+ 将图片周围填充放大,有助于提高准确率,降低耗时。
339
+ :return: 所有识别结果
340
+ """
341
+ if rect is not None:
342
+ x, y, w, h = rect.xywh
343
+ img = img[y:y+h, x:x+w]
344
+ original_img = img
345
+ if pad:
346
+ # TODO: 详细研究哪个尺寸最佳,以及背景颜色、图片位置是否对准确率与耗时有影响
347
+ # https://blog.csdn.net/YY007H/article/details/124973777
348
+ original_img = img.copy()
349
+ img, pos_in_padded_img = pad_to(img, 631)
350
+ else:
351
+ pos_in_padded_img = (0, 0)
352
+ img_content = img
353
+ result, elapse = self.__engine(img_content)
354
+ if result is None:
355
+ return OcrResultList()
356
+ ret = []
357
+ for r in result:
358
+ text = sanitize_text(r[1])
359
+ # r[0] = [左上, 右上, 右下, 左下]
360
+ # 这里有个坑,返回的点不一定是矩形,只能保证是四边形
361
+ # 所以这里需要计算出四个点的外接矩形
362
+ result_rect = tuple(int(x) for x in bounding_box(r[0])) # type: ignore
363
+ # result_rect (x, y, w, h)
364
+ if rect is not None:
365
+ original_rect = (
366
+ result_rect[0] + rect.x1 - pos_in_padded_img[0],
367
+ result_rect[1] + rect.y1 - pos_in_padded_img[1],
368
+ result_rect[2],
369
+ result_rect[3]
370
+ )
371
+ else:
372
+ original_rect = result_rect
373
+ if not len(original_rect) == 4:
374
+ raise ValueError(f'Invalid original_rect: {original_rect}')
375
+ if not len(result_rect) == 4:
376
+ raise ValueError(f'Invalid result_rect: {result_rect}')
377
+ confidence = float(r[2])
378
+ ret.append(OcrResult(
379
+ text=text,
380
+ rect=Rect(xywh=result_rect),
381
+ original_rect=Rect(xywh=original_rect),
382
+ confidence=confidence
383
+ ))
384
+ ret = OcrResultList(ret)
385
+ if debug.enabled:
386
+ result_image = _draw_result(img, ret)
387
+ elapse = elapse or [0, 0, 0]
388
+ debug_result(
389
+ 'ocr',
390
+ [result_image, original_img],
391
+ f"pad={pad}\n" + \
392
+ f"rect={rect}\n" + \
393
+ f"elapsed: det={elapse[0]:.3f}s cls={elapse[1]:.3f}s rec={elapse[2]:.3f}s\n" + \
394
+ f"result: \n" + \
395
+ "<table class='result-table'><tr><th>Text</th><th>Confidence</th></tr>" + \
396
+ "\n".join([f"<tr><td>{r.text}</td><td>{r.confidence:.3f}</td></tr>" for r in ret]) + \
397
+ "</table>"
398
+ )
399
+ return ret
400
+
401
+ def find(
402
+ self,
403
+ img: 'MatLike',
404
+ text: str | re.Pattern | StringMatchFunction,
405
+ *,
406
+ hint: HintBox | None = None,
407
+ rect: Rect | None = None,
408
+ pad: bool = True,
409
+ ) -> OcrResult | None:
410
+ """
411
+ 识别图像中的文本,并寻找满足指定要求的文本。
412
+
413
+ :param hint: 如果指定,则首先只识别 HintBox 范围内的文本,若未命中,再全局寻找。
414
+ :param rect: 如果指定,则只识别指定矩形区域。此参数优先级低于 `hint`。
415
+ :param pad: `ocr` 的 `pad` 参数。
416
+ :return: 找到的文本,如果未找到则返回 None
417
+ """
418
+ if hint is not None:
419
+ warnings.warn("使用 `rect` 参数代替")
420
+ if ret := self.find(img, text, rect=Rect(xywh=hint.rect)):
421
+ logger.debug(f"find: {text} SUCCESS [hint={hint}]")
422
+ return ret
423
+ logger.debug(f"find: {text} FAILED [hint={hint}]")
424
+
425
+ start_time = time.time()
426
+ results = self.ocr(img, rect=rect, pad=pad)
427
+ end_time = time.time()
428
+ target = None
429
+ for result in results:
430
+ if _is_match(result.text, text):
431
+ target = result
432
+ break
433
+ logger.debug(
434
+ f"find: {text} {'SUCCESS' if target else 'FAILED'} " + \
435
+ f"[elapsed={end_time - start_time:.3f}s] [rect={rect}]"
436
+ )
437
+ return target
438
+
439
+ def find_all(
440
+ self,
441
+ img: 'MatLike',
442
+ texts: list[str | re.Pattern | StringMatchFunction],
443
+ *,
444
+ hint: HintBox | None = None,
445
+ rect: Rect | None = None,
446
+ pad: bool = True,
447
+ ) -> list[OcrResult | None]:
448
+ """
449
+ 识别图像中的文本,并寻找多个满足指定要求的文本。
450
+
451
+ :return:
452
+ 所有找到的文本,结果顺序与输入顺序相同。
453
+ 若某个文本未找到,则该位置为 None。
454
+ """
455
+ # HintBox 处理
456
+ if hint is not None:
457
+ warnings.warn("使用 `rect` 参数代替")
458
+ result = self.find_all(img, texts, rect=Rect(xywh=hint.rect), pad=pad)
459
+ if all(result):
460
+ return result
461
+
462
+ ret: list[OcrResult | None] = []
463
+ ocr_results = self.ocr(img, rect=rect, pad=pad)
464
+ logger.debug(f"ocr_results: {ocr_results}")
465
+ for text in texts:
466
+ for result in ocr_results:
467
+ if _is_match(result.text, text):
468
+ ret.append(result)
469
+ break
470
+ else:
471
+ ret.append(None)
472
+ return ret
473
+
474
+ def expect(
475
+ self,
476
+ img: 'MatLike',
477
+ text: str | re.Pattern | StringMatchFunction,
478
+ *,
479
+ hint: HintBox | None = None,
480
+ rect: Rect | None = None,
481
+ pad: bool = True,
482
+ ) -> OcrResult:
483
+ """
484
+ 识别图像中的文本,并寻找满足指定要求的文本。如果未找到则抛出异常。
485
+
486
+ :param hint: 如果指定,则首先只识别 HintBox 范围内的文本,若未命中,再全局寻找。
487
+ :param rect: 如果指定,则只识别指定矩形区域。此参数优先级高于 `hint`。
488
+ :param pad: 见 `ocr` 的 `pad` 参数。
489
+ :return: 找到的文本
490
+ """
491
+ ret = self.find(img, text, hint=hint, rect=rect, pad=pad)
492
+ if ret is None:
493
+ raise TextNotFoundError(text, img)
494
+ return ret
495
+
496
+ # TODO: 这个路径需要能够独立设置
497
+ _engine_jp: RapidOCR | None = None
498
+ _engine_en: RapidOCR | None = RapidOCR(
499
+ rec_model_path=lf_path('models/en_PP-OCRv3_rec_infer.onnx'),
500
+ use_det=True,
501
+ use_cls=False,
502
+ use_rec=True,
503
+ )
504
+
505
+ def jp() -> Ocr:
506
+ """
507
+ 日语 OCR 引擎。
508
+ """
509
+ global _engine_jp
510
+ if _engine_jp is None:
511
+ _engine_jp = RapidOCR(
512
+ rec_model_path=lf_path('models/japan_PP-OCRv3_rec_infer.onnx'),
513
+ use_det=True,
514
+ use_cls=False,
515
+ use_rec=True,
516
+ )
517
+ return Ocr(_engine_jp)
518
+
519
+ def en() -> Ocr:
520
+ """
521
+ 英语 OCR 引擎。
522
+ """
523
+ global _engine_en
524
+ if _engine_en is None:
525
+ _engine_en = RapidOCR(
526
+ rec_model_path=lf_path('models/en_PP-OCRv3_rec_infer.onnx'),
527
+ use_det=True,
528
+ use_cls=False,
529
+ use_rec=True,
530
+ )
531
+ return Ocr(_engine_en)
532
+
533
+
534
+ if __name__ == '__main__':
535
+ pass