yomitoku 0.4.0.post1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. yomitoku/__init__.py +20 -0
  2. yomitoku/base.py +136 -0
  3. yomitoku/cli/__init__.py +0 -0
  4. yomitoku/cli/main.py +230 -0
  5. yomitoku/configs/__init__.py +13 -0
  6. yomitoku/configs/cfg_layout_parser_rtdtrv2.py +89 -0
  7. yomitoku/configs/cfg_table_structure_recognizer_rtdtrv2.py +80 -0
  8. yomitoku/configs/cfg_text_detector_dbnet.py +49 -0
  9. yomitoku/configs/cfg_text_recognizer_parseq.py +51 -0
  10. yomitoku/constants.py +32 -0
  11. yomitoku/data/__init__.py +3 -0
  12. yomitoku/data/dataset.py +40 -0
  13. yomitoku/data/functions.py +279 -0
  14. yomitoku/document_analyzer.py +315 -0
  15. yomitoku/export/__init__.py +6 -0
  16. yomitoku/export/export_csv.py +71 -0
  17. yomitoku/export/export_html.py +188 -0
  18. yomitoku/export/export_json.py +34 -0
  19. yomitoku/export/export_markdown.py +145 -0
  20. yomitoku/layout_analyzer.py +66 -0
  21. yomitoku/layout_parser.py +189 -0
  22. yomitoku/models/__init__.py +9 -0
  23. yomitoku/models/dbnet_plus.py +272 -0
  24. yomitoku/models/layers/__init__.py +0 -0
  25. yomitoku/models/layers/activate.py +38 -0
  26. yomitoku/models/layers/dbnet_feature_attention.py +160 -0
  27. yomitoku/models/layers/parseq_transformer.py +218 -0
  28. yomitoku/models/layers/rtdetr_backbone.py +333 -0
  29. yomitoku/models/layers/rtdetr_hybrid_encoder.py +433 -0
  30. yomitoku/models/layers/rtdetrv2_decoder.py +811 -0
  31. yomitoku/models/parseq.py +243 -0
  32. yomitoku/models/rtdetr.py +22 -0
  33. yomitoku/ocr.py +87 -0
  34. yomitoku/postprocessor/__init__.py +9 -0
  35. yomitoku/postprocessor/dbnet_postporcessor.py +137 -0
  36. yomitoku/postprocessor/parseq_tokenizer.py +128 -0
  37. yomitoku/postprocessor/rtdetr_postprocessor.py +107 -0
  38. yomitoku/reading_order.py +214 -0
  39. yomitoku/resource/MPLUS1p-Medium.ttf +0 -0
  40. yomitoku/resource/charset.txt +1 -0
  41. yomitoku/table_structure_recognizer.py +244 -0
  42. yomitoku/text_detector.py +103 -0
  43. yomitoku/text_recognizer.py +128 -0
  44. yomitoku/utils/__init__.py +0 -0
  45. yomitoku/utils/graph.py +20 -0
  46. yomitoku/utils/logger.py +15 -0
  47. yomitoku/utils/misc.py +102 -0
  48. yomitoku/utils/visualizer.py +179 -0
  49. yomitoku-0.4.0.post1.dev0.dist-info/METADATA +127 -0
  50. yomitoku-0.4.0.post1.dev0.dist-info/RECORD +52 -0
  51. yomitoku-0.4.0.post1.dev0.dist-info/WHEEL +4 -0
  52. yomitoku-0.4.0.post1.dev0.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,71 @@
