yomitoku 0.4.0.post1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yomitoku/__init__.py +20 -0
- yomitoku/base.py +136 -0
- yomitoku/cli/__init__.py +0 -0
- yomitoku/cli/main.py +230 -0
- yomitoku/configs/__init__.py +13 -0
- yomitoku/configs/cfg_layout_parser_rtdtrv2.py +89 -0
- yomitoku/configs/cfg_table_structure_recognizer_rtdtrv2.py +80 -0
- yomitoku/configs/cfg_text_detector_dbnet.py +49 -0
- yomitoku/configs/cfg_text_recognizer_parseq.py +51 -0
- yomitoku/constants.py +32 -0
- yomitoku/data/__init__.py +3 -0
- yomitoku/data/dataset.py +40 -0
- yomitoku/data/functions.py +279 -0
- yomitoku/document_analyzer.py +315 -0
- yomitoku/export/__init__.py +6 -0
- yomitoku/export/export_csv.py +71 -0
- yomitoku/export/export_html.py +188 -0
- yomitoku/export/export_json.py +34 -0
- yomitoku/export/export_markdown.py +145 -0
- yomitoku/layout_analyzer.py +66 -0
- yomitoku/layout_parser.py +189 -0
- yomitoku/models/__init__.py +9 -0
- yomitoku/models/dbnet_plus.py +272 -0
- yomitoku/models/layers/__init__.py +0 -0
- yomitoku/models/layers/activate.py +38 -0
- yomitoku/models/layers/dbnet_feature_attention.py +160 -0
- yomitoku/models/layers/parseq_transformer.py +218 -0
- yomitoku/models/layers/rtdetr_backbone.py +333 -0
- yomitoku/models/layers/rtdetr_hybrid_encoder.py +433 -0
- yomitoku/models/layers/rtdetrv2_decoder.py +811 -0
- yomitoku/models/parseq.py +243 -0
- yomitoku/models/rtdetr.py +22 -0
- yomitoku/ocr.py +87 -0
- yomitoku/postprocessor/__init__.py +9 -0
- yomitoku/postprocessor/dbnet_postporcessor.py +137 -0
- yomitoku/postprocessor/parseq_tokenizer.py +128 -0
- yomitoku/postprocessor/rtdetr_postprocessor.py +107 -0
- yomitoku/reading_order.py +214 -0
- yomitoku/resource/MPLUS1p-Medium.ttf +0 -0
- yomitoku/resource/charset.txt +1 -0
- yomitoku/table_structure_recognizer.py +244 -0
- yomitoku/text_detector.py +103 -0
- yomitoku/text_recognizer.py +128 -0
- yomitoku/utils/__init__.py +0 -0
- yomitoku/utils/graph.py +20 -0
- yomitoku/utils/logger.py +15 -0
- yomitoku/utils/misc.py +102 -0
- yomitoku/utils/visualizer.py +179 -0
- yomitoku-0.4.0.post1.dev0.dist-info/METADATA +127 -0
- yomitoku-0.4.0.post1.dev0.dist-info/RECORD +52 -0
- yomitoku-0.4.0.post1.dev0.dist-info/WHEEL +4 -0
- yomitoku-0.4.0.post1.dev0.dist-info/entry_points.txt +2 -0
yomitoku/data/dataset.py
ADDED
@@ -0,0 +1,40 @@
|
|
1
|
+
from torch.utils.data import Dataset
|
2
|
+
from torchvision import transforms as T
|
3
|
+
|
4
|
+
from .functions import (
|
5
|
+
extract_roi_with_perspective,
|
6
|
+
resize_with_padding,
|
7
|
+
rotate_text_image,
|
8
|
+
validate_quads,
|
9
|
+
)
|
10
|
+
|
11
|
+
|
12
|
+
class ParseqDataset(Dataset):
|
13
|
+
def __init__(self, cfg, img, quads):
|
14
|
+
self.img = img[:, :, ::-1]
|
15
|
+
self.quads = quads
|
16
|
+
self.cfg = cfg
|
17
|
+
self.img = img
|
18
|
+
self.transform = T.Compose(
|
19
|
+
[
|
20
|
+
T.ToTensor(),
|
21
|
+
T.Normalize(0.5, 0.5),
|
22
|
+
]
|
23
|
+
)
|
24
|
+
|
25
|
+
validate_quads(self.img, self.quads)
|
26
|
+
|
27
|
+
def __len__(self):
|
28
|
+
return len(self.quads)
|
29
|
+
|
30
|
+
def __getitem__(self, index):
|
31
|
+
polygon = self.quads[index]
|
32
|
+
roi_img = extract_roi_with_perspective(self.img, polygon)
|
33
|
+
if roi_img is None:
|
34
|
+
return
|
35
|
+
|
36
|
+
roi_img = rotate_text_image(roi_img, thresh_aspect=2)
|
37
|
+
resized = resize_with_padding(roi_img, self.cfg.data.img_size)
|
38
|
+
tensor = self.transform(resized)
|
39
|
+
|
40
|
+
return tensor
|
@@ -0,0 +1,279 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
3
|
+
import cv2
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
from pdf2image import convert_from_path
|
7
|
+
|
8
|
+
from ..constants import (
|
9
|
+
MIN_IMAGE_SIZE,
|
10
|
+
SUPPORT_INPUT_FORMAT,
|
11
|
+
WARNING_IMAGE_SIZE,
|
12
|
+
)
|
13
|
+
from ..utils.logger import set_logger
|
14
|
+
|
15
|
+
logger = set_logger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
def load_image(image_path: str) -> np.ndarray:
|
19
|
+
"""
|
20
|
+
Open an image file.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
image_path (str): path to the image file
|
24
|
+
|
25
|
+
Returns:
|
26
|
+
np.ndarray: image data(BGR)
|
27
|
+
"""
|
28
|
+
image_path = Path(image_path)
|
29
|
+
if not image_path.exists():
|
30
|
+
raise FileNotFoundError(f"File not found: {image_path}")
|
31
|
+
|
32
|
+
ext = image_path.suffix[1:].lower()
|
33
|
+
if ext not in SUPPORT_INPUT_FORMAT:
|
34
|
+
raise ValueError(
|
35
|
+
f"Unsupported image format. Supported formats are {SUPPORT_INPUT_FORMAT}"
|
36
|
+
)
|
37
|
+
|
38
|
+
if ext == "pdf":
|
39
|
+
raise ValueError(
|
40
|
+
"PDF file is not supported by load_image(). Use load_pdf() instead."
|
41
|
+
)
|
42
|
+
|
43
|
+
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
44
|
+
|
45
|
+
if img is None:
|
46
|
+
raise ValueError("Invalid image data.")
|
47
|
+
|
48
|
+
h, w = img.shape[:2]
|
49
|
+
if h < MIN_IMAGE_SIZE or w < MIN_IMAGE_SIZE:
|
50
|
+
raise ValueError("Image size is too small.")
|
51
|
+
|
52
|
+
if min(h, w) < WARNING_IMAGE_SIZE:
|
53
|
+
logger.warning(
|
54
|
+
"""
|
55
|
+
The image size is small, which may result in reduced OCR accuracy.
|
56
|
+
The process will continue, but it is recommended to input images with a minimum size of 720 pixels on the shorter side.
|
57
|
+
"""
|
58
|
+
)
|
59
|
+
|
60
|
+
return img
|
61
|
+
|
62
|
+
|
63
|
+
def load_pdf(pdf_path: str, dpi=200) -> list[np.ndarray]:
|
64
|
+
"""
|
65
|
+
Open a PDF file.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
pdf_path (str): path to the PDF file
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
list[np.ndarray]: list of image data(BGR)
|
72
|
+
"""
|
73
|
+
pdf_path = Path(pdf_path)
|
74
|
+
if not pdf_path.exists():
|
75
|
+
raise FileNotFoundError(f"File not found: {pdf_path}")
|
76
|
+
|
77
|
+
ext = pdf_path.suffix[1:].lower()
|
78
|
+
if ext not in SUPPORT_INPUT_FORMAT:
|
79
|
+
raise ValueError(
|
80
|
+
f"Unsupported image format. Supported formats are {SUPPORT_INPUT_FORMAT}"
|
81
|
+
)
|
82
|
+
|
83
|
+
if ext != "pdf":
|
84
|
+
raise ValueError(
|
85
|
+
"image file is not supported by load_pdf(). Use load_image() instead."
|
86
|
+
)
|
87
|
+
|
88
|
+
try:
|
89
|
+
images = convert_from_path(pdf_path, dpi=dpi)
|
90
|
+
except Exception as e:
|
91
|
+
raise ValueError(f"Failed to open the PDF file: {pdf_path}") from e
|
92
|
+
|
93
|
+
return [np.array(img)[:, :, ::-1] for img in images]
|
94
|
+
|
95
|
+
|
96
|
+
def resize_shortest_edge(
|
97
|
+
img: np.ndarray, shortest_edge_length: int, max_length: int
|
98
|
+
) -> np.ndarray:
|
99
|
+
"""
|
100
|
+
Resize the shortest edge of the image to `shortest_edge_length` while keeping the aspect ratio.
|
101
|
+
if the longest edge is longer than `max_length`, resize the longest edge to `max_length` while keeping the aspect ratio.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
img (np.ndarray): target image
|
105
|
+
shortest_edge_length (int): pixel length of the shortest edge after resizing
|
106
|
+
max_length (int): pixel length of maximum edge after resizing
|
107
|
+
|
108
|
+
Returns:
|
109
|
+
np.ndarray: resized image
|
110
|
+
"""
|
111
|
+
|
112
|
+
h, w = img.shape[:2]
|
113
|
+
scale = shortest_edge_length / min(h, w)
|
114
|
+
if h < w:
|
115
|
+
new_h, new_w = shortest_edge_length, int(w * scale)
|
116
|
+
else:
|
117
|
+
new_h, new_w = int(h * scale), shortest_edge_length
|
118
|
+
|
119
|
+
if max(new_h, new_w) > max_length:
|
120
|
+
scale = float(max_length) / max(new_h, new_w)
|
121
|
+
new_h, new_w = int(new_h * scale), int(new_w * scale)
|
122
|
+
|
123
|
+
neww = max(int(new_w / 32) * 32, 32)
|
124
|
+
newh = max(int(new_h / 32) * 32, 32)
|
125
|
+
|
126
|
+
img = cv2.resize(img, (neww, newh))
|
127
|
+
return img
|
128
|
+
|
129
|
+
|
130
|
+
def standardization_image(
|
131
|
+
img: np.ndarray, rgb=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
|
132
|
+
) -> np.ndarray:
|
133
|
+
"""
|
134
|
+
Normalize the image data.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
img (np.ndarray): target image
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
np.ndarray: normalized image
|
141
|
+
"""
|
142
|
+
img = img[:, :, ::-1]
|
143
|
+
img = img / 255.0
|
144
|
+
img = (img - np.array(rgb)) / np.array(std)
|
145
|
+
img = img.astype(np.float32)
|
146
|
+
|
147
|
+
return img
|
148
|
+
|
149
|
+
|
150
|
+
def array_to_tensor(img: np.ndarray) -> torch.Tensor:
|
151
|
+
"""
|
152
|
+
Convert the image data to tensor.
|
153
|
+
(H, W, C) -> (N, C, H, W)
|
154
|
+
|
155
|
+
Args:
|
156
|
+
img (np.ndarray): target image(H, W, C)
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
torch.Tensor: (N, C, H, W) tensor
|
160
|
+
"""
|
161
|
+
img = np.transpose(img, (2, 0, 1))
|
162
|
+
tensor = torch.as_tensor(img, dtype=torch.float)
|
163
|
+
tensor = tensor[None, :, :, :]
|
164
|
+
return tensor
|
165
|
+
|
166
|
+
|
167
|
+
def validate_quads(img: np.ndarray, quads: list[list[list[int]]]):
|
168
|
+
"""
|
169
|
+
Validate the vertices of the quadrilateral.
|
170
|
+
|
171
|
+
Args:
|
172
|
+
img (np.ndarray): target image
|
173
|
+
quads (list[list[list[int]]]): list of quadrilateral
|
174
|
+
|
175
|
+
Raises:
|
176
|
+
ValueError: if the vertices are invalid
|
177
|
+
"""
|
178
|
+
|
179
|
+
h, w = img.shape[:2]
|
180
|
+
for quad in quads:
|
181
|
+
if len(quad) != 4:
|
182
|
+
raise ValueError("The number of vertices must be 4.")
|
183
|
+
|
184
|
+
for point in quad:
|
185
|
+
if len(point) != 2:
|
186
|
+
raise ValueError("The number of coordinates must be 2.")
|
187
|
+
|
188
|
+
quad = np.array(quad, dtype=int)
|
189
|
+
x1 = np.min(quad[:, 0])
|
190
|
+
x2 = np.max(quad[:, 0])
|
191
|
+
y1 = np.min(quad[:, 1])
|
192
|
+
y2 = np.max(quad[:, 1])
|
193
|
+
h, w = img.shape[:2]
|
194
|
+
|
195
|
+
if x1 < 0 or x2 > w or y1 < 0 or y2 > h:
|
196
|
+
raise ValueError(
|
197
|
+
f"The vertices are out of the image. {quad.tolist()}"
|
198
|
+
)
|
199
|
+
|
200
|
+
return True
|
201
|
+
|
202
|
+
|
203
|
+
def extract_roi_with_perspective(img, quad):
|
204
|
+
"""
|
205
|
+
Extract the word image from the image with perspective transformation.
|
206
|
+
|
207
|
+
Args:
|
208
|
+
img (np.ndarray): target image
|
209
|
+
polygon (np.ndarray): polygon vertices
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
np.ndarray: extracted image
|
213
|
+
"""
|
214
|
+
dst = img.copy()
|
215
|
+
quad = np.array(quad, dtype=np.float32)
|
216
|
+
width = np.linalg.norm(quad[0] - quad[1])
|
217
|
+
height = np.linalg.norm(quad[1] - quad[2])
|
218
|
+
|
219
|
+
width = int(width)
|
220
|
+
height = int(height)
|
221
|
+
|
222
|
+
pts1 = np.float32(quad)
|
223
|
+
pts2 = np.float32([[0, 0], [width, 0], [width, height], [0, height]])
|
224
|
+
|
225
|
+
M = cv2.getPerspectiveTransform(pts1, pts2)
|
226
|
+
dst = cv2.warpPerspective(dst, M, (width, height))
|
227
|
+
|
228
|
+
return dst
|
229
|
+
|
230
|
+
|
231
|
+
def rotate_text_image(img, thresh_aspect=2):
|
232
|
+
"""
|
233
|
+
Rotate the image if the aspect ratio is too high.
|
234
|
+
|
235
|
+
Args:
|
236
|
+
img (np.ndarray): target image
|
237
|
+
thresh_aspect (int): threshold of aspect ratio
|
238
|
+
|
239
|
+
Returns:
|
240
|
+
np.ndarray: rotated image
|
241
|
+
"""
|
242
|
+
h, w = img.shape[:2]
|
243
|
+
if h > thresh_aspect * w:
|
244
|
+
img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
245
|
+
return img
|
246
|
+
|
247
|
+
|
248
|
+
def resize_with_padding(img, target_size, background_color=(0, 0, 0)):
|
249
|
+
"""
|
250
|
+
Resize the image with padding.
|
251
|
+
|
252
|
+
Args:
|
253
|
+
img (np.ndarray): target image
|
254
|
+
target_size (int, int): target size
|
255
|
+
background_color (Tuple[int, int, int]): background color
|
256
|
+
|
257
|
+
Returns:
|
258
|
+
np.ndarray: resized image
|
259
|
+
"""
|
260
|
+
h, w = img.shape[:2]
|
261
|
+
scale_w = 1.0
|
262
|
+
scale_h = 1.0
|
263
|
+
if w > target_size[1]:
|
264
|
+
scale_w = target_size[1] / w
|
265
|
+
if h > target_size[0]:
|
266
|
+
scale_h = target_size[0] / h
|
267
|
+
|
268
|
+
new_w = int(w * min(scale_w, scale_h))
|
269
|
+
new_h = int(h * min(scale_w, scale_h))
|
270
|
+
|
271
|
+
resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
272
|
+
|
273
|
+
canvas = np.zeros((target_size[0], target_size[1], 3), dtype=np.uint8)
|
274
|
+
canvas[:, :] = background_color
|
275
|
+
|
276
|
+
resized_size = resized.shape[:2]
|
277
|
+
canvas[: resized_size[0], : resized_size[1], :] = resized
|
278
|
+
|
279
|
+
return canvas
|
@@ -0,0 +1,315 @@
|
|
1
|
+
import asyncio
|
2
|
+
from concurrent.futures import ThreadPoolExecutor
|
3
|
+
from typing import List, Union
|
4
|
+
|
5
|
+
from pydantic import conlist
|
6
|
+
|
7
|
+
from .base import BaseSchema
|
8
|
+
from .export import export_csv, export_html, export_markdown
|
9
|
+
from .layout_analyzer import LayoutAnalyzer
|
10
|
+
from .ocr import OCR, WordPrediction
|
11
|
+
from .table_structure_recognizer import TableStructureRecognizerSchema
|
12
|
+
from .utils.misc import is_contained, quad_to_xyxy
|
13
|
+
from .reading_order import prediction_reading_order
|
14
|
+
|
15
|
+
from .utils.visualizer import reading_order_visualizer
|
16
|
+
|
17
|
+
|
18
|
+
class ParagraphSchema(BaseSchema):
|
19
|
+
box: conlist(int, min_length=4, max_length=4)
|
20
|
+
contents: Union[str, None]
|
21
|
+
direction: Union[str, None]
|
22
|
+
order: Union[int, None]
|
23
|
+
role: Union[str, None]
|
24
|
+
|
25
|
+
|
26
|
+
class FigureSchema(BaseSchema):
|
27
|
+
box: conlist(int, min_length=4, max_length=4)
|
28
|
+
order: Union[int, None]
|
29
|
+
paragraphs: List[ParagraphSchema]
|
30
|
+
order: Union[int, None]
|
31
|
+
direction: Union[str, None]
|
32
|
+
|
33
|
+
|
34
|
+
class DocumentAnalyzerSchema(BaseSchema):
|
35
|
+
paragraphs: List[ParagraphSchema]
|
36
|
+
tables: List[TableStructureRecognizerSchema]
|
37
|
+
words: List[WordPrediction]
|
38
|
+
figures: List[FigureSchema]
|
39
|
+
|
40
|
+
def to_html(self, out_path: str, **kwargs):
|
41
|
+
export_html(self, out_path, **kwargs)
|
42
|
+
|
43
|
+
def to_markdown(self, out_path: str, **kwargs):
|
44
|
+
export_markdown(self, out_path, **kwargs)
|
45
|
+
|
46
|
+
def to_csv(self, out_path: str, **kwargs):
|
47
|
+
export_csv(self, out_path, **kwargs)
|
48
|
+
|
49
|
+
|
50
|
+
def combine_flags(flag1, flag2):
|
51
|
+
return [f1 or f2 for f1, f2 in zip(flag1, flag2)]
|
52
|
+
|
53
|
+
|
54
|
+
def judge_page_direction(paragraphs):
|
55
|
+
h_sum_area = 0
|
56
|
+
v_sum_area = 0
|
57
|
+
|
58
|
+
for paragraph in paragraphs:
|
59
|
+
x1, y1, x2, y2 = paragraph.box
|
60
|
+
w = x2 - x1
|
61
|
+
h = y2 - y1
|
62
|
+
|
63
|
+
if paragraph.direction == "horizontal":
|
64
|
+
h_sum_area += w * h
|
65
|
+
else:
|
66
|
+
v_sum_area += w * h
|
67
|
+
|
68
|
+
if v_sum_area > h_sum_area:
|
69
|
+
return "vertical"
|
70
|
+
|
71
|
+
return "horizontal"
|
72
|
+
|
73
|
+
|
74
|
+
def extract_paragraph_within_figure(paragraphs, figures):
|
75
|
+
new_figures = []
|
76
|
+
check_list = [False] * len(paragraphs)
|
77
|
+
for figure in figures:
|
78
|
+
figure = {"box": figure.box, "order": 0}
|
79
|
+
contained_paragraphs = []
|
80
|
+
for i, paragraph in enumerate(paragraphs):
|
81
|
+
if is_contained(figure["box"], paragraph.box, threshold=0.7):
|
82
|
+
contained_paragraphs.append(paragraph)
|
83
|
+
check_list[i] = True
|
84
|
+
|
85
|
+
figure["direction"] = judge_page_direction(contained_paragraphs)
|
86
|
+
figure_paragraphs = prediction_reading_order(
|
87
|
+
contained_paragraphs, figure["direction"]
|
88
|
+
)
|
89
|
+
figure["paragraphs"] = sorted(figure_paragraphs, key=lambda x: x.order)
|
90
|
+
figure = FigureSchema(**figure)
|
91
|
+
new_figures.append(figure)
|
92
|
+
|
93
|
+
return new_figures, check_list
|
94
|
+
|
95
|
+
|
96
|
+
def extract_words_within_element(pred_words, element):
|
97
|
+
contained_words = []
|
98
|
+
word_sum_width = 0
|
99
|
+
word_sum_height = 0
|
100
|
+
check_list = [False] * len(pred_words)
|
101
|
+
for i, word in enumerate(pred_words):
|
102
|
+
word_box = quad_to_xyxy(word.points)
|
103
|
+
if is_contained(element.box, word_box, threshold=0.5):
|
104
|
+
contained_words.append(word)
|
105
|
+
word_sum_width += word_box[2] - word_box[0]
|
106
|
+
word_sum_height += word_box[3] - word_box[1]
|
107
|
+
check_list[i] = True
|
108
|
+
|
109
|
+
if len(contained_words) == 0:
|
110
|
+
return None, None, check_list
|
111
|
+
|
112
|
+
mean_width = word_sum_width / len(contained_words)
|
113
|
+
mean_height = word_sum_height / len(contained_words)
|
114
|
+
|
115
|
+
word_direction = [word.direction for word in contained_words]
|
116
|
+
cnt_horizontal = word_direction.count("horizontal")
|
117
|
+
cnt_vertical = word_direction.count("vertical")
|
118
|
+
|
119
|
+
element_direction = "horizontal" if cnt_horizontal > cnt_vertical else "vertical"
|
120
|
+
if element_direction == "horizontal":
|
121
|
+
contained_words = sorted(
|
122
|
+
contained_words,
|
123
|
+
key=lambda x: (
|
124
|
+
x.points[0][1] // int(mean_height),
|
125
|
+
x.points[0][0],
|
126
|
+
),
|
127
|
+
)
|
128
|
+
else:
|
129
|
+
contained_words = sorted(
|
130
|
+
contained_words,
|
131
|
+
key=lambda x: (
|
132
|
+
x.points[1][0] // int(mean_width),
|
133
|
+
x.points[1][1],
|
134
|
+
),
|
135
|
+
reverse=True,
|
136
|
+
)
|
137
|
+
|
138
|
+
contained_words = "\n".join([content.content for content in contained_words])
|
139
|
+
return (contained_words, element_direction, check_list)
|
140
|
+
|
141
|
+
|
142
|
+
def recursive_update(original, new_data):
|
143
|
+
for key, value in new_data.items():
|
144
|
+
# `value`が辞書の場合、再帰的に更新
|
145
|
+
if (
|
146
|
+
isinstance(value, dict)
|
147
|
+
and key in original
|
148
|
+
and isinstance(original[key], dict)
|
149
|
+
):
|
150
|
+
recursive_update(original[key], value)
|
151
|
+
# `value`が辞書でない場合、またはキーが存在しない場合に上書き
|
152
|
+
else:
|
153
|
+
original[key] = value
|
154
|
+
return original
|
155
|
+
|
156
|
+
|
157
|
+
class DocumentAnalyzer:
|
158
|
+
def __init__(self, configs=None, device="cuda", visualize=False):
|
159
|
+
default_configs = {
|
160
|
+
"ocr": {
|
161
|
+
"text_detector": {
|
162
|
+
"device": device,
|
163
|
+
"visualize": visualize,
|
164
|
+
},
|
165
|
+
"text_recognizer": {
|
166
|
+
"device": device,
|
167
|
+
"visualize": visualize,
|
168
|
+
},
|
169
|
+
},
|
170
|
+
"layout_analyzer": {
|
171
|
+
"layout_parser": {
|
172
|
+
"device": device,
|
173
|
+
"visualize": visualize,
|
174
|
+
},
|
175
|
+
"table_structure_recognizer": {
|
176
|
+
"device": device,
|
177
|
+
"visualize": visualize,
|
178
|
+
},
|
179
|
+
},
|
180
|
+
}
|
181
|
+
|
182
|
+
if isinstance(configs, dict):
|
183
|
+
recursive_update(default_configs, configs)
|
184
|
+
else:
|
185
|
+
raise ValueError(
|
186
|
+
"configs must be a dict. See the https://kotaro-kinoshita.github.io/yomitoku-dev/usage/"
|
187
|
+
)
|
188
|
+
|
189
|
+
self.ocr = OCR(configs=default_configs["ocr"])
|
190
|
+
self.layout = LayoutAnalyzer(configs=default_configs["layout_analyzer"])
|
191
|
+
self.visualize = visualize
|
192
|
+
|
193
|
+
def aggregate(self, ocr_res, layout_res):
|
194
|
+
paragraphs = []
|
195
|
+
check_list = [False] * len(ocr_res.words)
|
196
|
+
for table in layout_res.tables:
|
197
|
+
for cell in table.cells:
|
198
|
+
words, direction, flags = extract_words_within_element(
|
199
|
+
ocr_res.words, cell
|
200
|
+
)
|
201
|
+
|
202
|
+
if words is None:
|
203
|
+
words = ""
|
204
|
+
|
205
|
+
cell.contents = words
|
206
|
+
check_list = combine_flags(check_list, flags)
|
207
|
+
|
208
|
+
for paragraph in layout_res.paragraphs:
|
209
|
+
words, direction, flags = extract_words_within_element(
|
210
|
+
ocr_res.words, paragraph
|
211
|
+
)
|
212
|
+
|
213
|
+
if words is None:
|
214
|
+
continue
|
215
|
+
|
216
|
+
paragraph = {
|
217
|
+
"contents": words,
|
218
|
+
"box": paragraph.box,
|
219
|
+
"direction": direction,
|
220
|
+
"order": 0,
|
221
|
+
"role": paragraph.role,
|
222
|
+
}
|
223
|
+
|
224
|
+
check_list = combine_flags(check_list, flags)
|
225
|
+
paragraph = ParagraphSchema(**paragraph)
|
226
|
+
paragraphs.append(paragraph)
|
227
|
+
|
228
|
+
for i, word in enumerate(ocr_res.words):
|
229
|
+
direction = word.direction
|
230
|
+
if not check_list[i]:
|
231
|
+
paragraph = {
|
232
|
+
"contents": word.content,
|
233
|
+
"box": quad_to_xyxy(word.points),
|
234
|
+
"direction": direction,
|
235
|
+
"order": 0,
|
236
|
+
"role": None,
|
237
|
+
}
|
238
|
+
|
239
|
+
paragraph = ParagraphSchema(**paragraph)
|
240
|
+
paragraphs.append(paragraph)
|
241
|
+
|
242
|
+
figures, check_list = extract_paragraph_within_figure(
|
243
|
+
paragraphs, layout_res.figures
|
244
|
+
)
|
245
|
+
|
246
|
+
paragraphs = [
|
247
|
+
paragraph for paragraph, flag in zip(paragraphs, check_list) if not flag
|
248
|
+
]
|
249
|
+
|
250
|
+
page_direction = judge_page_direction(paragraphs)
|
251
|
+
|
252
|
+
headers = [
|
253
|
+
paragraph for paragraph in paragraphs if paragraph.role == "page_header"
|
254
|
+
]
|
255
|
+
|
256
|
+
footers = [
|
257
|
+
paragraph for paragraph in paragraphs if paragraph.role == "page_footer"
|
258
|
+
]
|
259
|
+
|
260
|
+
page_contents = [
|
261
|
+
paragraph
|
262
|
+
for paragraph in paragraphs
|
263
|
+
if paragraph.role is None or paragraph.role == "section_headings"
|
264
|
+
]
|
265
|
+
|
266
|
+
elements = page_contents + layout_res.tables + figures
|
267
|
+
|
268
|
+
prediction_reading_order(headers, page_direction)
|
269
|
+
prediction_reading_order(footers, page_direction)
|
270
|
+
prediction_reading_order(elements, page_direction, self.img)
|
271
|
+
|
272
|
+
for i, element in enumerate(elements):
|
273
|
+
element.order += len(headers)
|
274
|
+
for i, footer in enumerate(footers):
|
275
|
+
footer.order += len(elements) + len(headers)
|
276
|
+
|
277
|
+
paragraphs = headers + page_contents + footers
|
278
|
+
paragraphs = sorted(paragraphs, key=lambda x: x.order)
|
279
|
+
figures = sorted(figures, key=lambda x: x.order)
|
280
|
+
tables = sorted(layout_res.tables, key=lambda x: x.order)
|
281
|
+
|
282
|
+
outputs = {
|
283
|
+
"paragraphs": paragraphs,
|
284
|
+
"tables": tables,
|
285
|
+
"figures": figures,
|
286
|
+
"words": ocr_res.words,
|
287
|
+
}
|
288
|
+
|
289
|
+
return outputs
|
290
|
+
|
291
|
+
async def run(self, img):
|
292
|
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
293
|
+
loop = asyncio.get_running_loop()
|
294
|
+
tasks = [
|
295
|
+
loop.run_in_executor(executor, self.ocr, img),
|
296
|
+
loop.run_in_executor(executor, self.layout, img),
|
297
|
+
]
|
298
|
+
|
299
|
+
results = await asyncio.gather(*tasks)
|
300
|
+
|
301
|
+
results_ocr, ocr = results[0]
|
302
|
+
results_layout, layout = results[1]
|
303
|
+
|
304
|
+
outputs = self.aggregate(results_ocr, results_layout)
|
305
|
+
results = DocumentAnalyzerSchema(**outputs)
|
306
|
+
return results, ocr, layout
|
307
|
+
|
308
|
+
def __call__(self, img):
|
309
|
+
self.img = img
|
310
|
+
resutls, ocr, layout = asyncio.run(self.run(img))
|
311
|
+
|
312
|
+
if self.visualize:
|
313
|
+
layout = reading_order_visualizer(layout, resutls)
|
314
|
+
|
315
|
+
return resutls, ocr, layout
|