yomitoku 0.5.3__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yomitoku/cli/main.py +47 -1
- yomitoku/configs/__init__.py +2 -0
- yomitoku/configs/cfg_text_recognizer_parseq_small.py +51 -0
- yomitoku/document_analyzer.py +229 -26
- yomitoku/export/export_csv.py +39 -2
- yomitoku/export/export_html.py +2 -1
- yomitoku/export/export_json.py +40 -2
- yomitoku/export/export_markdown.py +2 -1
- yomitoku/layout_analyzer.py +1 -5
- yomitoku/layout_parser.py +58 -4
- yomitoku/models/layers/rtdetr_backbone.py +5 -15
- yomitoku/models/layers/rtdetr_hybrid_encoder.py +6 -18
- yomitoku/models/layers/rtdetrv2_decoder.py +17 -42
- yomitoku/models/parseq.py +9 -9
- yomitoku/ocr.py +24 -27
- yomitoku/onnx/.gitkeep +0 -0
- yomitoku/postprocessor/rtdetr_postprocessor.py +4 -13
- yomitoku/table_structure_recognizer.py +79 -9
- yomitoku/text_detector.py +57 -7
- yomitoku/text_recognizer.py +80 -16
- yomitoku/utils/misc.py +20 -13
- yomitoku/utils/visualizer.py +5 -5
- {yomitoku-0.5.3.dist-info → yomitoku-0.7.0.dist-info}/METADATA +21 -9
- {yomitoku-0.5.3.dist-info → yomitoku-0.7.0.dist-info}/RECORD +26 -24
- {yomitoku-0.5.3.dist-info → yomitoku-0.7.0.dist-info}/WHEEL +1 -1
- {yomitoku-0.5.3.dist-info → yomitoku-0.7.0.dist-info}/entry_points.txt +0 -0
yomitoku/cli/main.py
CHANGED
@@ -13,6 +13,18 @@ from ..utils.logger import set_logger
|
|
13
13
|
logger = set_logger(__name__, "INFO")
|
14
14
|
|
15
15
|
|
16
|
+
def validate_encoding(encoding):
|
17
|
+
if encoding not in [
|
18
|
+
"utf-8",
|
19
|
+
"utf-8-sig",
|
20
|
+
"shift-jis",
|
21
|
+
"euc-jp",
|
22
|
+
"cp932",
|
23
|
+
]:
|
24
|
+
raise ValueError(f"Invalid encoding: {encoding}")
|
25
|
+
return True
|
26
|
+
|
27
|
+
|
16
28
|
def process_single_file(args, analyzer, path, format):
|
17
29
|
if path.suffix[1:].lower() in ["pdf"]:
|
18
30
|
imgs = load_pdf(path)
|
@@ -21,7 +33,6 @@ def process_single_file(args, analyzer, path, format):
|
|
21
33
|
|
22
34
|
for page, img in enumerate(imgs):
|
23
35
|
results, ocr, layout = analyzer(img)
|
24
|
-
|
25
36
|
dirname = path.parent.name
|
26
37
|
filename = path.stem
|
27
38
|
|
@@ -47,11 +58,19 @@ def process_single_file(args, analyzer, path, format):
|
|
47
58
|
results.to_json(
|
48
59
|
out_path,
|
49
60
|
ignore_line_break=args.ignore_line_break,
|
61
|
+
encoding=args.encoding,
|
62
|
+
img=img,
|
63
|
+
export_figure=args.figure,
|
64
|
+
figure_dir=args.figure_dir,
|
50
65
|
)
|
51
66
|
elif format == "csv":
|
52
67
|
results.to_csv(
|
53
68
|
out_path,
|
54
69
|
ignore_line_break=args.ignore_line_break,
|
70
|
+
encoding=args.encoding,
|
71
|
+
img=img,
|
72
|
+
export_figure=args.figure,
|
73
|
+
figure_dir=args.figure_dir,
|
55
74
|
)
|
56
75
|
elif format == "html":
|
57
76
|
results.to_html(
|
@@ -62,6 +81,7 @@ def process_single_file(args, analyzer, path, format):
|
|
62
81
|
export_figure_letter=args.figure_letter,
|
63
82
|
figure_width=args.figure_width,
|
64
83
|
figure_dir=args.figure_dir,
|
84
|
+
encoding=args.encoding,
|
65
85
|
)
|
66
86
|
elif format == "md":
|
67
87
|
results.to_markdown(
|
@@ -72,6 +92,7 @@ def process_single_file(args, analyzer, path, format):
|
|
72
92
|
export_figure_letter=args.figure_letter,
|
73
93
|
figure_width=args.figure_width,
|
74
94
|
figure_dir=args.figure_dir,
|
95
|
+
encoding=args.encoding,
|
75
96
|
)
|
76
97
|
|
77
98
|
logger.info(f"Output file: {out_path}")
|
@@ -104,6 +125,12 @@ def main():
|
|
104
125
|
default="results",
|
105
126
|
help="output directory",
|
106
127
|
)
|
128
|
+
parser.add_argument(
|
129
|
+
"-l",
|
130
|
+
"--lite",
|
131
|
+
action="store_true",
|
132
|
+
help="if set, use lite model",
|
133
|
+
)
|
107
134
|
parser.add_argument(
|
108
135
|
"-d",
|
109
136
|
"--device",
|
@@ -162,6 +189,12 @@ def main():
|
|
162
189
|
default="figures",
|
163
190
|
help="directory to save figure images",
|
164
191
|
)
|
192
|
+
parser.add_argument(
|
193
|
+
"--encoding",
|
194
|
+
type=str,
|
195
|
+
default="utf-8",
|
196
|
+
help="Specifies the character encoding for the output file to be exported. If unsupported characters are included, they will be ignored.",
|
197
|
+
)
|
165
198
|
|
166
199
|
args = parser.parse_args()
|
167
200
|
|
@@ -175,6 +208,8 @@ def main():
|
|
175
208
|
f"Invalid output format: {args.format}. Supported formats are {SUPPORT_OUTPUT_FORMAT}"
|
176
209
|
)
|
177
210
|
|
211
|
+
validate_encoding(args.encoding)
|
212
|
+
|
178
213
|
if format == "markdown":
|
179
214
|
format = "md"
|
180
215
|
|
@@ -197,6 +232,17 @@ def main():
|
|
197
232
|
},
|
198
233
|
}
|
199
234
|
|
235
|
+
if args.lite:
|
236
|
+
configs["ocr"]["text_recognizer"]["model_name"] = "parseq-small"
|
237
|
+
|
238
|
+
if args.device == "cpu":
|
239
|
+
configs["ocr"]["text_detector"]["infer_onnx"] = True
|
240
|
+
|
241
|
+
# Note: Text Detector以外はONNX推論よりもPyTorch推論の方が速いため、ONNX推論は行わない
|
242
|
+
# configs["ocr"]["text_recognizer"]["infer_onnx"] = True
|
243
|
+
# configs["layout_analyzer"]["table_structure_recognizer"]["infer_onnx"] = True
|
244
|
+
# configs["layout_analyzer"]["layout_parser"]["infer_onnx"] = True
|
245
|
+
|
200
246
|
analyzer = DocumentAnalyzer(
|
201
247
|
configs=configs,
|
202
248
|
visualize=args.vis,
|
yomitoku/configs/__init__.py
CHANGED
@@ -4,10 +4,12 @@ from .cfg_table_structure_recognizer_rtdtrv2 import (
|
|
4
4
|
)
|
5
5
|
from .cfg_text_detector_dbnet import TextDetectorDBNetConfig
|
6
6
|
from .cfg_text_recognizer_parseq import TextRecognizerPARSeqConfig
|
7
|
+
from .cfg_text_recognizer_parseq_small import TextRecognizerPARSeqSmallConfig
|
7
8
|
|
8
9
|
__all__ = [
|
9
10
|
"TextDetectorDBNetConfig",
|
10
11
|
"TextRecognizerPARSeqConfig",
|
11
12
|
"LayoutParserRTDETRv2Config",
|
12
13
|
"TableStructureRecognizerRTDETRv2Config",
|
14
|
+
"TextRecognizerPARSeqSmallConfig",
|
13
15
|
]
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
from ..constants import ROOT_DIR
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class Data:
|
9
|
+
num_workers: int = 4
|
10
|
+
batch_size: int = 128
|
11
|
+
img_size: List[int] = field(default_factory=lambda: [32, 800])
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class Encoder:
|
16
|
+
patch_size: List[int] = field(default_factory=lambda: [16, 16])
|
17
|
+
num_heads: int = 8
|
18
|
+
embed_dim: int = 384
|
19
|
+
mlp_ratio: int = 4
|
20
|
+
depth: int = 9
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass
|
24
|
+
class Decoder:
|
25
|
+
embed_dim: int = 384
|
26
|
+
num_heads: int = 8
|
27
|
+
mlp_ratio: int = 4
|
28
|
+
depth: int = 1
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class Visualize:
|
33
|
+
font: str = str(ROOT_DIR + "/resource/MPLUS1p-Medium.ttf")
|
34
|
+
color: List[int] = field(default_factory=lambda: [0, 0, 255]) # RGB
|
35
|
+
font_size: int = 18
|
36
|
+
|
37
|
+
|
38
|
+
@dataclass
|
39
|
+
class TextRecognizerPARSeqSmallConfig:
|
40
|
+
hf_hub_repo: str = "KotaroKinoshita/yomitoku-text-recognizer-parseq-small-open-beta"
|
41
|
+
charset: str = str(ROOT_DIR + "/resource/charset.txt")
|
42
|
+
num_tokens: int = 7312
|
43
|
+
max_label_length: int = 100
|
44
|
+
decode_ar: int = 1
|
45
|
+
refine_iters: int = 1
|
46
|
+
|
47
|
+
data: Data = field(default_factory=Data)
|
48
|
+
encoder: Encoder = field(default_factory=Encoder)
|
49
|
+
decoder: Decoder = field(default_factory=Decoder)
|
50
|
+
|
51
|
+
visualize: Visualize = field(default_factory=Visualize)
|
yomitoku/document_analyzer.py
CHANGED
@@ -2,17 +2,26 @@ import asyncio
|
|
2
2
|
from concurrent.futures import ThreadPoolExecutor
|
3
3
|
from typing import List, Union
|
4
4
|
|
5
|
+
import numpy as np
|
6
|
+
|
5
7
|
from pydantic import conlist
|
6
8
|
|
7
9
|
from .base import BaseSchema
|
8
10
|
from .export import export_csv, export_html, export_markdown
|
9
11
|
from .layout_analyzer import LayoutAnalyzer
|
10
|
-
from .ocr import
|
11
|
-
from .table_structure_recognizer import TableStructureRecognizerSchema
|
12
|
-
from .utils.misc import is_contained, quad_to_xyxy
|
12
|
+
from .ocr import OCRSchema, WordPrediction, ocr_aggregate
|
13
13
|
from .reading_order import prediction_reading_order
|
14
|
-
|
14
|
+
from .table_structure_recognizer import TableStructureRecognizerSchema
|
15
|
+
from .utils.misc import (
|
16
|
+
is_contained,
|
17
|
+
quad_to_xyxy,
|
18
|
+
calc_overlap_ratio,
|
19
|
+
)
|
15
20
|
from .utils.visualizer import reading_order_visualizer
|
21
|
+
from yomitoku.text_detector import TextDetector
|
22
|
+
from yomitoku.text_recognizer import TextRecognizer
|
23
|
+
|
24
|
+
from .utils.visualizer import det_visualizer
|
16
25
|
|
17
26
|
|
18
27
|
class ParagraphSchema(BaseSchema):
|
@@ -98,41 +107,57 @@ def extract_words_within_element(pred_words, element):
|
|
98
107
|
word_sum_width = 0
|
99
108
|
word_sum_height = 0
|
100
109
|
check_list = [False] * len(pred_words)
|
110
|
+
|
101
111
|
for i, word in enumerate(pred_words):
|
102
112
|
word_box = quad_to_xyxy(word.points)
|
103
113
|
if is_contained(element.box, word_box, threshold=0.5):
|
104
|
-
contained_words.append(word)
|
105
114
|
word_sum_width += word_box[2] - word_box[0]
|
106
115
|
word_sum_height += word_box[3] - word_box[1]
|
107
116
|
check_list[i] = True
|
108
117
|
|
118
|
+
word_element = ParagraphSchema(
|
119
|
+
box=word_box,
|
120
|
+
contents=word.content,
|
121
|
+
direction=word.direction,
|
122
|
+
order=0,
|
123
|
+
role=None,
|
124
|
+
)
|
125
|
+
contained_words.append(word_element)
|
126
|
+
|
109
127
|
if len(contained_words) == 0:
|
110
128
|
return None, None, check_list
|
111
129
|
|
112
|
-
|
113
|
-
# mean_height = word_sum_height / len(contained_words)
|
114
|
-
|
130
|
+
element_direction = "horizontal"
|
115
131
|
word_direction = [word.direction for word in contained_words]
|
116
132
|
cnt_horizontal = word_direction.count("horizontal")
|
117
133
|
cnt_vertical = word_direction.count("vertical")
|
118
134
|
|
119
135
|
element_direction = "horizontal" if cnt_horizontal > cnt_vertical else "vertical"
|
120
|
-
if element_direction == "horizontal":
|
121
|
-
contained_words = sorted(
|
122
|
-
contained_words,
|
123
|
-
key=lambda x: (sum([p[1] for p in x.points]) / 4),
|
124
|
-
)
|
125
|
-
else:
|
126
|
-
contained_words = sorted(
|
127
|
-
contained_words,
|
128
|
-
key=lambda x: (sum([p[0] for p in x.points]) / 4),
|
129
|
-
reverse=True,
|
130
|
-
)
|
131
136
|
|
132
|
-
contained_words
|
137
|
+
prediction_reading_order(contained_words, element_direction)
|
138
|
+
contained_words = sorted(contained_words, key=lambda x: x.order)
|
139
|
+
|
140
|
+
contained_words = "\n".join([content.contents for content in contained_words])
|
141
|
+
|
133
142
|
return (contained_words, element_direction, check_list)
|
134
143
|
|
135
144
|
|
145
|
+
def is_vertical(quad, thresh_aspect=2):
|
146
|
+
quad = np.array(quad)
|
147
|
+
width = np.linalg.norm(quad[0] - quad[1])
|
148
|
+
height = np.linalg.norm(quad[1] - quad[2])
|
149
|
+
|
150
|
+
return height > width * thresh_aspect
|
151
|
+
|
152
|
+
|
153
|
+
def is_noise(quad, thresh=15):
|
154
|
+
quad = np.array(quad)
|
155
|
+
width = np.linalg.norm(quad[0] - quad[1])
|
156
|
+
height = np.linalg.norm(quad[1] - quad[2])
|
157
|
+
|
158
|
+
return width < thresh or height < thresh
|
159
|
+
|
160
|
+
|
136
161
|
def recursive_update(original, new_data):
|
137
162
|
for key, value in new_data.items():
|
138
163
|
# `value`が辞書の場合、再帰的に更新
|
@@ -148,8 +173,163 @@ def recursive_update(original, new_data):
|
|
148
173
|
return original
|
149
174
|
|
150
175
|
|
176
|
+
def _extract_words_within_table(words, table, check_list):
|
177
|
+
horizontal_words = []
|
178
|
+
vertical_words = []
|
179
|
+
|
180
|
+
for i, (points, score) in enumerate(zip(words.points, words.scores)):
|
181
|
+
word_box = quad_to_xyxy(points)
|
182
|
+
if is_contained(table.box, word_box, threshold=0.5):
|
183
|
+
if is_vertical(points):
|
184
|
+
vertical_words.append({"points": points, "score": score})
|
185
|
+
else:
|
186
|
+
horizontal_words.append({"points": points, "score": score})
|
187
|
+
|
188
|
+
check_list[i] = True
|
189
|
+
|
190
|
+
return (horizontal_words, vertical_words, check_list)
|
191
|
+
|
192
|
+
|
193
|
+
def _calc_overlap_words_on_lines(lines, words):
|
194
|
+
overlap_ratios = [[0 for _ in lines] for _ in words]
|
195
|
+
|
196
|
+
for i, word in enumerate(words):
|
197
|
+
word_box = quad_to_xyxy(word["points"])
|
198
|
+
for j, row in enumerate(lines):
|
199
|
+
overlap_ratio, _ = calc_overlap_ratio(
|
200
|
+
row.box,
|
201
|
+
word_box,
|
202
|
+
)
|
203
|
+
overlap_ratios[i][j] = overlap_ratio
|
204
|
+
|
205
|
+
return overlap_ratios
|
206
|
+
|
207
|
+
|
208
|
+
def _correct_vertical_word_boxes(overlap_ratios_vertical, table, table_words_vertical):
|
209
|
+
allocated_cols = [cols.index(max(cols)) for cols in overlap_ratios_vertical]
|
210
|
+
|
211
|
+
new_points = []
|
212
|
+
new_scores = []
|
213
|
+
for i, col_index in enumerate(allocated_cols):
|
214
|
+
col_cells = []
|
215
|
+
for cell in table.cells:
|
216
|
+
if cell.col <= (col_index + 1) < (cell.col + cell.col_span):
|
217
|
+
col_cells.append(cell)
|
218
|
+
|
219
|
+
word_point = table_words_vertical[i]["points"]
|
220
|
+
word_score = table_words_vertical[i]["score"]
|
221
|
+
|
222
|
+
for cell in col_cells:
|
223
|
+
word_box = quad_to_xyxy(word_point)
|
224
|
+
|
225
|
+
_, intersection = calc_overlap_ratio(
|
226
|
+
cell.box,
|
227
|
+
word_box,
|
228
|
+
)
|
229
|
+
|
230
|
+
if intersection is not None:
|
231
|
+
_, y1, _, y2 = intersection
|
232
|
+
|
233
|
+
new_point = [
|
234
|
+
[word_point[0][0], max(word_point[0][1], y1)],
|
235
|
+
[word_point[1][0], max(word_point[1][1], y1)],
|
236
|
+
[word_point[2][0], min(word_point[2][1], y2)],
|
237
|
+
[word_point[3][0], min(word_point[3][1], y2)],
|
238
|
+
]
|
239
|
+
|
240
|
+
if not is_noise(new_point):
|
241
|
+
new_points.append(new_point)
|
242
|
+
new_scores.append(word_score)
|
243
|
+
|
244
|
+
return new_points, new_scores
|
245
|
+
|
246
|
+
|
247
|
+
def _correct_horizontal_word_boxes(
|
248
|
+
overlap_ratios_horizontal, table, table_words_horizontal
|
249
|
+
):
|
250
|
+
allocated_rows = [rows.index(max(rows)) for rows in overlap_ratios_horizontal]
|
251
|
+
|
252
|
+
new_points = []
|
253
|
+
new_scores = []
|
254
|
+
for i, row_index in enumerate(allocated_rows):
|
255
|
+
row_cells = []
|
256
|
+
for cell in table.cells:
|
257
|
+
if cell.row <= (row_index + 1) < (cell.row + cell.row_span):
|
258
|
+
row_cells.append(cell)
|
259
|
+
|
260
|
+
word_point = table_words_horizontal[i]["points"]
|
261
|
+
word_score = table_words_horizontal[i]["score"]
|
262
|
+
|
263
|
+
for cell in row_cells:
|
264
|
+
word_box = quad_to_xyxy(word_point)
|
265
|
+
|
266
|
+
_, intersection = calc_overlap_ratio(
|
267
|
+
cell.box,
|
268
|
+
word_box,
|
269
|
+
)
|
270
|
+
|
271
|
+
if intersection is not None:
|
272
|
+
x1, _, x2, _ = intersection
|
273
|
+
|
274
|
+
new_point = [
|
275
|
+
[max(word_point[0][0], x1), word_point[0][1]],
|
276
|
+
[min(word_point[1][0], x2), word_point[1][1]],
|
277
|
+
[min(word_point[2][0], x2), word_point[2][1]],
|
278
|
+
[max(word_point[3][0], x1), word_point[3][1]],
|
279
|
+
]
|
280
|
+
|
281
|
+
if not is_noise(new_point):
|
282
|
+
new_points.append(new_point)
|
283
|
+
new_scores.append(word_score)
|
284
|
+
|
285
|
+
return new_points, new_scores
|
286
|
+
|
287
|
+
|
288
|
+
def _split_text_across_cells(results_det, results_layout):
|
289
|
+
check_list = [False] * len(results_det.points)
|
290
|
+
new_points = []
|
291
|
+
new_scores = []
|
292
|
+
for table in results_layout.tables:
|
293
|
+
table_words_horizontal, table_words_vertical, check_list = (
|
294
|
+
_extract_words_within_table(results_det, table, check_list)
|
295
|
+
)
|
296
|
+
|
297
|
+
overlap_ratios_horizontal = _calc_overlap_words_on_lines(
|
298
|
+
table.rows,
|
299
|
+
table_words_horizontal,
|
300
|
+
)
|
301
|
+
|
302
|
+
overlap_ratios_vertical = _calc_overlap_words_on_lines(
|
303
|
+
table.cols,
|
304
|
+
table_words_vertical,
|
305
|
+
)
|
306
|
+
|
307
|
+
new_points_horizontal, new_scores_horizontal = _correct_horizontal_word_boxes(
|
308
|
+
overlap_ratios_horizontal, table, table_words_horizontal
|
309
|
+
)
|
310
|
+
|
311
|
+
new_points_vertical, new_scores_vertical = _correct_vertical_word_boxes(
|
312
|
+
overlap_ratios_vertical, table, table_words_vertical
|
313
|
+
)
|
314
|
+
|
315
|
+
new_points.extend(new_points_horizontal)
|
316
|
+
new_scores.extend(new_scores_horizontal)
|
317
|
+
new_points.extend(new_points_vertical)
|
318
|
+
new_scores.extend(new_scores_vertical)
|
319
|
+
|
320
|
+
for i, flag in enumerate(check_list):
|
321
|
+
if not flag:
|
322
|
+
new_points.append(results_det.points[i])
|
323
|
+
new_scores.append(results_det.scores[i])
|
324
|
+
|
325
|
+
results_det.points = new_points
|
326
|
+
results_det.scores = new_scores
|
327
|
+
|
328
|
+
return results_det
|
329
|
+
|
330
|
+
|
151
331
|
class DocumentAnalyzer:
|
152
|
-
def __init__(self, configs=
|
332
|
+
def __init__(self, configs={}, device="cuda", visualize=False):
|
153
333
|
default_configs = {
|
154
334
|
"ocr": {
|
155
335
|
"text_detector": {
|
@@ -180,8 +360,16 @@ class DocumentAnalyzer:
|
|
180
360
|
"configs must be a dict. See the https://kotaro-kinoshita.github.io/yomitoku-dev/usage/"
|
181
361
|
)
|
182
362
|
|
183
|
-
self.
|
184
|
-
|
363
|
+
self.text_detector = TextDetector(
|
364
|
+
**default_configs["ocr"]["text_detector"],
|
365
|
+
)
|
366
|
+
self.text_recognizer = TextRecognizer(
|
367
|
+
**default_configs["ocr"]["text_recognizer"]
|
368
|
+
)
|
369
|
+
|
370
|
+
self.layout = LayoutAnalyzer(
|
371
|
+
configs=default_configs["layout_analyzer"],
|
372
|
+
)
|
185
373
|
self.visualize = visualize
|
186
374
|
|
187
375
|
def aggregate(self, ocr_res, layout_res):
|
@@ -286,16 +474,31 @@ class DocumentAnalyzer:
|
|
286
474
|
with ThreadPoolExecutor(max_workers=2) as executor:
|
287
475
|
loop = asyncio.get_running_loop()
|
288
476
|
tasks = [
|
289
|
-
loop.run_in_executor(executor, self.ocr, img),
|
477
|
+
# loop.run_in_executor(executor, self.ocr, img),
|
478
|
+
loop.run_in_executor(executor, self.text_detector, img),
|
290
479
|
loop.run_in_executor(executor, self.layout, img),
|
291
480
|
]
|
292
481
|
|
293
482
|
results = await asyncio.gather(*tasks)
|
294
483
|
|
295
|
-
|
484
|
+
results_det, _ = results[0]
|
296
485
|
results_layout, layout = results[1]
|
297
486
|
|
298
|
-
|
487
|
+
results_det = _split_text_across_cells(results_det, results_layout)
|
488
|
+
|
489
|
+
vis_det = None
|
490
|
+
if self.visualize:
|
491
|
+
vis_det = det_visualizer(
|
492
|
+
img,
|
493
|
+
results_det.points,
|
494
|
+
)
|
495
|
+
|
496
|
+
results_rec, ocr = self.text_recognizer(img, results_det.points, vis_det)
|
497
|
+
|
498
|
+
outputs = {"words": ocr_aggregate(results_det, results_rec)}
|
499
|
+
results_ocr = OCRSchema(**outputs)
|
500
|
+
outputs = self.aggregate(results_ocr, results_layout)
|
501
|
+
|
299
502
|
results = DocumentAnalyzerSchema(**outputs)
|
300
503
|
return results, ocr, layout
|
301
504
|
|
yomitoku/export/export_csv.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1
1
|
import csv
|
2
|
+
import cv2
|
3
|
+
import os
|
2
4
|
|
3
5
|
|
4
6
|
def table_to_csv(table, ignore_line_break):
|
@@ -33,7 +35,34 @@ def paragraph_to_csv(paragraph, ignore_line_break):
|
|
33
35
|
return contents
|
34
36
|
|
35
37
|
|
36
|
-
def
|
38
|
+
def save_figure(
|
39
|
+
figures,
|
40
|
+
img,
|
41
|
+
out_path,
|
42
|
+
figure_dir="figures",
|
43
|
+
):
|
44
|
+
for i, figure in enumerate(figures):
|
45
|
+
x1, y1, x2, y2 = map(int, figure.box)
|
46
|
+
figure_img = img[y1:y2, x1:x2, :]
|
47
|
+
save_dir = os.path.dirname(out_path)
|
48
|
+
save_dir = os.path.join(save_dir, figure_dir)
|
49
|
+
os.makedirs(save_dir, exist_ok=True)
|
50
|
+
|
51
|
+
filename = os.path.splitext(os.path.basename(out_path))[0]
|
52
|
+
figure_name = f"{filename}_figure_{i}.png"
|
53
|
+
figure_path = os.path.join(save_dir, figure_name)
|
54
|
+
cv2.imwrite(figure_path, figure_img)
|
55
|
+
|
56
|
+
|
57
|
+
def export_csv(
|
58
|
+
inputs,
|
59
|
+
out_path: str,
|
60
|
+
ignore_line_break: bool = False,
|
61
|
+
encoding: str = "utf-8",
|
62
|
+
img=None,
|
63
|
+
export_figure: bool = True,
|
64
|
+
figure_dir="figures",
|
65
|
+
):
|
37
66
|
elements = []
|
38
67
|
for table in inputs.tables:
|
39
68
|
table_csv = table_to_csv(table, ignore_line_break)
|
@@ -58,9 +87,17 @@ def export_csv(inputs, out_path: str, ignore_line_break: bool = False):
|
|
58
87
|
}
|
59
88
|
)
|
60
89
|
|
90
|
+
if export_figure:
|
91
|
+
save_figure(
|
92
|
+
inputs.figures,
|
93
|
+
img,
|
94
|
+
out_path,
|
95
|
+
figure_dir=figure_dir,
|
96
|
+
)
|
97
|
+
|
61
98
|
elements = sorted(elements, key=lambda x: x["order"])
|
62
99
|
|
63
|
-
with open(out_path, "w", newline="", encoding="
|
100
|
+
with open(out_path, "w", newline="", encoding=encoding, errors="ignore") as f:
|
64
101
|
writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
|
65
102
|
for element in elements:
|
66
103
|
if element["type"] == "table":
|
yomitoku/export/export_html.py
CHANGED
@@ -154,6 +154,7 @@ def export_html(
|
|
154
154
|
img=None,
|
155
155
|
figure_width=200,
|
156
156
|
figure_dir="figures",
|
157
|
+
encoding: str = "utf-8",
|
157
158
|
):
|
158
159
|
html_string = ""
|
159
160
|
elements = []
|
@@ -184,5 +185,5 @@ def export_html(
|
|
184
185
|
parsed_html = html.fromstring(html_string)
|
185
186
|
formatted_html = etree.tostring(parsed_html, pretty_print=True, encoding="unicode")
|
186
187
|
|
187
|
-
with open(out_path, "w", encoding="
|
188
|
+
with open(out_path, "w", encoding=encoding, errors="ignore") as f:
|
188
189
|
f.write(formatted_html)
|
yomitoku/export/export_json.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1
1
|
import json
|
2
2
|
|
3
|
+
import cv2
|
4
|
+
import os
|
5
|
+
|
3
6
|
|
4
7
|
def paragraph_to_json(paragraph, ignore_line_break):
|
5
8
|
if ignore_line_break:
|
@@ -12,7 +15,34 @@ def table_to_json(table, ignore_line_break):
|
|
12
15
|
cell.contents = cell.contents.replace("\n", "")
|
13
16
|
|
14
17
|
|
15
|
-
def
|
18
|
+
def save_figure(
|
19
|
+
figures,
|
20
|
+
img,
|
21
|
+
out_path,
|
22
|
+
figure_dir="figures",
|
23
|
+
):
|
24
|
+
for i, figure in enumerate(figures):
|
25
|
+
x1, y1, x2, y2 = map(int, figure.box)
|
26
|
+
figure_img = img[y1:y2, x1:x2, :]
|
27
|
+
save_dir = os.path.dirname(out_path)
|
28
|
+
save_dir = os.path.join(save_dir, figure_dir)
|
29
|
+
os.makedirs(save_dir, exist_ok=True)
|
30
|
+
|
31
|
+
filename = os.path.splitext(os.path.basename(out_path))[0]
|
32
|
+
figure_name = f"{filename}_figure_{i}.png"
|
33
|
+
figure_path = os.path.join(save_dir, figure_name)
|
34
|
+
cv2.imwrite(figure_path, figure_img)
|
35
|
+
|
36
|
+
|
37
|
+
def export_json(
|
38
|
+
inputs,
|
39
|
+
out_path,
|
40
|
+
ignore_line_break=False,
|
41
|
+
encoding: str = "utf-8",
|
42
|
+
img=None,
|
43
|
+
export_figure=False,
|
44
|
+
figure_dir="figures",
|
45
|
+
):
|
16
46
|
from yomitoku.document_analyzer import DocumentAnalyzerSchema
|
17
47
|
|
18
48
|
if isinstance(inputs, DocumentAnalyzerSchema):
|
@@ -23,7 +53,15 @@ def export_json(inputs, out_path, ignore_line_break=False):
|
|
23
53
|
for paragraph in inputs.paragraphs:
|
24
54
|
paragraph_to_json(paragraph, ignore_line_break)
|
25
55
|
|
26
|
-
|
56
|
+
if export_figure:
|
57
|
+
save_figure(
|
58
|
+
inputs.figures,
|
59
|
+
img,
|
60
|
+
out_path,
|
61
|
+
figure_dir=figure_dir,
|
62
|
+
)
|
63
|
+
|
64
|
+
with open(out_path, "w", encoding=encoding, errors="ignore") as f:
|
27
65
|
json.dump(
|
28
66
|
inputs.model_dump(),
|
29
67
|
f,
|
@@ -117,6 +117,7 @@ def export_markdown(
|
|
117
117
|
export_figure=True,
|
118
118
|
figure_width=200,
|
119
119
|
figure_dir="figures",
|
120
|
+
encoding: str = "utf-8",
|
120
121
|
):
|
121
122
|
elements = []
|
122
123
|
for table in inputs.tables:
|
@@ -141,5 +142,5 @@ def export_markdown(
|
|
141
142
|
elements = sorted(elements, key=lambda x: x["order"])
|
142
143
|
markdown = "\n".join([element["md"] for element in elements])
|
143
144
|
|
144
|
-
with open(out_path, "w", encoding="
|
145
|
+
with open(out_path, "w", encoding=encoding, errors="ignore") as f:
|
145
146
|
f.write(markdown)
|
yomitoku/layout_analyzer.py
CHANGED
@@ -15,7 +15,7 @@ class LayoutAnalyzerSchema(BaseSchema):
|
|
15
15
|
|
16
16
|
|
17
17
|
class LayoutAnalyzer:
|
18
|
-
def __init__(self, configs=
|
18
|
+
def __init__(self, configs={}, device="cuda", visualize=False):
|
19
19
|
layout_parser_kwargs = {
|
20
20
|
"device": device,
|
21
21
|
"visualize": visualize,
|
@@ -26,10 +26,6 @@ class LayoutAnalyzer:
|
|
26
26
|
}
|
27
27
|
|
28
28
|
if isinstance(configs, dict):
|
29
|
-
assert (
|
30
|
-
"layout_parser" in configs or "table_structure_recognizer" in configs
|
31
|
-
), "Invalid config key. Please check the config keys."
|
32
|
-
|
33
29
|
if "layout_parser" in configs:
|
34
30
|
layout_parser_kwargs.update(configs["layout_parser"])
|
35
31
|
|