magic-pdf 0.9.1__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. magic_pdf/dict2md/ocr_mkcontent.py +1 -1
  2. magic_pdf/libs/Constants.py +3 -1
  3. magic_pdf/libs/config_reader.py +1 -1
  4. magic_pdf/libs/draw_bbox.py +10 -4
  5. magic_pdf/libs/version.py +1 -1
  6. magic_pdf/model/pdf_extract_kit.py +42 -310
  7. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +21 -0
  8. magic_pdf/model/sub_modules/mfd/__init__.py +0 -0
  9. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +12 -0
  10. magic_pdf/model/sub_modules/mfd/yolov8/__init__.py +0 -0
  11. magic_pdf/model/sub_modules/mfr/__init__.py +0 -0
  12. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +98 -0
  13. magic_pdf/model/sub_modules/mfr/unimernet/__init__.py +0 -0
  14. magic_pdf/model/sub_modules/model_init.py +144 -0
  15. magic_pdf/model/sub_modules/model_utils.py +51 -0
  16. magic_pdf/model/sub_modules/ocr/__init__.py +0 -0
  17. magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py +0 -0
  18. magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +259 -0
  19. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +168 -0
  20. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +213 -0
  21. magic_pdf/model/sub_modules/reading_oreder/__init__.py +0 -0
  22. magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py +0 -0
  23. magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py +242 -0
  24. magic_pdf/model/sub_modules/table/__init__.py +0 -0
  25. magic_pdf/model/sub_modules/table/rapidtable/__init__.py +0 -0
  26. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +14 -0
  27. magic_pdf/model/sub_modules/table/structeqtable/__init__.py +0 -0
  28. magic_pdf/model/{pek_sub_modules/structeqtable/StructTableModel.py → sub_modules/table/structeqtable/struct_eqtable.py} +3 -11
  29. magic_pdf/model/sub_modules/table/table_utils.py +11 -0
  30. magic_pdf/model/sub_modules/table/tablemaster/__init__.py +0 -0
  31. magic_pdf/model/{ppTableModel.py → sub_modules/table/tablemaster/tablemaster_paddle.py} +1 -1
  32. magic_pdf/para/para_split_v3.py +13 -15
  33. magic_pdf/pdf_parse_union_core_v2.py +56 -19
  34. magic_pdf/resources/model_config/model_configs.yaml +2 -1
  35. magic_pdf/tools/common.py +47 -3
  36. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/METADATA +35 -25
  37. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/RECORD +65 -44
  38. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/WHEEL +1 -1
  39. magic_pdf/model/pek_sub_modules/post_process.py +0 -36
  40. magic_pdf/model/pek_sub_modules/self_modify.py +0 -388
  41. /magic_pdf/model/{pek_sub_modules → sub_modules}/__init__.py +0 -0
  42. /magic_pdf/model/{pek_sub_modules/layoutlmv3 → sub_modules/layout}/__init__.py +0 -0
  43. /magic_pdf/model/{pek_sub_modules/structeqtable → sub_modules/layout/doclayout_yolo}/__init__.py +0 -0
  44. /magic_pdf/model/{v3 → sub_modules/layout/layoutlmv3}/__init__.py +0 -0
  45. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/backbone.py +0 -0
  46. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/beit.py +0 -0
  47. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/deit.py +0 -0
  48. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/__init__.py +0 -0
  49. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/__init__.py +0 -0
  50. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/cord.py +0 -0
  51. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/data_collator.py +0 -0
  52. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/funsd.py +0 -0
  53. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/image_utils.py +0 -0
  54. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/xfund.py +0 -0
  55. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/__init__.py +0 -0
  56. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +0 -0
  57. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +0 -0
  58. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +0 -0
  59. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +0 -0
  60. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +0 -0
  61. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/model_init.py +0 -0
  62. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/rcnn_vl.py +0 -0
  63. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/visualizer.py +0 -0
  64. /magic_pdf/model/{v3 → sub_modules/reading_oreder/layoutreader}/helpers.py +0 -0
  65. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/LICENSE.md +0 -0
  66. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/entry_points.txt +0 -0
  67. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/top_level.txt +0 -0