1
+ import csv
2
+
3
+
4
+ def table_to_csv(table, ignore_line_break):
5
+ num_rows = table.n_row
6
+ num_cols = table.n_col
7
+
8
+ table_array = [["" for _ in range(num_cols)] for _ in range(num_rows)]
9
+
10
+ for cell in table.cells:
11
+ row = cell.row - 1
12
+ col = cell.col - 1
13
+ row_span = cell.row_span
14
+ col_span = cell.col_span
15
+ contents = cell.contents
16
+
17
+ if ignore_line_break:
18
+ contents = contents.replace("\n", "")
19
+
20
+ for i in range(row, row + row_span):
21
+ for j in range(col, col + col_span):
22
+ if i == row and j == col:
23
+ table_array[i][j] = contents
24
+ return table_array
25
+
26
+
27
+ def paragraph_to_csv(paragraph, ignore_line_break):
28
+ contents = paragraph.contents
29
+
30
+ if ignore_line_break:
31
+ contents = contents.replace("\n", "")
32
+
33
+ return contents
34
+
35
+
36
+ def export_csv(inputs, out_path: str, ignore_line_break: bool = False):
37
+ elements = []
38
+ for table in inputs.tables:
39
+ table_csv = table_to_csv(table, ignore_line_break)
40
+
41
+ elements.append(
42
+ {
43
+ "type": "table",
44
+ "box": table.box,
45
+ "element": table_csv,
46
+ "order": table.order,
47
+ }
48
+ )
49
+
50
+ for paraghraph in inputs.paragraphs:
51
+ contents = paragraph_to_csv(paraghraph, ignore_line_break)
52
+ elements.append(
53
+ {
54
+ "type": "paragraph",
55
+ "box": paraghraph.box,
56
+ "element": contents,
57
+ "order": paraghraph.order,
58
+ }
59
+ )
60
+
61
+ elements = sorted(elements, key=lambda x: x["order"])
62
+
63
+ with open(out_path, "w", newline="", encoding="utf-8") as f:
64
+ writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
65
+ for element in elements:
66
+ if element["type"] == "table":
67
+ writer.writerows(element["element"])
68
+ else:
69
+ writer.writerow([element["element"]])
70
+
71
+ writer.writerow([""])
@@ -0,0 +1,188 @@
1
+ import re
2
+ import os
3
+ import cv2
4
+
5
+ from html import escape
6
+
7
+ from lxml import etree, html
8
+
9
+
10
+ def convert_text_to_html(text):
11
+ """
12
+ 入力されたテキストをHTMLに変換する関数。
13
+ URLを検出してリンク化せずそのまま表示し、それ以外はHTMLエスケープする。
14
+ """
15
+ url_regex = re.compile(r"https?://[^\s<>]")
16
+
17
+ def replace_url(match):
18
+ url = match.group(0)
19
+ return escape(url)
20
+
21
+ return url_regex.sub(replace_url, escape(text))
22
+
23
+
24
+ def add_td_tag(contents, row_span, col_span):
25
+ return f'<td rowspan="{row_span}" colspan="{col_span}">{contents}</td>'
26
+
27
+
28
+ def add_table_tag(contents):
29
+ return f'<table border="1" style="border-collapse: collapse">{contents}</table>'
30
+
31
+
32
+ def add_tr_tag(contents):
33
+ return f"<tr>{contents}</tr>"
34
+
35
+
36
+ def add_p_tag(contents):
37
+ return f"<p>{contents}</p>"
38
+
39
+
40
+ def add_html_tag(text):
41
+ return f"<html><body>{text}</body></html>"
42
+
43
+
44
+ def add_h1_tag(contents):
45
+ return f"<h1>{contents}</h1>"
46
+
47
+
48
+ def table_to_html(table, ignore_line_break):
49
+ pre_row = 1
50
+ rows = []
51
+ row = []
52
+ for cell in table.cells:
53
+ if cell.row != pre_row:
54
+ rows.append(add_tr_tag("".join(row)))
55
+ row = []
56
+
57
+ row_span = cell.row_span
58
+ col_span = cell.col_span
59
+ contents = cell.contents
60
+
61
+ if contents is None:
62
+ contents = ""
63
+
64
+ contents = convert_text_to_html(contents)
65
+
66
+ if ignore_line_break:
67
+ contents = contents.replace("\n", "")
68
+ else:
69
+ contents = contents.replace("\n", "<br>")
70
+
71
+ row.append(add_td_tag(contents, row_span, col_span))
72
+ pre_row = cell.row
73
+ else:
74
+ rows.append(add_tr_tag("".join(row)))
75
+
76
+ table_html = add_table_tag("".join(rows))
77
+
78
+ return {
79
+ "box": table.box,
80
+ "order": table.order,
81
+ "html": table_html,
82
+ }
83
+
84
+
85
+ def paragraph_to_html(paragraph, ignore_line_break):
86
+ contents = paragraph.contents
87
+ contents = convert_text_to_html(contents)
88
+
89
+ if ignore_line_break:
90
+ contents = contents.replace("\n", "")
91
+ else:
92
+ contents = contents.replace("\n", "<br>")
93
+
94
+ if paragraph.role == "section_headings":
95
+ contents = add_h1_tag(contents)
96
+
97
+ return {
98
+ "box": paragraph.box,
99
+ "order": paragraph.order,
100
+ "html": add_p_tag(contents),
101
+ }
102
+
103
+
104
+ def figure_to_html(
105
+ figures,
106
+ img,
107
+ out_path,
108
+ export_figure_letter=False,
109
+ ignore_line_break=False,
110
+ figure_dir="figures",
111
+ width=200,
112
+ ):
113
+ elements = []
114
+ for i, figure in enumerate(figures):
115
+ x1, y1, x2, y2 = map(int, figure.box)
116
+ figure_img = img[y1:y2, x1:x2, :]
117
+ save_dir = os.path.dirname(out_path)
118
+ save_dir = os.path.join(save_dir, figure_dir)
119
+ os.makedirs(save_dir, exist_ok=True)
120
+
121
+ filename = os.path.splitext(os.path.basename(out_path))[0]
122
+ figure_name = f"{filename}_figure_{i}.png"
123
+ figure_path = os.path.join(save_dir, figure_name)
124
+ cv2.imwrite(figure_path, figure_img)
125
+
126
+ elements.append(
127
+ {
128
+ "order": figure.order,
129
+ "html": f'<img src="{figure_dir}/{figure_name}" width="{width}"><br>',
130
+ }
131
+ )
132
+
133
+ if export_figure_letter:
134
+ paragraphs = sorted(figure.paragraphs, key=lambda x: x.order)
135
+ for paragraph in paragraphs:
136
+ contents = paragraph_to_html(paragraph, ignore_line_break)
137
+ html = contents["html"]
138
+ elements.append(
139
+ {
140
+ "order": figure.order,
141
+ "html": html,
142
+ }
143
+ )
144
+
145
+ return elements
146
+
147
+
148
+ def export_html(
149
+ inputs,
150
+ out_path: str,
151
+ ignore_line_break: bool = False,
152
+ export_figure: bool = True,
153
+ export_figure_letter: bool = False,
154
+ img=None,
155
+ figure_width=200,
156
+ figure_dir="figures",
157
+ ):
158
+ html_string = ""
159
+ elements = []
160
+ for table in inputs.tables:
161
+ elements.append(table_to_html(table, ignore_line_break))
162
+
163
+ for paragraph in inputs.paragraphs:
164
+ elements.append(paragraph_to_html(paragraph, ignore_line_break))
165
+
166
+ if export_figure:
167
+ elements.extend(
168
+ figure_to_html(
169
+ inputs.figures,
170
+ img,
171
+ out_path,
172
+ export_figure_letter,
173
+ ignore_line_break,
174
+ width=figure_width,
175
+ figure_dir=figure_dir,
176
+ ),
177
+ )
178
+
179
+ elements = sorted(elements, key=lambda x: x["order"])
180
+
181
+ html_string = "".join([element["html"] for element in elements])
182
+ html_string = add_html_tag(html_string)
183
+
184
+ parsed_html = html.fromstring(html_string)
185
+ formatted_html = etree.tostring(parsed_html, pretty_print=True, encoding="unicode")
186
+
187
+ with open(out_path, "w", encoding="utf-8") as f:
188
+ f.write(formatted_html)
@@ -0,0 +1,34 @@
1
+ import json
2
+
3
+
4
+ def paragraph_to_json(paragraph, ignore_line_break):
5
+ if ignore_line_break:
6
+ paragraph.contents = paragraph.contents.replace("\n", "")
7
+
8
+
9
+ def table_to_json(table, ignore_line_break):
10
+ for cell in table.cells:
11
+ if ignore_line_break:
12
+ cell.contents = cell.contents.replace("\n", "")
13
+
14
+
15
+ def export_json(inputs, out_path, ignore_line_break=False):
16
+ from yomitoku.document_analyzer import DocumentAnalyzerSchema
17
+
18
+ if isinstance(inputs, DocumentAnalyzerSchema):
19
+ for table in inputs.tables:
20
+ table_to_json(table, ignore_line_break)
21
+
22
+ if isinstance(inputs, DocumentAnalyzerSchema):
23
+ for paragraph in inputs.paragraphs:
24
+ paragraph_to_json(paragraph, ignore_line_break)
25
+
26
+ with open(out_path, "w", encoding="utf-8") as f:
27
+ json.dump(
28
+ inputs.model_dump(),
29
+ f,
30
+ ensure_ascii=False,
31
+ indent=4,
32
+ sort_keys=True,
33
+ separators=(",", ": "),
34
+ )
@@ -0,0 +1,145 @@
1
+ import re
2
+ import cv2
3
+ import os
4
+
5
+
6
+ def escape_markdown_special_chars(text):
7
+ special_chars = r"([`*_{}[\]()#+.!|-])"
8
+ return re.sub(special_chars, r"\\\1", text)
9
+
10
+
11
+ def paragraph_to_md(paragraph, ignore_line_break):
12
+ contents = escape_markdown_special_chars(paragraph.contents)
13
+
14
+ if ignore_line_break:
15
+ contents = contents.replace("\n", "")
16
+ else:
17
+ contents = contents.replace("\n", "<br>")
18
+
19
+ if paragraph.role == "section_headings":
20
+ contents = "# " + contents
21
+
22
+ return {
23
+ "order": paragraph.order,
24
+ "box": paragraph.box,
25
+ "md": contents + "\n",
26
+ }
27
+
28
+
29
+ def table_to_md(table, ignore_line_break):
30
+ num_rows = table.n_row
31
+ num_cols = table.n_col
32
+
33
+ table_array = [["" for _ in range(num_cols)] for _ in range(num_rows)]
34
+
35
+ for cell in table.cells:
36
+ row = cell.row - 1
37
+ col = cell.col - 1
38
+ row_span = cell.row_span
39
+ col_span = cell.col_span
40
+ contents = cell.contents
41
+
42
+ for i in range(row, row + row_span):
43
+ for j in range(col, col + col_span):
44
+ contents = escape_markdown_special_chars(contents)
45
+ if ignore_line_break:
46
+ contents = contents.replace("\n", "")
47
+ else:
48
+ contents = contents.replace("\n", "<br>")
49
+
50
+ if i == row and j == col:
51
+ table_array[i][j] = contents
52
+
53
+ table_md = ""
54
+ for i in range(num_rows):
55
+ row = "|".join(table_array[i])
56
+ table_md += f"|{row}|\n"
57
+
58
+ if i == 0:
59
+ header = "|".join(["-" for _ in range(num_cols)])
60
+ table_md += f"|{header}|\n"
61
+
62
+ return {
63
+ "order": table.order,
64
+ "box": table.box,
65
+ "md": table_md,
66
+ }
67
+
68
+
69
+ def figure_to_md(
70
+ figures,
71
+ img,
72
+ out_path,
73
+ export_figure_letter=False,
74
+ ignore_line_break=False,
75
+ width=200,
76
+ figure_dir="figures",
77
+ ):
78
+ elements = []
79
+ for i, figure in enumerate(figures):
80
+ x1, y1, x2, y2 = map(int, figure.box)
81
+ figure_img = img[y1:y2, x1:x2, :]
82
+ save_dir = os.path.dirname(out_path)
83
+ save_dir = os.path.join(save_dir, figure_dir)
84
+ os.makedirs(save_dir, exist_ok=True)
85
+
86
+ filename = os.path.splitext(os.path.basename(out_path))[0]
87
+ figure_name = f"{filename}_figure_{i}.png"
88
+ figure_path = os.path.join(save_dir, figure_name)
89
+ cv2.imwrite(figure_path, figure_img)
90
+
91
+ elements.append(
92
+ {
93
+ "order": figure.order,
94
+ "md": f'<img src="{figure_dir}/{figure_name}" width="{width}px"><br>',
95
+ }
96
+ )
97
+
98
+ if export_figure_letter:
99
+ paragraphs = sorted(figure.paragraphs, key=lambda x: x.order)
100
+ for paragraph in paragraphs:
101
+ element = paragraph_to_md(paragraph, ignore_line_break)
102
+ element = {
103
+ "order": figure.order,
104
+ "md": element["md"],
105
+ }
106
+ elements.append(element)
107
+
108
+ return elements
109
+
110
+
111
+ def export_markdown(
112
+ inputs,
113
+ out_path: str,
114
+ img=None,
115
+ ignore_line_break: bool = False,
116
+ export_figure_letter=False,
117
+ export_figure=True,
118
+ figure_width=200,
119
+ figure_dir="figures",
120
+ ):
121
+ elements = []
122
+ for table in inputs.tables:
123
+ elements.append(table_to_md(table, ignore_line_break))
124
+
125
+ for paragraph in inputs.paragraphs:
126
+ elements.append(paragraph_to_md(paragraph, ignore_line_break))
127
+
128
+ if export_figure:
129
+ elements.extend(
130
+ figure_to_md(
131
+ inputs.figures,
132
+ img,
133
+ out_path,
134
+ export_figure_letter,
135
+ ignore_line_break,
136
+ figure_width,
137
+ figure_dir=figure_dir,
138
+ )
139
+ )
140
+
141
+ elements = sorted(elements, key=lambda x: x["order"])
142
+ markdown = "\n".join([element["md"] for element in elements])
143
+
144
+ with open(out_path, "w", encoding="utf-8") as f:
145
+ f.write(markdown)
@@ -0,0 +1,66 @@
1
+ from typing import List
2
+
3
+ from .base import BaseSchema
4
+ from .layout_parser import Element, LayoutParser
5
+ from .table_structure_recognizer import (
6
+ TableStructureRecognizer,
7
+ TableStructureRecognizerSchema,
8
+ )
9
+
10
+
11
+ class LayoutAnalyzerSchema(BaseSchema):
12
+ paragraphs: List[Element]
13
+ tables: List[TableStructureRecognizerSchema]
14
+ figures: List[Element]
15
+
16
+
17
+ class LayoutAnalyzer:
18
+ def __init__(self, configs=None, device="cuda", visualize=False):
19
+ layout_parser_kwargs = {
20
+ "device": device,
21
+ "visualize": visualize,
22
+ }
23
+ table_structure_recognizer_kwargs = {
24
+ "device": device,
25
+ "visualize": visualize,
26
+ }
27
+
28
+ if isinstance(configs, dict):
29
+ assert (
30
+ "layout_parser" in configs
31
+ or "table_structure_recognizer" in configs
32
+ ), "Invalid config key. Please check the config keys."
33
+
34
+ if "layout_parser" in configs:
35
+ layout_parser_kwargs.update(configs["layout_parser"])
36
+
37
+ if "table_structure_recognizer" in configs:
38
+ table_structure_recognizer_kwargs.update(
39
+ configs["table_structure_recognizer"]
40
+ )
41
+ else:
42
+ raise ValueError(
43
+ "configs must be a dict. See the https://kotaro-kinoshita.github.io/yomitoku-dev/usage/"
44
+ )
45
+
46
+ self.layout_parser = LayoutParser(
47
+ **layout_parser_kwargs,
48
+ )
49
+ self.table_structure_recognizer = TableStructureRecognizer(
50
+ **table_structure_recognizer_kwargs,
51
+ )
52
+
53
+ def __call__(self, img):
54
+ layout_results, vis = self.layout_parser(img)
55
+ table_boxes = [table.box for table in layout_results.tables]
56
+ table_results, vis = self.table_structure_recognizer(
57
+ img, table_boxes, vis=vis
58
+ )
59
+
60
+ results = LayoutAnalyzerSchema(
61
+ paragraphs=layout_results.paragraphs,
62
+ tables=table_results,
63
+ figures=layout_results.figures,
64
+ )
65
+
66
+ return results, vis
@@ -0,0 +1,189 @@
1
+ from typing import List, Union
2
+
3
+ import cv2
4
+ import torch
5
+ import torchvision.transforms as T
6
+ from PIL import Image
7
+ from pydantic import conlist
8
+
9
+ from .base import BaseModelCatalog, BaseModule, BaseSchema
10
+ from .configs import LayoutParserRTDETRv2Config
11
+ from .models import RTDETRv2
12
+ from .postprocessor import RTDETRPostProcessor
13
+ from .utils.misc import filter_by_flag, is_contained
14
+ from .utils.visualizer import layout_visualizer
15
+
16
+
17
+ class Element(BaseSchema):
18
+ box: conlist(int, min_length=4, max_length=4)
19
+ score: float
20
+ role: Union[str, None]
21
+
22
+
23
+ class LayoutParserSchema(BaseSchema):
24
+ paragraphs: List[Element]
25
+ tables: List[Element]
26
+ figures: List[Element]
27
+
28
+
29
+ class LayoutParserModelCatalog(BaseModelCatalog):
30
+ def __init__(self):
31
+ super().__init__()
32
+ self.register("rtdetrv2", LayoutParserRTDETRv2Config, RTDETRv2)
33
+
34
+
35
+ def filter_contained_rectangles_within_category(category_elements):
36
+ """同一カテゴリに属する矩形のうち、他の矩形の内側に含まれるものを除外"""
37
+
38
+ for category, elements in category_elements.items():
39
+ group_box = [element["box"] for element in elements]
40
+ check_list = [True] * len(group_box)
41
+ for i, box_i in enumerate(group_box):
42
+ for j, box_j in enumerate(group_box):
43
+ if i >= j:
44
+ continue
45
+
46
+ ij = is_contained(box_i, box_j)
47
+ ji = is_contained(box_j, box_i)
48
+
49
+ box_i_area = (box_i[2] - box_i[0]) * (box_i[3] - box_i[1])
50
+ box_j_area = (box_j[2] - box_j[0]) * (box_j[3] - box_j[1])
51
+
52
+ # 双方から見て内包関係にある場合、面積の大きい方を残す
53
+ if ij and ji:
54
+ if box_i_area > box_j_area:
55
+ check_list[j] = False
56
+ else:
57
+ check_list[i] = False
58
+ elif ij:
59
+ check_list[j] = False
60
+ elif ji:
61
+ check_list[i] = False
62
+
63
+ category_elements[category] = filter_by_flag(elements, check_list)
64
+
65
+ return category_elements
66
+
67
+
68
+ def filter_contained_rectangles_across_categories(category_elements, source, target):
69
+ """sourceカテゴリの矩形がtargetカテゴリの矩形に内包される場合、sourceカテゴリの矩形を除外"""
70
+
71
+ src_boxes = [element["box"] for element in category_elements[source]]
72
+ tgt_boxes = [element["box"] for element in category_elements[target]]
73
+
74
+ check_list = [True] * len(tgt_boxes)
75
+ for i, src_box in enumerate(src_boxes):
76
+ for j, tgt_box in enumerate(tgt_boxes):
77
+ if is_contained(src_box, tgt_box):
78
+ check_list[j] = False
79
+
80
+ category_elements[target] = filter_by_flag(category_elements[target], check_list)
81
+ return category_elements
82
+
83
+
84
+ class LayoutParser(BaseModule):
85
+ model_catalog = LayoutParserModelCatalog()
86
+
87
+ def __init__(
88
+ self,
89
+ model_name="rtdetrv2",
90
+ path_cfg=None,
91
+ device="cuda",
92
+ visualize=False,
93
+ from_pretrained=True,
94
+ ):
95
+ super().__init__()
96
+ self.load_model(model_name, path_cfg, from_pretrained)
97
+ self.device = device
98
+ self.visualize = visualize
99
+
100
+ self.model.eval()
101
+ self.model.to(self.device)
102
+
103
+ self.postprocessor = RTDETRPostProcessor(
104
+ num_classes=self._cfg.RTDETRTransformerv2.num_classes,
105
+ num_top_queries=self._cfg.RTDETRTransformerv2.num_queries,
106
+ )
107
+
108
+ self.transforms = T.Compose(
109
+ [
110
+ T.Resize(self._cfg.data.img_size),
111
+ T.ToTensor(),
112
+ ]
113
+ )
114
+
115
+ self.thresh_score = self._cfg.thresh_score
116
+
117
+ self.label_mapper = {
118
+ id: category for id, category in enumerate(self._cfg.category)
119
+ }
120
+
121
+ self.role = self._cfg.role
122
+
123
+ def preprocess(self, img):
124
+ cv_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
125
+ img = Image.fromarray(cv_img)
126
+ img_tensor = self.transforms(img)[None].to(self.device)
127
+ return img_tensor
128
+
129
+ def postprocess(self, preds, image_size):
130
+ h, w = image_size
131
+ orig_size = torch.tensor([w, h])[None].to(self.device)
132
+ outputs = self.postprocessor(preds, orig_size, self.thresh_score)
133
+ outputs = self.filtering_elements(outputs[0])
134
+ results = LayoutParserSchema(**outputs)
135
+ return results
136
+
137
+ def filtering_elements(self, preds):
138
+ scores = preds["scores"]
139
+ boxes = preds["boxes"]
140
+ labels = preds["labels"]
141
+
142
+ category_elements = {
143
+ category: []
144
+ for category in self.label_mapper.values()
145
+ if category not in self.role
146
+ }
147
+
148
+ for box, score, label in zip(boxes, scores, labels):
149
+ category = self.label_mapper[label.item()]
150
+
151
+ role = None
152
+ if category in self.role:
153
+ role = category
154
+ category = "paragraphs"
155
+
156
+ category_elements[category].append(
157
+ {
158
+ "box": box.astype(int).tolist(),
159
+ "score": float(score),
160
+ "role": role,
161
+ }
162
+ )
163
+
164
+ category_elements = filter_contained_rectangles_within_category(
165
+ category_elements
166
+ )
167
+
168
+ category_elements = filter_contained_rectangles_across_categories(
169
+ category_elements, "tables", "paragraphs"
170
+ )
171
+
172
+ return category_elements
173
+
174
+ def __call__(self, img):
175
+ ori_h, ori_w = img.shape[:2]
176
+ img_tensor = self.preprocess(img)
177
+
178
+ with torch.inference_mode():
179
+ preds = self.model(img_tensor)
180
+ results = self.postprocess(preds, (ori_h, ori_w))
181
+
182
+ vis = None
183
+ if self.visualize:
184
+ vis = layout_visualizer(
185
+ results,
186
+ img,
187
+ )
188
+
189
+ return results, vis
@@ -0,0 +1,9 @@
1
+ from .dbnet_plus import DBNet
2
+ from .parseq import PARSeq
3
+ from .rtdetr import RTDETRv2
4
+
5
+ __all__ = [
6
+ "DBNet",
7
+ "PARSeq",
8
+ "RTDETRv2",
9
+ ]