yomitoku 0.5.3__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yomitoku/cli/main.py +47 -1
- yomitoku/configs/__init__.py +2 -0
- yomitoku/configs/cfg_text_recognizer_parseq_small.py +51 -0
- yomitoku/document_analyzer.py +229 -26
- yomitoku/export/export_csv.py +39 -2
- yomitoku/export/export_html.py +2 -1
- yomitoku/export/export_json.py +40 -2
- yomitoku/export/export_markdown.py +2 -1
- yomitoku/layout_analyzer.py +1 -5
- yomitoku/layout_parser.py +58 -4
- yomitoku/models/layers/rtdetr_backbone.py +5 -15
- yomitoku/models/layers/rtdetr_hybrid_encoder.py +6 -18
- yomitoku/models/layers/rtdetrv2_decoder.py +17 -42
- yomitoku/models/parseq.py +9 -9
- yomitoku/ocr.py +24 -27
- yomitoku/onnx/.gitkeep +0 -0
- yomitoku/postprocessor/rtdetr_postprocessor.py +4 -13
- yomitoku/table_structure_recognizer.py +79 -9
- yomitoku/text_detector.py +57 -7
- yomitoku/text_recognizer.py +80 -16
- yomitoku/utils/misc.py +20 -13
- yomitoku/utils/visualizer.py +5 -5
- {yomitoku-0.5.3.dist-info → yomitoku-0.7.0.dist-info}/METADATA +21 -9
- {yomitoku-0.5.3.dist-info → yomitoku-0.7.0.dist-info}/RECORD +26 -24
- {yomitoku-0.5.3.dist-info → yomitoku-0.7.0.dist-info}/WHEEL +1 -1
- {yomitoku-0.5.3.dist-info → yomitoku-0.7.0.dist-info}/entry_points.txt +0 -0
yomitoku/layout_parser.py
CHANGED
@@ -1,11 +1,16 @@
|
|
1
1
|
from typing import List, Union
|
2
2
|
|
3
3
|
import cv2
|
4
|
+
import os
|
5
|
+
import onnx
|
6
|
+
import onnxruntime
|
4
7
|
import torch
|
5
8
|
import torchvision.transforms as T
|
6
9
|
from PIL import Image
|
7
10
|
from pydantic import conlist
|
8
11
|
|
12
|
+
from .constants import ROOT_DIR
|
13
|
+
|
9
14
|
from .base import BaseModelCatalog, BaseModule, BaseSchema
|
10
15
|
from .configs import LayoutParserRTDETRv2Config
|
11
16
|
from .models import RTDETRv2
|
@@ -91,6 +96,7 @@ class LayoutParser(BaseModule):
|
|
91
96
|
device="cuda",
|
92
97
|
visualize=False,
|
93
98
|
from_pretrained=True,
|
99
|
+
infer_onnx=False,
|
94
100
|
):
|
95
101
|
super().__init__()
|
96
102
|
self.load_model(model_name, path_cfg, from_pretrained)
|
@@ -98,7 +104,6 @@ class LayoutParser(BaseModule):
|
|
98
104
|
self.visualize = visualize
|
99
105
|
|
100
106
|
self.model.eval()
|
101
|
-
self.model.to(self.device)
|
102
107
|
|
103
108
|
self.postprocessor = RTDETRPostProcessor(
|
104
109
|
num_classes=self._cfg.RTDETRTransformerv2.num_classes,
|
@@ -119,11 +124,49 @@ class LayoutParser(BaseModule):
|
|
119
124
|
}
|
120
125
|
|
121
126
|
self.role = self._cfg.role
|
127
|
+
self.infer_onnx = infer_onnx
|
128
|
+
if infer_onnx:
|
129
|
+
name = self._cfg.hf_hub_repo.split("/")[-1]
|
130
|
+
path_onnx = f"{ROOT_DIR}/onnx/{name}.onnx"
|
131
|
+
if not os.path.exists(path_onnx):
|
132
|
+
self.convert_onnx(path_onnx)
|
133
|
+
|
134
|
+
self.model = None
|
135
|
+
|
136
|
+
model = onnx.load(path_onnx)
|
137
|
+
if torch.cuda.is_available() and device == "cuda":
|
138
|
+
self.sess = onnxruntime.InferenceSession(
|
139
|
+
model.SerializeToString(), providers=["CUDAExecutionProvider"]
|
140
|
+
)
|
141
|
+
else:
|
142
|
+
self.sess = onnxruntime.InferenceSession(model.SerializeToString())
|
143
|
+
|
144
|
+
if self.model is not None:
|
145
|
+
self.model.to(self.device)
|
146
|
+
|
147
|
+
def convert_onnx(self, path_onnx):
|
148
|
+
dynamic_axes = {
|
149
|
+
"input": {0: "batch_size"},
|
150
|
+
"output": {0: "batch_size"},
|
151
|
+
}
|
152
|
+
|
153
|
+
img_size = self._cfg.data.img_size
|
154
|
+
dummy_input = torch.randn(1, 3, *img_size, requires_grad=True)
|
155
|
+
|
156
|
+
torch.onnx.export(
|
157
|
+
self.model,
|
158
|
+
dummy_input,
|
159
|
+
path_onnx,
|
160
|
+
opset_version=16,
|
161
|
+
input_names=["input"],
|
162
|
+
output_names=["pred_logits", "pred_boxes"],
|
163
|
+
dynamic_axes=dynamic_axes,
|
164
|
+
)
|
122
165
|
|
123
166
|
def preprocess(self, img):
|
124
167
|
cv_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
125
168
|
img = Image.fromarray(cv_img)
|
126
|
-
img_tensor = self.transforms(img)[None]
|
169
|
+
img_tensor = self.transforms(img)[None]
|
127
170
|
return img_tensor
|
128
171
|
|
129
172
|
def postprocess(self, preds, image_size):
|
@@ -175,8 +218,19 @@ class LayoutParser(BaseModule):
|
|
175
218
|
ori_h, ori_w = img.shape[:2]
|
176
219
|
img_tensor = self.preprocess(img)
|
177
220
|
|
178
|
-
|
179
|
-
|
221
|
+
if self.infer_onnx:
|
222
|
+
input = img_tensor.numpy()
|
223
|
+
results = self.sess.run(None, {"input": input})
|
224
|
+
preds = {
|
225
|
+
"pred_logits": torch.tensor(results[0]).to(self.device),
|
226
|
+
"pred_boxes": torch.tensor(results[1]).to(self.device),
|
227
|
+
}
|
228
|
+
|
229
|
+
else:
|
230
|
+
with torch.inference_mode():
|
231
|
+
img_tensor = img_tensor.to(self.device)
|
232
|
+
preds = self.model(img_tensor)
|
233
|
+
|
180
234
|
results = self.postprocess(preds, (ori_h, ori_w))
|
181
235
|
|
182
236
|
vis = None
|
@@ -59,9 +59,7 @@ class ConvNormLayer(nn.Module):
|
|
59
59
|
class BasicBlock(nn.Module):
|
60
60
|
expansion = 1
|
61
61
|
|
62
|
-
def __init__(
|
63
|
-
self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"
|
64
|
-
):
|
62
|
+
def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"):
|
65
63
|
super().__init__()
|
66
64
|
|
67
65
|
self.shortcut = shortcut
|
@@ -100,9 +98,7 @@ class BasicBlock(nn.Module):
|
|
100
98
|
class BottleNeck(nn.Module):
|
101
99
|
expansion = 4
|
102
100
|
|
103
|
-
def __init__(
|
104
|
-
self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"
|
105
|
-
):
|
101
|
+
def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"):
|
106
102
|
super().__init__()
|
107
103
|
|
108
104
|
if variant == "a":
|
@@ -125,17 +121,13 @@ class BottleNeck(nn.Module):
|
|
125
121
|
("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
|
126
122
|
(
|
127
123
|
"conv",
|
128
|
-
ConvNormLayer(
|
129
|
-
ch_in, ch_out * self.expansion, 1, 1
|
130
|
-
),
|
124
|
+
ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1),
|
131
125
|
),
|
132
126
|
]
|
133
127
|
)
|
134
128
|
)
|
135
129
|
else:
|
136
|
-
self.short = ConvNormLayer(
|
137
|
-
ch_in, ch_out * self.expansion, 1, stride
|
138
|
-
)
|
130
|
+
self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride)
|
139
131
|
|
140
132
|
self.act = nn.Identity() if act is None else get_activation(act)
|
141
133
|
|
@@ -156,9 +148,7 @@ class BottleNeck(nn.Module):
|
|
156
148
|
|
157
149
|
|
158
150
|
class Blocks(nn.Module):
|
159
|
-
def __init__(
|
160
|
-
self, block, ch_in, ch_out, count, stage_num, act="relu", variant="b"
|
161
|
-
):
|
151
|
+
def __init__(self, block, ch_in, ch_out, count, stage_num, act="relu", variant="b"):
|
162
152
|
super().__init__()
|
163
153
|
|
164
154
|
self.blocks = nn.ModuleList()
|
@@ -252,9 +252,7 @@ class HybridEncoder(nn.Module):
|
|
252
252
|
for in_channel in in_channels:
|
253
253
|
if version == "v1":
|
254
254
|
proj = nn.Sequential(
|
255
|
-
nn.Conv2d(
|
256
|
-
in_channel, hidden_dim, kernel_size=1, bias=False
|
257
|
-
),
|
255
|
+
nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False),
|
258
256
|
nn.BatchNorm2d(hidden_dim),
|
259
257
|
)
|
260
258
|
elif version == "v2":
|
@@ -290,9 +288,7 @@ class HybridEncoder(nn.Module):
|
|
290
288
|
|
291
289
|
self.encoder = nn.ModuleList(
|
292
290
|
[
|
293
|
-
TransformerEncoder(
|
294
|
-
copy.deepcopy(encoder_layer), num_encoder_layers
|
295
|
-
)
|
291
|
+
TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers)
|
296
292
|
for _ in range(len(use_encoder_idx))
|
297
293
|
]
|
298
294
|
)
|
@@ -347,9 +343,7 @@ class HybridEncoder(nn.Module):
|
|
347
343
|
# self.register_buffer(f'pos_embed{idx}', pos_embed)
|
348
344
|
|
349
345
|
@staticmethod
|
350
|
-
def build_2d_sincos_position_embedding(
|
351
|
-
w, h, embed_dim=256, temperature=10000.0
|
352
|
-
):
|
346
|
+
def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0):
|
353
347
|
""" """
|
354
348
|
grid_w = torch.arange(int(w), dtype=torch.float32)
|
355
349
|
grid_h = torch.arange(int(h), dtype=torch.float32)
|
@@ -387,9 +381,7 @@ class HybridEncoder(nn.Module):
|
|
387
381
|
src_flatten.device
|
388
382
|
)
|
389
383
|
|
390
|
-
memory: torch.Tensor = self.encoder[i](
|
391
|
-
src_flatten, pos_embed=pos_embed
|
392
|
-
)
|
384
|
+
memory: torch.Tensor = self.encoder[i](src_flatten, pos_embed=pos_embed)
|
393
385
|
proj_feats[enc_ind] = (
|
394
386
|
memory.permute(0, 2, 1)
|
395
387
|
.reshape(-1, self.hidden_dim, h, w)
|
@@ -401,13 +393,9 @@ class HybridEncoder(nn.Module):
|
|
401
393
|
for idx in range(len(self.in_channels) - 1, 0, -1):
|
402
394
|
feat_heigh = inner_outs[0]
|
403
395
|
feat_low = proj_feats[idx - 1]
|
404
|
-
feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](
|
405
|
-
feat_heigh
|
406
|
-
)
|
396
|
+
feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_heigh)
|
407
397
|
inner_outs[0] = feat_heigh
|
408
|
-
upsample_feat = F.interpolate(
|
409
|
-
feat_heigh, scale_factor=2.0, mode="nearest"
|
410
|
-
)
|
398
|
+
upsample_feat = F.interpolate(feat_heigh, scale_factor=2.0, mode="nearest")
|
411
399
|
inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx](
|
412
400
|
torch.concat([upsample_feat, feat_low], dim=1)
|
413
401
|
)
|
@@ -40,9 +40,7 @@ def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
|
|
40
40
|
|
41
41
|
|
42
42
|
class MLP(nn.Module):
|
43
|
-
def __init__(
|
44
|
-
self, input_dim, hidden_dim, output_dim, num_layers, act="relu"
|
45
|
-
):
|
43
|
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act="relu"):
|
46
44
|
super().__init__()
|
47
45
|
self.num_layers = num_layers
|
48
46
|
h = [hidden_dim] * (num_layers - 1)
|
@@ -193,9 +191,7 @@ class MSDeformableAttention(nn.Module):
|
|
193
191
|
elif reference_points.shape[-1] == 4:
|
194
192
|
# reference_points [8, 480, None, 1, 4]
|
195
193
|
# sampling_offsets [8, 480, 8, 12, 2]
|
196
|
-
num_points_scale = self.num_points_scale.to(
|
197
|
-
dtype=query.dtype
|
198
|
-
).unsqueeze(-1)
|
194
|
+
num_points_scale = self.num_points_scale.to(dtype=query.dtype).unsqueeze(-1)
|
199
195
|
offset = (
|
200
196
|
sampling_offsets
|
201
197
|
* num_points_scale
|
@@ -330,9 +326,7 @@ def deformable_attention_core_func_v2(
|
|
330
326
|
_, Len_q, _, _, _ = sampling_locations.shape
|
331
327
|
|
332
328
|
split_shape = [h * w for h, w in value_spatial_shapes]
|
333
|
-
value_list = (
|
334
|
-
value.permute(0, 2, 3, 1).flatten(0, 1).split(split_shape, dim=-1)
|
335
|
-
)
|
329
|
+
value_list = value.permute(0, 2, 3, 1).flatten(0, 1).split(split_shape, dim=-1)
|
336
330
|
|
337
331
|
# sampling_offsets [8, 480, 8, 12, 2]
|
338
332
|
if method == "default":
|
@@ -361,8 +355,7 @@ def deformable_attention_core_func_v2(
|
|
361
355
|
elif method == "discrete":
|
362
356
|
# n * m, seq, n, 2
|
363
357
|
sampling_coord = (
|
364
|
-
sampling_grid_l * torch.tensor([[w, h]], device=value.device)
|
365
|
-
+ 0.5
|
358
|
+
sampling_grid_l * torch.tensor([[w, h]], device=value.device) + 0.5
|
366
359
|
).to(torch.int64)
|
367
360
|
|
368
361
|
# FIX ME? for rectangle input
|
@@ -389,9 +382,7 @@ def deformable_attention_core_func_v2(
|
|
389
382
|
attn_weights = attention_weights.permute(0, 2, 1, 3).reshape(
|
390
383
|
bs * n_head, 1, Len_q, sum(num_points_list)
|
391
384
|
)
|
392
|
-
weighted_sample_locs = (
|
393
|
-
torch.concat(sampling_value_list, dim=-1) * attn_weights
|
394
|
-
)
|
385
|
+
weighted_sample_locs = torch.concat(sampling_value_list, dim=-1) * attn_weights
|
395
386
|
output = weighted_sample_locs.sum(-1).reshape(bs, n_head * c, Len_q)
|
396
387
|
|
397
388
|
return output.permute(0, 2, 1)
|
@@ -606,9 +597,7 @@ class RTDETRTransformerv2(nn.Module):
|
|
606
597
|
[
|
607
598
|
(
|
608
599
|
"conv",
|
609
|
-
nn.Conv2d(
|
610
|
-
in_channels, self.hidden_dim, 1, bias=False
|
611
|
-
),
|
600
|
+
nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False),
|
612
601
|
),
|
613
602
|
(
|
614
603
|
"norm",
|
@@ -689,13 +678,9 @@ class RTDETRTransformerv2(nn.Module):
|
|
689
678
|
torch.arange(h), torch.arange(w), indexing="ij"
|
690
679
|
)
|
691
680
|
grid_xy = torch.stack([grid_x, grid_y], dim=-1)
|
692
|
-
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / torch.tensor(
|
693
|
-
[w, h], dtype=dtype
|
694
|
-
)
|
681
|
+
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / torch.tensor([w, h], dtype=dtype)
|
695
682
|
wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl)
|
696
|
-
lvl_anchors = torch.concat([grid_xy, wh], dim=-1).reshape(
|
697
|
-
-1, h * w, 4
|
698
|
-
)
|
683
|
+
lvl_anchors = torch.concat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4)
|
699
684
|
anchors.append(lvl_anchors)
|
700
685
|
|
701
686
|
anchors = torch.concat(anchors, dim=1).to(device)
|
@@ -729,22 +714,18 @@ class RTDETRTransformerv2(nn.Module):
|
|
729
714
|
)
|
730
715
|
|
731
716
|
enc_topk_bboxes_list, enc_topk_logits_list = [], []
|
732
|
-
enc_topk_memory, enc_topk_logits, enc_topk_bbox_unact = (
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
self.num_queries,
|
738
|
-
)
|
717
|
+
enc_topk_memory, enc_topk_logits, enc_topk_bbox_unact = self._select_topk(
|
718
|
+
output_memory,
|
719
|
+
enc_outputs_logits,
|
720
|
+
enc_outputs_coord_unact,
|
721
|
+
self.num_queries,
|
739
722
|
)
|
740
723
|
|
741
724
|
# if self.num_select_queries != self.num_queries:
|
742
725
|
# raise NotImplementedError('')
|
743
726
|
|
744
727
|
if self.learn_query_content:
|
745
|
-
content = self.tgt_embed.weight.unsqueeze(0).tile(
|
746
|
-
[memory.shape[0], 1, 1]
|
747
|
-
)
|
728
|
+
content = self.tgt_embed.weight.unsqueeze(0).tile([memory.shape[0], 1, 1])
|
748
729
|
else:
|
749
730
|
content = enc_topk_memory.detach()
|
750
731
|
|
@@ -771,9 +752,7 @@ class RTDETRTransformerv2(nn.Module):
|
|
771
752
|
topk: int,
|
772
753
|
):
|
773
754
|
if self.query_select_method == "default":
|
774
|
-
_, topk_ind = torch.topk(
|
775
|
-
outputs_logits.max(-1).values, topk, dim=-1
|
776
|
-
)
|
755
|
+
_, topk_ind = torch.topk(outputs_logits.max(-1).values, topk, dim=-1)
|
777
756
|
|
778
757
|
elif self.query_select_method == "one2many":
|
779
758
|
_, topk_ind = torch.topk(outputs_logits.flatten(1), topk, dim=-1)
|
@@ -786,16 +765,12 @@ class RTDETRTransformerv2(nn.Module):
|
|
786
765
|
|
787
766
|
topk_coords = outputs_coords_unact.gather(
|
788
767
|
dim=1,
|
789
|
-
index=topk_ind.unsqueeze(-1).repeat(
|
790
|
-
1, 1, outputs_coords_unact.shape[-1]
|
791
|
-
),
|
768
|
+
index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_coords_unact.shape[-1]),
|
792
769
|
)
|
793
770
|
|
794
771
|
topk_logits = outputs_logits.gather(
|
795
772
|
dim=1,
|
796
|
-
index=topk_ind.unsqueeze(-1).repeat(
|
797
|
-
1, 1, outputs_logits.shape[-1]
|
798
|
-
),
|
773
|
+
index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_logits.shape[-1]),
|
799
774
|
)
|
800
775
|
|
801
776
|
topk_memory = memory.gather(
|
yomitoku/models/parseq.py
CHANGED
@@ -22,7 +22,6 @@ from huggingface_hub import PyTorchModelHubMixin
|
|
22
22
|
from timm.models.helpers import named_apply
|
23
23
|
from torch import Tensor
|
24
24
|
|
25
|
-
from ..postprocessor import ParseqTokenizer as Tokenizer
|
26
25
|
from .layers.parseq_transformer import Decoder, Encoder, TokenEmbedding
|
27
26
|
|
28
27
|
|
@@ -123,7 +122,6 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
|
|
123
122
|
|
124
123
|
def forward(
|
125
124
|
self,
|
126
|
-
tokenizer: Tokenizer,
|
127
125
|
images: Tensor,
|
128
126
|
max_length: Optional[int] = None,
|
129
127
|
) -> Tensor:
|
@@ -150,11 +148,11 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
|
|
150
148
|
if self.decode_ar:
|
151
149
|
tgt_in = torch.full(
|
152
150
|
(bs, num_steps),
|
153
|
-
tokenizer.pad_id,
|
151
|
+
self.tokenizer.pad_id,
|
154
152
|
dtype=torch.long,
|
155
153
|
device=self._device,
|
156
154
|
)
|
157
|
-
tgt_in[:, 0] = tokenizer.bos_id
|
155
|
+
tgt_in[:, 0] = self.tokenizer.bos_id
|
158
156
|
|
159
157
|
logits = []
|
160
158
|
for i in range(num_steps):
|
@@ -177,7 +175,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
|
|
177
175
|
# greedy decode. add the next token index to the target input
|
178
176
|
tgt_in[:, j] = p_i.squeeze().argmax(-1)
|
179
177
|
# Efficient batch decoding: If all output words have at least one EOS token, end decoding.
|
180
|
-
if testing and (tgt_in == tokenizer.eos_id).any(dim=-1).all():
|
178
|
+
if testing and (tgt_in == self.tokenizer.eos_id).any(dim=-1).all():
|
181
179
|
break
|
182
180
|
|
183
181
|
logits = torch.cat(logits, dim=1)
|
@@ -185,7 +183,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
|
|
185
183
|
# No prior context, so input is just <bos>. We query all positions.
|
186
184
|
tgt_in = torch.full(
|
187
185
|
(bs, 1),
|
188
|
-
tokenizer.bos_id,
|
186
|
+
self.tokenizer.bos_id,
|
189
187
|
dtype=torch.long,
|
190
188
|
device=self._device,
|
191
189
|
)
|
@@ -200,7 +198,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
|
|
200
198
|
torch.ones(
|
201
199
|
num_steps,
|
202
200
|
num_steps,
|
203
|
-
dtype=torch.
|
201
|
+
dtype=torch.int64,
|
204
202
|
device=self._device,
|
205
203
|
),
|
206
204
|
2,
|
@@ -208,7 +206,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
|
|
208
206
|
] = 0
|
209
207
|
bos = torch.full(
|
210
208
|
(bs, 1),
|
211
|
-
tokenizer.bos_id,
|
209
|
+
self.tokenizer.bos_id,
|
212
210
|
dtype=torch.long,
|
213
211
|
device=self._device,
|
214
212
|
)
|
@@ -216,7 +214,9 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
|
|
216
214
|
# Prior context is the previous output.
|
217
215
|
tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
|
218
216
|
# Mask tokens beyond the first EOS token.
|
219
|
-
tgt_padding_mask = (tgt_in == tokenizer.eos_id).int().cumsum(
|
217
|
+
tgt_padding_mask = (tgt_in == self.tokenizer.eos_id).int().cumsum(
|
218
|
+
-1
|
219
|
+
) > 0
|
220
220
|
tgt_out = self.decode(
|
221
221
|
tgt_in,
|
222
222
|
memory,
|
yomitoku/ocr.py
CHANGED
@@ -16,16 +16,37 @@ class WordPrediction(BaseSchema):
|
|
16
16
|
)
|
17
17
|
content: str
|
18
18
|
direction: str
|
19
|
-
det_score: float
|
20
19
|
rec_score: float
|
20
|
+
det_score: float
|
21
21
|
|
22
22
|
|
23
23
|
class OCRSchema(BaseSchema):
|
24
24
|
words: List[WordPrediction]
|
25
25
|
|
26
26
|
|
27
|
+
def ocr_aggregate(det_outputs, rec_outputs):
|
28
|
+
words = []
|
29
|
+
for points, det_score, pred, rec_score, direction in zip(
|
30
|
+
det_outputs.points,
|
31
|
+
det_outputs.scores,
|
32
|
+
rec_outputs.contents,
|
33
|
+
rec_outputs.scores,
|
34
|
+
rec_outputs.directions,
|
35
|
+
):
|
36
|
+
words.append(
|
37
|
+
{
|
38
|
+
"points": points,
|
39
|
+
"content": pred,
|
40
|
+
"direction": direction,
|
41
|
+
"det_score": det_score,
|
42
|
+
"rec_score": rec_score,
|
43
|
+
}
|
44
|
+
)
|
45
|
+
return words
|
46
|
+
|
47
|
+
|
27
48
|
class OCR:
|
28
|
-
def __init__(self, configs=
|
49
|
+
def __init__(self, configs={}, device="cuda", visualize=False):
|
29
50
|
text_detector_kwargs = {
|
30
51
|
"device": device,
|
31
52
|
"visualize": visualize,
|
@@ -36,10 +57,6 @@ class OCR:
|
|
36
57
|
}
|
37
58
|
|
38
59
|
if isinstance(configs, dict):
|
39
|
-
assert (
|
40
|
-
"text_detector" in configs or "text_recognizer" in configs
|
41
|
-
), "Invalid config key. Please check the config keys."
|
42
|
-
|
43
60
|
if "text_detector" in configs:
|
44
61
|
text_detector_kwargs.update(configs["text_detector"])
|
45
62
|
if "text_recognizer" in configs:
|
@@ -52,26 +69,6 @@ class OCR:
|
|
52
69
|
self.detector = TextDetector(**text_detector_kwargs)
|
53
70
|
self.recognizer = TextRecognizer(**text_recognizer_kwargs)
|
54
71
|
|
55
|
-
def aggregate(self, det_outputs, rec_outputs):
|
56
|
-
words = []
|
57
|
-
for points, det_score, pred, rec_score, direction in zip(
|
58
|
-
det_outputs.points,
|
59
|
-
det_outputs.scores,
|
60
|
-
rec_outputs.contents,
|
61
|
-
rec_outputs.scores,
|
62
|
-
rec_outputs.directions,
|
63
|
-
):
|
64
|
-
words.append(
|
65
|
-
{
|
66
|
-
"points": points,
|
67
|
-
"content": pred,
|
68
|
-
"direction": direction,
|
69
|
-
"det_score": det_score,
|
70
|
-
"rec_score": rec_score,
|
71
|
-
}
|
72
|
-
)
|
73
|
-
return words
|
74
|
-
|
75
72
|
def __call__(self, img):
|
76
73
|
"""_summary_
|
77
74
|
|
@@ -82,6 +79,6 @@ class OCR:
|
|
82
79
|
det_outputs, vis = self.detector(img)
|
83
80
|
rec_outputs, vis = self.recognizer(img, det_outputs.points, vis=vis)
|
84
81
|
|
85
|
-
outputs = {"words":
|
82
|
+
outputs = {"words": ocr_aggregate(det_outputs, rec_outputs)}
|
86
83
|
results = OCRSchema(**outputs)
|
87
84
|
return results, vis
|
yomitoku/onnx/.gitkeep
ADDED
File without changes
|
@@ -54,16 +54,12 @@ class RTDETRPostProcessor(nn.Module):
|
|
54
54
|
logits, boxes = outputs["pred_logits"], outputs["pred_boxes"]
|
55
55
|
# orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
|
56
56
|
|
57
|
-
bbox_pred = torchvision.ops.box_convert(
|
58
|
-
boxes, in_fmt="cxcywh", out_fmt="xyxy"
|
59
|
-
)
|
57
|
+
bbox_pred = torchvision.ops.box_convert(boxes, in_fmt="cxcywh", out_fmt="xyxy")
|
60
58
|
bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1)
|
61
59
|
|
62
60
|
if self.use_focal_loss:
|
63
61
|
scores = F.sigmoid(logits)
|
64
|
-
scores, index = torch.topk(
|
65
|
-
scores.flatten(1), self.num_top_queries, dim=-1
|
66
|
-
)
|
62
|
+
scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1)
|
67
63
|
# TODO for older tensorrt
|
68
64
|
# labels = index % self.num_classes
|
69
65
|
labels = mod(index, self.num_classes)
|
@@ -77,9 +73,7 @@ class RTDETRPostProcessor(nn.Module):
|
|
77
73
|
scores = F.softmax(logits)[:, :, :-1]
|
78
74
|
scores, labels = scores.max(dim=-1)
|
79
75
|
if scores.shape[1] > self.num_top_queries:
|
80
|
-
scores, index = torch.topk(
|
81
|
-
scores, self.num_top_queries, dim=-1
|
82
|
-
)
|
76
|
+
scores, index = torch.topk(scores, self.num_top_queries, dim=-1)
|
83
77
|
labels = torch.gather(labels, dim=1, index=index)
|
84
78
|
boxes = torch.gather(
|
85
79
|
boxes,
|
@@ -97,10 +91,7 @@ class RTDETRPostProcessor(nn.Module):
|
|
97
91
|
|
98
92
|
labels = (
|
99
93
|
torch.tensor(
|
100
|
-
[
|
101
|
-
mscoco_label2category[int(x.item())]
|
102
|
-
for x in labels.flatten()
|
103
|
-
]
|
94
|
+
[mscoco_label2category[int(x.item())] for x in labels.flatten()]
|
104
95
|
)
|
105
96
|
.to(boxes.device)
|
106
97
|
.reshape(labels.shape)
|