File without changes
@@ -0,0 +1,144 @@
1
+ from loguru import logger
2
+
3
+ from magic_pdf.libs.Constants import MODEL_NAME
4
+ from magic_pdf.model.model_list import AtomicModel
5
+ from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
6
+ from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
7
+ from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
8
+
9
+ from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
10
+ from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
11
+ # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
12
+ from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
13
+ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
14
+ from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
15
+
16
+
17
+ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
18
+ if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
19
+ table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
20
+ elif table_model_type == MODEL_NAME.TABLE_MASTER:
21
+ config = {
22
+ "model_dir": model_path,
23
+ "device": _device_
24
+ }
25
+ table_model = TableMasterPaddleModel(config)
26
+ elif table_model_type == MODEL_NAME.RAPID_TABLE:
27
+ table_model = RapidTableModel()
28
+ else:
29
+ logger.error("table model type not allow")
30
+ exit(1)
31
+
32
+ return table_model
33
+
34
+
35
+ def mfd_model_init(weight, device='cpu'):
36
+ mfd_model = YOLOv8MFDModel(weight, device)
37
+ return mfd_model
38
+
39
+
40
+ def mfr_model_init(weight_dir, cfg_path, device='cpu'):
41
+ mfr_model = UnimernetModel(weight_dir, cfg_path, device)
42
+ return mfr_model
43
+
44
+
45
+ def layout_model_init(weight, config_file, device):
46
+ model = Layoutlmv3_Predictor(weight, config_file, device)
47
+ return model
48
+
49
+
50
+ def doclayout_yolo_model_init(weight, device='cpu'):
51
+ model = DocLayoutYOLOModel(weight, device)
52
+ return model
53
+
54
+
55
+ def ocr_model_init(show_log: bool = False,
56
+ det_db_box_thresh=0.3,
57
+ lang=None,
58
+ use_dilation=True,
59
+ det_db_unclip_ratio=1.8,
60
+ ):
61
+ if lang is not None:
62
+ model = ModifiedPaddleOCR(
63
+ show_log=show_log,
64
+ det_db_box_thresh=det_db_box_thresh,
65
+ lang=lang,
66
+ use_dilation=use_dilation,
67
+ det_db_unclip_ratio=det_db_unclip_ratio,
68
+ )
69
+ else:
70
+ model = ModifiedPaddleOCR(
71
+ show_log=show_log,
72
+ det_db_box_thresh=det_db_box_thresh,
73
+ use_dilation=use_dilation,
74
+ det_db_unclip_ratio=det_db_unclip_ratio,
75
+ # use_angle_cls=True,
76
+ )
77
+ return model
78
+
79
+
80
+ class AtomModelSingleton:
81
+ _instance = None
82
+ _models = {}
83
+
84
+ def __new__(cls, *args, **kwargs):
85
+ if cls._instance is None:
86
+ cls._instance = super().__new__(cls)
87
+ return cls._instance
88
+
89
+ def get_atom_model(self, atom_model_name: str, **kwargs):
90
+ lang = kwargs.get("lang", None)
91
+ layout_model_name = kwargs.get("layout_model_name", None)
92
+ key = (atom_model_name, layout_model_name, lang)
93
+ if key not in self._models:
94
+ self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
95
+ return self._models[key]
96
+
97
+
98
+ def atom_model_init(model_name: str, **kwargs):
99
+ atom_model = None
100
+ if model_name == AtomicModel.Layout:
101
+ if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
102
+ atom_model = layout_model_init(
103
+ kwargs.get("layout_weights"),
104
+ kwargs.get("layout_config_file"),
105
+ kwargs.get("device")
106
+ )
107
+ elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
108
+ atom_model = doclayout_yolo_model_init(
109
+ kwargs.get("doclayout_yolo_weights"),
110
+ kwargs.get("device")
111
+ )
112
+ elif model_name == AtomicModel.MFD:
113
+ atom_model = mfd_model_init(
114
+ kwargs.get("mfd_weights"),
115
+ kwargs.get("device")
116
+ )
117
+ elif model_name == AtomicModel.MFR:
118
+ atom_model = mfr_model_init(
119
+ kwargs.get("mfr_weight_dir"),
120
+ kwargs.get("mfr_cfg_path"),
121
+ kwargs.get("device")
122
+ )
123
+ elif model_name == AtomicModel.OCR:
124
+ atom_model = ocr_model_init(
125
+ kwargs.get("ocr_show_log"),
126
+ kwargs.get("det_db_box_thresh"),
127
+ kwargs.get("lang")
128
+ )
129
+ elif model_name == AtomicModel.Table:
130
+ atom_model = table_model_init(
131
+ kwargs.get("table_model_name"),
132
+ kwargs.get("table_model_path"),
133
+ kwargs.get("table_max_time"),
134
+ kwargs.get("device")
135
+ )
136
+ else:
137
+ logger.error("model name not allow")
138
+ exit(1)
139
+
140
+ if atom_model is None:
141
+ logger.error("model init failed")
142
+ exit(1)
143
+ else:
144
+ return atom_model
@@ -0,0 +1,51 @@
1
+ import time
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from loguru import logger
6
+
7
+ from magic_pdf.libs.clean_memory import clean_memory
8
+
9
+
10
+ def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
11
+ crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
12
+ crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
13
+ # Create a white background with an additional width and height of 50
14
+ crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
15
+ crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
16
+ return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
17
+
18
+ # Crop image
19
+ crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
20
+ cropped_img = input_pil_img.crop(crop_box)
21
+ return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
22
+ return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
23
+ return return_image, return_list
24
+
25
+
26
+ # Select regions for OCR / formula regions / table regions
27
+ def get_res_list_from_layout_res(layout_res):
28
+ ocr_res_list = []
29
+ table_res_list = []
30
+ single_page_mfdetrec_res = []
31
+ for res in layout_res:
32
+ if int(res['category_id']) in [13, 14]:
33
+ single_page_mfdetrec_res.append({
34
+ "bbox": [int(res['poly'][0]), int(res['poly'][1]),
35
+ int(res['poly'][4]), int(res['poly'][5])],
36
+ })
37
+ elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
38
+ ocr_res_list.append(res)
39
+ elif int(res['category_id']) in [5]:
40
+ table_res_list.append(res)
41
+ return ocr_res_list, table_res_list, single_page_mfdetrec_res
42
+
43
+
44
+ def clean_vram(device, vram_threshold=8):
45
+ if torch.cuda.is_available() and device != 'cpu':
46
+ total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
47
+ if total_memory <= vram_threshold:
48
+ gc_start = time.time()
49
+ clean_memory()
50
+ gc_time = round(time.time() - gc_start, 2)
51
+ logger.info(f"gc time: {gc_time}")
File without changes
File without changes
@@ -0,0 +1,259 @@
1
+ import math
2
+
3
+ import numpy as np
4
+ from loguru import logger
5
+
6
+ from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
7
+ from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
8
+
9
+
10
+ def bbox_to_points(bbox):
11
+ """ 将bbox格式转换为四个顶点的数组 """
12
+ x0, y0, x1, y1 = bbox
13
+ return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
14
+
15
+
16
+ def points_to_bbox(points):
17
+ """ 将四个顶点的数组转换为bbox格式 """
18
+ x0, y0 = points[0]
19
+ x1, _ = points[1]
20
+ _, y1 = points[2]
21
+ return [x0, y0, x1, y1]
22
+
23
+
24
+ def merge_intervals(intervals):
25
+ # Sort the intervals based on the start value
26
+ intervals.sort(key=lambda x: x[0])
27
+
28
+ merged = []
29
+ for interval in intervals:
30
+ # If the list of merged intervals is empty or if the current
31
+ # interval does not overlap with the previous, simply append it.
32
+ if not merged or merged[-1][1] < interval[0]:
33
+ merged.append(interval)
34
+ else:
35
+ # Otherwise, there is overlap, so we merge the current and previous intervals.
36
+ merged[-1][1] = max(merged[-1][1], interval[1])
37
+
38
+ return merged
39
+
40
+
41
+ def remove_intervals(original, masks):
42
+ # Merge all mask intervals
43
+ merged_masks = merge_intervals(masks)
44
+
45
+ result = []
46
+ original_start, original_end = original
47
+
48
+ for mask in merged_masks:
49
+ mask_start, mask_end = mask
50
+
51
+ # If the mask starts after the original range, ignore it
52
+ if mask_start > original_end:
53
+ continue
54
+
55
+ # If the mask ends before the original range starts, ignore it
56
+ if mask_end < original_start:
57
+ continue
58
+
59
+ # Remove the masked part from the original range
60
+ if original_start < mask_start:
61
+ result.append([original_start, mask_start - 1])
62
+
63
+ original_start = max(mask_end + 1, original_start)
64
+
65
+ # Add the remaining part of the original range, if any
66
+ if original_start <= original_end:
67
+ result.append([original_start, original_end])
68
+
69
+ return result
70
+
71
+
72
+ def update_det_boxes(dt_boxes, mfd_res):
73
+ new_dt_boxes = []
74
+ for text_box in dt_boxes:
75
+ text_bbox = points_to_bbox(text_box)
76
+ masks_list = []
77
+ for mf_box in mfd_res:
78
+ mf_bbox = mf_box['bbox']
79
+ if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
80
+ masks_list.append([mf_bbox[0], mf_bbox[2]])
81
+ text_x_range = [text_bbox[0], text_bbox[2]]
82
+ text_remove_mask_range = remove_intervals(text_x_range, masks_list)
83
+ temp_dt_box = []
84
+ for text_remove_mask in text_remove_mask_range:
85
+ temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
86
+ if len(temp_dt_box) > 0:
87
+ new_dt_boxes.extend(temp_dt_box)
88
+ return new_dt_boxes
89
+
90
+
91
+ def merge_overlapping_spans(spans):
92
+ """
93
+ Merges overlapping spans on the same line.
94
+
95
+ :param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
96
+ :return: A list of merged spans
97
+ """
98
+ # Return an empty list if the input spans list is empty
99
+ if not spans:
100
+ return []
101
+
102
+ # Sort spans by their starting x-coordinate
103
+ spans.sort(key=lambda x: x[0])
104
+
105
+ # Initialize the list of merged spans
106
+ merged = []
107
+ for span in spans:
108
+ # Unpack span coordinates
109
+ x1, y1, x2, y2 = span
110
+ # If the merged list is empty or there's no horizontal overlap, add the span directly
111
+ if not merged or merged[-1][2] < x1:
112
+ merged.append(span)
113
+ else:
114
+ # If there is horizontal overlap, merge the current span with the previous one
115
+ last_span = merged.pop()
116
+ # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
117
+ x1 = min(last_span[0], x1)
118
+ y1 = min(last_span[1], y1)
119
+ x2 = max(last_span[2], x2)
120
+ y2 = max(last_span[3], y2)
121
+ # Add the merged span back to the list
122
+ merged.append((x1, y1, x2, y2))
123
+
124
+ # Return the list of merged spans
125
+ return merged
126
+
127
+
128
+ def merge_det_boxes(dt_boxes):
129
+ """
130
+ Merge detection boxes.
131
+
132
+ This function takes a list of detected bounding boxes, each represented by four corner points.
133
+ The goal is to merge these bounding boxes into larger text regions.
134
+
135
+ Parameters:
136
+ dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
137
+
138
+ Returns:
139
+ list: A list containing the merged text regions, where each region is represented by four corner points.
140
+ """
141
+ # Convert the detection boxes into a dictionary format with bounding boxes and type
142
+ dt_boxes_dict_list = []
143
+ angle_boxes_list = []
144
+ for text_box in dt_boxes:
145
+ text_bbox = points_to_bbox(text_box)
146
+ if text_bbox[2] <= text_bbox[0] or text_bbox[3] <= text_bbox[1]:
147
+ angle_boxes_list.append(text_box)
148
+ continue
149
+ text_box_dict = {
150
+ 'bbox': text_bbox,
151
+ 'type': 'text',
152
+ }
153
+ dt_boxes_dict_list.append(text_box_dict)
154
+
155
+ # Merge adjacent text regions into lines
156
+ lines = merge_spans_to_line(dt_boxes_dict_list)
157
+
158
+ # Initialize a new list for storing the merged text regions
159
+ new_dt_boxes = []
160
+ for line in lines:
161
+ line_bbox_list = []
162
+ for span in line:
163
+ line_bbox_list.append(span['bbox'])
164
+
165
+ # Merge overlapping text regions within the same line
166
+ merged_spans = merge_overlapping_spans(line_bbox_list)
167
+
168
+ # Convert the merged text regions back to point format and add them to the new detection box list
169
+ for span in merged_spans:
170
+ new_dt_boxes.append(bbox_to_points(span))
171
+
172
+ new_dt_boxes.extend(angle_boxes_list)
173
+
174
+ return new_dt_boxes
175
+
176
+
177
+ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
178
+ paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
179
+ # Adjust the coordinates of the formula area
180
+ adjusted_mfdetrec_res = []
181
+ for mf_res in single_page_mfdetrec_res:
182
+ mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
183
+ # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
184
+ x0 = mf_xmin - xmin + paste_x
185
+ y0 = mf_ymin - ymin + paste_y
186
+ x1 = mf_xmax - xmin + paste_x
187
+ y1 = mf_ymax - ymin + paste_y
188
+ # Filter formula blocks outside the graph
189
+ if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
190
+ continue
191
+ else:
192
+ adjusted_mfdetrec_res.append({
193
+ "bbox": [x0, y0, x1, y1],
194
+ })
195
+ return adjusted_mfdetrec_res
196
+
197
+
198
+ def get_ocr_result_list(ocr_res, useful_list):
199
+ paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
200
+ ocr_result_list = []
201
+ for box_ocr_res in ocr_res:
202
+
203
+ p1, p2, p3, p4 = box_ocr_res[0]
204
+ text, score = box_ocr_res[1]
205
+ average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
206
+ if average_angle_degrees > 0.5:
207
+ # logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
208
+ # 与x轴的夹角超过0.5度,对边界做一下矫正
209
+ # 计算几何中心
210
+ x_center = sum(point[0] for point in box_ocr_res[0]) / 4
211
+ y_center = sum(point[1] for point in box_ocr_res[0]) / 4
212
+ new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
213
+ new_width = p3[0] - p1[0]
214
+ p1 = [x_center - new_width / 2, y_center - new_height / 2]
215
+ p2 = [x_center + new_width / 2, y_center - new_height / 2]
216
+ p3 = [x_center + new_width / 2, y_center + new_height / 2]
217
+ p4 = [x_center - new_width / 2, y_center + new_height / 2]
218
+
219
+ # Convert the coordinates back to the original coordinate system
220
+ p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
221
+ p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
222
+ p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
223
+ p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
224
+
225
+ ocr_result_list.append({
226
+ 'category_id': 15,
227
+ 'poly': p1 + p2 + p3 + p4,
228
+ 'score': float(round(score, 2)),
229
+ 'text': text,
230
+ })
231
+
232
+ return ocr_result_list
233
+
234
+
235
+ def calculate_angle_degrees(poly):
236
+ # 定义对角线的顶点
237
+ diagonal1 = (poly[0], poly[2])
238
+ diagonal2 = (poly[1], poly[3])
239
+
240
+ # 计算对角线的斜率
241
+ def slope(p1, p2):
242
+ return (p2[1] - p1[1]) / (p2[0] - p1[0]) if p2[0] != p1[0] else float('inf')
243
+
244
+ slope1 = slope(diagonal1[0], diagonal1[1])
245
+ slope2 = slope(diagonal2[0], diagonal2[1])
246
+
247
+ # 计算对角线与x轴的夹角(以弧度为单位)
248
+ angle1_radians = math.atan(slope1)
249
+ angle2_radians = math.atan(slope2)
250
+
251
+ # 将弧度转换为角度
252
+ angle1_degrees = math.degrees(angle1_radians)
253
+ angle2_degrees = math.degrees(angle2_radians)
254
+
255
+ # 取两条对角线与x轴夹角的平均值
256
+ average_angle_degrees = abs((angle1_degrees + angle2_degrees) / 2)
257
+ # logger.info(f"average_angle_degrees: {average_angle_degrees}")
258
+ return average_angle_degrees
259
+
@@ -0,0 +1,168 @@
1
+ import copy
2
+ import time
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from paddleocr import PaddleOCR
7
+ from paddleocr.paddleocr import check_img, logger
8
+ from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img
9
+ from paddleocr.tools.infer.predict_system import sorted_boxes
10
+ from paddleocr.tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
11
+
12
+ from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes
13
+
14
+
15
+ class ModifiedPaddleOCR(PaddleOCR):
16
+ def ocr(self,
17
+ img,
18
+ det=True,
19
+ rec=True,
20
+ cls=True,
21
+ bin=False,
22
+ inv=False,
23
+ alpha_color=(255, 255, 255),
24
+ mfd_res=None,
25
+ ):
26
+ """
27
+ OCR with PaddleOCR
28
+ args:
29
+ img: img for OCR, support ndarray, img_path and list or ndarray
30
+ det: use text detection or not. If False, only rec will be exec. Default is True
31
+ rec: use text recognition or not. If False, only det will be exec. Default is True
32
+ cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
33
+ bin: binarize image to black and white. Default is False.
34
+ inv: invert image colors. Default is False.
35
+ alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
36
+ """
37
+ assert isinstance(img, (np.ndarray, list, str, bytes))
38
+ if isinstance(img, list) and det == True:
39
+ logger.error('When input a list of images, det must be false')
40
+ exit(0)
41
+ if cls == True and self.use_angle_cls == False:
42
+ pass
43
+ # logger.warning(
44
+ # 'Since the angle classifier is not initialized, it will not be used during the forward process'
45
+ # )
46
+
47
+ img = check_img(img)
48
+ # for infer pdf file
49
+ if isinstance(img, list):
50
+ if self.page_num > len(img) or self.page_num == 0:
51
+ self.page_num = len(img)
52
+ imgs = img[:self.page_num]
53
+ else:
54
+ imgs = [img]
55
+
56
+ def preprocess_image(_image):
57
+ _image = alpha_to_color(_image, alpha_color)
58
+ if inv:
59
+ _image = cv2.bitwise_not(_image)
60
+ if bin:
61
+ _image = binarize_img(_image)
62
+ return _image
63
+
64
+ if det and rec:
65
+ ocr_res = []
66
+ for idx, img in enumerate(imgs):
67
+ img = preprocess_image(img)
68
+ dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)
69
+ if not dt_boxes and not rec_res:
70
+ ocr_res.append(None)
71
+ continue
72
+ tmp_res = [[box.tolist(), res]
73
+ for box, res in zip(dt_boxes, rec_res)]
74
+ ocr_res.append(tmp_res)
75
+ return ocr_res
76
+ elif det and not rec:
77
+ ocr_res = []
78
+ for idx, img in enumerate(imgs):
79
+ img = preprocess_image(img)
80
+ dt_boxes, elapse = self.text_detector(img)
81
+ if not dt_boxes:
82
+ ocr_res.append(None)
83
+ continue
84
+ tmp_res = [box.tolist() for box in dt_boxes]
85
+ ocr_res.append(tmp_res)
86
+ return ocr_res
87
+ else:
88
+ ocr_res = []
89
+ cls_res = []
90
+ for idx, img in enumerate(imgs):
91
+ if not isinstance(img, list):
92
+ img = preprocess_image(img)
93
+ img = [img]
94
+ if self.use_angle_cls and cls:
95
+ img, cls_res_tmp, elapse = self.text_classifier(img)
96
+ if not rec:
97
+ cls_res.append(cls_res_tmp)
98
+ rec_res, elapse = self.text_recognizer(img)
99
+ ocr_res.append(rec_res)
100
+ if not rec:
101
+ return cls_res
102
+ return ocr_res
103
+
104
+ def __call__(self, img, cls=True, mfd_res=None):
105
+ time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
106
+
107
+ if img is None:
108
+ logger.debug("no valid image provided")
109
+ return None, None, time_dict
110
+
111
+ start = time.time()
112
+ ori_im = img.copy()
113
+ dt_boxes, elapse = self.text_detector(img)
114
+ time_dict['det'] = elapse
115
+
116
+ if dt_boxes is None:
117
+ logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
118
+ end = time.time()
119
+ time_dict['all'] = end - start
120
+ return None, None, time_dict
121
+ else:
122
+ logger.debug("dt_boxes num : {}, elapsed : {}".format(
123
+ len(dt_boxes), elapse))
124
+ img_crop_list = []
125
+
126
+ dt_boxes = sorted_boxes(dt_boxes)
127
+
128
+ # @todo 目前是在bbox层merge,对倾斜文本行的兼容性不佳,需要修改成支持poly的merge
129
+ # dt_boxes = merge_det_boxes(dt_boxes)
130
+
131
+
132
+ if mfd_res:
133
+ bef = time.time()
134
+ dt_boxes = update_det_boxes(dt_boxes, mfd_res)
135
+ aft = time.time()
136
+ logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
137
+ len(dt_boxes), aft - bef))
138
+
139
+ for bno in range(len(dt_boxes)):
140
+ tmp_box = copy.deepcopy(dt_boxes[bno])
141
+ if self.args.det_box_type == "quad":
142
+ img_crop = get_rotate_crop_image(ori_im, tmp_box)
143
+ else:
144
+ img_crop = get_minarea_rect_crop(ori_im, tmp_box)
145
+ img_crop_list.append(img_crop)
146
+ if self.use_angle_cls and cls:
147
+ img_crop_list, angle_list, elapse = self.text_classifier(
148
+ img_crop_list)
149
+ time_dict['cls'] = elapse
150
+ logger.debug("cls num : {}, elapsed : {}".format(
151
+ len(img_crop_list), elapse))
152
+
153
+ rec_res, elapse = self.text_recognizer(img_crop_list)
154
+ time_dict['rec'] = elapse
155
+ logger.debug("rec_res num : {}, elapsed : {}".format(
156
+ len(rec_res), elapse))
157
+ if self.args.save_crop_res:
158
+ self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
159
+ rec_res)
160
+ filter_boxes, filter_rec_res = [], []
161
+ for box, rec_result in zip(dt_boxes, rec_res):
162
+ text, score = rec_result
163
+ if score >= self.drop_score:
164
+ filter_boxes.append(box)
165
+ filter_rec_res.append(rec_result)
166
+ end = time.time()
167
+ time_dict['all'] = end - start
168
+ return filter_boxes, filter_rec_res, time_dict