yomitoku 0.4.0.post1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yomitoku/__init__.py +20 -0
- yomitoku/base.py +136 -0
- yomitoku/cli/__init__.py +0 -0
- yomitoku/cli/main.py +230 -0
- yomitoku/configs/__init__.py +13 -0
- yomitoku/configs/cfg_layout_parser_rtdtrv2.py +89 -0
- yomitoku/configs/cfg_table_structure_recognizer_rtdtrv2.py +80 -0
- yomitoku/configs/cfg_text_detector_dbnet.py +49 -0
- yomitoku/configs/cfg_text_recognizer_parseq.py +51 -0
- yomitoku/constants.py +32 -0
- yomitoku/data/__init__.py +3 -0
- yomitoku/data/dataset.py +40 -0
- yomitoku/data/functions.py +279 -0
- yomitoku/document_analyzer.py +315 -0
- yomitoku/export/__init__.py +6 -0
- yomitoku/export/export_csv.py +71 -0
- yomitoku/export/export_html.py +188 -0
- yomitoku/export/export_json.py +34 -0
- yomitoku/export/export_markdown.py +145 -0
- yomitoku/layout_analyzer.py +66 -0
- yomitoku/layout_parser.py +189 -0
- yomitoku/models/__init__.py +9 -0
- yomitoku/models/dbnet_plus.py +272 -0
- yomitoku/models/layers/__init__.py +0 -0
- yomitoku/models/layers/activate.py +38 -0
- yomitoku/models/layers/dbnet_feature_attention.py +160 -0
- yomitoku/models/layers/parseq_transformer.py +218 -0
- yomitoku/models/layers/rtdetr_backbone.py +333 -0
- yomitoku/models/layers/rtdetr_hybrid_encoder.py +433 -0
- yomitoku/models/layers/rtdetrv2_decoder.py +811 -0
- yomitoku/models/parseq.py +243 -0
- yomitoku/models/rtdetr.py +22 -0
- yomitoku/ocr.py +87 -0
- yomitoku/postprocessor/__init__.py +9 -0
- yomitoku/postprocessor/dbnet_postporcessor.py +137 -0
- yomitoku/postprocessor/parseq_tokenizer.py +128 -0
- yomitoku/postprocessor/rtdetr_postprocessor.py +107 -0
- yomitoku/reading_order.py +214 -0
- yomitoku/resource/MPLUS1p-Medium.ttf +0 -0
- yomitoku/resource/charset.txt +1 -0
- yomitoku/table_structure_recognizer.py +244 -0
- yomitoku/text_detector.py +103 -0
- yomitoku/text_recognizer.py +128 -0
- yomitoku/utils/__init__.py +0 -0
- yomitoku/utils/graph.py +20 -0
- yomitoku/utils/logger.py +15 -0
- yomitoku/utils/misc.py +102 -0
- yomitoku/utils/visualizer.py +179 -0
- yomitoku-0.4.0.post1.dev0.dist-info/METADATA +127 -0
- yomitoku-0.4.0.post1.dev0.dist-info/RECORD +52 -0
- yomitoku-0.4.0.post1.dev0.dist-info/WHEEL +4 -0
- yomitoku-0.4.0.post1.dev0.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,243 @@
|
|
1
|
+
# Scene Text Recognition Model Hub
|
2
|
+
# Copyright 2022 Darwin Bautista
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from functools import partial
|
17
|
+
from typing import Optional, Sequence
|
18
|
+
|
19
|
+
import torch
|
20
|
+
import torch.nn as nn
|
21
|
+
from huggingface_hub import PyTorchModelHubMixin
|
22
|
+
from timm.models.helpers import named_apply
|
23
|
+
from torch import Tensor
|
24
|
+
|
25
|
+
from ..postprocessor import ParseqTokenizer as Tokenizer
|
26
|
+
from .layers.parseq_transformer import Decoder, Encoder, TokenEmbedding
|
27
|
+
|
28
|
+
|
29
|
+
def init_weights(
|
30
|
+
module: nn.Module, name: str = "", exclude: Sequence[str] = ()
|
31
|
+
):
|
32
|
+
"""Initialize the weights using the typical initialization schemes used in SOTA models."""
|
33
|
+
if any(map(name.startswith, exclude)):
|
34
|
+
return
|
35
|
+
if isinstance(module, nn.Linear):
|
36
|
+
nn.init.trunc_normal_(module.weight, std=0.02)
|
37
|
+
if module.bias is not None:
|
38
|
+
nn.init.zeros_(module.bias)
|
39
|
+
elif isinstance(module, nn.Embedding):
|
40
|
+
nn.init.trunc_normal_(module.weight, std=0.02)
|
41
|
+
if module.padding_idx is not None:
|
42
|
+
module.weight.data[module.padding_idx].zero_()
|
43
|
+
elif isinstance(module, nn.Conv2d):
|
44
|
+
nn.init.kaiming_normal_(
|
45
|
+
module.weight, mode="fan_out", nonlinearity="relu"
|
46
|
+
)
|
47
|
+
if module.bias is not None:
|
48
|
+
nn.init.zeros_(module.bias)
|
49
|
+
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
|
50
|
+
nn.init.ones_(module.weight)
|
51
|
+
nn.init.zeros_(module.bias)
|
52
|
+
|
53
|
+
|
54
|
+
class PARSeq(nn.Module, PyTorchModelHubMixin):
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
cfg,
|
58
|
+
) -> None:
|
59
|
+
super().__init__()
|
60
|
+
self.cfg = cfg
|
61
|
+
self.max_label_length = self.cfg.max_label_length
|
62
|
+
self.decode_ar = self.cfg.decode_ar
|
63
|
+
self.refine_iters = self.cfg.refine_iters
|
64
|
+
embed_dim = self.cfg.decoder.embed_dim
|
65
|
+
|
66
|
+
self.encoder = Encoder(
|
67
|
+
self.cfg.data.img_size,
|
68
|
+
**self.cfg.encoder,
|
69
|
+
)
|
70
|
+
|
71
|
+
self.decoder = Decoder(
|
72
|
+
norm=nn.LayerNorm(self.cfg.decoder.embed_dim),
|
73
|
+
cfg=self.cfg.decoder,
|
74
|
+
)
|
75
|
+
|
76
|
+
# We don't predict <bos> nor <pad>
|
77
|
+
self.head = nn.Linear(embed_dim, self.cfg.num_tokens - 2)
|
78
|
+
self.text_embed = TokenEmbedding(self.cfg.num_tokens, embed_dim)
|
79
|
+
|
80
|
+
# +1 for <eos>
|
81
|
+
self.pos_queries = nn.Parameter(
|
82
|
+
torch.Tensor(1, self.max_label_length + 1, embed_dim)
|
83
|
+
)
|
84
|
+
self.dropout = nn.Dropout()
|
85
|
+
# Encoder has its own init.
|
86
|
+
named_apply(partial(init_weights, exclude=["encoder"]), self)
|
87
|
+
nn.init.trunc_normal_(self.pos_queries, std=0.02)
|
88
|
+
|
89
|
+
@property
|
90
|
+
def _device(self) -> torch.device:
|
91
|
+
return next(self.head.parameters(recurse=False)).device
|
92
|
+
|
93
|
+
@torch.jit.ignore
|
94
|
+
def no_weight_decay(self):
|
95
|
+
param_names = {"text_embed.embedding.weight", "pos_queries"}
|
96
|
+
enc_param_names = {
|
97
|
+
"encoder." + n for n in self.encoder.no_weight_decay()
|
98
|
+
}
|
99
|
+
return param_names.union(enc_param_names)
|
100
|
+
|
101
|
+
def encode(self, img: torch.Tensor):
|
102
|
+
return self.encoder(img)
|
103
|
+
|
104
|
+
def decode(
|
105
|
+
self,
|
106
|
+
tgt: torch.Tensor,
|
107
|
+
memory: torch.Tensor,
|
108
|
+
tgt_mask: Optional[Tensor] = None,
|
109
|
+
tgt_padding_mask: Optional[Tensor] = None,
|
110
|
+
tgt_query: Optional[Tensor] = None,
|
111
|
+
tgt_query_mask: Optional[Tensor] = None,
|
112
|
+
):
|
113
|
+
N, L = tgt.shape
|
114
|
+
# <bos> stands for the null context. We only supply position information for characters after <bos>.
|
115
|
+
null_ctx = self.text_embed(tgt[:, :1])
|
116
|
+
tgt_emb = self.pos_queries[:, : L - 1] + self.text_embed(tgt[:, 1:])
|
117
|
+
tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1))
|
118
|
+
if tgt_query is None:
|
119
|
+
tgt_query = self.pos_queries[:, :L].expand(N, -1, -1)
|
120
|
+
tgt_query = self.dropout(tgt_query)
|
121
|
+
return self.decoder(
|
122
|
+
tgt_query,
|
123
|
+
tgt_emb,
|
124
|
+
memory,
|
125
|
+
tgt_query_mask,
|
126
|
+
tgt_mask,
|
127
|
+
tgt_padding_mask,
|
128
|
+
)
|
129
|
+
|
130
|
+
def forward(
|
131
|
+
self,
|
132
|
+
tokenizer: Tokenizer,
|
133
|
+
images: Tensor,
|
134
|
+
max_length: Optional[int] = None,
|
135
|
+
) -> Tensor:
|
136
|
+
testing = max_length is None
|
137
|
+
max_length = (
|
138
|
+
self.max_label_length
|
139
|
+
if max_length is None
|
140
|
+
else min(max_length, self.max_label_length)
|
141
|
+
)
|
142
|
+
bs = images.shape[0]
|
143
|
+
# +1 for <eos> at end of sequence.
|
144
|
+
num_steps = max_length + 1
|
145
|
+
memory = self.encode(images)
|
146
|
+
|
147
|
+
# Query positions up to `num_steps`
|
148
|
+
pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1)
|
149
|
+
|
150
|
+
# Special case for the forward permutation. Faster than using `generate_attn_masks()`
|
151
|
+
tgt_mask = query_mask = torch.triu(
|
152
|
+
torch.ones(
|
153
|
+
(num_steps, num_steps), dtype=torch.bool, device=self._device
|
154
|
+
),
|
155
|
+
1,
|
156
|
+
)
|
157
|
+
|
158
|
+
if self.decode_ar:
|
159
|
+
tgt_in = torch.full(
|
160
|
+
(bs, num_steps),
|
161
|
+
tokenizer.pad_id,
|
162
|
+
dtype=torch.long,
|
163
|
+
device=self._device,
|
164
|
+
)
|
165
|
+
tgt_in[:, 0] = tokenizer.bos_id
|
166
|
+
|
167
|
+
logits = []
|
168
|
+
for i in range(num_steps):
|
169
|
+
j = i + 1 # next token index
|
170
|
+
# Efficient decoding:
|
171
|
+
# Input the context up to the ith token. We use only one query (at poad masking effect of the canonical (forward) AR context.
|
172
|
+
# Past tokens have no access to future tokens, hence are fixed once computed.sition = i) at a time.
|
173
|
+
# This works because of the lookahe
|
174
|
+
tgt_out = self.decode(
|
175
|
+
tgt_in[:, :j],
|
176
|
+
memory,
|
177
|
+
tgt_mask[:j, :j],
|
178
|
+
tgt_query=pos_queries[:, i:j],
|
179
|
+
tgt_query_mask=query_mask[i:j, :j],
|
180
|
+
)
|
181
|
+
# the next token probability is in the output's ith token position
|
182
|
+
p_i = self.head(tgt_out)
|
183
|
+
logits.append(p_i)
|
184
|
+
if j < num_steps:
|
185
|
+
# greedy decode. add the next token index to the target input
|
186
|
+
tgt_in[:, j] = p_i.squeeze().argmax(-1)
|
187
|
+
# Efficient batch decoding: If all output words have at least one EOS token, end decoding.
|
188
|
+
if (
|
189
|
+
testing
|
190
|
+
and (tgt_in == tokenizer.eos_id).any(dim=-1).all()
|
191
|
+
):
|
192
|
+
break
|
193
|
+
|
194
|
+
logits = torch.cat(logits, dim=1)
|
195
|
+
else:
|
196
|
+
# No prior context, so input is just <bos>. We query all positions.
|
197
|
+
tgt_in = torch.full(
|
198
|
+
(bs, 1),
|
199
|
+
tokenizer.bos_id,
|
200
|
+
dtype=torch.long,
|
201
|
+
device=self._device,
|
202
|
+
)
|
203
|
+
tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
|
204
|
+
logits = self.head(tgt_out)
|
205
|
+
|
206
|
+
if self.refine_iters:
|
207
|
+
# For iterative refinement, we always use a 'cloze' mask.
|
208
|
+
# We can derive it from the AR forward mask by unmasking the token context to the right.
|
209
|
+
query_mask[
|
210
|
+
torch.triu(
|
211
|
+
torch.ones(
|
212
|
+
num_steps,
|
213
|
+
num_steps,
|
214
|
+
dtype=torch.bool,
|
215
|
+
device=self._device,
|
216
|
+
),
|
217
|
+
2,
|
218
|
+
)
|
219
|
+
] = 0
|
220
|
+
bos = torch.full(
|
221
|
+
(bs, 1),
|
222
|
+
tokenizer.bos_id,
|
223
|
+
dtype=torch.long,
|
224
|
+
device=self._device,
|
225
|
+
)
|
226
|
+
for i in range(self.refine_iters):
|
227
|
+
# Prior context is the previous output.
|
228
|
+
tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
|
229
|
+
# Mask tokens beyond the first EOS token.
|
230
|
+
tgt_padding_mask = (tgt_in == tokenizer.eos_id).int().cumsum(
|
231
|
+
-1
|
232
|
+
) > 0
|
233
|
+
tgt_out = self.decode(
|
234
|
+
tgt_in,
|
235
|
+
memory,
|
236
|
+
tgt_mask,
|
237
|
+
tgt_padding_mask,
|
238
|
+
pos_queries,
|
239
|
+
query_mask[:, : tgt_in.shape[1]],
|
240
|
+
)
|
241
|
+
logits = self.head(tgt_out)
|
242
|
+
|
243
|
+
return logits
|
@@ -0,0 +1,22 @@
|
|
1
|
+
import torch.nn as nn
|
2
|
+
from huggingface_hub import PyTorchModelHubMixin
|
3
|
+
|
4
|
+
from .layers.rtdetr_backbone import PResNet
|
5
|
+
from .layers.rtdetr_hybrid_encoder import HybridEncoder
|
6
|
+
from .layers.rtdetrv2_decoder import RTDETRTransformerv2
|
7
|
+
|
8
|
+
|
9
|
+
class RTDETRv2(nn.Module, PyTorchModelHubMixin):
|
10
|
+
def __init__(self, cfg):
|
11
|
+
super().__init__()
|
12
|
+
self.cfg = cfg
|
13
|
+
self.backbone = PResNet(**cfg.PResNet)
|
14
|
+
self.encoder = HybridEncoder(**cfg.HybridEncoder)
|
15
|
+
self.decoder = RTDETRTransformerv2(**cfg.RTDETRTransformerv2)
|
16
|
+
|
17
|
+
def forward(self, x, targets=None):
|
18
|
+
x = self.backbone(x)
|
19
|
+
x = self.encoder(x)
|
20
|
+
x = self.decoder(x, targets)
|
21
|
+
|
22
|
+
return x
|
yomitoku/ocr.py
ADDED
@@ -0,0 +1,87 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from pydantic import conlist
|
4
|
+
|
5
|
+
from yomitoku.text_detector import TextDetector
|
6
|
+
from yomitoku.text_recognizer import TextRecognizer
|
7
|
+
|
8
|
+
from .base import BaseSchema
|
9
|
+
|
10
|
+
|
11
|
+
class WordPrediction(BaseSchema):
|
12
|
+
points: conlist(
|
13
|
+
conlist(int, min_length=2, max_length=2),
|
14
|
+
min_length=4,
|
15
|
+
max_length=4,
|
16
|
+
)
|
17
|
+
content: str
|
18
|
+
direction: str
|
19
|
+
det_score: float
|
20
|
+
rec_score: float
|
21
|
+
|
22
|
+
|
23
|
+
class OCRSchema(BaseSchema):
|
24
|
+
words: List[WordPrediction]
|
25
|
+
|
26
|
+
|
27
|
+
class OCR:
|
28
|
+
def __init__(self, configs=None, device="cuda", visualize=False):
|
29
|
+
text_detector_kwargs = {
|
30
|
+
"device": device,
|
31
|
+
"visualize": visualize,
|
32
|
+
}
|
33
|
+
text_recognizer_kwargs = {
|
34
|
+
"device": device,
|
35
|
+
"visualize": visualize,
|
36
|
+
}
|
37
|
+
|
38
|
+
if isinstance(configs, dict):
|
39
|
+
assert (
|
40
|
+
"text_detector" in configs or "text_recognizer" in configs
|
41
|
+
), "Invalid config key. Please check the config keys."
|
42
|
+
|
43
|
+
if "text_detector" in configs:
|
44
|
+
text_detector_kwargs.update(configs["text_detector"])
|
45
|
+
if "text_recognizer" in configs:
|
46
|
+
text_recognizer_kwargs.update(configs["text_recognizer"])
|
47
|
+
else:
|
48
|
+
raise ValueError(
|
49
|
+
"configs must be a dict. See the https://kotaro-kinoshita.github.io/yomitoku-dev/usage/"
|
50
|
+
)
|
51
|
+
|
52
|
+
self.detector = TextDetector(**text_detector_kwargs)
|
53
|
+
self.recognizer = TextRecognizer(**text_recognizer_kwargs)
|
54
|
+
|
55
|
+
def aggregate(self, det_outputs, rec_outputs):
|
56
|
+
words = []
|
57
|
+
for points, det_score, pred, rec_score, direction in zip(
|
58
|
+
det_outputs.points,
|
59
|
+
det_outputs.scores,
|
60
|
+
rec_outputs.contents,
|
61
|
+
rec_outputs.scores,
|
62
|
+
rec_outputs.directions,
|
63
|
+
):
|
64
|
+
words.append(
|
65
|
+
{
|
66
|
+
"points": points,
|
67
|
+
"content": pred,
|
68
|
+
"direction": direction,
|
69
|
+
"det_score": det_score,
|
70
|
+
"rec_score": rec_score,
|
71
|
+
}
|
72
|
+
)
|
73
|
+
return words
|
74
|
+
|
75
|
+
def __call__(self, img):
|
76
|
+
"""_summary_
|
77
|
+
|
78
|
+
Args:
|
79
|
+
img (np.ndarray): cv2 image(BGR)
|
80
|
+
"""
|
81
|
+
|
82
|
+
det_outputs, vis = self.detector(img)
|
83
|
+
rec_outputs, vis = self.recognizer(img, det_outputs.points, vis=vis)
|
84
|
+
|
85
|
+
outputs = {"words": self.aggregate(det_outputs, rec_outputs)}
|
86
|
+
results = OCRSchema(**outputs)
|
87
|
+
return results, vis
|
@@ -0,0 +1,137 @@
|
|
1
|
+
import cv2
|
2
|
+
import numpy as np
|
3
|
+
import pyclipper
|
4
|
+
from shapely.geometry import Polygon
|
5
|
+
|
6
|
+
|
7
|
+
class DBnetPostProcessor:
|
8
|
+
def __init__(
|
9
|
+
self, min_size, thresh, box_thresh, max_candidates, unclip_ratio
|
10
|
+
):
|
11
|
+
self.min_size = min_size
|
12
|
+
self.thresh = thresh
|
13
|
+
self.box_thresh = box_thresh
|
14
|
+
self.max_candidates = max_candidates
|
15
|
+
self.unclip_ratio = unclip_ratio
|
16
|
+
|
17
|
+
def __call__(self, preds, image_size):
|
18
|
+
"""
|
19
|
+
pred:
|
20
|
+
binary: text region segmentation map, with shape (N, H, W)
|
21
|
+
thresh: [if exists] thresh hold prediction with shape (N, H, W)
|
22
|
+
thresh_binary: [if exists] binarized with threshhold, (N, H, W)
|
23
|
+
"""
|
24
|
+
pred = preds["binary"][0]
|
25
|
+
segmentation = self.binarize(pred)[0]
|
26
|
+
height, width = image_size
|
27
|
+
quads, scores = self.boxes_from_bitmap(
|
28
|
+
pred, segmentation, width, height
|
29
|
+
)
|
30
|
+
return quads, scores
|
31
|
+
|
32
|
+
def binarize(self, pred):
|
33
|
+
return pred > self.thresh
|
34
|
+
|
35
|
+
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
36
|
+
"""
|
37
|
+
_bitmap: single map with shape (H, W),
|
38
|
+
whose values are binarized as {0, 1}
|
39
|
+
"""
|
40
|
+
|
41
|
+
assert len(_bitmap.shape) == 2
|
42
|
+
bitmap = _bitmap.cpu().numpy() # The first channel
|
43
|
+
|
44
|
+
pred = pred.cpu().detach().numpy()[0]
|
45
|
+
height, width = bitmap.shape
|
46
|
+
contours, _ = cv2.findContours(
|
47
|
+
(bitmap * 255).astype(np.uint8),
|
48
|
+
cv2.RETR_LIST,
|
49
|
+
cv2.CHAIN_APPROX_SIMPLE,
|
50
|
+
)
|
51
|
+
|
52
|
+
num_contours = min(len(contours), self.max_candidates)
|
53
|
+
|
54
|
+
boxes = []
|
55
|
+
scores = []
|
56
|
+
for index in range(num_contours):
|
57
|
+
contour = contours[index].squeeze(1)
|
58
|
+
points, sside = self.get_mini_boxes(contour)
|
59
|
+
|
60
|
+
if sside < self.min_size:
|
61
|
+
continue
|
62
|
+
points = np.array(points)
|
63
|
+
score = self.box_score_fast(pred, contour)
|
64
|
+
|
65
|
+
if self.box_thresh > score:
|
66
|
+
continue
|
67
|
+
|
68
|
+
box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(
|
69
|
+
-1, 1, 2
|
70
|
+
)
|
71
|
+
box, sside = self.get_mini_boxes(box)
|
72
|
+
if sside < self.min_size + 2:
|
73
|
+
continue
|
74
|
+
box = np.array(box)
|
75
|
+
if not isinstance(dest_width, int):
|
76
|
+
dest_width = dest_width.item()
|
77
|
+
dest_height = dest_height.item()
|
78
|
+
|
79
|
+
box[:, 0] = np.clip(
|
80
|
+
np.round(box[:, 0] / width * dest_width), 0, dest_width
|
81
|
+
)
|
82
|
+
box[:, 1] = np.clip(
|
83
|
+
np.round(box[:, 1] / height * dest_height), 0, dest_height
|
84
|
+
)
|
85
|
+
|
86
|
+
boxes.append(box.astype(np.int16).tolist())
|
87
|
+
scores.append(score)
|
88
|
+
|
89
|
+
return boxes, scores
|
90
|
+
|
91
|
+
def unclip(self, box, unclip_ratio=1.5):
|
92
|
+
poly = Polygon(box)
|
93
|
+
distance = poly.area * unclip_ratio / poly.length
|
94
|
+
offset = pyclipper.PyclipperOffset()
|
95
|
+
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
96
|
+
expanded = np.array(offset.Execute(distance))
|
97
|
+
return expanded
|
98
|
+
|
99
|
+
def get_mini_boxes(self, contour):
|
100
|
+
bounding_box = cv2.minAreaRect(contour)
|
101
|
+
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
102
|
+
|
103
|
+
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
|
104
|
+
if points[1][1] > points[0][1]:
|
105
|
+
index_1 = 0
|
106
|
+
index_4 = 1
|
107
|
+
else:
|
108
|
+
index_1 = 1
|
109
|
+
index_4 = 0
|
110
|
+
if points[3][1] > points[2][1]:
|
111
|
+
index_2 = 2
|
112
|
+
index_3 = 3
|
113
|
+
else:
|
114
|
+
index_2 = 3
|
115
|
+
index_3 = 2
|
116
|
+
|
117
|
+
box = [
|
118
|
+
points[index_1],
|
119
|
+
points[index_2],
|
120
|
+
points[index_3],
|
121
|
+
points[index_4],
|
122
|
+
]
|
123
|
+
return box, min(bounding_box[1])
|
124
|
+
|
125
|
+
def box_score_fast(self, bitmap, _box):
|
126
|
+
h, w = bitmap.shape[:2]
|
127
|
+
box = _box.copy()
|
128
|
+
xmin = np.clip(np.floor(box[:, 0].min()).astype(int), 0, w - 1)
|
129
|
+
xmax = np.clip(np.ceil(box[:, 0].max()).astype(int), 0, w - 1)
|
130
|
+
ymin = np.clip(np.floor(box[:, 1].min()).astype(int), 0, h - 1)
|
131
|
+
ymax = np.clip(np.ceil(box[:, 1].max()).astype(int), 0, h - 1)
|
132
|
+
|
133
|
+
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
134
|
+
box[:, 0] = box[:, 0] - xmin
|
135
|
+
box[:, 1] = box[:, 1] - ymin
|
136
|
+
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
|
137
|
+
return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
|
@@ -0,0 +1,128 @@
|
|
1
|
+
# Scene Text Recognition Model Hub
|
2
|
+
# Copyright 2022 Darwin Bautista
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from abc import ABC, abstractmethod
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from torch import Tensor
|
21
|
+
from torch.nn.utils.rnn import pad_sequence
|
22
|
+
|
23
|
+
|
24
|
+
class BaseTokenizer(ABC):
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
charset: str,
|
28
|
+
specials_first: tuple = (),
|
29
|
+
specials_last: tuple = (),
|
30
|
+
) -> None:
|
31
|
+
self._itos = specials_first + tuple(charset) + specials_last
|
32
|
+
self._stoi = {s: i for i, s in enumerate(self._itos)}
|
33
|
+
|
34
|
+
def __len__(self):
|
35
|
+
return len(self._itos)
|
36
|
+
|
37
|
+
def _tok2ids(self, tokens: str) -> list[int]:
|
38
|
+
return [self._stoi[s] for s in tokens]
|
39
|
+
|
40
|
+
def _ids2tok(self, token_ids: list[int], join: bool = True) -> str:
|
41
|
+
tokens = [self._itos[i] for i in token_ids]
|
42
|
+
return "".join(tokens) if join else tokens
|
43
|
+
|
44
|
+
@abstractmethod
|
45
|
+
def encode(
|
46
|
+
self, labels: list[str], device: Optional[torch.device] = None
|
47
|
+
) -> Tensor:
|
48
|
+
"""Encode a batch of labels to a representation suitable for the model.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
labels: List of labels. Each can be of arbitrary length.
|
52
|
+
device: Create tensor on this device.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
Batched tensor representation padded to the max label length. Shape: N, L
|
56
|
+
"""
|
57
|
+
raise NotImplementedError
|
58
|
+
|
59
|
+
@abstractmethod
|
60
|
+
def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
|
61
|
+
"""Internal method which performs the necessary filtering prior to decoding."""
|
62
|
+
raise NotImplementedError
|
63
|
+
|
64
|
+
def decode(
|
65
|
+
self, token_dists: Tensor, raw: bool = False
|
66
|
+
) -> tuple[list[str], list[Tensor]]:
|
67
|
+
"""Decode a batch of token distributions.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
token_dists: softmax probabilities over the token distribution. Shape: N, L, C
|
71
|
+
raw: return unprocessed labels (will return list of list of strings)
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
list of string labels (arbitrary length) and
|
75
|
+
their corresponding sequence probabilities as a list of Tensors
|
76
|
+
"""
|
77
|
+
batch_tokens = []
|
78
|
+
batch_probs = []
|
79
|
+
for dist in token_dists:
|
80
|
+
probs, ids = dist.max(-1) # greedy selection
|
81
|
+
if not raw:
|
82
|
+
probs, ids = self._filter(probs, ids)
|
83
|
+
tokens = self._ids2tok(ids, not raw)
|
84
|
+
probs = probs.cpu().numpy()
|
85
|
+
probs = float(probs.prod())
|
86
|
+
batch_tokens.append(tokens)
|
87
|
+
batch_probs.append(probs)
|
88
|
+
return batch_tokens, batch_probs
|
89
|
+
|
90
|
+
|
91
|
+
class ParseqTokenizer(BaseTokenizer):
|
92
|
+
BOS = "[B]"
|
93
|
+
EOS = "[E]"
|
94
|
+
PAD = "[P]"
|
95
|
+
|
96
|
+
def __init__(self, charset: str) -> None:
|
97
|
+
specials_first = (self.EOS,)
|
98
|
+
specials_last = (self.BOS, self.PAD)
|
99
|
+
super().__init__(charset, specials_first, specials_last)
|
100
|
+
self.eos_id, self.bos_id, self.pad_id = [
|
101
|
+
self._stoi[s] for s in specials_first + specials_last
|
102
|
+
]
|
103
|
+
|
104
|
+
def encode(
|
105
|
+
self, labels: list[str], device: Optional[torch.device] = None
|
106
|
+
) -> Tensor:
|
107
|
+
batch = [
|
108
|
+
torch.as_tensor(
|
109
|
+
[self.bos_id] + self._tok2ids(y) + [self.eos_id],
|
110
|
+
dtype=torch.long,
|
111
|
+
device=device,
|
112
|
+
)
|
113
|
+
for y in labels
|
114
|
+
]
|
115
|
+
return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
|
116
|
+
|
117
|
+
def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
|
118
|
+
ids = ids.tolist()
|
119
|
+
try:
|
120
|
+
eos_idx = ids.index(self.eos_id)
|
121
|
+
except ValueError:
|
122
|
+
eos_idx = len(ids) # Nothing to truncate.
|
123
|
+
# Truncate after EOS
|
124
|
+
ids = ids[:eos_idx]
|
125
|
+
probs = probs[
|
126
|
+
: eos_idx + 1
|
127
|
+
] # but include prob. for EOS (if it exists)
|
128
|
+
return probs, ids
|