py2ls 0.1.9.9__py3-none-any.whl → 0.1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. py2ls/.git/COMMIT_EDITMSG +1 -1
  2. py2ls/.git/FETCH_HEAD +1 -1
  3. py2ls/.git/index +0 -0
  4. py2ls/.git/logs/HEAD +1 -0
  5. py2ls/.git/logs/refs/heads/main +1 -0
  6. py2ls/.git/logs/refs/remotes/origin/HEAD +1 -0
  7. py2ls/.git/logs/refs/remotes/origin/main +1 -0
  8. py2ls/.git/objects/27/aa6074f652bc6f7078f8647489d9ee8e24f0e2 +0 -0
  9. py2ls/.git/objects/28/c2969d785c1b892c2a96b3f00eba63a59811b3 +0 -0
  10. py2ls/.git/objects/2a/fdf45791a26d42ccead35ace76a8f0b2a56561 +0 -0
  11. py2ls/.git/objects/34/b6f3a2ee84f39bed4eee57f2c0e0afb994feb1 +0 -0
  12. py2ls/.git/objects/35/1a5f491ab97eee9d1ee699478d75a8bb5d3dc2 +0 -0
  13. py2ls/.git/objects/39/b13be65125556784e44c7a1d9821703c7ab67e +0 -0
  14. py2ls/.git/objects/3b/507acc7f23391644cc0b824b1e79fd2677a362 +0 -0
  15. py2ls/.git/objects/3d/9d10d27724657a436c65a6254bfd213d4b3562 +0 -0
  16. py2ls/.git/objects/47/6cbd5a7c5e35cddef2f8a38bdc4896d403b095 +0 -0
  17. py2ls/.git/objects/78/063f4c863fc371ec0313303c0a81283b35d9b6 +0 -0
  18. py2ls/.git/objects/82/70b319ce4046854fbe7dc41054b6c2d112dab2 +0 -0
  19. py2ls/.git/objects/85/aee46f478e9afdb84d50a05242c53b04ed2e21 +0 -0
  20. py2ls/.git/objects/86/e288b46f8fe179907e4413f665aeb5053fddb1 +0 -0
  21. py2ls/.git/objects/94/f7dbe88e80c4205a901b71eb8f181974376bba +0 -0
  22. py2ls/.git/objects/9b/ec5ee2236ee2d5532c36bfd132e23c58fdb69c +0 -0
  23. py2ls/.git/objects/b3/4f7f271c6d6105e35a6556ffda71d03afe8c96 +0 -0
  24. py2ls/.git/objects/b3/69579064bde9de9a19d114fc33e4e48cc8c0e4 +0 -0
  25. py2ls/.git/objects/bf/b54d65922ce1dfda1aaa014913a54e7172d0bc +0 -0
  26. py2ls/.git/objects/c1/397c6ed72c4e20ef6b9ab83163e9a6baba5b45 +0 -0
  27. py2ls/.git/objects/cc/45df1d317a2eb63ff1ff3a5f3b4a9f98fd92b5 +0 -0
  28. py2ls/.git/objects/d6/39e8af592cd75a318d8affddd1bcc70c2095f2 +0 -0
  29. py2ls/.git/objects/db/3f2cd643292057936230b95cf7ec3046affe11 +0 -0
  30. py2ls/.git/objects/de/214c626ac2dd2685bfaa0bc0fc20f528d014d7 +0 -0
  31. py2ls/.git/objects/e4/6c715352db9fe3c887a635f1916df4ca1f4ff9 +0 -0
  32. py2ls/.git/objects/e5/0580a0bd1e1b3d29f834382b80fceb61d5cf0c +0 -0
  33. py2ls/.git/objects/ec/d980279432b13f0374b90ca439a6329cdece0f +0 -0
  34. py2ls/.git/objects/ee/cee64eacaff022dcdc509c0c2b1da492f21060 +0 -0
  35. py2ls/.git/objects/f5/61c3c1bf1c9ea9c9d1f556a7be2869f71f3bdf +0 -0
  36. py2ls/.git/refs/heads/main +1 -1
  37. py2ls/.git/refs/remotes/origin/main +1 -1
  38. py2ls/batman.py +198 -0
  39. py2ls/ich2ls.py +539 -85
  40. py2ls/ips.py +1 -1
  41. py2ls/netfinder.py +105 -3
  42. py2ls/ocr.py +557 -0
  43. py2ls/plot.py +68 -11
  44. {py2ls-0.1.9.9.dist-info → py2ls-0.1.10.1.dist-info}/METADATA +1 -1
  45. {py2ls-0.1.9.9.dist-info → py2ls-0.1.10.1.dist-info}/RECORD +46 -16
  46. {py2ls-0.1.9.9.dist-info → py2ls-0.1.10.1.dist-info}/WHEEL +0 -0
