screenshot-vision-algorithm 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. screenshot_vision_algorithm/__init__.py +48 -0
  2. screenshot_vision_algorithm/_config.py +61 -0
  3. screenshot_vision_algorithm/android/__init__.py +1 -0
  4. screenshot_vision_algorithm/android/wechat/__init__.py +1 -0
  5. screenshot_vision_algorithm/android/wechat/algorithms/__init__.py +0 -0
  6. screenshot_vision_algorithm/android/wechat/algorithms/avatar_column.py +209 -0
  7. screenshot_vision_algorithm/android/wechat/algorithms/badge_detection.py +275 -0
  8. screenshot_vision_algorithm/android/wechat/algorithms/card_bbox.py +1000 -0
  9. screenshot_vision_algorithm/android/wechat/algorithms/phash_utils.py +267 -0
  10. screenshot_vision_algorithm/android/wechat/algorithms/speaker_band.py +290 -0
  11. screenshot_vision_algorithm/android/wechat/algorithms/template_matching.py +2163 -0
  12. screenshot_vision_algorithm/android/wechat/algorithms/title_ocr.py +143 -0
  13. screenshot_vision_algorithm/android/wechat/merge/__init__.py +0 -0
  14. screenshot_vision_algorithm/android/wechat/merge/multipage.py +157 -0
  15. screenshot_vision_algorithm/android/wechat/ocr/__init__.py +0 -0
  16. screenshot_vision_algorithm/android/wechat/ocr/avatar_guard.py +434 -0
  17. screenshot_vision_algorithm/android/wechat/ocr/badge_ocr.py +232 -0
  18. screenshot_vision_algorithm/android/wechat/ocr/nickname_binding.py +1888 -0
  19. screenshot_vision_algorithm/android/wechat/ocr/text_ocr_adapter.py +625 -0
  20. screenshot_vision_algorithm/android/wechat/profiles/__init__.py +0 -0
  21. screenshot_vision_algorithm/android/wechat/profiles/android.py +53 -0
  22. screenshot_vision_algorithm/android/wechat/profiles/harmony.py +10 -0
  23. screenshot_vision_algorithm/android/wechat/profiles/ios.py +53 -0
  24. screenshot_vision_algorithm/android/wechat/templates/android/8.0.69/chat_back_chevron.png +0 -0
  25. screenshot_vision_algorithm/android/wechat/templates/android/8.0.69/chat_input_emoji_smile.png +0 -0
  26. screenshot_vision_algorithm/android/wechat/templates/android/8.0.69/chat_input_plus.png +0 -0
  27. screenshot_vision_algorithm/android/wechat/templates/android/8.0.69/chat_input_voice.png +0 -0
  28. screenshot_vision_algorithm/android/wechat/templates/android/8.0.69/chat_title_more_dots.png +0 -0
  29. screenshot_vision_algorithm/android/wechat/templates/android/8.0.69/favorite_label.png +0 -0
  30. screenshot_vision_algorithm/android/wechat/templates/android/8.0.69/new_messages_hint_suffix.png +0 -0
  31. screenshot_vision_algorithm/android/wechat/templates/android/8.0.69/unread_divider_hint.png +0 -0
  32. screenshot_vision_algorithm/android/wechat/templates/android/8.0.69/unread_divider_hint_v2_textonly.png +0 -0
  33. screenshot_vision_algorithm/android/wechat/templates/android/8.0.69/wechat_note_header.png +0 -0
  34. screenshot_vision_algorithm/android/xhs/__init__.py +4 -0
  35. screenshot_vision_algorithm/android/zhihu/__init__.py +4 -0
  36. screenshot_vision_algorithm/png_utils.py +86 -0
  37. screenshot_vision_algorithm-0.3.0.dist-info/METADATA +425 -0
  38. screenshot_vision_algorithm-0.3.0.dist-info/RECORD +40 -0
  39. screenshot_vision_algorithm-0.3.0.dist-info/WHEEL +5 -0
  40. screenshot_vision_algorithm-0.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,434 @@
