magic-pdf 0.9.2__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/dict2md/ocr_mkcontent.py +1 -1
- magic_pdf/libs/Constants.py +3 -1
- magic_pdf/libs/config_reader.py +1 -1
- magic_pdf/libs/draw_bbox.py +10 -4
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/pdf_extract_kit.py +42 -297
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +21 -0
- magic_pdf/model/sub_modules/mfd/__init__.py +0 -0
- magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +12 -0
- magic_pdf/model/sub_modules/mfd/yolov8/__init__.py +0 -0
- magic_pdf/model/sub_modules/mfr/__init__.py +0 -0
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +98 -0
- magic_pdf/model/sub_modules/mfr/unimernet/__init__.py +0 -0
- magic_pdf/model/sub_modules/model_init.py +144 -0
- magic_pdf/model/sub_modules/model_utils.py +51 -0
- magic_pdf/model/sub_modules/ocr/__init__.py +0 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py +0 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +259 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +168 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +213 -0
- magic_pdf/model/sub_modules/reading_oreder/__init__.py +0 -0
- magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py +0 -0
- magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py +242 -0
- magic_pdf/model/sub_modules/table/__init__.py +0 -0
- magic_pdf/model/sub_modules/table/rapidtable/__init__.py +0 -0
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +14 -0
- magic_pdf/model/sub_modules/table/structeqtable/__init__.py +0 -0
- magic_pdf/model/{pek_sub_modules/structeqtable/StructTableModel.py → sub_modules/table/structeqtable/struct_eqtable.py} +3 -11
- magic_pdf/model/sub_modules/table/table_utils.py +11 -0
- magic_pdf/model/sub_modules/table/tablemaster/__init__.py +0 -0
- magic_pdf/model/{ppTableModel.py → sub_modules/table/tablemaster/tablemaster_paddle.py} +1 -1
- magic_pdf/para/para_split_v3.py +13 -15
- magic_pdf/pdf_parse_union_core_v2.py +56 -19
- magic_pdf/resources/model_config/model_configs.yaml +2 -1
- magic_pdf/tools/common.py +47 -3
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/METADATA +9 -3
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/RECORD +65 -44
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/WHEEL +1 -1
- magic_pdf/model/pek_sub_modules/post_process.py +0 -36
- magic_pdf/model/pek_sub_modules/self_modify.py +0 -388
- /magic_pdf/model/{pek_sub_modules → sub_modules}/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules/layoutlmv3 → sub_modules/layout}/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules/structeqtable → sub_modules/layout/doclayout_yolo}/__init__.py +0 -0
- /magic_pdf/model/{v3 → sub_modules/layout/layoutlmv3}/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/backbone.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/beit.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/deit.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/cord.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/data_collator.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/funsd.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/image_utils.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/xfund.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/model_init.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/rcnn_vl.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/visualizer.py +0 -0
- /magic_pdf/model/{v3 → sub_modules/reading_oreder/layoutreader}/helpers.py +0 -0
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,144 @@
|
|
1
|
+
from loguru import logger
|
2
|
+
|
3
|
+
from magic_pdf.libs.Constants import MODEL_NAME
|
4
|
+
from magic_pdf.model.model_list import AtomicModel
|
5
|
+
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
|
6
|
+
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
|
7
|
+
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
|
8
|
+
|
9
|
+
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
|
10
|
+
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
|
11
|
+
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
|
12
|
+
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
|
13
|
+
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
|
14
|
+
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
|
15
|
+
|
16
|
+
|
17
|
+
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
18
|
+
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
19
|
+
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
|
20
|
+
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
21
|
+
config = {
|
22
|
+
"model_dir": model_path,
|
23
|
+
"device": _device_
|
24
|
+
}
|
25
|
+
table_model = TableMasterPaddleModel(config)
|
26
|
+
elif table_model_type == MODEL_NAME.RAPID_TABLE:
|
27
|
+
table_model = RapidTableModel()
|
28
|
+
else:
|
29
|
+
logger.error("table model type not allow")
|
30
|
+
exit(1)
|
31
|
+
|
32
|
+
return table_model
|
33
|
+
|
34
|
+
|
35
|
+
def mfd_model_init(weight, device='cpu'):
|
36
|
+
mfd_model = YOLOv8MFDModel(weight, device)
|
37
|
+
return mfd_model
|
38
|
+
|
39
|
+
|
40
|
+
def mfr_model_init(weight_dir, cfg_path, device='cpu'):
|
41
|
+
mfr_model = UnimernetModel(weight_dir, cfg_path, device)
|
42
|
+
return mfr_model
|
43
|
+
|
44
|
+
|
45
|
+
def layout_model_init(weight, config_file, device):
|
46
|
+
model = Layoutlmv3_Predictor(weight, config_file, device)
|
47
|
+
return model
|
48
|
+
|
49
|
+
|
50
|
+
def doclayout_yolo_model_init(weight, device='cpu'):
|
51
|
+
model = DocLayoutYOLOModel(weight, device)
|
52
|
+
return model
|
53
|
+
|
54
|
+
|
55
|
+
def ocr_model_init(show_log: bool = False,
|
56
|
+
det_db_box_thresh=0.3,
|
57
|
+
lang=None,
|
58
|
+
use_dilation=True,
|
59
|
+
det_db_unclip_ratio=1.8,
|
60
|
+
):
|
61
|
+
if lang is not None:
|
62
|
+
model = ModifiedPaddleOCR(
|
63
|
+
show_log=show_log,
|
64
|
+
det_db_box_thresh=det_db_box_thresh,
|
65
|
+
lang=lang,
|
66
|
+
use_dilation=use_dilation,
|
67
|
+
det_db_unclip_ratio=det_db_unclip_ratio,
|
68
|
+
)
|
69
|
+
else:
|
70
|
+
model = ModifiedPaddleOCR(
|
71
|
+
show_log=show_log,
|
72
|
+
det_db_box_thresh=det_db_box_thresh,
|
73
|
+
use_dilation=use_dilation,
|
74
|
+
det_db_unclip_ratio=det_db_unclip_ratio,
|
75
|
+
# use_angle_cls=True,
|
76
|
+
)
|
77
|
+
return model
|
78
|
+
|
79
|
+
|
80
|
+
class AtomModelSingleton:
|
81
|
+
_instance = None
|
82
|
+
_models = {}
|
83
|
+
|
84
|
+
def __new__(cls, *args, **kwargs):
|
85
|
+
if cls._instance is None:
|
86
|
+
cls._instance = super().__new__(cls)
|
87
|
+
return cls._instance
|
88
|
+
|
89
|
+
def get_atom_model(self, atom_model_name: str, **kwargs):
|
90
|
+
lang = kwargs.get("lang", None)
|
91
|
+
layout_model_name = kwargs.get("layout_model_name", None)
|
92
|
+
key = (atom_model_name, layout_model_name, lang)
|
93
|
+
if key not in self._models:
|
94
|
+
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
95
|
+
return self._models[key]
|
96
|
+
|
97
|
+
|
98
|
+
def atom_model_init(model_name: str, **kwargs):
|
99
|
+
atom_model = None
|
100
|
+
if model_name == AtomicModel.Layout:
|
101
|
+
if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
|
102
|
+
atom_model = layout_model_init(
|
103
|
+
kwargs.get("layout_weights"),
|
104
|
+
kwargs.get("layout_config_file"),
|
105
|
+
kwargs.get("device")
|
106
|
+
)
|
107
|
+
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
|
108
|
+
atom_model = doclayout_yolo_model_init(
|
109
|
+
kwargs.get("doclayout_yolo_weights"),
|
110
|
+
kwargs.get("device")
|
111
|
+
)
|
112
|
+
elif model_name == AtomicModel.MFD:
|
113
|
+
atom_model = mfd_model_init(
|
114
|
+
kwargs.get("mfd_weights"),
|
115
|
+
kwargs.get("device")
|
116
|
+
)
|
117
|
+
elif model_name == AtomicModel.MFR:
|
118
|
+
atom_model = mfr_model_init(
|
119
|
+
kwargs.get("mfr_weight_dir"),
|
120
|
+
kwargs.get("mfr_cfg_path"),
|
121
|
+
kwargs.get("device")
|
122
|
+
)
|
123
|
+
elif model_name == AtomicModel.OCR:
|
124
|
+
atom_model = ocr_model_init(
|
125
|
+
kwargs.get("ocr_show_log"),
|
126
|
+
kwargs.get("det_db_box_thresh"),
|
127
|
+
kwargs.get("lang")
|
128
|
+
)
|
129
|
+
elif model_name == AtomicModel.Table:
|
130
|
+
atom_model = table_model_init(
|
131
|
+
kwargs.get("table_model_name"),
|
132
|
+
kwargs.get("table_model_path"),
|
133
|
+
kwargs.get("table_max_time"),
|
134
|
+
kwargs.get("device")
|
135
|
+
)
|
136
|
+
else:
|
137
|
+
logger.error("model name not allow")
|
138
|
+
exit(1)
|
139
|
+
|
140
|
+
if atom_model is None:
|
141
|
+
logger.error("model init failed")
|
142
|
+
exit(1)
|
143
|
+
else:
|
144
|
+
return atom_model
|
@@ -0,0 +1,51 @@
|
|
1
|
+
import time
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from PIL import Image
|
5
|
+
from loguru import logger
|
6
|
+
|
7
|
+
from magic_pdf.libs.clean_memory import clean_memory
|
8
|
+
|
9
|
+
|
10
|
+
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
|
11
|
+
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
|
12
|
+
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
|
13
|
+
# Create a white background with an additional width and height of 50
|
14
|
+
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
|
15
|
+
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
|
16
|
+
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
|
17
|
+
|
18
|
+
# Crop image
|
19
|
+
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
|
20
|
+
cropped_img = input_pil_img.crop(crop_box)
|
21
|
+
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
|
22
|
+
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
|
23
|
+
return return_image, return_list
|
24
|
+
|
25
|
+
|
26
|
+
# Select regions for OCR / formula regions / table regions
|
27
|
+
def get_res_list_from_layout_res(layout_res):
|
28
|
+
ocr_res_list = []
|
29
|
+
table_res_list = []
|
30
|
+
single_page_mfdetrec_res = []
|
31
|
+
for res in layout_res:
|
32
|
+
if int(res['category_id']) in [13, 14]:
|
33
|
+
single_page_mfdetrec_res.append({
|
34
|
+
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
|
35
|
+
int(res['poly'][4]), int(res['poly'][5])],
|
36
|
+
})
|
37
|
+
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
|
38
|
+
ocr_res_list.append(res)
|
39
|
+
elif int(res['category_id']) in [5]:
|
40
|
+
table_res_list.append(res)
|
41
|
+
return ocr_res_list, table_res_list, single_page_mfdetrec_res
|
42
|
+
|
43
|
+
|
44
|
+
def clean_vram(device, vram_threshold=8):
|
45
|
+
if torch.cuda.is_available() and device != 'cpu':
|
46
|
+
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
|
47
|
+
if total_memory <= vram_threshold:
|
48
|
+
gc_start = time.time()
|
49
|
+
clean_memory()
|
50
|
+
gc_time = round(time.time() - gc_start, 2)
|
51
|
+
logger.info(f"gc time: {gc_time}")
|
File without changes
|
File without changes
|
@@ -0,0 +1,259 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from loguru import logger
|
5
|
+
|
6
|
+
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
|
7
|
+
from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
|
8
|
+
|
9
|
+
|
10
|
+
def bbox_to_points(bbox):
|
11
|
+
""" 将bbox格式转换为四个顶点的数组 """
|
12
|
+
x0, y0, x1, y1 = bbox
|
13
|
+
return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
|
14
|
+
|
15
|
+
|
16
|
+
def points_to_bbox(points):
|
17
|
+
""" 将四个顶点的数组转换为bbox格式 """
|
18
|
+
x0, y0 = points[0]
|
19
|
+
x1, _ = points[1]
|
20
|
+
_, y1 = points[2]
|
21
|
+
return [x0, y0, x1, y1]
|
22
|
+
|
23
|
+
|
24
|
+
def merge_intervals(intervals):
|
25
|
+
# Sort the intervals based on the start value
|
26
|
+
intervals.sort(key=lambda x: x[0])
|
27
|
+
|
28
|
+
merged = []
|
29
|
+
for interval in intervals:
|
30
|
+
# If the list of merged intervals is empty or if the current
|
31
|
+
# interval does not overlap with the previous, simply append it.
|
32
|
+
if not merged or merged[-1][1] < interval[0]:
|
33
|
+
merged.append(interval)
|
34
|
+
else:
|
35
|
+
# Otherwise, there is overlap, so we merge the current and previous intervals.
|
36
|
+
merged[-1][1] = max(merged[-1][1], interval[1])
|
37
|
+
|
38
|
+
return merged
|
39
|
+
|
40
|
+
|
41
|
+
def remove_intervals(original, masks):
|
42
|
+
# Merge all mask intervals
|
43
|
+
merged_masks = merge_intervals(masks)
|
44
|
+
|
45
|
+
result = []
|
46
|
+
original_start, original_end = original
|
47
|
+
|
48
|
+
for mask in merged_masks:
|
49
|
+
mask_start, mask_end = mask
|
50
|
+
|
51
|
+
# If the mask starts after the original range, ignore it
|
52
|
+
if mask_start > original_end:
|
53
|
+
continue
|
54
|
+
|
55
|
+
# If the mask ends before the original range starts, ignore it
|
56
|
+
if mask_end < original_start:
|
57
|
+
continue
|
58
|
+
|
59
|
+
# Remove the masked part from the original range
|
60
|
+
if original_start < mask_start:
|
61
|
+
result.append([original_start, mask_start - 1])
|
62
|
+
|
63
|
+
original_start = max(mask_end + 1, original_start)
|
64
|
+
|
65
|
+
# Add the remaining part of the original range, if any
|
66
|
+
if original_start <= original_end:
|
67
|
+
result.append([original_start, original_end])
|
68
|
+
|
69
|
+
return result
|
70
|
+
|
71
|
+
|
72
|
+
def update_det_boxes(dt_boxes, mfd_res):
|
73
|
+
new_dt_boxes = []
|
74
|
+
for text_box in dt_boxes:
|
75
|
+
text_bbox = points_to_bbox(text_box)
|
76
|
+
masks_list = []
|
77
|
+
for mf_box in mfd_res:
|
78
|
+
mf_bbox = mf_box['bbox']
|
79
|
+
if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
|
80
|
+
masks_list.append([mf_bbox[0], mf_bbox[2]])
|
81
|
+
text_x_range = [text_bbox[0], text_bbox[2]]
|
82
|
+
text_remove_mask_range = remove_intervals(text_x_range, masks_list)
|
83
|
+
temp_dt_box = []
|
84
|
+
for text_remove_mask in text_remove_mask_range:
|
85
|
+
temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
|
86
|
+
if len(temp_dt_box) > 0:
|
87
|
+
new_dt_boxes.extend(temp_dt_box)
|
88
|
+
return new_dt_boxes
|
89
|
+
|
90
|
+
|
91
|
+
def merge_overlapping_spans(spans):
|
92
|
+
"""
|
93
|
+
Merges overlapping spans on the same line.
|
94
|
+
|
95
|
+
:param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
|
96
|
+
:return: A list of merged spans
|
97
|
+
"""
|
98
|
+
# Return an empty list if the input spans list is empty
|
99
|
+
if not spans:
|
100
|
+
return []
|
101
|
+
|
102
|
+
# Sort spans by their starting x-coordinate
|
103
|
+
spans.sort(key=lambda x: x[0])
|
104
|
+
|
105
|
+
# Initialize the list of merged spans
|
106
|
+
merged = []
|
107
|
+
for span in spans:
|
108
|
+
# Unpack span coordinates
|
109
|
+
x1, y1, x2, y2 = span
|
110
|
+
# If the merged list is empty or there's no horizontal overlap, add the span directly
|
111
|
+
if not merged or merged[-1][2] < x1:
|
112
|
+
merged.append(span)
|
113
|
+
else:
|
114
|
+
# If there is horizontal overlap, merge the current span with the previous one
|
115
|
+
last_span = merged.pop()
|
116
|
+
# Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
|
117
|
+
x1 = min(last_span[0], x1)
|
118
|
+
y1 = min(last_span[1], y1)
|
119
|
+
x2 = max(last_span[2], x2)
|
120
|
+
y2 = max(last_span[3], y2)
|
121
|
+
# Add the merged span back to the list
|
122
|
+
merged.append((x1, y1, x2, y2))
|
123
|
+
|
124
|
+
# Return the list of merged spans
|
125
|
+
return merged
|
126
|
+
|
127
|
+
|
128
|
+
def merge_det_boxes(dt_boxes):
|
129
|
+
"""
|
130
|
+
Merge detection boxes.
|
131
|
+
|
132
|
+
This function takes a list of detected bounding boxes, each represented by four corner points.
|
133
|
+
The goal is to merge these bounding boxes into larger text regions.
|
134
|
+
|
135
|
+
Parameters:
|
136
|
+
dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
list: A list containing the merged text regions, where each region is represented by four corner points.
|
140
|
+
"""
|
141
|
+
# Convert the detection boxes into a dictionary format with bounding boxes and type
|
142
|
+
dt_boxes_dict_list = []
|
143
|
+
angle_boxes_list = []
|
144
|
+
for text_box in dt_boxes:
|
145
|
+
text_bbox = points_to_bbox(text_box)
|
146
|
+
if text_bbox[2] <= text_bbox[0] or text_bbox[3] <= text_bbox[1]:
|
147
|
+
angle_boxes_list.append(text_box)
|
148
|
+
continue
|
149
|
+
text_box_dict = {
|
150
|
+
'bbox': text_bbox,
|
151
|
+
'type': 'text',
|
152
|
+
}
|
153
|
+
dt_boxes_dict_list.append(text_box_dict)
|
154
|
+
|
155
|
+
# Merge adjacent text regions into lines
|
156
|
+
lines = merge_spans_to_line(dt_boxes_dict_list)
|
157
|
+
|
158
|
+
# Initialize a new list for storing the merged text regions
|
159
|
+
new_dt_boxes = []
|
160
|
+
for line in lines:
|
161
|
+
line_bbox_list = []
|
162
|
+
for span in line:
|
163
|
+
line_bbox_list.append(span['bbox'])
|
164
|
+
|
165
|
+
# Merge overlapping text regions within the same line
|
166
|
+
merged_spans = merge_overlapping_spans(line_bbox_list)
|
167
|
+
|
168
|
+
# Convert the merged text regions back to point format and add them to the new detection box list
|
169
|
+
for span in merged_spans:
|
170
|
+
new_dt_boxes.append(bbox_to_points(span))
|
171
|
+
|
172
|
+
new_dt_boxes.extend(angle_boxes_list)
|
173
|
+
|
174
|
+
return new_dt_boxes
|
175
|
+
|
176
|
+
|
177
|
+
def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
|
178
|
+
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
|
179
|
+
# Adjust the coordinates of the formula area
|
180
|
+
adjusted_mfdetrec_res = []
|
181
|
+
for mf_res in single_page_mfdetrec_res:
|
182
|
+
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
|
183
|
+
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
|
184
|
+
x0 = mf_xmin - xmin + paste_x
|
185
|
+
y0 = mf_ymin - ymin + paste_y
|
186
|
+
x1 = mf_xmax - xmin + paste_x
|
187
|
+
y1 = mf_ymax - ymin + paste_y
|
188
|
+
# Filter formula blocks outside the graph
|
189
|
+
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
|
190
|
+
continue
|
191
|
+
else:
|
192
|
+
adjusted_mfdetrec_res.append({
|
193
|
+
"bbox": [x0, y0, x1, y1],
|
194
|
+
})
|
195
|
+
return adjusted_mfdetrec_res
|
196
|
+
|
197
|
+
|
198
|
+
def get_ocr_result_list(ocr_res, useful_list):
|
199
|
+
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
|
200
|
+
ocr_result_list = []
|
201
|
+
for box_ocr_res in ocr_res:
|
202
|
+
|
203
|
+
p1, p2, p3, p4 = box_ocr_res[0]
|
204
|
+
text, score = box_ocr_res[1]
|
205
|
+
average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
|
206
|
+
if average_angle_degrees > 0.5:
|
207
|
+
# logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
|
208
|
+
# 与x轴的夹角超过0.5度,对边界做一下矫正
|
209
|
+
# 计算几何中心
|
210
|
+
x_center = sum(point[0] for point in box_ocr_res[0]) / 4
|
211
|
+
y_center = sum(point[1] for point in box_ocr_res[0]) / 4
|
212
|
+
new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
|
213
|
+
new_width = p3[0] - p1[0]
|
214
|
+
p1 = [x_center - new_width / 2, y_center - new_height / 2]
|
215
|
+
p2 = [x_center + new_width / 2, y_center - new_height / 2]
|
216
|
+
p3 = [x_center + new_width / 2, y_center + new_height / 2]
|
217
|
+
p4 = [x_center - new_width / 2, y_center + new_height / 2]
|
218
|
+
|
219
|
+
# Convert the coordinates back to the original coordinate system
|
220
|
+
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
|
221
|
+
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
|
222
|
+
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
|
223
|
+
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
|
224
|
+
|
225
|
+
ocr_result_list.append({
|
226
|
+
'category_id': 15,
|
227
|
+
'poly': p1 + p2 + p3 + p4,
|
228
|
+
'score': float(round(score, 2)),
|
229
|
+
'text': text,
|
230
|
+
})
|
231
|
+
|
232
|
+
return ocr_result_list
|
233
|
+
|
234
|
+
|
235
|
+
def calculate_angle_degrees(poly):
|
236
|
+
# 定义对角线的顶点
|
237
|
+
diagonal1 = (poly[0], poly[2])
|
238
|
+
diagonal2 = (poly[1], poly[3])
|
239
|
+
|
240
|
+
# 计算对角线的斜率
|
241
|
+
def slope(p1, p2):
|
242
|
+
return (p2[1] - p1[1]) / (p2[0] - p1[0]) if p2[0] != p1[0] else float('inf')
|
243
|
+
|
244
|
+
slope1 = slope(diagonal1[0], diagonal1[1])
|
245
|
+
slope2 = slope(diagonal2[0], diagonal2[1])
|
246
|
+
|
247
|
+
# 计算对角线与x轴的夹角(以弧度为单位)
|
248
|
+
angle1_radians = math.atan(slope1)
|
249
|
+
angle2_radians = math.atan(slope2)
|
250
|
+
|
251
|
+
# 将弧度转换为角度
|
252
|
+
angle1_degrees = math.degrees(angle1_radians)
|
253
|
+
angle2_degrees = math.degrees(angle2_radians)
|
254
|
+
|
255
|
+
# 取两条对角线与x轴夹角的平均值
|
256
|
+
average_angle_degrees = abs((angle1_degrees + angle2_degrees) / 2)
|
257
|
+
# logger.info(f"average_angle_degrees: {average_angle_degrees}")
|
258
|
+
return average_angle_degrees
|
259
|
+
|
@@ -0,0 +1,168 @@
|
|
1
|
+
import copy
|
2
|
+
import time
|
3
|
+
|
4
|
+
import cv2
|
5
|
+
import numpy as np
|
6
|
+
from paddleocr import PaddleOCR
|
7
|
+
from paddleocr.paddleocr import check_img, logger
|
8
|
+
from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img
|
9
|
+
from paddleocr.tools.infer.predict_system import sorted_boxes
|
10
|
+
from paddleocr.tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
|
11
|
+
|
12
|
+
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes
|
13
|
+
|
14
|
+
|
15
|
+
class ModifiedPaddleOCR(PaddleOCR):
|
16
|
+
def ocr(self,
|
17
|
+
img,
|
18
|
+
det=True,
|
19
|
+
rec=True,
|
20
|
+
cls=True,
|
21
|
+
bin=False,
|
22
|
+
inv=False,
|
23
|
+
alpha_color=(255, 255, 255),
|
24
|
+
mfd_res=None,
|
25
|
+
):
|
26
|
+
"""
|
27
|
+
OCR with PaddleOCR
|
28
|
+
args:
|
29
|
+
img: img for OCR, support ndarray, img_path and list or ndarray
|
30
|
+
det: use text detection or not. If False, only rec will be exec. Default is True
|
31
|
+
rec: use text recognition or not. If False, only det will be exec. Default is True
|
32
|
+
cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
|
33
|
+
bin: binarize image to black and white. Default is False.
|
34
|
+
inv: invert image colors. Default is False.
|
35
|
+
alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
|
36
|
+
"""
|
37
|
+
assert isinstance(img, (np.ndarray, list, str, bytes))
|
38
|
+
if isinstance(img, list) and det == True:
|
39
|
+
logger.error('When input a list of images, det must be false')
|
40
|
+
exit(0)
|
41
|
+
if cls == True and self.use_angle_cls == False:
|
42
|
+
pass
|
43
|
+
# logger.warning(
|
44
|
+
# 'Since the angle classifier is not initialized, it will not be used during the forward process'
|
45
|
+
# )
|
46
|
+
|
47
|
+
img = check_img(img)
|
48
|
+
# for infer pdf file
|
49
|
+
if isinstance(img, list):
|
50
|
+
if self.page_num > len(img) or self.page_num == 0:
|
51
|
+
self.page_num = len(img)
|
52
|
+
imgs = img[:self.page_num]
|
53
|
+
else:
|
54
|
+
imgs = [img]
|
55
|
+
|
56
|
+
def preprocess_image(_image):
|
57
|
+
_image = alpha_to_color(_image, alpha_color)
|
58
|
+
if inv:
|
59
|
+
_image = cv2.bitwise_not(_image)
|
60
|
+
if bin:
|
61
|
+
_image = binarize_img(_image)
|
62
|
+
return _image
|
63
|
+
|
64
|
+
if det and rec:
|
65
|
+
ocr_res = []
|
66
|
+
for idx, img in enumerate(imgs):
|
67
|
+
img = preprocess_image(img)
|
68
|
+
dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)
|
69
|
+
if not dt_boxes and not rec_res:
|
70
|
+
ocr_res.append(None)
|
71
|
+
continue
|
72
|
+
tmp_res = [[box.tolist(), res]
|
73
|
+
for box, res in zip(dt_boxes, rec_res)]
|
74
|
+
ocr_res.append(tmp_res)
|
75
|
+
return ocr_res
|
76
|
+
elif det and not rec:
|
77
|
+
ocr_res = []
|
78
|
+
for idx, img in enumerate(imgs):
|
79
|
+
img = preprocess_image(img)
|
80
|
+
dt_boxes, elapse = self.text_detector(img)
|
81
|
+
if not dt_boxes:
|
82
|
+
ocr_res.append(None)
|
83
|
+
continue
|
84
|
+
tmp_res = [box.tolist() for box in dt_boxes]
|
85
|
+
ocr_res.append(tmp_res)
|
86
|
+
return ocr_res
|
87
|
+
else:
|
88
|
+
ocr_res = []
|
89
|
+
cls_res = []
|
90
|
+
for idx, img in enumerate(imgs):
|
91
|
+
if not isinstance(img, list):
|
92
|
+
img = preprocess_image(img)
|
93
|
+
img = [img]
|
94
|
+
if self.use_angle_cls and cls:
|
95
|
+
img, cls_res_tmp, elapse = self.text_classifier(img)
|
96
|
+
if not rec:
|
97
|
+
cls_res.append(cls_res_tmp)
|
98
|
+
rec_res, elapse = self.text_recognizer(img)
|
99
|
+
ocr_res.append(rec_res)
|
100
|
+
if not rec:
|
101
|
+
return cls_res
|
102
|
+
return ocr_res
|
103
|
+
|
104
|
+
def __call__(self, img, cls=True, mfd_res=None):
|
105
|
+
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
106
|
+
|
107
|
+
if img is None:
|
108
|
+
logger.debug("no valid image provided")
|
109
|
+
return None, None, time_dict
|
110
|
+
|
111
|
+
start = time.time()
|
112
|
+
ori_im = img.copy()
|
113
|
+
dt_boxes, elapse = self.text_detector(img)
|
114
|
+
time_dict['det'] = elapse
|
115
|
+
|
116
|
+
if dt_boxes is None:
|
117
|
+
logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
|
118
|
+
end = time.time()
|
119
|
+
time_dict['all'] = end - start
|
120
|
+
return None, None, time_dict
|
121
|
+
else:
|
122
|
+
logger.debug("dt_boxes num : {}, elapsed : {}".format(
|
123
|
+
len(dt_boxes), elapse))
|
124
|
+
img_crop_list = []
|
125
|
+
|
126
|
+
dt_boxes = sorted_boxes(dt_boxes)
|
127
|
+
|
128
|
+
# @todo 目前是在bbox层merge,对倾斜文本行的兼容性不佳,需要修改成支持poly的merge
|
129
|
+
# dt_boxes = merge_det_boxes(dt_boxes)
|
130
|
+
|
131
|
+
|
132
|
+
if mfd_res:
|
133
|
+
bef = time.time()
|
134
|
+
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
|
135
|
+
aft = time.time()
|
136
|
+
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
|
137
|
+
len(dt_boxes), aft - bef))
|
138
|
+
|
139
|
+
for bno in range(len(dt_boxes)):
|
140
|
+
tmp_box = copy.deepcopy(dt_boxes[bno])
|
141
|
+
if self.args.det_box_type == "quad":
|
142
|
+
img_crop = get_rotate_crop_image(ori_im, tmp_box)
|
143
|
+
else:
|
144
|
+
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
|
145
|
+
img_crop_list.append(img_crop)
|
146
|
+
if self.use_angle_cls and cls:
|
147
|
+
img_crop_list, angle_list, elapse = self.text_classifier(
|
148
|
+
img_crop_list)
|
149
|
+
time_dict['cls'] = elapse
|
150
|
+
logger.debug("cls num : {}, elapsed : {}".format(
|
151
|
+
len(img_crop_list), elapse))
|
152
|
+
|
153
|
+
rec_res, elapse = self.text_recognizer(img_crop_list)
|
154
|
+
time_dict['rec'] = elapse
|
155
|
+
logger.debug("rec_res num : {}, elapsed : {}".format(
|
156
|
+
len(rec_res), elapse))
|
157
|
+
if self.args.save_crop_res:
|
158
|
+
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
|
159
|
+
rec_res)
|
160
|
+
filter_boxes, filter_rec_res = [], []
|
161
|
+
for box, rec_result in zip(dt_boxes, rec_res):
|
162
|
+
text, score = rec_result
|
163
|
+
if score >= self.drop_score:
|
164
|
+
filter_boxes.append(box)
|
165
|
+
filter_rec_res.append(rec_result)
|
166
|
+
end = time.time()
|
167
|
+
time_dict['all'] = end - start
|
168
|
+
return filter_boxes, filter_rec_res, time_dict
|