py2ls/ocr.py ADDED
@@ -0,0 +1,557 @@
1
+ import easyocr
2
+ import cv2
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from py2ls.ips import (
6
+ strcmp,
7
+ detect_angle,
8
+ ) # Ensure this function is defined in your 'ips' module
9
+ from spellchecker import SpellChecker
10
+ import re
11
+
12
+ from PIL import Image, ImageDraw, ImageFont
13
+ import PIL.PngImagePlugin
14
+
15
+ """
16
+ Optical Character Recognition (OCR)
17
+ """
18
+
19
+ # Valid language codes
20
+ lang_valid = {
21
+ "english": "en",
22
+ "thai": "th",
23
+ "chinese_traditional": "ch_tra",
24
+ "chinese": "ch_sim",
25
+ "japanese": "ja",
26
+ "korean": "ko",
27
+ "tamil": "ta",
28
+ "telugu": "te",
29
+ "kannada": "kn",
30
+ "german": "de",
31
+ }
32
+
33
+
34
+ def lang_auto_detect(lang):
35
+ res_lang = []
36
+ if isinstance(lang, str):
37
+ lang = [lang]
38
+ for i in lang:
39
+ res_lang.append(lang_valid[strcmp(i, list(lang_valid.keys()))[0]])
40
+ return res_lang
41
+
42
+
43
+ def determine_src_points(image):
44
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
45
+ _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
46
+ contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
47
+
48
+ # Sort contours by area and pick the largest one
49
+ contours = sorted(contours, key=cv2.contourArea, reverse=True)[:5]
50
+ src_points = None
51
+
52
+ for contour in contours:
53
+ epsilon = 0.02 * cv2.arcLength(contour, True)
54
+ approx = cv2.approxPolyDP(contour, epsilon, True)
55
+ if len(approx) == 4: # We need a quadrilateral
56
+ src_points = np.array(approx, dtype="float32")
57
+ break
58
+
59
+ if src_points is not None:
60
+ # Order points in a specific order (top-left, top-right, bottom-right, bottom-left)
61
+ src_points = src_points.reshape(4, 2)
62
+ rect = np.zeros((4, 2), dtype="float32")
63
+ s = src_points.sum(axis=1)
64
+ diff = np.diff(src_points, axis=1)
65
+ rect[0] = src_points[np.argmin(s)]
66
+ rect[2] = src_points[np.argmax(s)]
67
+ rect[1] = src_points[np.argmin(diff)]
68
+ rect[3] = src_points[np.argmax(diff)]
69
+ src_points = rect
70
+ else:
71
+ # If no rectangle is detected, fallback to a default or user-defined points
72
+ height, width = image.shape[:2]
73
+ src_points = np.array(
74
+ [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]],
75
+ dtype="float32",
76
+ )
77
+ return src_points
78
+
79
+
80
+ def get_default_camera_matrix(image_shape):
81
+ height, width = image_shape[:2]
82
+ focal_length = width
83
+ center = (width / 2, height / 2)
84
+ camera_matrix = np.array(
85
+ [[focal_length, 0, center[0]], [0, focal_length, center[1]], [0, 0, 1]],
86
+ dtype="float32",
87
+ )
88
+ dist_coeffs = np.zeros((4, 1)) # Assuming no distortion
89
+ return camera_matrix, dist_coeffs
90
+
91
+
92
+ def correct_perspective(image, src_points):
93
+ # Define the destination points for the perspective transform
94
+ width, height = 1000, 1000 # Adjust size as needed
95
+ dst_points = np.array(
96
+ [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]],
97
+ dtype="float32",
98
+ )
99
+
100
+ # Calculate the perspective transform matrix
101
+ M = cv2.getPerspectiveTransform(src_points, dst_points)
102
+ # Apply the perspective transform
103
+ corrected_image = cv2.warpPerspective(image, M, (width, height))
104
+ return corrected_image
105
+
106
+
107
+ def detect_text_orientation(image):
108
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
109
+ edges = cv2.Canny(gray, 50, 150, apertureSize=3)
110
+ lines = cv2.HoughLines(edges, 1, np.pi / 180, 200)
111
+
112
+ if lines is None:
113
+ return 0
114
+
115
+ angles = []
116
+ for rho, theta in lines[:, 0]:
117
+ angle = theta * 180 / np.pi
118
+ if angle > 90:
119
+ angle -= 180
120
+ angles.append(angle)
121
+
122
+ median_angle = np.median(angles)
123
+ return median_angle
124
+
125
+
126
+ def rotate_image(image, angle):
127
+ center = (image.shape[1] // 2, image.shape[0] // 2)
128
+ rot_mat = cv2.getRotationMatrix2D(center, angle, 1.0)
129
+ rotated_image = cv2.warpAffine(
130
+ image, rot_mat, (image.shape[1], image.shape[0]), flags=cv2.INTER_LINEAR
131
+ )
132
+ return rotated_image
133
+
134
+
135
+ def correct_skew(image):
136
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
137
+ coords = np.column_stack(np.where(gray > 0))
138
+ angle = cv2.minAreaRect(coords)[-1]
139
+ if angle < -45:
140
+ angle = -(90 + angle)
141
+ else:
142
+ angle = -angle
143
+ return rotate_image(image, angle)
144
+
145
+
146
+ def undistort_image(image, camera_matrix, dist_coeffs):
147
+ return cv2.undistort(image, camera_matrix, dist_coeffs)
148
+
149
+
150
+ def add_text_pil(image, text, position, font_size=10, color=(255, 0, 0)):
151
+ # Convert the image to PIL format
152
+ pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
153
+ # Create a drawing context
154
+ draw = ImageDraw.Draw(pil_image)
155
+ # Define the font (make sure to use a font that supports Chinese characters)
156
+ try:
157
+ font = ImageFont.truetype(
158
+ "/System/Library/Fonts/Supplemental/Songti.ttc", font_size
159
+ )
160
+ except IOError:
161
+ font = ImageFont.load_default()
162
+
163
+ # cal top_left position
164
+ # Measure text size using textbbox
165
+ text_bbox = draw.textbbox((0, 0), text, font=font)
166
+ text_width = text_bbox[2] - text_bbox[0]
167
+ text_height = text_bbox[3] - text_bbox[1]
168
+ # Calculate 5% of the text height for upward adjustment
169
+ offset = int(0.5 * text_height) # 上移动 50%
170
+ # Adjust position to match OpenCV's bottom-left alignment
171
+ adjusted_position = (position[0], position[1] - text_height - offset)
172
+
173
+ # Add text to the image
174
+ draw.text(adjusted_position, text, font=font, fill=color)
175
+ # Convert the image back to OpenCV format
176
+ image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
177
+ return image
178
+
179
+
180
+ def preprocess_img(
181
+ image,
182
+ grayscale=True,
183
+ threshold=True,
184
+ threshold_method="adaptive",
185
+ rotate="auto",
186
+ skew=True,
187
+ denoise=True,
188
+ blur_ksize=(5, 5),
189
+ morph=True,
190
+ morph_op="open",
191
+ morph_kernel_size=(3, 3),
192
+ enhance_contrast=True,
193
+ clahe_clip=2.0,
194
+ clahe_grid_size=(8, 8),
195
+ edge_detection=False,
196
+ ):
197
+ """
198
+ 预处理步骤:
199
+
200
+ 转换为灰度图像: 如果 grayscale 为 True,将图像转换为灰度图像。
201
+ 二值化处理: 根据 threshold 和 threshold_method 参数,对图像进行二值化处理。
202
+ 降噪处理: 使用高斯模糊对图像进行降噪。
203
+ 形态学处理: 根据 morph_op 参数选择不同的形态学操作(开运算、闭运算、膨胀、腐蚀),用于去除噪声或填补孔洞。
204
+ 对比度增强: 使用 CLAHE 技术增强图像对比度。
205
+ 边缘检测: 如果 edge_detection 为 True,使用 Canny 边缘检测算法。
206
+
207
+ 预处理图像以提高 OCR 识别准确性。
208
+ 参数:
209
+ image: 输入的图像路径或图像数据。
210
+ grayscale: 是否将图像转换为灰度图像。
211
+ threshold: 是否对图像进行二值化处理。
212
+ threshold_method: 二值化方法,可以是 'global' 或 'adaptive'。
213
+ denoise: 是否对图像进行降噪处理。
214
+ blur_ksize: 高斯模糊的核大小。
215
+ morph: 是否进行形态学处理。
216
+ morph_op: 形态学操作的类型,包括 'open'(开运算)、'close'(闭运算)、'dilate'(膨胀)、'erode'(腐蚀)。
217
+ morph_kernel_size: 形态学操作的内核大小。
218
+ enhance_contrast: 是否增强图像对比度。
219
+ clahe_clip: CLAHE(对比度受限的自适应直方图均衡)的剪裁限制。
220
+ clahe_grid_size: CLAHE 的网格大小。
221
+ edge_detection: 是否进行边缘检测。
222
+ """
223
+ if isinstance(image, PIL.PngImagePlugin.PngImageFile):
224
+ image = np.array(image)
225
+ if isinstance(image, str):
226
+ image = cv2.imread(image)
227
+ if not isinstance(image, np.ndarray):
228
+ image = np.array(image)
229
+ if image.shape[1] == 4: # Check if it has an alpha channel
230
+ # Drop the alpha channel (if needed), or handle it as required
231
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
232
+ else:
233
+ # Convert RGB to BGR for OpenCV compatibility
234
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
235
+
236
+ # Rotate image
237
+ if rotate == "auto":
238
+ angle = detect_angle(image, by="fft")
239
+ img_preprocessed = rotate_image(image, angle)
240
+ else:
241
+ img_preprocessed = image
242
+
243
+ # # Correct skew
244
+ # if skew:
245
+ # img_preprocessed = correct_skew(image)
246
+
247
+ # Convert to grayscale
248
+ if grayscale:
249
+ img_preprocessed = cv2.cvtColor(img_preprocessed, cv2.COLOR_BGR2GRAY)
250
+
251
+ # Thresholding
252
+ if threshold:
253
+ if threshold_method == "adaptive":
254
+ image = cv2.adaptiveThreshold(
255
+ img_preprocessed,
256
+ 255,
257
+ cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
258
+ cv2.THRESH_BINARY,
259
+ 11,
260
+ 2,
261
+ )
262
+ elif threshold_method == "global":
263
+ _, img_preprocessed = cv2.threshold(
264
+ img_preprocessed, 127, 255, cv2.THRESH_BINARY
265
+ )
266
+
267
+ # Denoise
268
+ if denoise:
269
+ img_preprocessed = cv2.GaussianBlur(img_preprocessed, blur_ksize, 0)
270
+
271
+ # 形态学处理
272
+ if morph:
273
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, morph_kernel_size)
274
+ if morph_op == "close": # 闭运算
275
+ # 目的: 闭运算用于填补前景物体中的小孔或间隙,同时保留其形状和大小。
276
+ # 工作原理: 闭运算先进行膨胀,然后进行腐蚀。膨胀步骤填补小孔或间隙,腐蚀步骤恢复较大物体的形状。
277
+ # 效果:
278
+ # 填补前景物体中的小孔和间隙。
279
+ # 平滑较大物体的边缘。
280
+ # 示例用途: 填补物体中的小孔或间隙。
281
+ img_preprocessed = cv2.morphologyEx(
282
+ img_preprocessed, cv2.MORPH_CLOSE, kernel
283
+ )
284
+ elif morph_op == "open": # 开运算
285
+ # 目的: 开运算用于去除背景中的小物体或噪声,同时保留较大物体的形状和大小。
286
+ # 工作原理: 开运算先进行腐蚀,然后进行膨胀。腐蚀步骤去除小规模的噪声,膨胀步骤恢复剩余物体的大小。
287
+ # 效果:
288
+ # 去除前景中的小物体。
289
+ # 平滑较大物体的轮廓。
290
+ # 示例用途: 去除小噪声或伪影,同时保持较大物体完整。
291
+ img_preprocessed = cv2.morphologyEx(
292
+ img_preprocessed, cv2.MORPH_OPEN, kernel
293
+ )
294
+ elif morph_op == "dilate": # 膨胀
295
+ # 目的: 膨胀操作在物体边界上添加像素。它可以用来填补物体中的小孔或连接相邻的物体。
296
+ # 工作原理: 内核在图像上移动,每个位置上的像素值被设置为内核覆盖区域中的最大值。
297
+ # 效果:
298
+ # 物体变大。
299
+ # 填补物体中的小孔或间隙。
300
+ # 示例用途: 填补物体中的小孔或连接断裂的物体部分。
301
+ img_preprocessed = cv2.dilate(img_preprocessed, kernel)
302
+ elif morph_op == "erode": # 腐蚀
303
+ # 目的: 腐蚀操作用于去除物体边界上的像素。它可以用来去除小规模的噪声,并将靠近的物体分开。
304
+ # 工作原理: 内核(结构元素)在图像上移动,每个位置上的像素值被设置为内核覆盖区域中的最小值。
305
+ # 效果:
306
+ # 物体变小。
307
+ # 去除图像中的小白点(在白色前景/黑色背景的图像中)。
308
+ # 示例用途: 去除二值图像中的小噪声或分离相互接触的物体
309
+ img_preprocessed = cv2.erode(img_preprocessed, kernel)
310
+
311
+ # 对比度增强
312
+ if enhance_contrast:
313
+ clahe = cv2.createCLAHE(clipLimit=clahe_clip, tileGridSize=clahe_grid_size)
314
+ img_preprocessed = clahe.apply(img_preprocessed)
315
+
316
+ # 边缘检测
317
+ if edge_detection:
318
+ img_preprocessed = cv2.Canny(img_preprocessed, 100, 200)
319
+
320
+ return img_preprocessed
321
+
322
+
323
+ def text_postprocess(
324
+ text,
325
+ spell_check=True,
326
+ clean=True,
327
+ filter=dict(min_length=2),
328
+ pattern=None,
329
+ merge=True,
330
+ ):
331
+
332
+ def correct_spelling(text_list):
333
+ spell = SpellChecker()
334
+ corrected_text = [spell.candidates(word) for word in text_list]
335
+ return corrected_text
336
+
337
+ def clean_text(text_list):
338
+ cleaned_text = [re.sub(r"[^\w\s]", "", text) for text in text_list]
339
+ return cleaned_text
340
+
341
+ def filter_text(text_list, min_length=2):
342
+ filtered_text = [text for text in text_list if len(text) >= min_length]
343
+ return filtered_text
344
+
345
+ def extract_patterns(text_list, pattern):
346
+ pattern = re.compile(pattern)
347
+ matched_text = [text for text in text_list if pattern.search(text)]
348
+ return matched_text
349
+
350
+ def merge_fragments(text_list):
351
+ merged_text = " ".join(text_list)
352
+ return merged_text
353
+
354
+ results = text
355
+ print(results)
356
+ if spell_check:
357
+ results = correct_spelling(results)
358
+ if clean:
359
+ results = clean_text(results)
360
+ if filter:
361
+ results = filter_text(
362
+ results, min_length=postprocess["filter"].get("min_length", 2)
363
+ )
364
+ if pattern:
365
+ results = extract_patterns(results, postprocess["pattern"])
366
+ if merge:
367
+ results = merge_fragments(results)
368
+
369
+
370
+ # https://www.jaided.ai/easyocr/documentation/
371
+ # extract text from an image with EasyOCR
372
+ def get_text(
373
+ image,
374
+ lang=["ch_sim", "en"],
375
+ thr=0.25,
376
+ gpu=True,
377
+ decoder="wordbeamsearch", #'greedy', 'beamsearch' and 'wordbeamsearch'(hightly accurate)
378
+ output="all",
379
+ preprocess=None,
380
+ postprocess="not ready",
381
+ show=True,
382
+ ax=None,
383
+ cmap=cv2.COLOR_BGR2RGB, # draw_box
384
+ font=cv2.FONT_HERSHEY_SIMPLEX,
385
+ fontScale=0.8,
386
+ thickness_text=2, # Line thickness of 2 px
387
+ color_box=(0, 255, 0), # draw_box
388
+ color_text=(0, 0, 255), # draw_box
389
+ **kwargs,
390
+ ):
391
+ """
392
+ 功能: 该函数使用 EasyOCR 进行文本识别,并允许自定义图像预处理步骤和结果展示。
393
+ 参数:
394
+ image: 输入的图像路径或图像数据。
395
+ lang: OCR 语言列表。
396
+ thr: 置信度阈值,低于此阈值的检测结果将被过滤。
397
+ gpu: 是否使用 GPU。
398
+ output: 输出类型,可以是 'all'(返回所有检测结果)、'text'(返回文本)、'score'(返回置信度分数)、'box'(返回边界框)。
399
+ preprocess: 预处理参数字典,传递给 preprocess_img 函数。
400
+ show: 是否显示结果图像。
401
+ ax: 用于显示图像的 Matplotlib 子图。
402
+ cmap: 用于显示图像的颜色映射。
403
+ color_box: 边界框的颜色。
404
+ color_text: 文本的颜色。
405
+ kwargs: 传递给 EasyOCR readtext 函数的其他参数。
406
+
407
+ # Uage
408
+ image_path = 'car_plate.jpg' # 替换为你的图像路径
409
+ results = get_text(
410
+ image_path,
411
+ lang=["en"],
412
+ gpu=False,
413
+ output="text",
414
+ preprocess={
415
+ "grayscale": True,
416
+ "threshold": True,
417
+ "threshold_method": 'adaptive',
418
+ "denoise": True,
419
+ "blur_ksize": (5, 5),
420
+ "morph": True,
421
+ "morph_op": 'close',
422
+ "morph_kernel_size": (3, 3),
423
+ "enhance_contrast": True,
424
+ "clahe_clip": 2.0,
425
+ "clahe_grid_size": (8, 8),
426
+ "edge_detection": False
427
+ },
428
+ adjust_contrast=0.7
429
+ )
430
+ """
431
+ lang = lang_auto_detect(lang)
432
+ print(f"detecting language(s):{lang}")
433
+ if isinstance(image, str):
434
+ image = cv2.imread(image)
435
+
436
+ # Ensure lang is always a list
437
+ if isinstance(lang, str):
438
+ lang = [lang]
439
+
440
+ # ! preprocessing img
441
+ if preprocess is None:
442
+ preprocess = {}
443
+ image_process = preprocess_img(image, **preprocess)
444
+
445
+ # Perform OCR on the image
446
+ reader = easyocr.Reader(lang, gpu=gpu)
447
+ detections = reader.readtext(image_process, decoder=decoder, **kwargs)
448
+ if postprocess is None:
449
+ postprocess = dict(
450
+ spell_check=True,
451
+ clean=True,
452
+ filter=dict(min_length=2),
453
+ pattern=None,
454
+ merge=True,
455
+ )
456
+ text_corr = []
457
+ for _, text, _ in detections:
458
+ text_corr.extend(text_postprocess(text, **postprocess))
459
+ if show:
460
+ if ax is None:
461
+ ax = plt.gca()
462
+ for bbox, text, score in detections:
463
+ if score > thr:
464
+ top_left = tuple(map(int, bbox[0]))
465
+ bottom_right = tuple(map(int, bbox[2]))
466
+ image = cv2.rectangle(image, top_left, bottom_right, color_box, 2)
467
+ # image = cv2.putText(
468
+ # image, text, top_left, font, fontScale, color_text, thickness_text
469
+ # )
470
+ image = add_text_pil(
471
+ image, text, top_left, font_size=fontScale * 32, color=color_text
472
+ )
473
+
474
+ img_cmp = cv2.cvtColor(image, cmap)
475
+ ax.imshow(img_cmp)
476
+ ax.axis("off")
477
+ # plt.show()
478
+ # 根据输出类型返回相应的结果
479
+ if output == "all":
480
+ return ax, detections
481
+ elif "t" in output.lower() and "x" in output.lower():
482
+ # 提取文本,过滤低置信度的结果
483
+ text = [text_ for _, text_, score_ in detections if score_ >= thr]
484
+ if postprocess:
485
+ return ax, text
486
+ else:
487
+ return text_corr
488
+ elif "score" in output.lower() or "prob" in output.lower():
489
+ # 提取分数
490
+ scores = [score_ for _, _, score_ in detections]
491
+ return ax, scores
492
+ elif "box" in output.lower():
493
+ # 提取边界框,过滤低置信度的结果
494
+ bboxes = [bbox_ for bbox_, _, score_ in detections if score_ >= thr]
495
+ return ax, bboxes
496
+ else:
497
+ # 默认返回所有检测信息
498
+ return ax, detections
499
+ else:
500
+ # 根据输出类型返回相应的结果
501
+ if output == "all":
502
+ return detections
503
+ elif "t" in output.lower() and "x" in output.lower():
504
+ # 提取文本,过滤低置信度的结果
505
+ text = [text_ for _, text_, score_ in detections if score_ >= thr]
506
+ return text
507
+ elif "score" in output.lower() or "prob" in output.lower():
508
+ # 提取分数
509
+ scores = [score_ for _, _, score_ in detections]
510
+ return scores
511
+ elif "box" in output.lower():
512
+ # 提取边界框,过滤低置信度的结果
513
+ bboxes = [bbox_ for bbox_, _, score_ in detections if score_ >= thr]
514
+ return bboxes
515
+ else:
516
+ # 默认返回所有检测信息
517
+ return detections
518
+
519
+
520
+ def draw_box(
521
+ image,
522
+ detections=None,
523
+ thr=0.25,
524
+ cmap=cv2.COLOR_BGR2RGB,
525
+ color_box=(0, 255, 0), # draw_box
526
+ color_text=(0, 0, 255), # draw_box
527
+ font_scale=0.8,
528
+ show=True,
529
+ ax=None,
530
+ **kwargs,
531
+ ):
532
+
533
+ if ax is None:
534
+ ax = plt.gca()
535
+ if isinstance(image, str):
536
+ image = cv2.imread(image)
537
+ if detections is None:
538
+ detections = get_text(image=image, show=0, output="all", **kwargs)
539
+
540
+ for bbox, text, score in detections:
541
+ if score > thr:
542
+ top_left = tuple(map(int, bbox[0]))
543
+ bottom_right = tuple(map(int, bbox[2]))
544
+ image = cv2.rectangle(image, top_left, bottom_right, color_box, 2)
545
+ # image = cv2.putText(
546
+ # image, text, top_left, font, fontScale, color_text, thickness_text
547
+ # )
548
+ image = add_text_pil(
549
+ image, text, top_left, font_size=font_scale * 32, color=color_text
550
+ )
551
+
552
+ img_cmp = cv2.cvtColor(image, cmap)
553
+ if show:
554
+ ax.imshow(img_cmp)
555
+ ax.axis("off")
556
+ # plt.show()
557
+ return img_cmp
py2ls/plot.py CHANGED
@@ -1637,6 +1637,30 @@ def figsets(*args, **kwargs):
1637
1637
  ax.tick_params(
1638
1638
  labelsize=val
1639
1639
  ) # float, distance in points between tick and label