1
+ """昵称行左侧「头像栏」影像守门:标准 ROI + 可选 Hough 头像锚点对齐 ROI。
2
+
3
+ 纹理(Laplacian + RGB)与列边缘形状须 **同时** 满足(布尔过滤,非算分)。
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import Any, Literal, Optional
9
+
10
+ Side = Literal["left", "right"]
11
+
12
+ import numpy as np
13
+ from loguru import logger
14
+
15
+ try:
16
+ import cv2
17
+ except ImportError:
18
+ cv2 = None # type: ignore[assignment]
19
+
20
+ try:
21
+ from processor.left_avatar_column import AvatarCentroid
22
+ except ImportError:
23
+ AvatarCentroid = None # type: ignore[misc, assignment]
24
+
25
+
26
+ def bgr_roi_laplacian_variance(roi: np.ndarray) -> float:
27
+ if cv2 is None or roi.size == 0 or roi.shape[0] < 6 or roi.shape[1] < 6:
28
+ return 0.0
29
+ try:
30
+ roi = np.ascontiguousarray(roi, dtype=np.uint8)
31
+ gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
32
+ blur = cv2.GaussianBlur(gray, (3, 3), 0)
33
+ return float(cv2.Laplacian(blur, cv2.CV_64F).var())
34
+ except Exception:
35
+ return 0.0
36
+
37
+
38
+ def bgr_roi_color_std_mean(roi: np.ndarray) -> float:
39
+ flat = roi.reshape(-1, 3).astype(np.float32)
40
+ if flat.size == 0:
41
+ return 0.0
42
+ return float(np.std(flat, axis=0).mean())
43
+
44
+
45
+ def bgr_roi_edge_shape_ok(
46
+ roi: np.ndarray,
47
+ *,
48
+ sig_cols_min: float = 15.0,
49
+ right_left_ratio_min: float = 1.2,
50
+ ) -> bool:
51
+ if cv2 is None or roi.size == 0 or roi.shape[0] < 6 or roi.shape[1] < 6:
52
+ return False
53
+ try:
54
+ roi = np.ascontiguousarray(roi, dtype=np.uint8)
55
+ gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
56
+ sobelx = cv2.Sobel(gray.astype(np.float32), cv2.CV_64F, 1, 0, ksize=3)
57
+ except Exception:
58
+ return False
59
+ col_edge = np.mean(np.abs(sobelx), axis=0)
60
+ if col_edge.size == 0:
61
+ return False
62
+ third = max(1, len(col_edge) // 3)
63
+ left_avg = float(np.mean(col_edge[:third]))
64
+ right_avg = float(np.mean(col_edge[2 * third :]))
65
+ edge_thresh = max(float(np.percentile(col_edge, 60)), 3.0)
66
+ sig_cols = float(np.sum(col_edge > edge_thresh))
67
+ right_left_ratio = right_avg / max(left_avg, 1.0)
68
+ return bool(sig_cols >= sig_cols_min and right_left_ratio > right_left_ratio_min)
69
+
70
+
71
+ def _roi_passes_texture(
72
+ roi: np.ndarray,
73
+ *,
74
+ laplacian_min: float,
75
+ rgb_std_min: float,
76
+ ) -> bool:
77
+ lap_var = bgr_roi_laplacian_variance(roi)
78
+ rgb_std = bgr_roi_color_std_mean(roi)
79
+ return lap_var >= float(laplacian_min) and rgb_std >= float(rgb_std_min)
80
+
81
+
82
+ def _roi_passes_texture_and_shape(
83
+ roi: np.ndarray,
84
+ *,
85
+ laplacian_min: float,
86
+ rgb_std_min: float,
87
+ edge_sig_cols_min: float,
88
+ edge_right_left_ratio_min: float,
89
+ ) -> bool:
90
+ if not _roi_passes_texture(roi, laplacian_min=laplacian_min, rgb_std_min=rgb_std_min):
91
+ return False
92
+ return bgr_roi_edge_shape_ok(
93
+ roi,
94
+ sig_cols_min=edge_sig_cols_min,
95
+ right_left_ratio_min=edge_right_left_ratio_min,
96
+ )
97
+
98
+
99
+ def _crop_standard_roi(
100
+ bgr: np.ndarray,
101
+ *,
102
+ nickname_x1: int,
103
+ y_top: int,
104
+ y_bottom: int,
105
+ gutter_px: int,
106
+ min_roi_width_px: int,
107
+ ) -> Optional[np.ndarray]:
108
+ h_i, w_i = bgr.shape[:2]
109
+ x_end = int(max(0, min(int(nickname_x1) - int(gutter_px), w_i)))
110
+ if x_end < int(min_roi_width_px):
111
+ return None
112
+ y_t = int(max(0, min(y_top, h_i - 1)))
113
+ y_b = int(max(y_t + 4, min(y_bottom, h_i)))
114
+ roi = np.ascontiguousarray(bgr[y_t:y_b, 0:x_end])
115
+ if roi.size == 0 or roi.shape[1] < int(min_roi_width_px):
116
+ return None
117
+ return roi
118
+
119
+
120
+ def _crop_anchor_aligned_roi(
121
+ bgr: np.ndarray,
122
+ *,
123
+ anchor: "AvatarCentroid",
124
+ nickname_x1: int,
125
+ y_bottom: int,
126
+ gutter_px: int,
127
+ min_roi_width_px: int,
128
+ ) -> Optional[np.ndarray]:
129
+ """头像圆顶 → 昵称行底;与列表页「先定位头像再裁窗」一致。"""
130
+ h_i, w_i = bgr.shape[:2]
131
+ x_end = int(max(0, min(int(nickname_x1) - int(gutter_px), w_i)))
132
+ if x_end < int(min_roi_width_px):
133
+ return None
134
+ y_t = int(max(0, min(anchor.y_top, h_i - 1)))
135
+ y_b = int(max(y_t + 4, min(y_bottom, h_i)))
136
+ roi = np.ascontiguousarray(bgr[y_t:y_b, 0:x_end])
137
+ if roi.size == 0 or roi.shape[1] < int(min_roi_width_px):
138
+ return None
139
+ return roi
140
+
141
+
142
+ def _crop_avatar_disk_roi(
143
+ bgr: np.ndarray,
144
+ *,
145
+ anchor: "AvatarCentroid",
146
+ nickname_x1: int,
147
+ gutter_px: int,
148
+ min_roi_width_px: int,
149
+ ) -> Optional[np.ndarray]:
150
+ """头像圆附近紧凑窗,专供形状守门(避免锚点条带过高稀释列边缘)。"""
151
+ h_i, w_i = bgr.shape[:2]
152
+ pad = max(4, int(round(anchor.r * 0.15)))
153
+ x_end = int(max(0, min(int(nickname_x1) - int(gutter_px), w_i)))
154
+ x_start = int(max(0, anchor.cx - anchor.r - pad))
155
+ if x_end - x_start < int(min_roi_width_px):
156
+ return None
157
+ y_t = int(max(0, anchor.cy - anchor.r - pad))
158
+ y_b = int(min(h_i, anchor.cy + anchor.r + pad))
159
+ if y_b <= y_t + 4:
160
+ return None
161
+ roi = np.ascontiguousarray(bgr[y_t:y_b, x_start:x_end])
162
+ if roi.size == 0 or roi.shape[1] < int(min_roi_width_px):
163
+ return None
164
+ return roi
165
+
166
+
167
+ def _crop_standard_roi_right(
168
+ bgr: np.ndarray,
169
+ *,
170
+ nickname_x2: int,
171
+ y_top: int,
172
+ y_bottom: int,
173
+ gutter_px: int,
174
+ min_roi_width_px: int,
175
+ ) -> Optional[np.ndarray]:
176
+ """靠右发言:昵称行右缘以右至屏右缘(与左栏 ``_crop_standard_roi`` 对称)。"""
177
+ h_i, w_i = bgr.shape[:2]
178
+ x_start = int(min(w_i, max(0, int(nickname_x2) + int(gutter_px))))
179
+ if w_i - x_start < int(min_roi_width_px):
180
+ return None
181
+ y_t = int(max(0, min(y_top, h_i - 1)))
182
+ y_b = int(max(y_t + 4, min(y_bottom, h_i)))
183
+ roi = np.ascontiguousarray(bgr[y_t:y_b, x_start:w_i])
184
+ if roi.size == 0 or roi.shape[1] < int(min_roi_width_px):
185
+ return None
186
+ return roi
187
+
188
+
189
+ def nickname_row_passes_prd_avatar_guard(
190
+ bgr: Optional[np.ndarray],
191
+ *,
192
+ nickname_x1: int,
193
+ nickname_x2: int,
194
+ y_top: int,
195
+ y_bottom: int,
196
+ side: Side,
197
+ avatar_anchor: Optional["AvatarCentroid"] = None,
198
+ gutter_px: int = 10,
199
+ min_roi_width_px: int = 40,
200
+ laplacian_min: float = 48.0,
201
+ rgb_std_min: float = 4.0,
202
+ ) -> bool:
203
+ """PRD §九 7(2):**须**有 Hough 头像锚点,且锚点条带 + 圆盘纹理均通过(**禁止**仅靠标准条带纹理∧形状即放行)。"""
204
+ if cv2 is None or bgr is None or bgr.size == 0:
205
+ return False
206
+ if avatar_anchor is None:
207
+ return False
208
+ tex_kw = dict(laplacian_min=laplacian_min, rgb_std_min=rgb_std_min)
209
+ if side == "left":
210
+ roi_strip = _crop_anchor_aligned_roi(
211
+ bgr,
212
+ anchor=avatar_anchor,
213
+ nickname_x1=nickname_x1,
214
+ y_bottom=y_bottom,
215
+ gutter_px=gutter_px,
216
+ min_roi_width_px=min_roi_width_px,
217
+ )
218
+ roi_disk = _crop_avatar_disk_roi(
219
+ bgr,
220
+ anchor=avatar_anchor,
221
+ nickname_x1=nickname_x1,
222
+ gutter_px=gutter_px,
223
+ min_roi_width_px=min_roi_width_px,
224
+ )
225
+ else:
226
+ roi_strip = _crop_anchor_aligned_roi(
227
+ bgr,
228
+ anchor=avatar_anchor,
229
+ nickname_x1=nickname_x2,
230
+ y_bottom=y_bottom,
231
+ gutter_px=gutter_px,
232
+ min_roi_width_px=min_roi_width_px,
233
+ )
234
+ roi_disk = _crop_avatar_disk_roi(
235
+ bgr,
236
+ anchor=avatar_anchor,
237
+ nickname_x1=nickname_x2,
238
+ gutter_px=gutter_px,
239
+ min_roi_width_px=min_roi_width_px,
240
+ )
241
+ if roi_strip is not None and roi_disk is not None:
242
+ if _roi_passes_texture(roi_strip, **tex_kw) and _roi_passes_texture(
243
+ roi_disk, **tex_kw
244
+ ):
245
+ return True
246
+ # 已绑定 Hough 锚点但圆盘纹理偏弱时:仍要求标准条带纹理∧形状(**禁止**无锚点条带放行)。
247
+ shape_kw = dict(
248
+ laplacian_min=laplacian_min,
249
+ rgb_std_min=rgb_std_min,
250
+ edge_sig_cols_min=15.0,
251
+ edge_right_left_ratio_min=1.2,
252
+ )
253
+ if side == "left":
254
+ roi_std = _crop_standard_roi(
255
+ bgr,
256
+ nickname_x1=nickname_x1,
257
+ y_top=y_top,
258
+ y_bottom=y_bottom,
259
+ gutter_px=gutter_px,
260
+ min_roi_width_px=min_roi_width_px,
261
+ )
262
+ else:
263
+ roi_std = _crop_standard_roi_right(
264
+ bgr,
265
+ nickname_x2=nickname_x2,
266
+ y_top=y_top,
267
+ y_bottom=y_bottom,
268
+ gutter_px=gutter_px,
269
+ min_roi_width_px=min_roi_width_px,
270
+ )
271
+ if roi_std is not None and _roi_passes_texture_and_shape(roi_std, **shape_kw):
272
+ return True
273
+ return False
274
+
275
+
276
+ def nickname_row_passes_avatar_roi(
277
+ bgr: Optional[np.ndarray],
278
+ *,
279
+ nickname_x1: int,
280
+ y_top: int,
281
+ y_bottom: int,
282
+ avatar_anchor: Optional["AvatarCentroid"] = None,
283
+ gutter_px: int = 10,
284
+ min_roi_width_px: int = 40,
285
+ laplacian_min: float = 48.0,
286
+ rgb_std_min: float = 4.0,
287
+ edge_sig_cols_min: float = 15.0,
288
+ edge_right_left_ratio_min: float = 1.2,
289
+ prd_strict: bool = False,
290
+ nickname_x2: int = 0,
291
+ side: Side = "left",
292
+ ) -> bool:
293
+ """标准 ROI:纹理∧列缘形状;锚点路径:条带纹理 + 圆盘纹理(几何先验定位头像)。
294
+
295
+ ``prd_strict=True`` 时委托 :func:`nickname_row_passes_prd_avatar_guard`(验收 SSOT)。
296
+ """
297
+ if prd_strict:
298
+ return nickname_row_passes_prd_avatar_guard(
299
+ bgr,
300
+ nickname_x1=nickname_x1,
301
+ nickname_x2=nickname_x2 or nickname_x1,
302
+ y_top=y_top,
303
+ y_bottom=y_bottom,
304
+ side=side,
305
+ avatar_anchor=avatar_anchor,
306
+ gutter_px=gutter_px,
307
+ min_roi_width_px=min_roi_width_px,
308
+ laplacian_min=laplacian_min,
309
+ rgb_std_min=rgb_std_min,
310
+ )
311
+ if cv2 is None or bgr is None or bgr.size == 0:
312
+ return False
313
+
314
+ kw = dict(
315
+ laplacian_min=laplacian_min,
316
+ rgb_std_min=rgb_std_min,
317
+ edge_sig_cols_min=edge_sig_cols_min,
318
+ edge_right_left_ratio_min=edge_right_left_ratio_min,
319
+ )
320
+
321
+ roi_std = _crop_standard_roi(
322
+ bgr,
323
+ nickname_x1=nickname_x1,
324
+ y_top=y_top,
325
+ y_bottom=y_bottom,
326
+ gutter_px=gutter_px,
327
+ min_roi_width_px=min_roi_width_px,
328
+ )
329
+ if roi_std is not None and _roi_passes_texture_and_shape(roi_std, **kw):
330
+ return True
331
+
332
+ if avatar_anchor is None:
333
+ return False
334
+
335
+ roi_strip = _crop_anchor_aligned_roi(
336
+ bgr,
337
+ anchor=avatar_anchor,
338
+ nickname_x1=nickname_x1,
339
+ y_bottom=y_bottom,
340
+ gutter_px=gutter_px,
341
+ min_roi_width_px=min_roi_width_px,
342
+ )
343
+ roi_disk = _crop_avatar_disk_roi(
344
+ bgr,
345
+ anchor=avatar_anchor,
346
+ nickname_x1=nickname_x1,
347
+ gutter_px=gutter_px,
348
+ min_roi_width_px=min_roi_width_px,
349
+ )
350
+ if roi_strip is None or roi_disk is None:
351
+ return False
352
+ tex_kw = dict(laplacian_min=laplacian_min, rgb_std_min=rgb_std_min)
353
+ return _roi_passes_texture(roi_strip, **tex_kw) and _roi_passes_texture(
354
+ roi_disk, **tex_kw
355
+ )
356
+
357
+
358
+ def avatar_roi_pass(
359
+ bgr: Optional[np.ndarray],
360
+ *,
361
+ nickname_x1: int,
362
+ y_top: int,
363
+ y_bottom: int,
364
+ avatar_anchor: Optional["AvatarCentroid"] = None,
365
+ gutter_px: int = 10,
366
+ min_roi_width_px: int = 40,
367
+ laplacian_min: float = 48.0,
368
+ rgb_std_min: float = 4.0,
369
+ edge_sig_cols_min: float = 15.0,
370
+ edge_right_left_ratio_min: float = 1.2,
371
+ require_edge_shape: bool = True,
372
+ strict_narrow_only: bool = False,
373
+ **legacy: Any,
374
+ ) -> bool:
375
+ """``avatar_roi_pass`` 为 ``nickname_row_passes_avatar_roi`` 别名;忽略扩展类旧参数。"""
376
+ _ = legacy
377
+ if strict_narrow_only and avatar_anchor is None:
378
+ pass
379
+ if not require_edge_shape:
380
+ logger.debug("avatar_roi_pass: require_edge_shape=False 已弃用,仍使用纹理∧形状")
381
+ return nickname_row_passes_avatar_roi(
382
+ bgr,
383
+ nickname_x1=nickname_x1,
384
+ y_top=y_top,
385
+ y_bottom=y_bottom,
386
+ avatar_anchor=avatar_anchor,
387
+ gutter_px=gutter_px,
388
+ min_roi_width_px=min_roi_width_px,
389
+ laplacian_min=laplacian_min,
390
+ rgb_std_min=rgb_std_min,
391
+ edge_sig_cols_min=edge_sig_cols_min,
392
+ edge_right_left_ratio_min=edge_right_left_ratio_min,
393
+ )
394
+
395
+
396
+ def nickname_left_roi_passes_avatar_signal(
397
+ bgr: Optional[np.ndarray],
398
+ *,
399
+ nickname_x1: int,
400
+ y_top: int,
401
+ y_bottom: int,
402
+ avatar_anchor: Optional["AvatarCentroid"] = None,
403
+ gutter_px: int = 10,
404
+ min_roi_width_px: int = 40,
405
+ laplacian_min: float = 48.0,
406
+ rgb_std_min: float = 4.0,
407
+ edge_sig_cols_min: float = 15.0,
408
+ edge_right_left_ratio_min: float = 1.2,
409
+ **kwargs: Any,
410
+ ) -> bool:
411
+ _ = kwargs
412
+ return nickname_row_passes_avatar_roi(
413
+ bgr,
414
+ nickname_x1=nickname_x1,
415
+ y_top=y_top,
416
+ y_bottom=y_bottom,
417
+ avatar_anchor=avatar_anchor,
418
+ gutter_px=gutter_px,
419
+ min_roi_width_px=min_roi_width_px,
420
+ laplacian_min=laplacian_min,
421
+ rgb_std_min=rgb_std_min,
422
+ edge_sig_cols_min=edge_sig_cols_min,
423
+ edge_right_left_ratio_min=edge_right_left_ratio_min,
424
+ )
425
+
426
+
427
+ def load_png_bgr(png_path: Any) -> Optional[np.ndarray]:
428
+ if cv2 is None or png_path is None:
429
+ return None
430
+ try:
431
+ return cv2.imread(str(png_path))
432
+ except Exception as e:
433
+ logger.warning("nickname_avatar_guard: imread 失败 path=%s err=%s", png_path, e)
434
+ return None
@@ -0,0 +1,232 @@
1
+ """Optional PaddleOCR pass to read digits off unread badges on the main
2
+ conversation list.
3
+
4
+ Why this module exists
5
+ ----------------------
6
+ The bulk of :func:`TopGroupScanDriver.scan_pinned_groups` is **OCR-free** —
7
+ it only has to tell whether a row has an unread badge (red blob near the
8
+ avatar's top-right corner) and whether the row is pinned (grey cell
9
+ background). The digit count ("3", "12", "99+") is a **nice-to-have**
10
+ used by downstream prioritisation (iterate rows with larger counts
11
+ first) and logging. We keep it behind an opt-in flag so:
12
+
13
+ - The fast path stays fast — PaddleOCR takes ~1.5 s per cold init + a
14
+ few hundred ms per badge, which would slow every live scan loop.
15
+ - Unit tests don't need a working Paddle install to cover the badge
16
+ geometry path.
17
+
18
+ Usage
19
+ -----
20
+ Call :func:`ocr_unread_badge_digits` AFTER ``scan_pinned_groups``
21
+ produced its rows. The function lazily imports PaddleOCR, crops a small
22
+ padded window around each ``UnreadDotHit``, binarises (red-bg/white-fg
23
+ → black-bg/white-fg), upscales 6x, and returns a map
24
+ ``{hit_index → (digits, score)}`` (positional hit index ``0 …``) for every row whose OCR confidently
25
+ matches a digit pattern. Rows without a confident read are absent.
26
+
27
+ Digit regex
28
+ -----------
29
+ ``\\d+\\+?`` — matches "1", "23", "99+", "0" (the last is never shown
30
+ by WeChat but we still accept it to keep the regex portable).
31
+
32
+ Thresholds / pre-processing are the d4-auto-scan calibration values
33
+ that scored ~90%+ on the edb1a89f baseline when the previous scratch
34
+ tool ``tools/detect_unread_badges.py`` (since removed) was still
35
+ around. Kept as module-level constants so Day 7 calibration can bump
36
+ them without touching the driver.
37
+ """
38
+ from __future__ import annotations
39
+
40
+ import re
41
+ from typing import Optional
42
+
43
+ import cv2 # type: ignore
44
+ import numpy as np # type: ignore
45
+ from loguru import logger
46
+
47
+ from screenshot_vision_algorithm.android.wechat.algorithms.template_matching import UnreadDotHit
48
+
49
+ _DIGIT_RE = re.compile(r"\d+\+?")
50
+
51
+ # Pre-processing ------------------------------------------------------------
52
+
53
+ #: Padding (in pixels AT RAW SCALE) around each badge bbox before cropping.
54
+ #: PaddleOCR's detection head needs ~20% whitespace margin around the glyph.
55
+ #: We take ``max(8, w // 3)`` / ``max(8, h // 3)`` so both 27-px solid dots
56
+ #: and 50-px digit badges get proportionally sized margins.
57
+ PAD_MIN = 8
58
+ PAD_DIV = 3
59
+
60
+ #: Red-mask HSV cut (``cv2.inRange`` low/high). Matches the white-pixel
61
+ #: detector used downstream — pixels that are BRIGHT and DESATURATED get
62
+ #: promoted to white, everything else (red badge body, avatar bleed,
63
+ #: anti-aliased purple) falls to black.
64
+ WHITE_HSV_LO: tuple[int, int, int] = (0, 0, 200)
65
+ WHITE_HSV_HI: tuple[int, int, int] = (179, 60, 255)
66
+
67
+ #: Fallback grey-scale brightness threshold: if HSV mask misses an AA
68
+ #: near-white pixel (V < 200 but > 180 visually white), the classic
69
+ #: ``threshold`` pass recaptures it.
70
+ GRAY_THRESH = 180
71
+
72
+ #: Upscale factor applied to the binarised badge crop. PP-OCRv5 mobile
73
+ #: wants ~32 px min glyph height; a single-digit 27-px badge at 6x
74
+ #: becomes 162 px, well above that floor.
75
+ OCR_UPSCALE = 6.0
76
+
77
+
78
+ def _binarise(crop_bgr: np.ndarray) -> np.ndarray:
79
+ """Convert a red-bg / white-fg badge crop to clean black-bg / white-fg.
80
+
81
+ Two strategies OR'd together:
82
+ 1. HSV ``inRange`` for near-white (low saturation + high value).
83
+ 2. Gray-scale ``threshold`` at ``GRAY_THRESH`` to recapture AA edges.
84
+
85
+ Returns a single-channel ``uint8`` mask. The caller should 3-channel
86
+ expand + cubic-upscale it for PaddleOCR.
87
+ """
88
+ gray = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2GRAY)
89
+ hsv = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2HSV)
90
+ is_white = cv2.inRange(
91
+ hsv, np.asarray(WHITE_HSV_LO, dtype=np.uint8),
92
+ np.asarray(WHITE_HSV_HI, dtype=np.uint8),
93
+ )
94
+ _, bright = cv2.threshold(gray, GRAY_THRESH, 255, cv2.THRESH_BINARY)
95
+ return cv2.bitwise_or(is_white, bright)
96
+
97
+
98
+ # PaddleOCR init (lazy, cached) --------------------------------------------
99
+
100
+ _OCR_SINGLETON = None
101
+ _OCR_VERSION_MAJOR: Optional[int] = None
102
+
103
+
104
+ def _get_ocr():
105
+ """Lazy PaddleOCR factory. Cached for the process lifetime."""
106
+ global _OCR_SINGLETON, _OCR_VERSION_MAJOR
107
+ if _OCR_SINGLETON is not None:
108
+ return _OCR_SINGLETON, _OCR_VERSION_MAJOR
109
+
110
+ try:
111
+ from paddleocr import PaddleOCR # type: ignore
112
+ import paddleocr as _pkg # type: ignore
113
+ except ImportError as e:
114
+ raise RuntimeError(
115
+ "ocr_unread_badge_digits requires paddleocr; "
116
+ "pip install paddleocr"
117
+ ) from e
118
+
119
+ try:
120
+ ver_major = int(str(getattr(_pkg, "__version__", "2.0.0")).split(".")[0])
121
+ except Exception: # pragma: no cover
122
+ ver_major = 2
123
+
124
+ logger.info("initialising PaddleOCR v%d (once-per-process)", ver_major)
125
+ if ver_major >= 3:
126
+ ocr = PaddleOCR(
127
+ use_textline_orientation=False,
128
+ lang="ch", device="cpu",
129
+ text_detection_model_name="PP-OCRv5_mobile_det",
130
+ text_recognition_model_name="PP-OCRv5_mobile_rec",
131
+ )
132
+ else: # pragma: no cover — legacy 2.x fallback, project pins 3.x
133
+ ocr = PaddleOCR(use_angle_cls=False, lang="ch", use_gpu=False)
134
+
135
+ _OCR_SINGLETON = ocr
136
+ _OCR_VERSION_MAJOR = ver_major
137
+ return ocr, ver_major
138
+
139
+
140
+ # Public API ---------------------------------------------------------------
141
+
142
+
143
+ def ocr_unread_badge_digits(
144
+ screen_bgr: np.ndarray,
145
+ hits: tuple[UnreadDotHit, ...] | list[UnreadDotHit],
146
+ ) -> dict[int, tuple[str, float]]:
147
+ """Read digit(s) off each badge.
148
+
149
+ Args:
150
+ screen_bgr: The FULL conversation-list screenshot (the same
151
+ ``np.ndarray`` the driver passed to ``detect_unread_dots``).
152
+ We crop from the raw image so the badge stays at native
153
+ resolution before the 6x upscale.
154
+ hits: Sequence of :class:`UnreadDotHit`. Indexing in the
155
+ returned dict is POSITIONAL (``0..len(hits)-1``), NOT
156
+ ``UnreadDotHit`` identity — this keeps the function
157
+ ``dataclass(frozen=True)``-safe.
158
+
159
+ Returns:
160
+ ``{i: (digits, score)}`` for each hit index that produced a
161
+ confident digit match. Hits without a match are absent from
162
+ the dict (rather than mapped to ``None``) so callers can use
163
+ ``dict.get(i)`` freely.
164
+
165
+ Raises:
166
+ RuntimeError: if ``paddleocr`` is not installed.
167
+ """
168
+ if not hits:
169
+ return {}
170
+
171
+ ocr, ver_major = _get_ocr()
172
+ H, W = screen_bgr.shape[:2]
173
+ out: dict[int, tuple[str, float]] = {}
174
+
175
+ for i, b in enumerate(hits):
176
+ pad_x = max(PAD_MIN, b.w // PAD_DIV)
177
+ pad_y = max(PAD_MIN, b.h // PAD_DIV)
178
+ cx1 = max(0, b.x - pad_x)
179
+ cy1 = max(0, b.y - pad_y)
180
+ cx2 = min(W, b.x + b.w + pad_x)
181
+ cy2 = min(H, b.y + b.h + pad_y)
182
+ crop = screen_bgr[cy1:cy2, cx1:cx2]
183
+ if crop.size == 0:
184
+ continue
185
+ binary = _binarise(crop)
186
+ bin_bgr = cv2.merge([binary, binary, binary])
187
+ upscaled = cv2.resize(
188
+ bin_bgr, None, fx=OCR_UPSCALE, fy=OCR_UPSCALE,
189
+ interpolation=cv2.INTER_CUBIC,
190
+ )
191
+ texts: list[str] = []
192
+ scores: list[float] = []
193
+ try:
194
+ if ver_major >= 3:
195
+ results = list(ocr.predict(upscaled))
196
+ if results:
197
+ r = results[0]
198
+ texts = list(r.get("rec_texts", []))
199
+ scores = [float(s) for s in r.get("rec_scores", [])]
200
+ else: # pragma: no cover
201
+ raw = ocr.ocr(upscaled, cls=False)
202
+ if raw and raw[0]:
203
+ texts = [line[1][0] for line in raw[0]]
204
+ scores = [float(line[1][1]) for line in raw[0]]
205
+ except Exception as e: # pragma: no cover — OCR backend failure
206
+ logger.debug("badge %d OCR failed: %s", i, e)
207
+ continue
208
+
209
+ best_digits: Optional[str] = None
210
+ best_score = 0.0
211
+ for text, score in zip(texts, scores):
212
+ m = _DIGIT_RE.search(text.strip())
213
+ if not m:
214
+ continue
215
+ if score > best_score:
216
+ best_digits = m.group(0)
217
+ best_score = float(score)
218
+ if best_digits is not None:
219
+ out[i] = (best_digits, round(best_score, 3))
220
+ logger.debug(
221
+ "badge[%d] bbox=(%d,%d,%d,%d) digits=%r score=%.3f",
222
+ i, b.x, b.y, b.w, b.h, best_digits, best_score,
223
+ )
224
+
225
+ logger.info(
226
+ "ocr_unread_badge_digits: %d/%d badge(s) produced a digit match",
227
+ len(out), len(hits),
228
+ )
229
+ return out
230
+
231
+
232
+ __all__ = ["ocr_unread_badge_digits"]