yomitoku 0.4.1__py3-none-any.whl → 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yomitoku/base.py +1 -1
- yomitoku/cli/main.py +219 -27
- yomitoku/configs/__init__.py +2 -0
- yomitoku/configs/cfg_text_detector_dbnet.py +1 -1
- yomitoku/configs/cfg_text_recognizer_parseq_small.py +51 -0
- yomitoku/data/functions.py +48 -23
- yomitoku/document_analyzer.py +243 -41
- yomitoku/export/__init__.py +18 -5
- yomitoku/export/export_csv.py +71 -2
- yomitoku/export/export_html.py +46 -12
- yomitoku/export/export_json.py +66 -3
- yomitoku/export/export_markdown.py +42 -6
- yomitoku/layout_analyzer.py +2 -9
- yomitoku/layout_parser.py +58 -4
- yomitoku/models/dbnet_plus.py +13 -39
- yomitoku/models/layers/activate.py +13 -0
- yomitoku/models/layers/rtdetr_backbone.py +18 -17
- yomitoku/models/layers/rtdetr_hybrid_encoder.py +19 -20
- yomitoku/models/layers/rtdetrv2_decoder.py +14 -1
- yomitoku/models/parseq.py +15 -22
- yomitoku/ocr.py +24 -27
- yomitoku/onnx/.gitkeep +0 -0
- yomitoku/postprocessor/dbnet_postporcessor.py +15 -14
- yomitoku/postprocessor/parseq_tokenizer.py +1 -3
- yomitoku/postprocessor/rtdetr_postprocessor.py +14 -1
- yomitoku/table_structure_recognizer.py +82 -9
- yomitoku/text_detector.py +57 -7
- yomitoku/text_recognizer.py +84 -16
- yomitoku/utils/misc.py +21 -14
- yomitoku/utils/visualizer.py +15 -8
- {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/METADATA +34 -41
- yomitoku-0.7.4.dist-info/RECORD +54 -0
- {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/WHEEL +1 -1
- yomitoku-0.4.1.dist-info/RECORD +0 -52
- {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/entry_points.txt +0 -0
yomitoku/ocr.py
CHANGED
@@ -16,16 +16,37 @@ class WordPrediction(BaseSchema):
|
|
16
16
|
)
|
17
17
|
content: str
|
18
18
|
direction: str
|
19
|
-
det_score: float
|
20
19
|
rec_score: float
|
20
|
+
det_score: float
|
21
21
|
|
22
22
|
|
23
23
|
class OCRSchema(BaseSchema):
|
24
24
|
words: List[WordPrediction]
|
25
25
|
|
26
26
|
|
27
|
+
def ocr_aggregate(det_outputs, rec_outputs):
|
28
|
+
words = []
|
29
|
+
for points, det_score, pred, rec_score, direction in zip(
|
30
|
+
det_outputs.points,
|
31
|
+
det_outputs.scores,
|
32
|
+
rec_outputs.contents,
|
33
|
+
rec_outputs.scores,
|
34
|
+
rec_outputs.directions,
|
35
|
+
):
|
36
|
+
words.append(
|
37
|
+
{
|
38
|
+
"points": points,
|
39
|
+
"content": pred,
|
40
|
+
"direction": direction,
|
41
|
+
"det_score": det_score,
|
42
|
+
"rec_score": rec_score,
|
43
|
+
}
|
44
|
+
)
|
45
|
+
return words
|
46
|
+
|
47
|
+
|
27
48
|
class OCR:
|
28
|
-
def __init__(self, configs=
|
49
|
+
def __init__(self, configs={}, device="cuda", visualize=False):
|
29
50
|
text_detector_kwargs = {
|
30
51
|
"device": device,
|
31
52
|
"visualize": visualize,
|
@@ -36,10 +57,6 @@ class OCR:
|
|
36
57
|
}
|
37
58
|
|
38
59
|
if isinstance(configs, dict):
|
39
|
-
assert (
|
40
|
-
"text_detector" in configs or "text_recognizer" in configs
|
41
|
-
), "Invalid config key. Please check the config keys."
|
42
|
-
|
43
60
|
if "text_detector" in configs:
|
44
61
|
text_detector_kwargs.update(configs["text_detector"])
|
45
62
|
if "text_recognizer" in configs:
|
@@ -52,26 +69,6 @@ class OCR:
|
|
52
69
|
self.detector = TextDetector(**text_detector_kwargs)
|
53
70
|
self.recognizer = TextRecognizer(**text_recognizer_kwargs)
|
54
71
|
|
55
|
-
def aggregate(self, det_outputs, rec_outputs):
|
56
|
-
words = []
|
57
|
-
for points, det_score, pred, rec_score, direction in zip(
|
58
|
-
det_outputs.points,
|
59
|
-
det_outputs.scores,
|
60
|
-
rec_outputs.contents,
|
61
|
-
rec_outputs.scores,
|
62
|
-
rec_outputs.directions,
|
63
|
-
):
|
64
|
-
words.append(
|
65
|
-
{
|
66
|
-
"points": points,
|
67
|
-
"content": pred,
|
68
|
-
"direction": direction,
|
69
|
-
"det_score": det_score,
|
70
|
-
"rec_score": rec_score,
|
71
|
-
}
|
72
|
-
)
|
73
|
-
return words
|
74
|
-
|
75
72
|
def __call__(self, img):
|
76
73
|
"""_summary_
|
77
74
|
|
@@ -82,6 +79,6 @@ class OCR:
|
|
82
79
|
det_outputs, vis = self.detector(img)
|
83
80
|
rec_outputs, vis = self.recognizer(img, det_outputs.points, vis=vis)
|
84
81
|
|
85
|
-
outputs = {"words":
|
82
|
+
outputs = {"words": ocr_aggregate(det_outputs, rec_outputs)}
|
86
83
|
results = OCRSchema(**outputs)
|
87
84
|
return results, vis
|
yomitoku/onnx/.gitkeep
ADDED
File without changes
|
@@ -1,13 +1,12 @@
|
|
1
1
|
import cv2
|
2
|
+
import math
|
2
3
|
import numpy as np
|
3
4
|
import pyclipper
|
4
5
|
from shapely.geometry import Polygon
|
5
6
|
|
6
7
|
|
7
8
|
class DBnetPostProcessor:
|
8
|
-
def __init__(
|
9
|
-
self, min_size, thresh, box_thresh, max_candidates, unclip_ratio
|
10
|
-
):
|
9
|
+
def __init__(self, min_size, thresh, box_thresh, max_candidates, unclip_ratio):
|
11
10
|
self.min_size = min_size
|
12
11
|
self.thresh = thresh
|
13
12
|
self.box_thresh = box_thresh
|
@@ -24,9 +23,7 @@ class DBnetPostProcessor:
|
|
24
23
|
pred = preds["binary"][0]
|
25
24
|
segmentation = self.binarize(pred)[0]
|
26
25
|
height, width = image_size
|
27
|
-
quads, scores = self.boxes_from_bitmap(
|
28
|
-
pred, segmentation, width, height
|
29
|
-
)
|
26
|
+
quads, scores = self.boxes_from_bitmap(pred, segmentation, width, height)
|
30
27
|
return quads, scores
|
31
28
|
|
32
29
|
def binarize(self, pred):
|
@@ -65,9 +62,7 @@ class DBnetPostProcessor:
|
|
65
62
|
if self.box_thresh > score:
|
66
63
|
continue
|
67
64
|
|
68
|
-
box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(
|
69
|
-
-1, 1, 2
|
70
|
-
)
|
65
|
+
box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(-1, 1, 2)
|
71
66
|
box, sside = self.get_mini_boxes(box)
|
72
67
|
if sside < self.min_size + 2:
|
73
68
|
continue
|
@@ -76,9 +71,7 @@ class DBnetPostProcessor:
|
|
76
71
|
dest_width = dest_width.item()
|
77
72
|
dest_height = dest_height.item()
|
78
73
|
|
79
|
-
box[:, 0] = np.clip(
|
80
|
-
np.round(box[:, 0] / width * dest_width), 0, dest_width
|
81
|
-
)
|
74
|
+
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
82
75
|
box[:, 1] = np.clip(
|
83
76
|
np.round(box[:, 1] / height * dest_height), 0, dest_height
|
84
77
|
)
|
@@ -88,9 +81,17 @@ class DBnetPostProcessor:
|
|
88
81
|
|
89
82
|
return boxes, scores
|
90
83
|
|
91
|
-
def unclip(self, box, unclip_ratio=
|
84
|
+
def unclip(self, box, unclip_ratio=7):
|
85
|
+
# 小さい文字が見切れやすい、大きい文字のマージンが過度に大きくなる等の課題がある
|
86
|
+
# 対応として、文字の大きさに応じて、拡大パラメータを動的に変更する
|
87
|
+
# Note: こののルールはヒューリスティックで理論的根拠はない
|
92
88
|
poly = Polygon(box)
|
93
|
-
|
89
|
+
width = box[:, 0].max() - box[:, 0].min()
|
90
|
+
height = box[:, 1].max() - box[:, 1].min()
|
91
|
+
box_dist = min(width, height)
|
92
|
+
ratio = unclip_ratio / math.sqrt(box_dist)
|
93
|
+
|
94
|
+
distance = poly.area * ratio / poly.length
|
94
95
|
offset = pyclipper.PyclipperOffset()
|
95
96
|
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
96
97
|
expanded = np.array(offset.Execute(distance))
|
@@ -122,7 +122,5 @@ class ParseqTokenizer(BaseTokenizer):
|
|
122
122
|
eos_idx = len(ids) # Nothing to truncate.
|
123
123
|
# Truncate after EOS
|
124
124
|
ids = ids[:eos_idx]
|
125
|
-
probs = probs[
|
126
|
-
: eos_idx + 1
|
127
|
-
] # but include prob. for EOS (if it exists)
|
125
|
+
probs = probs[: eos_idx + 1] # but include prob. for EOS (if it exists)
|
128
126
|
return probs, ids
|
@@ -1,4 +1,17 @@
|
|
1
|
-
|
1
|
+
# Copyright 2023 lyuwenyu
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
2
15
|
|
3
16
|
import torch
|
4
17
|
import torch.nn as nn
|
@@ -1,11 +1,16 @@
|
|
1
1
|
from typing import List, Union
|
2
2
|
|
3
3
|
import cv2
|
4
|
+
import os
|
5
|
+
import onnx
|
6
|
+
import onnxruntime
|
4
7
|
import torch
|
5
8
|
import torchvision.transforms as T
|
6
9
|
from PIL import Image
|
7
10
|
from pydantic import conlist
|
8
11
|
|
12
|
+
from .constants import ROOT_DIR
|
13
|
+
|
9
14
|
from .base import BaseModelCatalog, BaseModule, BaseSchema
|
10
15
|
from .configs import TableStructureRecognizerRTDETRv2Config
|
11
16
|
from .layout_parser import filter_contained_rectangles_within_category
|
@@ -30,11 +35,19 @@ class TableCellSchema(BaseSchema):
|
|
30
35
|
contents: Union[str, None]
|
31
36
|
|
32
37
|
|
38
|
+
class TableLineSchema(BaseSchema):
|
39
|
+
box: conlist(int, min_length=4, max_length=4)
|
40
|
+
score: float
|
41
|
+
|
42
|
+
|
33
43
|
class TableStructureRecognizerSchema(BaseSchema):
|
34
44
|
box: conlist(int, min_length=4, max_length=4)
|
35
45
|
n_row: int
|
36
46
|
n_col: int
|
47
|
+
rows: List[TableLineSchema]
|
48
|
+
cols: List[TableLineSchema]
|
37
49
|
cells: List[TableCellSchema]
|
50
|
+
spans: List[TableLineSchema]
|
38
51
|
order: int
|
39
52
|
|
40
53
|
|
@@ -109,12 +122,13 @@ class TableStructureRecognizer(BaseModule):
|
|
109
122
|
device="cuda",
|
110
123
|
visualize=False,
|
111
124
|
from_pretrained=True,
|
125
|
+
infer_onnx=False,
|
112
126
|
):
|
113
127
|
super().__init__()
|
114
128
|
self.load_model(
|
115
129
|
model_name,
|
116
130
|
path_cfg,
|
117
|
-
from_pretrained=
|
131
|
+
from_pretrained=from_pretrained,
|
118
132
|
)
|
119
133
|
self.device = device
|
120
134
|
self.visualize = visualize
|
@@ -140,6 +154,45 @@ class TableStructureRecognizer(BaseModule):
|
|
140
154
|
id: category for id, category in enumerate(self._cfg.category)
|
141
155
|
}
|
142
156
|
|
157
|
+
self.infer_onnx = infer_onnx
|
158
|
+
if infer_onnx:
|
159
|
+
name = self._cfg.hf_hub_repo.split("/")[-1]
|
160
|
+
path_onnx = f"{ROOT_DIR}/onnx/{name}.onnx"
|
161
|
+
if not os.path.exists(path_onnx):
|
162
|
+
self.convert_onnx(path_onnx)
|
163
|
+
|
164
|
+
self.model = None
|
165
|
+
|
166
|
+
model = onnx.load(path_onnx)
|
167
|
+
if torch.cuda.is_available() and device == "cuda":
|
168
|
+
self.sess = onnxruntime.InferenceSession(
|
169
|
+
model.SerializeToString(), providers=["CUDAExecutionProvider"]
|
170
|
+
)
|
171
|
+
else:
|
172
|
+
self.sess = onnxruntime.InferenceSession(model.SerializeToString())
|
173
|
+
|
174
|
+
if self.model is not None:
|
175
|
+
self.model.to(self.device)
|
176
|
+
|
177
|
+
def convert_onnx(self, path_onnx):
|
178
|
+
dynamic_axes = {
|
179
|
+
"input": {0: "batch_size"},
|
180
|
+
"output": {0: "batch_size"},
|
181
|
+
}
|
182
|
+
|
183
|
+
img_size = self._cfg.data.img_size
|
184
|
+
dummy_input = torch.randn(1, 3, *img_size, requires_grad=True)
|
185
|
+
|
186
|
+
torch.onnx.export(
|
187
|
+
self.model,
|
188
|
+
dummy_input,
|
189
|
+
path_onnx,
|
190
|
+
opset_version=16,
|
191
|
+
input_names=["input"],
|
192
|
+
output_names=["pred_logits", "pred_boxes"],
|
193
|
+
dynamic_axes=dynamic_axes,
|
194
|
+
)
|
195
|
+
|
143
196
|
def preprocess(self, img, boxes):
|
144
197
|
cv_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
145
198
|
|
@@ -149,7 +202,7 @@ class TableStructureRecognizer(BaseModule):
|
|
149
202
|
table_img = cv_img[y1:y2, x1:x2, :]
|
150
203
|
th, hw = table_img.shape[:2]
|
151
204
|
table_img = Image.fromarray(table_img)
|
152
|
-
img_tensor = self.transforms(table_img)[None]
|
205
|
+
img_tensor = self.transforms(table_img)[None]
|
153
206
|
table_imgs.append(
|
154
207
|
{
|
155
208
|
"tensor": img_tensor,
|
@@ -190,7 +243,7 @@ class TableStructureRecognizer(BaseModule):
|
|
190
243
|
category_elements
|
191
244
|
)
|
192
245
|
|
193
|
-
cells,
|
246
|
+
cells, rows, cols, spans = self.extract_cell_elements(category_elements)
|
194
247
|
|
195
248
|
table_x, table_y = data["offset"]
|
196
249
|
table_x2 = table_x + data["size"][1]
|
@@ -199,8 +252,11 @@ class TableStructureRecognizer(BaseModule):
|
|
199
252
|
|
200
253
|
table = {
|
201
254
|
"box": table_box,
|
202
|
-
"n_row":
|
203
|
-
"n_col":
|
255
|
+
"n_row": len(rows),
|
256
|
+
"n_col": len(cols),
|
257
|
+
"rows": rows,
|
258
|
+
"cols": cols,
|
259
|
+
"spans": spans,
|
204
260
|
"cells": cells,
|
205
261
|
"order": 0,
|
206
262
|
}
|
@@ -220,16 +276,33 @@ class TableStructureRecognizer(BaseModule):
|
|
220
276
|
cells = extract_cells(row_boxes, col_boxes)
|
221
277
|
cells = filter_contained_cells_within_spancell(cells, span_boxes)
|
222
278
|
|
223
|
-
|
279
|
+
rows = sorted(elements["row"], key=lambda x: x["box"][1])
|
280
|
+
cols = sorted(elements["col"], key=lambda x: x["box"][0])
|
281
|
+
spans = sorted(elements["span"], key=lambda x: x["box"][1])
|
282
|
+
|
283
|
+
return cells, rows, cols, spans
|
224
284
|
|
225
285
|
def __call__(self, img, table_boxes, vis=None):
|
226
286
|
img_tensors = self.preprocess(img, table_boxes)
|
227
287
|
outputs = []
|
228
288
|
for data in img_tensors:
|
229
|
-
|
230
|
-
|
289
|
+
if self.infer_onnx:
|
290
|
+
input = data["tensor"].numpy()
|
291
|
+
results = self.sess.run(None, {"input": input})
|
292
|
+
pred = {
|
293
|
+
"pred_logits": torch.tensor(results[0]).to(self.device),
|
294
|
+
"pred_boxes": torch.tensor(results[1]).to(self.device),
|
295
|
+
}
|
296
|
+
|
297
|
+
else:
|
298
|
+
with torch.inference_mode():
|
299
|
+
data["tensor"] = data["tensor"].to(self.device)
|
300
|
+
pred = self.model(data["tensor"])
|
301
|
+
|
231
302
|
table = self.postprocess(pred, data)
|
232
|
-
|
303
|
+
|
304
|
+
if table.n_row > 0 and table.n_col > 0:
|
305
|
+
outputs.append(table)
|
233
306
|
|
234
307
|
if vis is None and self.visualize:
|
235
308
|
vis = img.copy()
|
yomitoku/text_detector.py
CHANGED
@@ -2,6 +2,7 @@ from typing import List
|
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import torch
|
5
|
+
import os
|
5
6
|
from pydantic import conlist
|
6
7
|
|
7
8
|
from .base import BaseModelCatalog, BaseModule, BaseSchema
|
@@ -14,6 +15,10 @@ from .data.functions import (
|
|
14
15
|
from .models import DBNet
|
15
16
|
from .postprocessor import DBnetPostProcessor
|
16
17
|
from .utils.visualizer import det_visualizer
|
18
|
+
from .constants import ROOT_DIR
|
19
|
+
|
20
|
+
import onnx
|
21
|
+
import onnxruntime
|
17
22
|
|
18
23
|
|
19
24
|
class TextDetectorModelCatalog(BaseModelCatalog):
|
@@ -43,21 +48,60 @@ class TextDetector(BaseModule):
|
|
43
48
|
device="cuda",
|
44
49
|
visualize=False,
|
45
50
|
from_pretrained=True,
|
51
|
+
infer_onnx=False,
|
46
52
|
):
|
47
53
|
super().__init__()
|
48
54
|
self.load_model(
|
49
55
|
model_name,
|
50
56
|
path_cfg,
|
51
|
-
from_pretrained=
|
57
|
+
from_pretrained=from_pretrained,
|
52
58
|
)
|
53
59
|
|
54
60
|
self.device = device
|
55
61
|
self.visualize = visualize
|
56
62
|
|
57
63
|
self.model.eval()
|
58
|
-
self.model.to(self.device)
|
59
|
-
|
60
64
|
self.post_processor = DBnetPostProcessor(**self._cfg.post_process)
|
65
|
+
self.infer_onnx = infer_onnx
|
66
|
+
|
67
|
+
if infer_onnx:
|
68
|
+
name = self._cfg.hf_hub_repo.split("/")[-1]
|
69
|
+
path_onnx = f"{ROOT_DIR}/onnx/{name}.onnx"
|
70
|
+
if not os.path.exists(path_onnx):
|
71
|
+
self.convert_onnx(path_onnx)
|
72
|
+
|
73
|
+
self.model = None
|
74
|
+
|
75
|
+
model = onnx.load(path_onnx)
|
76
|
+
if torch.cuda.is_available() and device == "cuda":
|
77
|
+
self.sess = onnxruntime.InferenceSession(
|
78
|
+
model.SerializeToString(), providers=["CUDAExecutionProvider"]
|
79
|
+
)
|
80
|
+
else:
|
81
|
+
self.sess = onnxruntime.InferenceSession(model.SerializeToString())
|
82
|
+
|
83
|
+
self.model = None
|
84
|
+
|
85
|
+
if self.model is not None:
|
86
|
+
self.model.to(self.device)
|
87
|
+
|
88
|
+
def convert_onnx(self, path_onnx):
|
89
|
+
dynamic_axes = {
|
90
|
+
"input": {0: "batch_size", 2: "height", 3: "width"},
|
91
|
+
"output": {0: "batch_size", 2: "height", 3: "width"},
|
92
|
+
}
|
93
|
+
|
94
|
+
dummy_input = torch.randn(1, 3, 256, 256, requires_grad=True)
|
95
|
+
|
96
|
+
torch.onnx.export(
|
97
|
+
self.model,
|
98
|
+
dummy_input,
|
99
|
+
path_onnx,
|
100
|
+
opset_version=14,
|
101
|
+
input_names=["input"],
|
102
|
+
output_names=["output"],
|
103
|
+
dynamic_axes=dynamic_axes,
|
104
|
+
)
|
61
105
|
|
62
106
|
def preprocess(self, img):
|
63
107
|
img = img.copy()
|
@@ -81,9 +125,15 @@ class TextDetector(BaseModule):
|
|
81
125
|
|
82
126
|
ori_h, ori_w = img.shape[:2]
|
83
127
|
tensor = self.preprocess(img)
|
84
|
-
|
85
|
-
|
86
|
-
|
128
|
+
|
129
|
+
if self.infer_onnx:
|
130
|
+
input = tensor.numpy()
|
131
|
+
results = self.sess.run(["output"], {"input": input})
|
132
|
+
preds = {"binary": torch.tensor(results[0])}
|
133
|
+
else:
|
134
|
+
with torch.inference_mode():
|
135
|
+
tensor = tensor.to(self.device)
|
136
|
+
preds = self.model(tensor)
|
87
137
|
|
88
138
|
quads, scores = self.postprocess(preds, (ori_h, ori_w))
|
89
139
|
outputs = {"points": quads, "scores": scores}
|
@@ -93,9 +143,9 @@ class TextDetector(BaseModule):
|
|
93
143
|
vis = None
|
94
144
|
if self.visualize:
|
95
145
|
vis = det_visualizer(
|
96
|
-
preds,
|
97
146
|
img,
|
98
147
|
quads,
|
148
|
+
preds=preds,
|
99
149
|
vis_heatmap=self._cfg.visualize.heatmap,
|
100
150
|
line_color=tuple(self._cfg.visualize.color[::-1]),
|
101
151
|
)
|
yomitoku/text_recognizer.py
CHANGED
@@ -2,21 +2,28 @@ from typing import List
|
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import torch
|
5
|
+
import os
|
6
|
+
import unicodedata
|
5
7
|
from pydantic import conlist
|
6
8
|
|
7
9
|
from .base import BaseModelCatalog, BaseModule, BaseSchema
|
8
|
-
from .configs import TextRecognizerPARSeqConfig
|
10
|
+
from .configs import TextRecognizerPARSeqConfig, TextRecognizerPARSeqSmallConfig
|
9
11
|
from .data.dataset import ParseqDataset
|
10
12
|
from .models import PARSeq
|
11
13
|
from .postprocessor import ParseqTokenizer as Tokenizer
|
12
14
|
from .utils.misc import load_charset
|
13
15
|
from .utils.visualizer import rec_visualizer
|
14
16
|
|
17
|
+
from .constants import ROOT_DIR
|
18
|
+
import onnx
|
19
|
+
import onnxruntime
|
20
|
+
|
15
21
|
|
16
22
|
class TextRecognizerModelCatalog(BaseModelCatalog):
|
17
23
|
def __init__(self):
|
18
24
|
super().__init__()
|
19
25
|
self.register("parseq", TextRecognizerPARSeqConfig, PARSeq)
|
26
|
+
self.register("parseq-small", TextRecognizerPARSeqSmallConfig, PARSeq)
|
20
27
|
|
21
28
|
|
22
29
|
class TextRecognizerSchema(BaseSchema):
|
@@ -42,36 +49,91 @@ class TextRecognizer(BaseModule):
|
|
42
49
|
device="cuda",
|
43
50
|
visualize=False,
|
44
51
|
from_pretrained=True,
|
52
|
+
infer_onnx=False,
|
45
53
|
):
|
46
54
|
super().__init__()
|
47
55
|
self.load_model(
|
48
56
|
model_name,
|
49
57
|
path_cfg,
|
50
|
-
from_pretrained=
|
58
|
+
from_pretrained=from_pretrained,
|
51
59
|
)
|
52
60
|
self.charset = load_charset(self._cfg.charset)
|
53
61
|
self.tokenizer = Tokenizer(self.charset)
|
54
62
|
|
55
63
|
self.device = device
|
56
64
|
|
65
|
+
self.model.tokenizer = self.tokenizer
|
57
66
|
self.model.eval()
|
58
|
-
self.model.to(self.device)
|
59
67
|
|
60
68
|
self.visualize = visualize
|
61
69
|
|
70
|
+
self.infer_onnx = infer_onnx
|
71
|
+
|
72
|
+
if infer_onnx:
|
73
|
+
name = self._cfg.hf_hub_repo.split("/")[-1]
|
74
|
+
path_onnx = f"{ROOT_DIR}/onnx/{name}.onnx"
|
75
|
+
if not os.path.exists(path_onnx):
|
76
|
+
self.convert_onnx(path_onnx)
|
77
|
+
|
78
|
+
self.model = None
|
79
|
+
|
80
|
+
model = onnx.load(path_onnx)
|
81
|
+
if torch.cuda.is_available() and device == "cuda":
|
82
|
+
self.sess = onnxruntime.InferenceSession(
|
83
|
+
model.SerializeToString(), providers=["CUDAExecutionProvider"]
|
84
|
+
)
|
85
|
+
else:
|
86
|
+
self.sess = onnxruntime.InferenceSession(model.SerializeToString())
|
87
|
+
|
88
|
+
if self.model is not None:
|
89
|
+
self.model.to(self.device)
|
90
|
+
|
62
91
|
def preprocess(self, img, polygons):
|
63
92
|
dataset = ParseqDataset(self._cfg, img, polygons)
|
64
|
-
dataloader =
|
65
|
-
dataset,
|
66
|
-
batch_size=self._cfg.data.batch_size,
|
67
|
-
shuffle=False,
|
68
|
-
num_workers=self._cfg.data.num_workers,
|
69
|
-
)
|
93
|
+
dataloader = self._make_mini_batch(dataset)
|
70
94
|
|
71
95
|
return dataloader
|
72
96
|
|
97
|
+
def _make_mini_batch(self, dataset):
|
98
|
+
mini_batches = []
|
99
|
+
mini_batch = []
|
100
|
+
for data in dataset:
|
101
|
+
data = torch.unsqueeze(data, 0)
|
102
|
+
mini_batch.append(data)
|
103
|
+
|
104
|
+
if len(mini_batch) == self._cfg.data.batch_size:
|
105
|
+
mini_batches.append(torch.cat(mini_batch, 0))
|
106
|
+
mini_batch = []
|
107
|
+
else:
|
108
|
+
if len(mini_batch) > 0:
|
109
|
+
mini_batches.append(torch.cat(mini_batch, 0))
|
110
|
+
|
111
|
+
return mini_batches
|
112
|
+
|
113
|
+
def convert_onnx(self, path_onnx):
|
114
|
+
img_size = self._cfg.data.img_size
|
115
|
+
input = torch.randn(1, 3, *img_size, requires_grad=True)
|
116
|
+
dynamic_axes = {
|
117
|
+
"input": {0: "batch_size"},
|
118
|
+
"output": {0: "batch_size"},
|
119
|
+
}
|
120
|
+
|
121
|
+
self.model.export_onnx = True
|
122
|
+
torch.onnx.export(
|
123
|
+
self.model,
|
124
|
+
input,
|
125
|
+
path_onnx,
|
126
|
+
opset_version=14,
|
127
|
+
input_names=["input"],
|
128
|
+
output_names=["output"],
|
129
|
+
do_constant_folding=True,
|
130
|
+
dynamic_axes=dynamic_axes,
|
131
|
+
)
|
132
|
+
|
73
133
|
def postprocess(self, p, points):
|
74
134
|
pred, score = self.tokenizer.decode(p)
|
135
|
+
pred = [unicodedata.normalize("NFKC", x) for x in pred]
|
136
|
+
|
75
137
|
directions = []
|
76
138
|
for point in points:
|
77
139
|
point = np.array(point)
|
@@ -98,13 +160,19 @@ class TextRecognizer(BaseModule):
|
|
98
160
|
scores = []
|
99
161
|
directions = []
|
100
162
|
for data in dataloader:
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
163
|
+
if self.infer_onnx:
|
164
|
+
input = data.numpy()
|
165
|
+
results = self.sess.run(["output"], {"input": input})
|
166
|
+
p = torch.tensor(results[0])
|
167
|
+
else:
|
168
|
+
with torch.inference_mode():
|
169
|
+
data = data.to(self.device)
|
170
|
+
p = self.model(data).softmax(-1)
|
171
|
+
|
172
|
+
pred, score, direction = self.postprocess(p, points)
|
173
|
+
preds.extend(pred)
|
174
|
+
scores.extend(score)
|
175
|
+
directions.extend(direction)
|
108
176
|
|
109
177
|
outputs = {
|
110
178
|
"contents": preds,
|
yomitoku/utils/misc.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
def load_charset(charset_path):
|
2
|
-
with open(charset_path, "r") as f:
|
2
|
+
with open(charset_path, "r", encoding="utf-8") as f:
|
3
3
|
charset = f.read()
|
4
4
|
return charset
|
5
5
|
|
@@ -9,6 +9,24 @@ def filter_by_flag(elements, flags):
|
|
9
9
|
return [element for element, flag in zip(elements, flags) if flag]
|
10
10
|
|
11
11
|
|
12
|
+
def calc_overlap_ratio(rect_a, rect_b):
|
13
|
+
intersection = calc_intersection(rect_a, rect_b)
|
14
|
+
if intersection is None:
|
15
|
+
return 0, None
|
16
|
+
|
17
|
+
ix1, iy1, ix2, iy2 = intersection
|
18
|
+
|
19
|
+
overlap_width = ix2 - ix1
|
20
|
+
overlap_height = iy2 - iy1
|
21
|
+
bx1, by1, bx2, by2 = rect_b
|
22
|
+
|
23
|
+
b_area = (bx2 - bx1) * (by2 - by1)
|
24
|
+
overlap_area = overlap_width * overlap_height
|
25
|
+
|
26
|
+
overlap_ratio = overlap_area / b_area
|
27
|
+
return overlap_ratio, intersection
|
28
|
+
|
29
|
+
|
12
30
|
def is_contained(rect_a, rect_b, threshold=0.8):
|
13
31
|
"""二つの矩形A, Bが与えられたとき、矩形Bが矩形Aに含まれるかどうかを判定する。
|
14
32
|
ずれを許容するため、重複率求め、thresholdを超える場合にTrueを返す。
|
@@ -23,20 +41,9 @@ def is_contained(rect_a, rect_b, threshold=0.8):
|
|
23
41
|
bool: 矩形Bが矩形Aに含まれる場合True
|
24
42
|
"""
|
25
43
|
|
26
|
-
|
27
|
-
if intersection is None:
|
28
|
-
return False
|
29
|
-
|
30
|
-
ix1, iy1, ix2, iy2 = intersection
|
31
|
-
|
32
|
-
overlap_width = ix2 - ix1
|
33
|
-
overlap_height = iy2 - iy1
|
34
|
-
bx1, by1, bx2, by2 = rect_b
|
35
|
-
|
36
|
-
b_area = (bx2 - bx1) * (by2 - by1)
|
37
|
-
overlap_area = overlap_width * overlap_height
|
44
|
+
overlap_ratio, _ = calc_overlap_ratio(rect_a, rect_b)
|
38
45
|
|
39
|
-
if
|
46
|
+
if overlap_ratio > threshold:
|
40
47
|
return True
|
41
48
|
|
42
49
|
return False
|