yomitoku 0.4.0.post1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yomitoku/__init__.py +20 -0
- yomitoku/base.py +136 -0
- yomitoku/cli/__init__.py +0 -0
- yomitoku/cli/main.py +230 -0
- yomitoku/configs/__init__.py +13 -0
- yomitoku/configs/cfg_layout_parser_rtdtrv2.py +89 -0
- yomitoku/configs/cfg_table_structure_recognizer_rtdtrv2.py +80 -0
- yomitoku/configs/cfg_text_detector_dbnet.py +49 -0
- yomitoku/configs/cfg_text_recognizer_parseq.py +51 -0
- yomitoku/constants.py +32 -0
- yomitoku/data/__init__.py +3 -0
- yomitoku/data/dataset.py +40 -0
- yomitoku/data/functions.py +279 -0
- yomitoku/document_analyzer.py +315 -0
- yomitoku/export/__init__.py +6 -0
- yomitoku/export/export_csv.py +71 -0
- yomitoku/export/export_html.py +188 -0
- yomitoku/export/export_json.py +34 -0
- yomitoku/export/export_markdown.py +145 -0
- yomitoku/layout_analyzer.py +66 -0
- yomitoku/layout_parser.py +189 -0
- yomitoku/models/__init__.py +9 -0
- yomitoku/models/dbnet_plus.py +272 -0
- yomitoku/models/layers/__init__.py +0 -0
- yomitoku/models/layers/activate.py +38 -0
- yomitoku/models/layers/dbnet_feature_attention.py +160 -0
- yomitoku/models/layers/parseq_transformer.py +218 -0
- yomitoku/models/layers/rtdetr_backbone.py +333 -0
- yomitoku/models/layers/rtdetr_hybrid_encoder.py +433 -0
- yomitoku/models/layers/rtdetrv2_decoder.py +811 -0
- yomitoku/models/parseq.py +243 -0
- yomitoku/models/rtdetr.py +22 -0
- yomitoku/ocr.py +87 -0
- yomitoku/postprocessor/__init__.py +9 -0
- yomitoku/postprocessor/dbnet_postporcessor.py +137 -0
- yomitoku/postprocessor/parseq_tokenizer.py +128 -0
- yomitoku/postprocessor/rtdetr_postprocessor.py +107 -0
- yomitoku/reading_order.py +214 -0
- yomitoku/resource/MPLUS1p-Medium.ttf +0 -0
- yomitoku/resource/charset.txt +1 -0
- yomitoku/table_structure_recognizer.py +244 -0
- yomitoku/text_detector.py +103 -0
- yomitoku/text_recognizer.py +128 -0
- yomitoku/utils/__init__.py +0 -0
- yomitoku/utils/graph.py +20 -0
- yomitoku/utils/logger.py +15 -0
- yomitoku/utils/misc.py +102 -0
- yomitoku/utils/visualizer.py +179 -0
- yomitoku-0.4.0.post1.dev0.dist-info/METADATA +127 -0
- yomitoku-0.4.0.post1.dev0.dist-info/RECORD +52 -0
- yomitoku-0.4.0.post1.dev0.dist-info/WHEEL +4 -0
- yomitoku-0.4.0.post1.dev0.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,71 @@
|
|
1
|
+
import csv
|
2
|
+
|
3
|
+
|
4
|
+
def table_to_csv(table, ignore_line_break):
|
5
|
+
num_rows = table.n_row
|
6
|
+
num_cols = table.n_col
|
7
|
+
|
8
|
+
table_array = [["" for _ in range(num_cols)] for _ in range(num_rows)]
|
9
|
+
|
10
|
+
for cell in table.cells:
|
11
|
+
row = cell.row - 1
|
12
|
+
col = cell.col - 1
|
13
|
+
row_span = cell.row_span
|
14
|
+
col_span = cell.col_span
|
15
|
+
contents = cell.contents
|
16
|
+
|
17
|
+
if ignore_line_break:
|
18
|
+
contents = contents.replace("\n", "")
|
19
|
+
|
20
|
+
for i in range(row, row + row_span):
|
21
|
+
for j in range(col, col + col_span):
|
22
|
+
if i == row and j == col:
|
23
|
+
table_array[i][j] = contents
|
24
|
+
return table_array
|
25
|
+
|
26
|
+
|
27
|
+
def paragraph_to_csv(paragraph, ignore_line_break):
|
28
|
+
contents = paragraph.contents
|
29
|
+
|
30
|
+
if ignore_line_break:
|
31
|
+
contents = contents.replace("\n", "")
|
32
|
+
|
33
|
+
return contents
|
34
|
+
|
35
|
+
|
36
|
+
def export_csv(inputs, out_path: str, ignore_line_break: bool = False):
|
37
|
+
elements = []
|
38
|
+
for table in inputs.tables:
|
39
|
+
table_csv = table_to_csv(table, ignore_line_break)
|
40
|
+
|
41
|
+
elements.append(
|
42
|
+
{
|
43
|
+
"type": "table",
|
44
|
+
"box": table.box,
|
45
|
+
"element": table_csv,
|
46
|
+
"order": table.order,
|
47
|
+
}
|
48
|
+
)
|
49
|
+
|
50
|
+
for paraghraph in inputs.paragraphs:
|
51
|
+
contents = paragraph_to_csv(paraghraph, ignore_line_break)
|
52
|
+
elements.append(
|
53
|
+
{
|
54
|
+
"type": "paragraph",
|
55
|
+
"box": paraghraph.box,
|
56
|
+
"element": contents,
|
57
|
+
"order": paraghraph.order,
|
58
|
+
}
|
59
|
+
)
|
60
|
+
|
61
|
+
elements = sorted(elements, key=lambda x: x["order"])
|
62
|
+
|
63
|
+
with open(out_path, "w", newline="", encoding="utf-8") as f:
|
64
|
+
writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
|
65
|
+
for element in elements:
|
66
|
+
if element["type"] == "table":
|
67
|
+
writer.writerows(element["element"])
|
68
|
+
else:
|
69
|
+
writer.writerow([element["element"]])
|
70
|
+
|
71
|
+
writer.writerow([""])
|
@@ -0,0 +1,188 @@
|
|
1
|
+
import re
|
2
|
+
import os
|
3
|
+
import cv2
|
4
|
+
|
5
|
+
from html import escape
|
6
|
+
|
7
|
+
from lxml import etree, html
|
8
|
+
|
9
|
+
|
10
|
+
def convert_text_to_html(text):
|
11
|
+
"""
|
12
|
+
入力されたテキストをHTMLに変換する関数。
|
13
|
+
URLを検出してリンク化せずそのまま表示し、それ以外はHTMLエスケープする。
|
14
|
+
"""
|
15
|
+
url_regex = re.compile(r"https?://[^\s<>]")
|
16
|
+
|
17
|
+
def replace_url(match):
|
18
|
+
url = match.group(0)
|
19
|
+
return escape(url)
|
20
|
+
|
21
|
+
return url_regex.sub(replace_url, escape(text))
|
22
|
+
|
23
|
+
|
24
|
+
def add_td_tag(contents, row_span, col_span):
|
25
|
+
return f'<td rowspan="{row_span}" colspan="{col_span}">{contents}</td>'
|
26
|
+
|
27
|
+
|
28
|
+
def add_table_tag(contents):
|
29
|
+
return f'<table border="1" style="border-collapse: collapse">{contents}</table>'
|
30
|
+
|
31
|
+
|
32
|
+
def add_tr_tag(contents):
|
33
|
+
return f"<tr>{contents}</tr>"
|
34
|
+
|
35
|
+
|
36
|
+
def add_p_tag(contents):
|
37
|
+
return f"<p>{contents}</p>"
|
38
|
+
|
39
|
+
|
40
|
+
def add_html_tag(text):
|
41
|
+
return f"<html><body>{text}</body></html>"
|
42
|
+
|
43
|
+
|
44
|
+
def add_h1_tag(contents):
|
45
|
+
return f"<h1>{contents}</h1>"
|
46
|
+
|
47
|
+
|
48
|
+
def table_to_html(table, ignore_line_break):
|
49
|
+
pre_row = 1
|
50
|
+
rows = []
|
51
|
+
row = []
|
52
|
+
for cell in table.cells:
|
53
|
+
if cell.row != pre_row:
|
54
|
+
rows.append(add_tr_tag("".join(row)))
|
55
|
+
row = []
|
56
|
+
|
57
|
+
row_span = cell.row_span
|
58
|
+
col_span = cell.col_span
|
59
|
+
contents = cell.contents
|
60
|
+
|
61
|
+
if contents is None:
|
62
|
+
contents = ""
|
63
|
+
|
64
|
+
contents = convert_text_to_html(contents)
|
65
|
+
|
66
|
+
if ignore_line_break:
|
67
|
+
contents = contents.replace("\n", "")
|
68
|
+
else:
|
69
|
+
contents = contents.replace("\n", "<br>")
|
70
|
+
|
71
|
+
row.append(add_td_tag(contents, row_span, col_span))
|
72
|
+
pre_row = cell.row
|
73
|
+
else:
|
74
|
+
rows.append(add_tr_tag("".join(row)))
|
75
|
+
|
76
|
+
table_html = add_table_tag("".join(rows))
|
77
|
+
|
78
|
+
return {
|
79
|
+
"box": table.box,
|
80
|
+
"order": table.order,
|
81
|
+
"html": table_html,
|
82
|
+
}
|
83
|
+
|
84
|
+
|
85
|
+
def paragraph_to_html(paragraph, ignore_line_break):
|
86
|
+
contents = paragraph.contents
|
87
|
+
contents = convert_text_to_html(contents)
|
88
|
+
|
89
|
+
if ignore_line_break:
|
90
|
+
contents = contents.replace("\n", "")
|
91
|
+
else:
|
92
|
+
contents = contents.replace("\n", "<br>")
|
93
|
+
|
94
|
+
if paragraph.role == "section_headings":
|
95
|
+
contents = add_h1_tag(contents)
|
96
|
+
|
97
|
+
return {
|
98
|
+
"box": paragraph.box,
|
99
|
+
"order": paragraph.order,
|
100
|
+
"html": add_p_tag(contents),
|
101
|
+
}
|
102
|
+
|
103
|
+
|
104
|
+
def figure_to_html(
|
105
|
+
figures,
|
106
|
+
img,
|
107
|
+
out_path,
|
108
|
+
export_figure_letter=False,
|
109
|
+
ignore_line_break=False,
|
110
|
+
figure_dir="figures",
|
111
|
+
width=200,
|
112
|
+
):
|
113
|
+
elements = []
|
114
|
+
for i, figure in enumerate(figures):
|
115
|
+
x1, y1, x2, y2 = map(int, figure.box)
|
116
|
+
figure_img = img[y1:y2, x1:x2, :]
|
117
|
+
save_dir = os.path.dirname(out_path)
|
118
|
+
save_dir = os.path.join(save_dir, figure_dir)
|
119
|
+
os.makedirs(save_dir, exist_ok=True)
|
120
|
+
|
121
|
+
filename = os.path.splitext(os.path.basename(out_path))[0]
|
122
|
+
figure_name = f"{filename}_figure_{i}.png"
|
123
|
+
figure_path = os.path.join(save_dir, figure_name)
|
124
|
+
cv2.imwrite(figure_path, figure_img)
|
125
|
+
|
126
|
+
elements.append(
|
127
|
+
{
|
128
|
+
"order": figure.order,
|
129
|
+
"html": f'<img src="{figure_dir}/{figure_name}" width="{width}"><br>',
|
130
|
+
}
|
131
|
+
)
|
132
|
+
|
133
|
+
if export_figure_letter:
|
134
|
+
paragraphs = sorted(figure.paragraphs, key=lambda x: x.order)
|
135
|
+
for paragraph in paragraphs:
|
136
|
+
contents = paragraph_to_html(paragraph, ignore_line_break)
|
137
|
+
html = contents["html"]
|
138
|
+
elements.append(
|
139
|
+
{
|
140
|
+
"order": figure.order,
|
141
|
+
"html": html,
|
142
|
+
}
|
143
|
+
)
|
144
|
+
|
145
|
+
return elements
|
146
|
+
|
147
|
+
|
148
|
+
def export_html(
|
149
|
+
inputs,
|
150
|
+
out_path: str,
|
151
|
+
ignore_line_break: bool = False,
|
152
|
+
export_figure: bool = True,
|
153
|
+
export_figure_letter: bool = False,
|
154
|
+
img=None,
|
155
|
+
figure_width=200,
|
156
|
+
figure_dir="figures",
|
157
|
+
):
|
158
|
+
html_string = ""
|
159
|
+
elements = []
|
160
|
+
for table in inputs.tables:
|
161
|
+
elements.append(table_to_html(table, ignore_line_break))
|
162
|
+
|
163
|
+
for paragraph in inputs.paragraphs:
|
164
|
+
elements.append(paragraph_to_html(paragraph, ignore_line_break))
|
165
|
+
|
166
|
+
if export_figure:
|
167
|
+
elements.extend(
|
168
|
+
figure_to_html(
|
169
|
+
inputs.figures,
|
170
|
+
img,
|
171
|
+
out_path,
|
172
|
+
export_figure_letter,
|
173
|
+
ignore_line_break,
|
174
|
+
width=figure_width,
|
175
|
+
figure_dir=figure_dir,
|
176
|
+
),
|
177
|
+
)
|
178
|
+
|
179
|
+
elements = sorted(elements, key=lambda x: x["order"])
|
180
|
+
|
181
|
+
html_string = "".join([element["html"] for element in elements])
|
182
|
+
html_string = add_html_tag(html_string)
|
183
|
+
|
184
|
+
parsed_html = html.fromstring(html_string)
|
185
|
+
formatted_html = etree.tostring(parsed_html, pretty_print=True, encoding="unicode")
|
186
|
+
|
187
|
+
with open(out_path, "w", encoding="utf-8") as f:
|
188
|
+
f.write(formatted_html)
|
@@ -0,0 +1,34 @@
|
|
1
|
+
import json
|
2
|
+
|
3
|
+
|
4
|
+
def paragraph_to_json(paragraph, ignore_line_break):
|
5
|
+
if ignore_line_break:
|
6
|
+
paragraph.contents = paragraph.contents.replace("\n", "")
|
7
|
+
|
8
|
+
|
9
|
+
def table_to_json(table, ignore_line_break):
|
10
|
+
for cell in table.cells:
|
11
|
+
if ignore_line_break:
|
12
|
+
cell.contents = cell.contents.replace("\n", "")
|
13
|
+
|
14
|
+
|
15
|
+
def export_json(inputs, out_path, ignore_line_break=False):
|
16
|
+
from yomitoku.document_analyzer import DocumentAnalyzerSchema
|
17
|
+
|
18
|
+
if isinstance(inputs, DocumentAnalyzerSchema):
|
19
|
+
for table in inputs.tables:
|
20
|
+
table_to_json(table, ignore_line_break)
|
21
|
+
|
22
|
+
if isinstance(inputs, DocumentAnalyzerSchema):
|
23
|
+
for paragraph in inputs.paragraphs:
|
24
|
+
paragraph_to_json(paragraph, ignore_line_break)
|
25
|
+
|
26
|
+
with open(out_path, "w", encoding="utf-8") as f:
|
27
|
+
json.dump(
|
28
|
+
inputs.model_dump(),
|
29
|
+
f,
|
30
|
+
ensure_ascii=False,
|
31
|
+
indent=4,
|
32
|
+
sort_keys=True,
|
33
|
+
separators=(",", ": "),
|
34
|
+
)
|
@@ -0,0 +1,145 @@
|
|
1
|
+
import re
|
2
|
+
import cv2
|
3
|
+
import os
|
4
|
+
|
5
|
+
|
6
|
+
def escape_markdown_special_chars(text):
|
7
|
+
special_chars = r"([`*_{}[\]()#+.!|-])"
|
8
|
+
return re.sub(special_chars, r"\\\1", text)
|
9
|
+
|
10
|
+
|
11
|
+
def paragraph_to_md(paragraph, ignore_line_break):
|
12
|
+
contents = escape_markdown_special_chars(paragraph.contents)
|
13
|
+
|
14
|
+
if ignore_line_break:
|
15
|
+
contents = contents.replace("\n", "")
|
16
|
+
else:
|
17
|
+
contents = contents.replace("\n", "<br>")
|
18
|
+
|
19
|
+
if paragraph.role == "section_headings":
|
20
|
+
contents = "# " + contents
|
21
|
+
|
22
|
+
return {
|
23
|
+
"order": paragraph.order,
|
24
|
+
"box": paragraph.box,
|
25
|
+
"md": contents + "\n",
|
26
|
+
}
|
27
|
+
|
28
|
+
|
29
|
+
def table_to_md(table, ignore_line_break):
|
30
|
+
num_rows = table.n_row
|
31
|
+
num_cols = table.n_col
|
32
|
+
|
33
|
+
table_array = [["" for _ in range(num_cols)] for _ in range(num_rows)]
|
34
|
+
|
35
|
+
for cell in table.cells:
|
36
|
+
row = cell.row - 1
|
37
|
+
col = cell.col - 1
|
38
|
+
row_span = cell.row_span
|
39
|
+
col_span = cell.col_span
|
40
|
+
contents = cell.contents
|
41
|
+
|
42
|
+
for i in range(row, row + row_span):
|
43
|
+
for j in range(col, col + col_span):
|
44
|
+
contents = escape_markdown_special_chars(contents)
|
45
|
+
if ignore_line_break:
|
46
|
+
contents = contents.replace("\n", "")
|
47
|
+
else:
|
48
|
+
contents = contents.replace("\n", "<br>")
|
49
|
+
|
50
|
+
if i == row and j == col:
|
51
|
+
table_array[i][j] = contents
|
52
|
+
|
53
|
+
table_md = ""
|
54
|
+
for i in range(num_rows):
|
55
|
+
row = "|".join(table_array[i])
|
56
|
+
table_md += f"|{row}|\n"
|
57
|
+
|
58
|
+
if i == 0:
|
59
|
+
header = "|".join(["-" for _ in range(num_cols)])
|
60
|
+
table_md += f"|{header}|\n"
|
61
|
+
|
62
|
+
return {
|
63
|
+
"order": table.order,
|
64
|
+
"box": table.box,
|
65
|
+
"md": table_md,
|
66
|
+
}
|
67
|
+
|
68
|
+
|
69
|
+
def figure_to_md(
|
70
|
+
figures,
|
71
|
+
img,
|
72
|
+
out_path,
|
73
|
+
export_figure_letter=False,
|
74
|
+
ignore_line_break=False,
|
75
|
+
width=200,
|
76
|
+
figure_dir="figures",
|
77
|
+
):
|
78
|
+
elements = []
|
79
|
+
for i, figure in enumerate(figures):
|
80
|
+
x1, y1, x2, y2 = map(int, figure.box)
|
81
|
+
figure_img = img[y1:y2, x1:x2, :]
|
82
|
+
save_dir = os.path.dirname(out_path)
|
83
|
+
save_dir = os.path.join(save_dir, figure_dir)
|
84
|
+
os.makedirs(save_dir, exist_ok=True)
|
85
|
+
|
86
|
+
filename = os.path.splitext(os.path.basename(out_path))[0]
|
87
|
+
figure_name = f"{filename}_figure_{i}.png"
|
88
|
+
figure_path = os.path.join(save_dir, figure_name)
|
89
|
+
cv2.imwrite(figure_path, figure_img)
|
90
|
+
|
91
|
+
elements.append(
|
92
|
+
{
|
93
|
+
"order": figure.order,
|
94
|
+
"md": f'<img src="{figure_dir}/{figure_name}" width="{width}px"><br>',
|
95
|
+
}
|
96
|
+
)
|
97
|
+
|
98
|
+
if export_figure_letter:
|
99
|
+
paragraphs = sorted(figure.paragraphs, key=lambda x: x.order)
|
100
|
+
for paragraph in paragraphs:
|
101
|
+
element = paragraph_to_md(paragraph, ignore_line_break)
|
102
|
+
element = {
|
103
|
+
"order": figure.order,
|
104
|
+
"md": element["md"],
|
105
|
+
}
|
106
|
+
elements.append(element)
|
107
|
+
|
108
|
+
return elements
|
109
|
+
|
110
|
+
|
111
|
+
def export_markdown(
|
112
|
+
inputs,
|
113
|
+
out_path: str,
|
114
|
+
img=None,
|
115
|
+
ignore_line_break: bool = False,
|
116
|
+
export_figure_letter=False,
|
117
|
+
export_figure=True,
|
118
|
+
figure_width=200,
|
119
|
+
figure_dir="figures",
|
120
|
+
):
|
121
|
+
elements = []
|
122
|
+
for table in inputs.tables:
|
123
|
+
elements.append(table_to_md(table, ignore_line_break))
|
124
|
+
|
125
|
+
for paragraph in inputs.paragraphs:
|
126
|
+
elements.append(paragraph_to_md(paragraph, ignore_line_break))
|
127
|
+
|
128
|
+
if export_figure:
|
129
|
+
elements.extend(
|
130
|
+
figure_to_md(
|
131
|
+
inputs.figures,
|
132
|
+
img,
|
133
|
+
out_path,
|
134
|
+
export_figure_letter,
|
135
|
+
ignore_line_break,
|
136
|
+
figure_width,
|
137
|
+
figure_dir=figure_dir,
|
138
|
+
)
|
139
|
+
)
|
140
|
+
|
141
|
+
elements = sorted(elements, key=lambda x: x["order"])
|
142
|
+
markdown = "\n".join([element["md"] for element in elements])
|
143
|
+
|
144
|
+
with open(out_path, "w", encoding="utf-8") as f:
|
145
|
+
f.write(markdown)
|
@@ -0,0 +1,66 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from .base import BaseSchema
|
4
|
+
from .layout_parser import Element, LayoutParser
|
5
|
+
from .table_structure_recognizer import (
|
6
|
+
TableStructureRecognizer,
|
7
|
+
TableStructureRecognizerSchema,
|
8
|
+
)
|
9
|
+
|
10
|
+
|
11
|
+
class LayoutAnalyzerSchema(BaseSchema):
|
12
|
+
paragraphs: List[Element]
|
13
|
+
tables: List[TableStructureRecognizerSchema]
|
14
|
+
figures: List[Element]
|
15
|
+
|
16
|
+
|
17
|
+
class LayoutAnalyzer:
|
18
|
+
def __init__(self, configs=None, device="cuda", visualize=False):
|
19
|
+
layout_parser_kwargs = {
|
20
|
+
"device": device,
|
21
|
+
"visualize": visualize,
|
22
|
+
}
|
23
|
+
table_structure_recognizer_kwargs = {
|
24
|
+
"device": device,
|
25
|
+
"visualize": visualize,
|
26
|
+
}
|
27
|
+
|
28
|
+
if isinstance(configs, dict):
|
29
|
+
assert (
|
30
|
+
"layout_parser" in configs
|
31
|
+
or "table_structure_recognizer" in configs
|
32
|
+
), "Invalid config key. Please check the config keys."
|
33
|
+
|
34
|
+
if "layout_parser" in configs:
|
35
|
+
layout_parser_kwargs.update(configs["layout_parser"])
|
36
|
+
|
37
|
+
if "table_structure_recognizer" in configs:
|
38
|
+
table_structure_recognizer_kwargs.update(
|
39
|
+
configs["table_structure_recognizer"]
|
40
|
+
)
|
41
|
+
else:
|
42
|
+
raise ValueError(
|
43
|
+
"configs must be a dict. See the https://kotaro-kinoshita.github.io/yomitoku-dev/usage/"
|
44
|
+
)
|
45
|
+
|
46
|
+
self.layout_parser = LayoutParser(
|
47
|
+
**layout_parser_kwargs,
|
48
|
+
)
|
49
|
+
self.table_structure_recognizer = TableStructureRecognizer(
|
50
|
+
**table_structure_recognizer_kwargs,
|
51
|
+
)
|
52
|
+
|
53
|
+
def __call__(self, img):
|
54
|
+
layout_results, vis = self.layout_parser(img)
|
55
|
+
table_boxes = [table.box for table in layout_results.tables]
|
56
|
+
table_results, vis = self.table_structure_recognizer(
|
57
|
+
img, table_boxes, vis=vis
|
58
|
+
)
|
59
|
+
|
60
|
+
results = LayoutAnalyzerSchema(
|
61
|
+
paragraphs=layout_results.paragraphs,
|
62
|
+
tables=table_results,
|
63
|
+
figures=layout_results.figures,
|
64
|
+
)
|
65
|
+
|
66
|
+
return results, vis
|
@@ -0,0 +1,189 @@
|
|
1
|
+
from typing import List, Union
|
2
|
+
|
3
|
+
import cv2
|
4
|
+
import torch
|
5
|
+
import torchvision.transforms as T
|
6
|
+
from PIL import Image
|
7
|
+
from pydantic import conlist
|
8
|
+
|
9
|
+
from .base import BaseModelCatalog, BaseModule, BaseSchema
|
10
|
+
from .configs import LayoutParserRTDETRv2Config
|
11
|
+
from .models import RTDETRv2
|
12
|
+
from .postprocessor import RTDETRPostProcessor
|
13
|
+
from .utils.misc import filter_by_flag, is_contained
|
14
|
+
from .utils.visualizer import layout_visualizer
|
15
|
+
|
16
|
+
|
17
|
+
class Element(BaseSchema):
|
18
|
+
box: conlist(int, min_length=4, max_length=4)
|
19
|
+
score: float
|
20
|
+
role: Union[str, None]
|
21
|
+
|
22
|
+
|
23
|
+
class LayoutParserSchema(BaseSchema):
|
24
|
+
paragraphs: List[Element]
|
25
|
+
tables: List[Element]
|
26
|
+
figures: List[Element]
|
27
|
+
|
28
|
+
|
29
|
+
class LayoutParserModelCatalog(BaseModelCatalog):
|
30
|
+
def __init__(self):
|
31
|
+
super().__init__()
|
32
|
+
self.register("rtdetrv2", LayoutParserRTDETRv2Config, RTDETRv2)
|
33
|
+
|
34
|
+
|
35
|
+
def filter_contained_rectangles_within_category(category_elements):
|
36
|
+
"""同一カテゴリに属する矩形のうち、他の矩形の内側に含まれるものを除外"""
|
37
|
+
|
38
|
+
for category, elements in category_elements.items():
|
39
|
+
group_box = [element["box"] for element in elements]
|
40
|
+
check_list = [True] * len(group_box)
|
41
|
+
for i, box_i in enumerate(group_box):
|
42
|
+
for j, box_j in enumerate(group_box):
|
43
|
+
if i >= j:
|
44
|
+
continue
|
45
|
+
|
46
|
+
ij = is_contained(box_i, box_j)
|
47
|
+
ji = is_contained(box_j, box_i)
|
48
|
+
|
49
|
+
box_i_area = (box_i[2] - box_i[0]) * (box_i[3] - box_i[1])
|
50
|
+
box_j_area = (box_j[2] - box_j[0]) * (box_j[3] - box_j[1])
|
51
|
+
|
52
|
+
# 双方から見て内包関係にある場合、面積の大きい方を残す
|
53
|
+
if ij and ji:
|
54
|
+
if box_i_area > box_j_area:
|
55
|
+
check_list[j] = False
|
56
|
+
else:
|
57
|
+
check_list[i] = False
|
58
|
+
elif ij:
|
59
|
+
check_list[j] = False
|
60
|
+
elif ji:
|
61
|
+
check_list[i] = False
|
62
|
+
|
63
|
+
category_elements[category] = filter_by_flag(elements, check_list)
|
64
|
+
|
65
|
+
return category_elements
|
66
|
+
|
67
|
+
|
68
|
+
def filter_contained_rectangles_across_categories(category_elements, source, target):
|
69
|
+
"""sourceカテゴリの矩形がtargetカテゴリの矩形に内包される場合、sourceカテゴリの矩形を除外"""
|
70
|
+
|
71
|
+
src_boxes = [element["box"] for element in category_elements[source]]
|
72
|
+
tgt_boxes = [element["box"] for element in category_elements[target]]
|
73
|
+
|
74
|
+
check_list = [True] * len(tgt_boxes)
|
75
|
+
for i, src_box in enumerate(src_boxes):
|
76
|
+
for j, tgt_box in enumerate(tgt_boxes):
|
77
|
+
if is_contained(src_box, tgt_box):
|
78
|
+
check_list[j] = False
|
79
|
+
|
80
|
+
category_elements[target] = filter_by_flag(category_elements[target], check_list)
|
81
|
+
return category_elements
|
82
|
+
|
83
|
+
|
84
|
+
class LayoutParser(BaseModule):
|
85
|
+
model_catalog = LayoutParserModelCatalog()
|
86
|
+
|
87
|
+
def __init__(
|
88
|
+
self,
|
89
|
+
model_name="rtdetrv2",
|
90
|
+
path_cfg=None,
|
91
|
+
device="cuda",
|
92
|
+
visualize=False,
|
93
|
+
from_pretrained=True,
|
94
|
+
):
|
95
|
+
super().__init__()
|
96
|
+
self.load_model(model_name, path_cfg, from_pretrained)
|
97
|
+
self.device = device
|
98
|
+
self.visualize = visualize
|
99
|
+
|
100
|
+
self.model.eval()
|
101
|
+
self.model.to(self.device)
|
102
|
+
|
103
|
+
self.postprocessor = RTDETRPostProcessor(
|
104
|
+
num_classes=self._cfg.RTDETRTransformerv2.num_classes,
|
105
|
+
num_top_queries=self._cfg.RTDETRTransformerv2.num_queries,
|
106
|
+
)
|
107
|
+
|
108
|
+
self.transforms = T.Compose(
|
109
|
+
[
|
110
|
+
T.Resize(self._cfg.data.img_size),
|
111
|
+
T.ToTensor(),
|
112
|
+
]
|
113
|
+
)
|
114
|
+
|
115
|
+
self.thresh_score = self._cfg.thresh_score
|
116
|
+
|
117
|
+
self.label_mapper = {
|
118
|
+
id: category for id, category in enumerate(self._cfg.category)
|
119
|
+
}
|
120
|
+
|
121
|
+
self.role = self._cfg.role
|
122
|
+
|
123
|
+
def preprocess(self, img):
|
124
|
+
cv_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
125
|
+
img = Image.fromarray(cv_img)
|
126
|
+
img_tensor = self.transforms(img)[None].to(self.device)
|
127
|
+
return img_tensor
|
128
|
+
|
129
|
+
def postprocess(self, preds, image_size):
|
130
|
+
h, w = image_size
|
131
|
+
orig_size = torch.tensor([w, h])[None].to(self.device)
|
132
|
+
outputs = self.postprocessor(preds, orig_size, self.thresh_score)
|
133
|
+
outputs = self.filtering_elements(outputs[0])
|
134
|
+
results = LayoutParserSchema(**outputs)
|
135
|
+
return results
|
136
|
+
|
137
|
+
def filtering_elements(self, preds):
|
138
|
+
scores = preds["scores"]
|
139
|
+
boxes = preds["boxes"]
|
140
|
+
labels = preds["labels"]
|
141
|
+
|
142
|
+
category_elements = {
|
143
|
+
category: []
|
144
|
+
for category in self.label_mapper.values()
|
145
|
+
if category not in self.role
|
146
|
+
}
|
147
|
+
|
148
|
+
for box, score, label in zip(boxes, scores, labels):
|
149
|
+
category = self.label_mapper[label.item()]
|
150
|
+
|
151
|
+
role = None
|
152
|
+
if category in self.role:
|
153
|
+
role = category
|
154
|
+
category = "paragraphs"
|
155
|
+
|
156
|
+
category_elements[category].append(
|
157
|
+
{
|
158
|
+
"box": box.astype(int).tolist(),
|
159
|
+
"score": float(score),
|
160
|
+
"role": role,
|
161
|
+
}
|
162
|
+
)
|
163
|
+
|
164
|
+
category_elements = filter_contained_rectangles_within_category(
|
165
|
+
category_elements
|
166
|
+
)
|
167
|
+
|
168
|
+
category_elements = filter_contained_rectangles_across_categories(
|
169
|
+
category_elements, "tables", "paragraphs"
|
170
|
+
)
|
171
|
+
|
172
|
+
return category_elements
|
173
|
+
|
174
|
+
def __call__(self, img):
|
175
|
+
ori_h, ori_w = img.shape[:2]
|
176
|
+
img_tensor = self.preprocess(img)
|
177
|
+
|
178
|
+
with torch.inference_mode():
|
179
|
+
preds = self.model(img_tensor)
|
180
|
+
results = self.postprocess(preds, (ori_h, ori_w))
|
181
|
+
|
182
|
+
vis = None
|
183
|
+
if self.visualize:
|
184
|
+
vis = layout_visualizer(
|
185
|
+
results,
|
186
|
+
img,
|
187
|
+
)
|
188
|
+
|
189
|
+
return results, vis
|