yomitoku 0.4.1__py3-none-any.whl → 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yomitoku/base.py +1 -1
- yomitoku/cli/main.py +219 -27
- yomitoku/configs/__init__.py +2 -0
- yomitoku/configs/cfg_text_detector_dbnet.py +1 -1
- yomitoku/configs/cfg_text_recognizer_parseq_small.py +51 -0
- yomitoku/data/functions.py +48 -23
- yomitoku/document_analyzer.py +243 -41
- yomitoku/export/__init__.py +18 -5
- yomitoku/export/export_csv.py +71 -2
- yomitoku/export/export_html.py +46 -12
- yomitoku/export/export_json.py +66 -3
- yomitoku/export/export_markdown.py +42 -6
- yomitoku/layout_analyzer.py +2 -9
- yomitoku/layout_parser.py +58 -4
- yomitoku/models/dbnet_plus.py +13 -39
- yomitoku/models/layers/activate.py +13 -0
- yomitoku/models/layers/rtdetr_backbone.py +18 -17
- yomitoku/models/layers/rtdetr_hybrid_encoder.py +19 -20
- yomitoku/models/layers/rtdetrv2_decoder.py +14 -1
- yomitoku/models/parseq.py +15 -22
- yomitoku/ocr.py +24 -27
- yomitoku/onnx/.gitkeep +0 -0
- yomitoku/postprocessor/dbnet_postporcessor.py +15 -14
- yomitoku/postprocessor/parseq_tokenizer.py +1 -3
- yomitoku/postprocessor/rtdetr_postprocessor.py +14 -1
- yomitoku/table_structure_recognizer.py +82 -9
- yomitoku/text_detector.py +57 -7
- yomitoku/text_recognizer.py +84 -16
- yomitoku/utils/misc.py +21 -14
- yomitoku/utils/visualizer.py +15 -8
- {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/METADATA +34 -41
- yomitoku-0.7.4.dist-info/RECORD +54 -0
- {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/WHEEL +1 -1
- yomitoku-0.4.1.dist-info/RECORD +0 -52
- {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/entry_points.txt +0 -0
@@ -1,10 +1,11 @@
|
|
1
|
+
import os
|
1
2
|
import re
|
3
|
+
|
2
4
|
import cv2
|
3
|
-
import os
|
4
5
|
|
5
6
|
|
6
7
|
def escape_markdown_special_chars(text):
|
7
|
-
special_chars = r"([`*
|
8
|
+
special_chars = r"([`*{}[\]()#+!~|-])"
|
8
9
|
return re.sub(special_chars, r"\\\1", text)
|
9
10
|
|
10
11
|
|
@@ -75,6 +76,8 @@ def figure_to_md(
|
|
75
76
|
width=200,
|
76
77
|
figure_dir="figures",
|
77
78
|
):
|
79
|
+
assert img is not None, "img is required for saving figures"
|
80
|
+
|
78
81
|
elements = []
|
79
82
|
for i, figure in enumerate(figures):
|
80
83
|
x1, y1, x2, y2 = map(int, figure.box)
|
@@ -108,11 +111,11 @@ def figure_to_md(
|
|
108
111
|
return elements
|
109
112
|
|
110
113
|
|
111
|
-
def
|
114
|
+
def convert_markdown(
|
112
115
|
inputs,
|
113
|
-
out_path
|
116
|
+
out_path,
|
117
|
+
ignore_line_break=False,
|
114
118
|
img=None,
|
115
|
-
ignore_line_break: bool = False,
|
116
119
|
export_figure_letter=False,
|
117
120
|
export_figure=True,
|
118
121
|
figure_width=200,
|
@@ -140,6 +143,39 @@ def export_markdown(
|
|
140
143
|
|
141
144
|
elements = sorted(elements, key=lambda x: x["order"])
|
142
145
|
markdown = "\n".join([element["md"] for element in elements])
|
146
|
+
return markdown, elements
|
143
147
|
|
144
|
-
|
148
|
+
|
149
|
+
def export_markdown(
|
150
|
+
inputs,
|
151
|
+
out_path: str,
|
152
|
+
ignore_line_break: bool = False,
|
153
|
+
img=None,
|
154
|
+
export_figure_letter=False,
|
155
|
+
export_figure=True,
|
156
|
+
figure_width=200,
|
157
|
+
figure_dir="figures",
|
158
|
+
encoding: str = "utf-8",
|
159
|
+
):
|
160
|
+
markdown, elements = convert_markdown(
|
161
|
+
inputs,
|
162
|
+
out_path,
|
163
|
+
ignore_line_break,
|
164
|
+
img,
|
165
|
+
export_figure_letter,
|
166
|
+
export_figure,
|
167
|
+
figure_width,
|
168
|
+
figure_dir,
|
169
|
+
)
|
170
|
+
|
171
|
+
save_markdown(markdown, out_path, encoding)
|
172
|
+
return markdown
|
173
|
+
|
174
|
+
|
175
|
+
def save_markdown(
|
176
|
+
markdown,
|
177
|
+
out_path,
|
178
|
+
encoding,
|
179
|
+
):
|
180
|
+
with open(out_path, "w", encoding=encoding, errors="ignore") as f:
|
145
181
|
f.write(markdown)
|
yomitoku/layout_analyzer.py
CHANGED
@@ -15,7 +15,7 @@ class LayoutAnalyzerSchema(BaseSchema):
|
|
15
15
|
|
16
16
|
|
17
17
|
class LayoutAnalyzer:
|
18
|
-
def __init__(self, configs=
|
18
|
+
def __init__(self, configs={}, device="cuda", visualize=False):
|
19
19
|
layout_parser_kwargs = {
|
20
20
|
"device": device,
|
21
21
|
"visualize": visualize,
|
@@ -26,11 +26,6 @@ class LayoutAnalyzer:
|
|
26
26
|
}
|
27
27
|
|
28
28
|
if isinstance(configs, dict):
|
29
|
-
assert (
|
30
|
-
"layout_parser" in configs
|
31
|
-
or "table_structure_recognizer" in configs
|
32
|
-
), "Invalid config key. Please check the config keys."
|
33
|
-
|
34
29
|
if "layout_parser" in configs:
|
35
30
|
layout_parser_kwargs.update(configs["layout_parser"])
|
36
31
|
|
@@ -53,9 +48,7 @@ class LayoutAnalyzer:
|
|
53
48
|
def __call__(self, img):
|
54
49
|
layout_results, vis = self.layout_parser(img)
|
55
50
|
table_boxes = [table.box for table in layout_results.tables]
|
56
|
-
table_results, vis = self.table_structure_recognizer(
|
57
|
-
img, table_boxes, vis=vis
|
58
|
-
)
|
51
|
+
table_results, vis = self.table_structure_recognizer(img, table_boxes, vis=vis)
|
59
52
|
|
60
53
|
results = LayoutAnalyzerSchema(
|
61
54
|
paragraphs=layout_results.paragraphs,
|
yomitoku/layout_parser.py
CHANGED
@@ -1,11 +1,16 @@
|
|
1
1
|
from typing import List, Union
|
2
2
|
|
3
3
|
import cv2
|
4
|
+
import os
|
5
|
+
import onnx
|
6
|
+
import onnxruntime
|
4
7
|
import torch
|
5
8
|
import torchvision.transforms as T
|
6
9
|
from PIL import Image
|
7
10
|
from pydantic import conlist
|
8
11
|
|
12
|
+
from .constants import ROOT_DIR
|
13
|
+
|
9
14
|
from .base import BaseModelCatalog, BaseModule, BaseSchema
|
10
15
|
from .configs import LayoutParserRTDETRv2Config
|
11
16
|
from .models import RTDETRv2
|
@@ -91,6 +96,7 @@ class LayoutParser(BaseModule):
|
|
91
96
|
device="cuda",
|
92
97
|
visualize=False,
|
93
98
|
from_pretrained=True,
|
99
|
+
infer_onnx=False,
|
94
100
|
):
|
95
101
|
super().__init__()
|
96
102
|
self.load_model(model_name, path_cfg, from_pretrained)
|
@@ -98,7 +104,6 @@ class LayoutParser(BaseModule):
|
|
98
104
|
self.visualize = visualize
|
99
105
|
|
100
106
|
self.model.eval()
|
101
|
-
self.model.to(self.device)
|
102
107
|
|
103
108
|
self.postprocessor = RTDETRPostProcessor(
|
104
109
|
num_classes=self._cfg.RTDETRTransformerv2.num_classes,
|
@@ -119,11 +124,49 @@ class LayoutParser(BaseModule):
|
|
119
124
|
}
|
120
125
|
|
121
126
|
self.role = self._cfg.role
|
127
|
+
self.infer_onnx = infer_onnx
|
128
|
+
if infer_onnx:
|
129
|
+
name = self._cfg.hf_hub_repo.split("/")[-1]
|
130
|
+
path_onnx = f"{ROOT_DIR}/onnx/{name}.onnx"
|
131
|
+
if not os.path.exists(path_onnx):
|
132
|
+
self.convert_onnx(path_onnx)
|
133
|
+
|
134
|
+
self.model = None
|
135
|
+
|
136
|
+
model = onnx.load(path_onnx)
|
137
|
+
if torch.cuda.is_available() and device == "cuda":
|
138
|
+
self.sess = onnxruntime.InferenceSession(
|
139
|
+
model.SerializeToString(), providers=["CUDAExecutionProvider"]
|
140
|
+
)
|
141
|
+
else:
|
142
|
+
self.sess = onnxruntime.InferenceSession(model.SerializeToString())
|
143
|
+
|
144
|
+
if self.model is not None:
|
145
|
+
self.model.to(self.device)
|
146
|
+
|
147
|
+
def convert_onnx(self, path_onnx):
|
148
|
+
dynamic_axes = {
|
149
|
+
"input": {0: "batch_size"},
|
150
|
+
"output": {0: "batch_size"},
|
151
|
+
}
|
152
|
+
|
153
|
+
img_size = self._cfg.data.img_size
|
154
|
+
dummy_input = torch.randn(1, 3, *img_size, requires_grad=True)
|
155
|
+
|
156
|
+
torch.onnx.export(
|
157
|
+
self.model,
|
158
|
+
dummy_input,
|
159
|
+
path_onnx,
|
160
|
+
opset_version=16,
|
161
|
+
input_names=["input"],
|
162
|
+
output_names=["pred_logits", "pred_boxes"],
|
163
|
+
dynamic_axes=dynamic_axes,
|
164
|
+
)
|
122
165
|
|
123
166
|
def preprocess(self, img):
|
124
167
|
cv_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
125
168
|
img = Image.fromarray(cv_img)
|
126
|
-
img_tensor = self.transforms(img)[None]
|
169
|
+
img_tensor = self.transforms(img)[None]
|
127
170
|
return img_tensor
|
128
171
|
|
129
172
|
def postprocess(self, preds, image_size):
|
@@ -175,8 +218,19 @@ class LayoutParser(BaseModule):
|
|
175
218
|
ori_h, ori_w = img.shape[:2]
|
176
219
|
img_tensor = self.preprocess(img)
|
177
220
|
|
178
|
-
|
179
|
-
|
221
|
+
if self.infer_onnx:
|
222
|
+
input = img_tensor.numpy()
|
223
|
+
results = self.sess.run(None, {"input": input})
|
224
|
+
preds = {
|
225
|
+
"pred_logits": torch.tensor(results[0]).to(self.device),
|
226
|
+
"pred_boxes": torch.tensor(results[1]).to(self.device),
|
227
|
+
}
|
228
|
+
|
229
|
+
else:
|
230
|
+
with torch.inference_mode():
|
231
|
+
img_tensor = img_tensor.to(self.device)
|
232
|
+
preds = self.model(img_tensor)
|
233
|
+
|
180
234
|
results = self.postprocess(preds, (ori_h, ori_w))
|
181
235
|
|
182
236
|
vis = None
|
yomitoku/models/dbnet_plus.py
CHANGED
@@ -20,9 +20,7 @@ class BackboneBase(nn.Module):
|
|
20
20
|
"layer4": "layer4",
|
21
21
|
}
|
22
22
|
|
23
|
-
self.body = IntermediateLayerGetter(
|
24
|
-
backbone, return_layers=return_layers
|
25
|
-
)
|
23
|
+
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
26
24
|
|
27
25
|
def forward(self, tensor):
|
28
26
|
xs = self.body(tensor)
|
@@ -57,18 +55,10 @@ class DBNetDecoder(nn.Module):
|
|
57
55
|
self.training = True
|
58
56
|
self.input_proj = nn.ModuleDict(
|
59
57
|
{
|
60
|
-
"layer1": nn.Conv2d(
|
61
|
-
|
62
|
-
),
|
63
|
-
"
|
64
|
-
in_channels[1], self.d_model, 1, bias=False
|
65
|
-
),
|
66
|
-
"layer3": nn.Conv2d(
|
67
|
-
in_channels[2], self.d_model, 1, bias=False
|
68
|
-
),
|
69
|
-
"layer4": nn.Conv2d(
|
70
|
-
in_channels[3], self.d_model, 1, bias=False
|
71
|
-
),
|
58
|
+
"layer1": nn.Conv2d(in_channels[0], self.d_model, 1, bias=False),
|
59
|
+
"layer2": nn.Conv2d(in_channels[1], self.d_model, 1, bias=False),
|
60
|
+
"layer3": nn.Conv2d(in_channels[2], self.d_model, 1, bias=False),
|
61
|
+
"layer4": nn.Conv2d(in_channels[3], self.d_model, 1, bias=False),
|
72
62
|
}
|
73
63
|
)
|
74
64
|
|
@@ -89,9 +79,7 @@ class DBNetDecoder(nn.Module):
|
|
89
79
|
padding=1,
|
90
80
|
bias=False,
|
91
81
|
),
|
92
|
-
nn.Upsample(
|
93
|
-
scale_factor=2, mode="bilinear", align_corners=False
|
94
|
-
),
|
82
|
+
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
95
83
|
),
|
96
84
|
"layer3": nn.Sequential(
|
97
85
|
nn.Conv2d(
|
@@ -101,9 +89,7 @@ class DBNetDecoder(nn.Module):
|
|
101
89
|
padding=1,
|
102
90
|
bias=False,
|
103
91
|
),
|
104
|
-
nn.Upsample(
|
105
|
-
scale_factor=4, mode="bilinear", align_corners=False
|
106
|
-
),
|
92
|
+
nn.Upsample(scale_factor=4, mode="bilinear", align_corners=False),
|
107
93
|
),
|
108
94
|
"layer4": nn.Sequential(
|
109
95
|
nn.Conv2d(
|
@@ -113,17 +99,13 @@ class DBNetDecoder(nn.Module):
|
|
113
99
|
padding=1,
|
114
100
|
bias=False,
|
115
101
|
),
|
116
|
-
nn.Upsample(
|
117
|
-
scale_factor=4, mode="bilinear", align_corners=False
|
118
|
-
),
|
102
|
+
nn.Upsample(scale_factor=4, mode="bilinear", align_corners=False),
|
119
103
|
),
|
120
104
|
}
|
121
105
|
)
|
122
106
|
|
123
107
|
self.binarize = nn.Sequential(
|
124
|
-
nn.Conv2d(
|
125
|
-
self.d_model, self.d_model // 4, 3, padding=1, bias=False
|
126
|
-
),
|
108
|
+
nn.Conv2d(self.d_model, self.d_model // 4, 3, padding=1, bias=False),
|
127
109
|
nn.BatchNorm2d(self.d_model // 4),
|
128
110
|
nn.ReLU(inplace=True),
|
129
111
|
nn.ConvTranspose2d(self.d_model // 4, self.d_model // 4, 2, 2),
|
@@ -166,16 +148,12 @@ class DBNetDecoder(nn.Module):
|
|
166
148
|
m.weight.data.fill_(1.0)
|
167
149
|
m.bias.data.fill_(1e-4)
|
168
150
|
|
169
|
-
def _init_thresh(
|
170
|
-
self, inner_channels, serial=False, smooth=False, bias=False
|
171
|
-
):
|
151
|
+
def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
|
172
152
|
in_channels = inner_channels
|
173
153
|
if serial:
|
174
154
|
in_channels += 1
|
175
155
|
self.thresh = nn.Sequential(
|
176
|
-
nn.Conv2d(
|
177
|
-
in_channels, inner_channels // 4, 3, padding=1, bias=bias
|
178
|
-
),
|
156
|
+
nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),
|
179
157
|
nn.BatchNorm2d(inner_channels // 4),
|
180
158
|
nn.ReLU(inplace=True),
|
181
159
|
self._init_upsample(
|
@@ -186,16 +164,12 @@ class DBNetDecoder(nn.Module):
|
|
186
164
|
),
|
187
165
|
nn.BatchNorm2d(inner_channels // 4),
|
188
166
|
nn.ReLU(inplace=True),
|
189
|
-
self._init_upsample(
|
190
|
-
inner_channels // 4, 1, smooth=smooth, bias=bias
|
191
|
-
),
|
167
|
+
self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
|
192
168
|
nn.Sigmoid(),
|
193
169
|
)
|
194
170
|
return self.thresh
|
195
171
|
|
196
|
-
def _init_upsample(
|
197
|
-
self, in_channels, out_channels, smooth=False, bias=False
|
198
|
-
):
|
172
|
+
def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
|
199
173
|
if smooth:
|
200
174
|
inter_out_channels = out_channels
|
201
175
|
if out_channels == 1:
|
@@ -1,3 +1,16 @@
|
|
1
|
+
# Copyright(c) 2023 lyuwenyu
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
1
14
|
import torch.nn as nn
|
2
15
|
|
3
16
|
|
@@ -1,5 +1,16 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
# Copyright 2023 lyuwenyu
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
3
14
|
|
4
15
|
from collections import OrderedDict
|
5
16
|
|
@@ -48,9 +59,7 @@ class ConvNormLayer(nn.Module):
|
|
48
59
|
class BasicBlock(nn.Module):
|
49
60
|
expansion = 1
|
50
61
|
|
51
|
-
def __init__(
|
52
|
-
self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"
|
53
|
-
):
|
62
|
+
def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"):
|
54
63
|
super().__init__()
|
55
64
|
|
56
65
|
self.shortcut = shortcut
|
@@ -89,9 +98,7 @@ class BasicBlock(nn.Module):
|
|
89
98
|
class BottleNeck(nn.Module):
|
90
99
|
expansion = 4
|
91
100
|
|
92
|
-
def __init__(
|
93
|
-
self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"
|
94
|
-
):
|
101
|
+
def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"):
|
95
102
|
super().__init__()
|
96
103
|
|
97
104
|
if variant == "a":
|
@@ -114,17 +121,13 @@ class BottleNeck(nn.Module):
|
|
114
121
|
("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
|
115
122
|
(
|
116
123
|
"conv",
|
117
|
-
ConvNormLayer(
|
118
|
-
ch_in, ch_out * self.expansion, 1, 1
|
119
|
-
),
|
124
|
+
ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1),
|
120
125
|
),
|
121
126
|
]
|
122
127
|
)
|
123
128
|
)
|
124
129
|
else:
|
125
|
-
self.short = ConvNormLayer(
|
126
|
-
ch_in, ch_out * self.expansion, 1, stride
|
127
|
-
)
|
130
|
+
self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride)
|
128
131
|
|
129
132
|
self.act = nn.Identity() if act is None else get_activation(act)
|
130
133
|
|
@@ -145,9 +148,7 @@ class BottleNeck(nn.Module):
|
|
145
148
|
|
146
149
|
|
147
150
|
class Blocks(nn.Module):
|
148
|
-
def __init__(
|
149
|
-
self, block, ch_in, ch_out, count, stage_num, act="relu", variant="b"
|
150
|
-
):
|
151
|
+
def __init__(self, block, ch_in, ch_out, count, stage_num, act="relu", variant="b"):
|
151
152
|
super().__init__()
|
152
153
|
|
153
154
|
self.blocks = nn.ModuleList()
|
@@ -1,5 +1,16 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
# Copyright 2023 lyuwenyu
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
3
14
|
|
4
15
|
import copy
|
5
16
|
from collections import OrderedDict
|
@@ -241,9 +252,7 @@ class HybridEncoder(nn.Module):
|
|
241
252
|
for in_channel in in_channels:
|
242
253
|
if version == "v1":
|
243
254
|
proj = nn.Sequential(
|
244
|
-
nn.Conv2d(
|
245
|
-
in_channel, hidden_dim, kernel_size=1, bias=False
|
246
|
-
),
|
255
|
+
nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False),
|
247
256
|
nn.BatchNorm2d(hidden_dim),
|
248
257
|
)
|
249
258
|
elif version == "v2":
|
@@ -279,9 +288,7 @@ class HybridEncoder(nn.Module):
|
|
279
288
|
|
280
289
|
self.encoder = nn.ModuleList(
|
281
290
|
[
|
282
|
-
TransformerEncoder(
|
283
|
-
copy.deepcopy(encoder_layer), num_encoder_layers
|
284
|
-
)
|
291
|
+
TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers)
|
285
292
|
for _ in range(len(use_encoder_idx))
|
286
293
|
]
|
287
294
|
)
|
@@ -336,9 +343,7 @@ class HybridEncoder(nn.Module):
|
|
336
343
|
# self.register_buffer(f'pos_embed{idx}', pos_embed)
|
337
344
|
|
338
345
|
@staticmethod
|
339
|
-
def build_2d_sincos_position_embedding(
|
340
|
-
w, h, embed_dim=256, temperature=10000.0
|
341
|
-
):
|
346
|
+
def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0):
|
342
347
|
""" """
|
343
348
|
grid_w = torch.arange(int(w), dtype=torch.float32)
|
344
349
|
grid_h = torch.arange(int(h), dtype=torch.float32)
|
@@ -376,9 +381,7 @@ class HybridEncoder(nn.Module):
|
|
376
381
|
src_flatten.device
|
377
382
|
)
|
378
383
|
|
379
|
-
memory: torch.Tensor = self.encoder[i](
|
380
|
-
src_flatten, pos_embed=pos_embed
|
381
|
-
)
|
384
|
+
memory: torch.Tensor = self.encoder[i](src_flatten, pos_embed=pos_embed)
|
382
385
|
proj_feats[enc_ind] = (
|
383
386
|
memory.permute(0, 2, 1)
|
384
387
|
.reshape(-1, self.hidden_dim, h, w)
|
@@ -390,13 +393,9 @@ class HybridEncoder(nn.Module):
|
|
390
393
|
for idx in range(len(self.in_channels) - 1, 0, -1):
|
391
394
|
feat_heigh = inner_outs[0]
|
392
395
|
feat_low = proj_feats[idx - 1]
|
393
|
-
feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](
|
394
|
-
feat_heigh
|
395
|
-
)
|
396
|
+
feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_heigh)
|
396
397
|
inner_outs[0] = feat_heigh
|
397
|
-
upsample_feat = F.interpolate(
|
398
|
-
feat_heigh, scale_factor=2.0, mode="nearest"
|
399
|
-
)
|
398
|
+
upsample_feat = F.interpolate(feat_heigh, scale_factor=2.0, mode="nearest")
|
400
399
|
inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx](
|
401
400
|
torch.concat([upsample_feat, feat_low], dim=1)
|
402
401
|
)
|
@@ -1,4 +1,17 @@
|
|
1
|
-
|
1
|
+
# Scene Text Recognition Model Hub
|
2
|
+
# Copyright 2023 lyuwenyu
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
2
15
|
|
3
16
|
import copy
|
4
17
|
import functools
|
yomitoku/models/parseq.py
CHANGED
@@ -22,13 +22,10 @@ from huggingface_hub import PyTorchModelHubMixin
|
|
22
22
|
from timm.models.helpers import named_apply
|
23
23
|
from torch import Tensor
|
24
24
|
|
25
|
-
from ..postprocessor import ParseqTokenizer as Tokenizer
|
26
25
|
from .layers.parseq_transformer import Decoder, Encoder, TokenEmbedding
|
27
26
|
|
28
27
|
|
29
|
-
def init_weights(
|
30
|
-
module: nn.Module, name: str = "", exclude: Sequence[str] = ()
|
31
|
-
):
|
28
|
+
def init_weights(module: nn.Module, name: str = "", exclude: Sequence[str] = ()):
|
32
29
|
"""Initialize the weights using the typical initialization schemes used in SOTA models."""
|
33
30
|
if any(map(name.startswith, exclude)):
|
34
31
|
return
|
@@ -41,9 +38,7 @@ def init_weights(
|
|
41
38
|
if module.padding_idx is not None:
|
42
39
|
module.weight.data[module.padding_idx].zero_()
|
43
40
|
elif isinstance(module, nn.Conv2d):
|
44
|
-
nn.init.kaiming_normal_(
|
45
|
-
module.weight, mode="fan_out", nonlinearity="relu"
|
46
|
-
)
|
41
|
+
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
47
42
|
if module.bias is not None:
|
48
43
|
nn.init.zeros_(module.bias)
|
49
44
|
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
|
@@ -86,6 +81,8 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
|
|
86
81
|
named_apply(partial(init_weights, exclude=["encoder"]), self)
|
87
82
|
nn.init.trunc_normal_(self.pos_queries, std=0.02)
|
88
83
|
|
84
|
+
self.export_onnx = False
|
85
|
+
|
89
86
|
@property
|
90
87
|
def _device(self) -> torch.device:
|
91
88
|
return next(self.head.parameters(recurse=False)).device
|
@@ -93,9 +90,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
|
|
93
90
|
@torch.jit.ignore
|
94
91
|
def no_weight_decay(self):
|
95
92
|
param_names = {"text_embed.embedding.weight", "pos_queries"}
|
96
|
-
enc_param_names = {
|
97
|
-
"encoder." + n for n in self.encoder.no_weight_decay()
|
98
|
-
}
|
93
|
+
enc_param_names = {"encoder." + n for n in self.encoder.no_weight_decay()}
|
99
94
|
return param_names.union(enc_param_names)
|
100
95
|
|
101
96
|
def encode(self, img: torch.Tensor):
|
@@ -129,7 +124,6 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
|
|
129
124
|
|
130
125
|
def forward(
|
131
126
|
self,
|
132
|
-
tokenizer: Tokenizer,
|
133
127
|
images: Tensor,
|
134
128
|
max_length: Optional[int] = None,
|
135
129
|
) -> Tensor:
|
@@ -149,20 +143,18 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
|
|
149
143
|
|
150
144
|
# Special case for the forward permutation. Faster than using `generate_attn_masks()`
|
151
145
|
tgt_mask = query_mask = torch.triu(
|
152
|
-
torch.ones(
|
153
|
-
(num_steps, num_steps), dtype=torch.bool, device=self._device
|
154
|
-
),
|
146
|
+
torch.ones((num_steps, num_steps), dtype=torch.bool, device=self._device),
|
155
147
|
1,
|
156
148
|
)
|
157
149
|
|
158
150
|
if self.decode_ar:
|
159
151
|
tgt_in = torch.full(
|
160
152
|
(bs, num_steps),
|
161
|
-
tokenizer.pad_id,
|
153
|
+
self.tokenizer.pad_id,
|
162
154
|
dtype=torch.long,
|
163
155
|
device=self._device,
|
164
156
|
)
|
165
|
-
tgt_in[:, 0] = tokenizer.bos_id
|
157
|
+
tgt_in[:, 0] = self.tokenizer.bos_id
|
166
158
|
|
167
159
|
logits = []
|
168
160
|
for i in range(num_steps):
|
@@ -186,8 +178,9 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
|
|
186
178
|
tgt_in[:, j] = p_i.squeeze().argmax(-1)
|
187
179
|
# Efficient batch decoding: If all output words have at least one EOS token, end decoding.
|
188
180
|
if (
|
189
|
-
|
190
|
-
and
|
181
|
+
not self.export_onnx
|
182
|
+
and testing
|
183
|
+
and (tgt_in == self.tokenizer.eos_id).any(dim=-1).all()
|
191
184
|
):
|
192
185
|
break
|
193
186
|
|
@@ -196,7 +189,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
|
|
196
189
|
# No prior context, so input is just <bos>. We query all positions.
|
197
190
|
tgt_in = torch.full(
|
198
191
|
(bs, 1),
|
199
|
-
tokenizer.bos_id,
|
192
|
+
self.tokenizer.bos_id,
|
200
193
|
dtype=torch.long,
|
201
194
|
device=self._device,
|
202
195
|
)
|
@@ -211,7 +204,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
|
|
211
204
|
torch.ones(
|
212
205
|
num_steps,
|
213
206
|
num_steps,
|
214
|
-
dtype=torch.
|
207
|
+
dtype=torch.int64,
|
215
208
|
device=self._device,
|
216
209
|
),
|
217
210
|
2,
|
@@ -219,7 +212,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
|
|
219
212
|
] = 0
|
220
213
|
bos = torch.full(
|
221
214
|
(bs, 1),
|
222
|
-
tokenizer.bos_id,
|
215
|
+
self.tokenizer.bos_id,
|
223
216
|
dtype=torch.long,
|
224
217
|
device=self._device,
|
225
218
|
)
|
@@ -227,7 +220,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
|
|
227
220
|
# Prior context is the previous output.
|
228
221
|
tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
|
229
222
|
# Mask tokens beyond the first EOS token.
|
230
|
-
tgt_padding_mask = (tgt_in == tokenizer.eos_id).int().cumsum(
|
223
|
+
tgt_padding_mask = (tgt_in == self.tokenizer.eos_id).int().cumsum(
|
231
224
|
-1
|
232
225
|
) > 0
|
233
226
|
tgt_out = self.decode(
|