1640
+ if "text" in key.lower():
1641
+ if isinstance(value, dict):
1642
+ ax.text(**value)
1643
+ elif isinstance(value, list):
1644
+ if all([isinstance(i, dict) for i in value]):
1645
+ [ax.text(**value_) for value_ in value]
1646
+ # e.g.,
1647
+ # figsets(ax=ax,
1648
+ # text=[
1649
+ # dict(
1650
+ # x=1,
1651
+ # y=1.3,
1652
+ # s="Wake",
1653
+ # c="k",
1654
+ # bbox=dict(facecolor="0.8", edgecolor="none", boxstyle="round,pad=0.1"),
1655
+ # ),
1656
+ # dict(
1657
+ # x=1,
1658
+ # y=0.4,
1659
+ # s="Sleep",
1660
+ # c="k",
1661
+ # bbox=dict(facecolor="0.8", edgecolor="none", boxstyle="round,pad=0.05"),
1662
+ # ),
1663
+ # ])
1640
1664
 
1641
1665
  if "mi" in key.lower() and "tic" in key.lower(): # minor_ticks
1642
1666
  if "x" in value.lower() or "x" in key.lower():
@@ -1823,15 +1847,43 @@ def get_color(
1823
1847
  cmap = "grey"
1824
1848
  # Determine color list based on cmap parameter
1825
1849
  if "aut" in cmap:
1826
- colorlist = [
1827
- "#474747",
1828
- "#FF2C00",
1829
- "#0C5DA5",
1830
- "#845B97",
1831
- "#58BBCC",
1832
- "#FF9500",
1833
- "#D57DBE",
1834
- ]
1850
+ if n == 1:
1851
+ colorlist = ["#3A4453"]
1852
+ elif n == 2:
1853
+ colorlist = ["#3A4453", "#DF5932"]
1854
+ elif n == 3:
1855
+ colorlist = ["#3A4453", "#DF5932", "#299D8F"]
1856
+ elif n == 4:
1857
+ # colorlist = ["#3A4453", "#DF5932", "#EBAA00", "#0B4083"]
1858
+ colorlist = ["#81C6BD", "#FBAF63", "#F2675B", "#72A1C9"]
1859
+ elif n == 5:
1860
+ colorlist = [
1861
+ "#3A4453",
1862
+ "#427AB2",
1863
+ "#F09148",
1864
+ "#DBDB8D",
1865
+ "#C59D94",
1866
+ "#AFC7E8",
1867
+ ]
1868
+ elif n == 6:
1869
+ colorlist = [
1870
+ "#3A4453",
1871
+ "#427AB2",
1872
+ "#F09148",
1873
+ "#DBDB8D",
1874
+ "#C59D94",
1875
+ "#E53528",
1876
+ ]
1877
+ else:
1878
+ colorlist = [
1879
+ "#474747",
1880
+ "#FF2C00",
1881
+ "#0C5DA5",
1882
+ "#845B97",
1883
+ "#58BBCC",
1884
+ "#FF9500",
1885
+ "#D57DBE",
1886
+ ]
1835
1887
  by = "start"
1836
1888
  elif any(["cub" in cmap.lower(), "sns" in cmap.lower()]):
1837
1889
  if kwargs:
@@ -1909,6 +1961,10 @@ import matplotlib.pyplot as plt
1909
1961
 
1910
1962
 
1911
1963
  def stdshade(ax=None, *args, **kwargs):
1964
+ """
1965
+ usage:
1966
+ plot.stdshade(data_array, c=clist[1], lw=2, ls="-.", alpha=0.2)
1967
+ """
1912
1968
  # Separate kws_line and kws_fill if necessary
1913
1969
  kws_line = kwargs.pop("kws_line", {})
1914
1970
  kws_fill = kwargs.pop("kws_fill", {})
@@ -1946,8 +2002,9 @@ def stdshade(ax=None, *args, **kwargs):
1946
2002
  ax = plt.gca()
1947
2003
  if ax is None:
1948
2004
  ax = plt.gca()
1949
- alpha = 0.5
1950
- acolor = "k"
2005
+ alpha = kwargs.get("alpha", 0.2)
2006
+ acolor = kwargs.get("color", "k")
2007
+ acolor = kwargs.get("c", "k")
1951
2008
  paraStdSem = "sem"
1952
2009
  plotStyle = "-"
1953
2010
  plotMarker = "none"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: py2ls
3
- Version: 0.1.9.9
3
+ Version: 0.1.10.1
4
4
  Summary: py(thon)2(too)ls
5
5
  Author: Jianfeng
6
6
  Author-email: Jianfeng.Liu0413@gmail.com