magic-pdf 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/data/batch_build_dataset.py +156 -0
- magic_pdf/data/dataset.py +44 -24
- magic_pdf/data/utils.py +108 -9
- magic_pdf/dict2md/ocr_mkcontent.py +4 -3
- magic_pdf/libs/pdf_image_tools.py +11 -6
- magic_pdf/libs/performance_stats.py +12 -1
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/batch_analyze.py +175 -201
- magic_pdf/model/doc_analyze_by_custom_model.py +137 -92
- magic_pdf/model/pdf_extract_kit.py +5 -38
- magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
- magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
- magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
- magic_pdf/model/sub_modules/model_init.py +50 -37
- magic_pdf/model/sub_modules/model_utils.py +17 -11
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +10 -18
- magic_pdf/pdf_parse_union_core_v2.py +112 -74
- magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
- magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
- magic_pdf/resources/model_config/model_configs.yaml +1 -1
- magic_pdf/tools/cli.py +30 -12
- magic_pdf/tools/common.py +90 -12
- {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/METADATA +50 -40
- magic_pdf-1.3.0.dist-info/RECORD +202 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
- magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
- magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
- magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
- magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
- magic_pdf-1.2.2.dist-info/RECORD +0 -147
- /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
- /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
- /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
- {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/WHEEL +0 -0
- {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,43 @@
|
|
1
|
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
__all__ = ["build_head"]
|
16
|
+
|
17
|
+
|
18
|
+
def build_head(config, **kwargs):
|
19
|
+
# det head
|
20
|
+
from .det_db_head import DBHead, PFHeadLocal
|
21
|
+
|
22
|
+
# rec head
|
23
|
+
from .rec_ctc_head import CTCHead
|
24
|
+
from .rec_multi_head import MultiHead
|
25
|
+
|
26
|
+
# cls head
|
27
|
+
from .cls_head import ClsHead
|
28
|
+
|
29
|
+
support_dict = [
|
30
|
+
"DBHead",
|
31
|
+
"CTCHead",
|
32
|
+
"ClsHead",
|
33
|
+
"MultiHead",
|
34
|
+
"PFHeadLocal",
|
35
|
+
]
|
36
|
+
|
37
|
+
module_name = config.pop("name")
|
38
|
+
char_num = config.pop("char_num", 6625)
|
39
|
+
assert module_name in support_dict, Exception(
|
40
|
+
"head only support {}".format(support_dict)
|
41
|
+
)
|
42
|
+
module_class = eval(module_name)(**config, **kwargs)
|
43
|
+
return module_class
|
@@ -0,0 +1,23 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn.functional as F
|
3
|
+
from torch import nn
|
4
|
+
|
5
|
+
|
6
|
+
class ClsHead(nn.Module):
|
7
|
+
"""
|
8
|
+
Class orientation
|
9
|
+
Args:
|
10
|
+
params(dict): super parameters for build Class network
|
11
|
+
"""
|
12
|
+
|
13
|
+
def __init__(self, in_channels, class_dim, **kwargs):
|
14
|
+
super(ClsHead, self).__init__()
|
15
|
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
16
|
+
self.fc = nn.Linear(in_channels, class_dim, bias=True)
|
17
|
+
|
18
|
+
def forward(self, x):
|
19
|
+
x = self.pool(x)
|
20
|
+
x = torch.reshape(x, shape=[x.shape[0], x.shape[1]])
|
21
|
+
x = self.fc(x)
|
22
|
+
x = F.softmax(x, dim=1)
|
23
|
+
return x
|
@@ -0,0 +1,109 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from ..common import Activation
|
5
|
+
from ..backbones.det_mobilenet_v3 import ConvBNLayer
|
6
|
+
|
7
|
+
class Head(nn.Module):
|
8
|
+
def __init__(self, in_channels, **kwargs):
|
9
|
+
super(Head, self).__init__()
|
10
|
+
self.conv1 = nn.Conv2d(
|
11
|
+
in_channels=in_channels,
|
12
|
+
out_channels=in_channels // 4,
|
13
|
+
kernel_size=3,
|
14
|
+
padding=1,
|
15
|
+
bias=False)
|
16
|
+
self.conv_bn1 = nn.BatchNorm2d(
|
17
|
+
in_channels // 4)
|
18
|
+
self.relu1 = Activation(act_type='relu')
|
19
|
+
|
20
|
+
self.conv2 = nn.ConvTranspose2d(
|
21
|
+
in_channels=in_channels // 4,
|
22
|
+
out_channels=in_channels // 4,
|
23
|
+
kernel_size=2,
|
24
|
+
stride=2)
|
25
|
+
self.conv_bn2 = nn.BatchNorm2d(
|
26
|
+
in_channels // 4)
|
27
|
+
self.relu2 = Activation(act_type='relu')
|
28
|
+
|
29
|
+
self.conv3 = nn.ConvTranspose2d(
|
30
|
+
in_channels=in_channels // 4,
|
31
|
+
out_channels=1,
|
32
|
+
kernel_size=2,
|
33
|
+
stride=2)
|
34
|
+
|
35
|
+
def forward(self, x, return_f=False):
|
36
|
+
x = self.conv1(x)
|
37
|
+
x = self.conv_bn1(x)
|
38
|
+
x = self.relu1(x)
|
39
|
+
x = self.conv2(x)
|
40
|
+
x = self.conv_bn2(x)
|
41
|
+
x = self.relu2(x)
|
42
|
+
if return_f is True:
|
43
|
+
f = x
|
44
|
+
x = self.conv3(x)
|
45
|
+
x = torch.sigmoid(x)
|
46
|
+
if return_f is True:
|
47
|
+
return x, f
|
48
|
+
return x
|
49
|
+
|
50
|
+
|
51
|
+
class DBHead(nn.Module):
|
52
|
+
"""
|
53
|
+
Differentiable Binarization (DB) for text detection:
|
54
|
+
see https://arxiv.org/abs/1911.08947
|
55
|
+
args:
|
56
|
+
params(dict): super parameters for build DB network
|
57
|
+
"""
|
58
|
+
|
59
|
+
def __init__(self, in_channels, k=50, **kwargs):
|
60
|
+
super(DBHead, self).__init__()
|
61
|
+
self.k = k
|
62
|
+
binarize_name_list = [
|
63
|
+
'conv2d_56', 'batch_norm_47', 'conv2d_transpose_0', 'batch_norm_48',
|
64
|
+
'conv2d_transpose_1', 'binarize'
|
65
|
+
]
|
66
|
+
thresh_name_list = [
|
67
|
+
'conv2d_57', 'batch_norm_49', 'conv2d_transpose_2', 'batch_norm_50',
|
68
|
+
'conv2d_transpose_3', 'thresh'
|
69
|
+
]
|
70
|
+
self.binarize = Head(in_channels, **kwargs)# binarize_name_list)
|
71
|
+
self.thresh = Head(in_channels, **kwargs)#thresh_name_list)
|
72
|
+
|
73
|
+
def step_function(self, x, y):
|
74
|
+
return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
|
75
|
+
|
76
|
+
def forward(self, x):
|
77
|
+
shrink_maps = self.binarize(x)
|
78
|
+
return {'maps': shrink_maps}
|
79
|
+
|
80
|
+
|
81
|
+
class LocalModule(nn.Module):
|
82
|
+
def __init__(self, in_c, mid_c, use_distance=True):
|
83
|
+
super(self.__class__, self).__init__()
|
84
|
+
self.last_3 = ConvBNLayer(in_c + 1, mid_c, 3, 1, 1, act='relu')
|
85
|
+
self.last_1 = nn.Conv2d(mid_c, 1, 1, 1, 0)
|
86
|
+
|
87
|
+
def forward(self, x, init_map, distance_map):
|
88
|
+
outf = torch.cat([init_map, x], dim=1)
|
89
|
+
# last Conv
|
90
|
+
out = self.last_1(self.last_3(outf))
|
91
|
+
return out
|
92
|
+
|
93
|
+
class PFHeadLocal(DBHead):
|
94
|
+
def __init__(self, in_channels, k=50, mode='small', **kwargs):
|
95
|
+
super(PFHeadLocal, self).__init__(in_channels, k, **kwargs)
|
96
|
+
self.mode = mode
|
97
|
+
|
98
|
+
self.up_conv = nn.Upsample(scale_factor=2, mode="nearest")
|
99
|
+
if self.mode == 'large':
|
100
|
+
self.cbn_layer = LocalModule(in_channels // 4, in_channels // 4)
|
101
|
+
elif self.mode == 'small':
|
102
|
+
self.cbn_layer = LocalModule(in_channels // 4, in_channels // 8)
|
103
|
+
|
104
|
+
def forward(self, x, targets=None):
|
105
|
+
shrink_maps, f = self.binarize(x, return_f=True)
|
106
|
+
base_maps = shrink_maps
|
107
|
+
cbn_maps = self.cbn_layer(self.up_conv(f), shrink_maps, None)
|
108
|
+
cbn_maps = F.sigmoid(cbn_maps)
|
109
|
+
return {'maps': 0.5 * (base_maps + cbn_maps), 'cbn_maps': cbn_maps}
|
@@ -0,0 +1,54 @@
|
|
1
|
+
import torch.nn.functional as F
|
2
|
+
from torch import nn
|
3
|
+
|
4
|
+
|
5
|
+
class CTCHead(nn.Module):
|
6
|
+
def __init__(
|
7
|
+
self,
|
8
|
+
in_channels,
|
9
|
+
out_channels=6625,
|
10
|
+
fc_decay=0.0004,
|
11
|
+
mid_channels=None,
|
12
|
+
return_feats=False,
|
13
|
+
**kwargs
|
14
|
+
):
|
15
|
+
super(CTCHead, self).__init__()
|
16
|
+
if mid_channels is None:
|
17
|
+
self.fc = nn.Linear(
|
18
|
+
in_channels,
|
19
|
+
out_channels,
|
20
|
+
bias=True,
|
21
|
+
)
|
22
|
+
else:
|
23
|
+
self.fc1 = nn.Linear(
|
24
|
+
in_channels,
|
25
|
+
mid_channels,
|
26
|
+
bias=True,
|
27
|
+
)
|
28
|
+
self.fc2 = nn.Linear(
|
29
|
+
mid_channels,
|
30
|
+
out_channels,
|
31
|
+
bias=True,
|
32
|
+
)
|
33
|
+
|
34
|
+
self.out_channels = out_channels
|
35
|
+
self.mid_channels = mid_channels
|
36
|
+
self.return_feats = return_feats
|
37
|
+
|
38
|
+
def forward(self, x, labels=None):
|
39
|
+
if self.mid_channels is None:
|
40
|
+
predicts = self.fc(x)
|
41
|
+
else:
|
42
|
+
x = self.fc1(x)
|
43
|
+
predicts = self.fc2(x)
|
44
|
+
|
45
|
+
if self.return_feats:
|
46
|
+
result = (x, predicts)
|
47
|
+
else:
|
48
|
+
result = predicts
|
49
|
+
|
50
|
+
if not self.training:
|
51
|
+
predicts = F.softmax(predicts, dim=2)
|
52
|
+
result = predicts
|
53
|
+
|
54
|
+
return result
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from torch import nn
|
2
|
+
|
3
|
+
from ..necks.rnn import Im2Seq, SequenceEncoder
|
4
|
+
from .rec_ctc_head import CTCHead
|
5
|
+
|
6
|
+
|
7
|
+
class FCTranspose(nn.Module):
|
8
|
+
def __init__(self, in_channels, out_channels, only_transpose=False):
|
9
|
+
super().__init__()
|
10
|
+
self.only_transpose = only_transpose
|
11
|
+
if not self.only_transpose:
|
12
|
+
self.fc = nn.Linear(in_channels, out_channels, bias=False)
|
13
|
+
|
14
|
+
def forward(self, x):
|
15
|
+
if self.only_transpose:
|
16
|
+
return x.permute([0, 2, 1])
|
17
|
+
else:
|
18
|
+
return self.fc(x.permute([0, 2, 1]))
|
19
|
+
|
20
|
+
|
21
|
+
class MultiHead(nn.Module):
|
22
|
+
def __init__(self, in_channels, out_channels_list, **kwargs):
|
23
|
+
super().__init__()
|
24
|
+
self.head_list = kwargs.pop("head_list")
|
25
|
+
|
26
|
+
self.gtc_head = "sar"
|
27
|
+
assert len(self.head_list) >= 2
|
28
|
+
for idx, head_name in enumerate(self.head_list):
|
29
|
+
name = list(head_name)[0]
|
30
|
+
if name == "SARHead":
|
31
|
+
pass
|
32
|
+
|
33
|
+
elif name == "NRTRHead":
|
34
|
+
pass
|
35
|
+
elif name == "CTCHead":
|
36
|
+
# ctc neck
|
37
|
+
self.encoder_reshape = Im2Seq(in_channels)
|
38
|
+
neck_args = self.head_list[idx][name]["Neck"]
|
39
|
+
encoder_type = neck_args.pop("name")
|
40
|
+
self.ctc_encoder = SequenceEncoder(
|
41
|
+
in_channels=in_channels, encoder_type=encoder_type, **neck_args
|
42
|
+
)
|
43
|
+
# ctc head
|
44
|
+
head_args = self.head_list[idx][name].get("Head", {})
|
45
|
+
if head_args is None:
|
46
|
+
head_args = {}
|
47
|
+
|
48
|
+
self.ctc_head = CTCHead(
|
49
|
+
in_channels=self.ctc_encoder.out_channels,
|
50
|
+
out_channels=out_channels_list["CTCLabelDecode"],
|
51
|
+
**head_args,
|
52
|
+
)
|
53
|
+
else:
|
54
|
+
raise NotImplementedError(f"{name} is not supported in MultiHead yet")
|
55
|
+
|
56
|
+
def forward(self, x, data=None):
|
57
|
+
ctc_encoder = self.ctc_encoder(x)
|
58
|
+
return self.ctc_head(ctc_encoder)
|
@@ -0,0 +1,29 @@
|
|
1
|
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
__all__ = ["build_neck"]
|
16
|
+
|
17
|
+
|
18
|
+
def build_neck(config):
|
19
|
+
from .db_fpn import DBFPN, LKPAN, RSEFPN
|
20
|
+
from .rnn import SequenceEncoder
|
21
|
+
|
22
|
+
support_dict = ["DBFPN", "SequenceEncoder", "RSEFPN", "LKPAN"]
|
23
|
+
|
24
|
+
module_name = config.pop("name")
|
25
|
+
assert module_name in support_dict, Exception(
|
26
|
+
"neck only support {}".format(support_dict)
|
27
|
+
)
|
28
|
+
module_class = eval(module_name)(**config)
|
29
|
+
return module_class
|