pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyxllib/algo/geo.py +12 -0
- pyxllib/algo/intervals.py +1 -1
- pyxllib/algo/matcher.py +78 -0
- pyxllib/algo/pupil.py +187 -19
- pyxllib/algo/specialist.py +2 -1
- pyxllib/algo/stat.py +38 -2
- {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
- pyxllib/autogui/activewin.py +246 -0
- pyxllib/autogui/all.py +9 -0
- pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
- pyxllib/autogui/uiautolib.py +362 -0
- pyxllib/autogui/wechat.py +827 -0
- pyxllib/autogui/wechat_msg.py +421 -0
- pyxllib/autogui/wxautolib.py +84 -0
- pyxllib/cv/slidercaptcha.py +137 -0
- pyxllib/data/echarts.py +123 -12
- pyxllib/data/jsonlib.py +89 -0
- pyxllib/data/pglib.py +514 -30
- pyxllib/data/sqlite.py +231 -4
- pyxllib/ext/JLineViewer.py +14 -1
- pyxllib/ext/drissionlib.py +277 -0
- pyxllib/ext/kq5034lib.py +0 -1594
- pyxllib/ext/robustprocfile.py +497 -0
- pyxllib/ext/unixlib.py +6 -5
- pyxllib/ext/utools.py +108 -95
- pyxllib/ext/webhook.py +32 -14
- pyxllib/ext/wjxlib.py +88 -0
- pyxllib/ext/wpsapi.py +124 -0
- pyxllib/ext/xlwork.py +9 -0
- pyxllib/ext/yuquelib.py +1003 -71
- pyxllib/file/docxlib.py +1 -1
- pyxllib/file/libreoffice.py +165 -0
- pyxllib/file/movielib.py +9 -0
- pyxllib/file/packlib/__init__.py +112 -75
- pyxllib/file/pdflib.py +1 -1
- pyxllib/file/pupil.py +1 -1
- pyxllib/file/specialist/dirlib.py +1 -1
- pyxllib/file/specialist/download.py +10 -3
- pyxllib/file/specialist/filelib.py +266 -55
- pyxllib/file/xlsxlib.py +205 -50
- pyxllib/file/xlsyncfile.py +341 -0
- pyxllib/prog/cachetools.py +64 -0
- pyxllib/prog/filelock.py +42 -0
- pyxllib/prog/multiprogs.py +940 -0
- pyxllib/prog/newbie.py +9 -2
- pyxllib/prog/pupil.py +129 -60
- pyxllib/prog/specialist/__init__.py +176 -2
- pyxllib/prog/specialist/bc.py +5 -2
- pyxllib/prog/specialist/browser.py +11 -2
- pyxllib/prog/specialist/datetime.py +68 -0
- pyxllib/prog/specialist/tictoc.py +12 -13
- pyxllib/prog/specialist/xllog.py +5 -5
- pyxllib/prog/xlosenv.py +7 -0
- pyxllib/text/airscript.js +744 -0
- pyxllib/text/charclasslib.py +17 -5
- pyxllib/text/jiebalib.py +6 -3
- pyxllib/text/jinjalib.py +32 -0
- pyxllib/text/jsa_ai_prompt.md +271 -0
- pyxllib/text/jscode.py +159 -4
- pyxllib/text/nestenv.py +1 -1
- pyxllib/text/newbie.py +12 -0
- pyxllib/text/pupil/common.py +26 -0
- pyxllib/text/specialist/ptag.py +2 -2
- pyxllib/text/templates/echart_base.html +11 -0
- pyxllib/text/templates/highlight_code.html +17 -0
- pyxllib/text/templates/latex_editor.html +103 -0
- pyxllib/text/xmllib.py +76 -14
- pyxllib/xl.py +2 -1
- pyxllib-0.3.197.dist-info/METADATA +48 -0
- pyxllib-0.3.197.dist-info/RECORD +126 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
- pyxllib/ext/autogui/__init__.py +0 -8
- pyxllib-0.3.96.dist-info/METADATA +0 -51
- pyxllib-0.3.96.dist-info/RECORD +0 -333
- pyxllib-0.3.96.dist-info/top_level.txt +0 -2
- pyxlpr/ai/__init__.py +0 -5
- pyxlpr/ai/clientlib.py +0 -1281
- pyxlpr/ai/specialist.py +0 -286
- pyxlpr/ai/torch_app.py +0 -172
- pyxlpr/ai/xlpaddle.py +0 -655
- pyxlpr/ai/xltorch.py +0 -705
- pyxlpr/data/__init__.py +0 -11
- pyxlpr/data/coco.py +0 -1325
- pyxlpr/data/datacls.py +0 -365
- pyxlpr/data/datasets.py +0 -200
- pyxlpr/data/gptlib.py +0 -1291
- pyxlpr/data/icdar/__init__.py +0 -96
- pyxlpr/data/icdar/deteval.py +0 -377
- pyxlpr/data/icdar/icdar2013.py +0 -341
- pyxlpr/data/icdar/iou.py +0 -340
- pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
- pyxlpr/data/imtextline.py +0 -473
- pyxlpr/data/labelme.py +0 -866
- pyxlpr/data/removeline.py +0 -179
- pyxlpr/data/specialist.py +0 -57
- pyxlpr/eval/__init__.py +0 -85
- pyxlpr/paddleocr.py +0 -776
- pyxlpr/ppocr/__init__.py +0 -15
- pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
- pyxlpr/ppocr/data/__init__.py +0 -135
- pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
- pyxlpr/ppocr/data/imaug/__init__.py +0 -67
- pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
- pyxlpr/ppocr/data/imaug/east_process.py +0 -437
- pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
- pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
- pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
- pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
- pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
- pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
- pyxlpr/ppocr/data/imaug/operators.py +0 -433
- pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
- pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
- pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
- pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
- pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
- pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
- pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
- pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
- pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
- pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
- pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
- pyxlpr/ppocr/data/simple_dataset.py +0 -372
- pyxlpr/ppocr/losses/__init__.py +0 -61
- pyxlpr/ppocr/losses/ace_loss.py +0 -52
- pyxlpr/ppocr/losses/basic_loss.py +0 -135
- pyxlpr/ppocr/losses/center_loss.py +0 -88
- pyxlpr/ppocr/losses/cls_loss.py +0 -30
- pyxlpr/ppocr/losses/combined_loss.py +0 -67
- pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
- pyxlpr/ppocr/losses/det_db_loss.py +0 -80
- pyxlpr/ppocr/losses/det_east_loss.py +0 -63
- pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
- pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
- pyxlpr/ppocr/losses/distillation_loss.py +0 -272
- pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
- pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
- pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
- pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
- pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
- pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
- pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
- pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
- pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
- pyxlpr/ppocr/losses/table_att_loss.py +0 -109
- pyxlpr/ppocr/metrics/__init__.py +0 -44
- pyxlpr/ppocr/metrics/cls_metric.py +0 -45
- pyxlpr/ppocr/metrics/det_metric.py +0 -82
- pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
- pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
- pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
- pyxlpr/ppocr/metrics/kie_metric.py +0 -70
- pyxlpr/ppocr/metrics/rec_metric.py +0 -75
- pyxlpr/ppocr/metrics/table_metric.py +0 -50
- pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
- pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
- pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
- pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
- pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
- pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
- pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
- pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
- pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
- pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
- pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
- pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
- pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
- pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
- pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
- pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
- pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
- pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
- pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
- pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
- pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
- pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
- pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
- pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
- pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
- pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
- pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
- pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
- pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
- pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
- pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
- pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
- pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
- pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
- pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
- pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
- pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
- pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
- pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
- pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
- pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
- pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
- pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
- pyxlpr/ppocr/optimizer/__init__.py +0 -61
- pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
- pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
- pyxlpr/ppocr/optimizer/optimizer.py +0 -160
- pyxlpr/ppocr/optimizer/regularizer.py +0 -52
- pyxlpr/ppocr/postprocess/__init__.py +0 -55
- pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
- pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
- pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
- pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
- pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
- pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
- pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
- pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
- pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
- pyxlpr/ppocr/tools/__init__.py +0 -14
- pyxlpr/ppocr/tools/eval.py +0 -83
- pyxlpr/ppocr/tools/export_center.py +0 -77
- pyxlpr/ppocr/tools/export_model.py +0 -129
- pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
- pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
- pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
- pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
- pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
- pyxlpr/ppocr/tools/infer/utility.py +0 -629
- pyxlpr/ppocr/tools/infer_cls.py +0 -83
- pyxlpr/ppocr/tools/infer_det.py +0 -134
- pyxlpr/ppocr/tools/infer_e2e.py +0 -122
- pyxlpr/ppocr/tools/infer_kie.py +0 -153
- pyxlpr/ppocr/tools/infer_rec.py +0 -146
- pyxlpr/ppocr/tools/infer_table.py +0 -107
- pyxlpr/ppocr/tools/program.py +0 -596
- pyxlpr/ppocr/tools/test_hubserving.py +0 -117
- pyxlpr/ppocr/tools/train.py +0 -163
- pyxlpr/ppocr/tools/xlprog.py +0 -748
- pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
- pyxlpr/ppocr/utils/__init__.py +0 -24
- pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
- pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
- pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
- pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
- pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
- pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
- pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
- pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
- pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
- pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
- pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
- pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
- pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
- pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
- pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
- pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
- pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
- pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
- pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
- pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
- pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
- pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
- pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
- pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
- pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
- pyxlpr/ppocr/utils/dict90.txt +0 -90
- pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
- pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
- pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
- pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
- pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
- pyxlpr/ppocr/utils/en_dict.txt +0 -95
- pyxlpr/ppocr/utils/gen_label.py +0 -81
- pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
- pyxlpr/ppocr/utils/iou.py +0 -54
- pyxlpr/ppocr/utils/logging.py +0 -69
- pyxlpr/ppocr/utils/network.py +0 -84
- pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
- pyxlpr/ppocr/utils/profiler.py +0 -110
- pyxlpr/ppocr/utils/save_load.py +0 -150
- pyxlpr/ppocr/utils/stats.py +0 -72
- pyxlpr/ppocr/utils/utility.py +0 -80
- pyxlpr/ppstructure/__init__.py +0 -13
- pyxlpr/ppstructure/predict_system.py +0 -187
- pyxlpr/ppstructure/table/__init__.py +0 -13
- pyxlpr/ppstructure/table/eval_table.py +0 -72
- pyxlpr/ppstructure/table/matcher.py +0 -192
- pyxlpr/ppstructure/table/predict_structure.py +0 -136
- pyxlpr/ppstructure/table/predict_table.py +0 -221
- pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
- pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
- pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
- pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
- pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
- pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
- pyxlpr/ppstructure/utility.py +0 -71
- pyxlpr/xlai.py +0 -10
- /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
@@ -1,402 +0,0 @@
|
|
1
|
-
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
"""
|
15
|
-
This code is refer from:
|
16
|
-
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py
|
17
|
-
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py
|
18
|
-
"""
|
19
|
-
|
20
|
-
from __future__ import absolute_import
|
21
|
-
from __future__ import division
|
22
|
-
from __future__ import print_function
|
23
|
-
|
24
|
-
import math
|
25
|
-
import paddle
|
26
|
-
from paddle import ParamAttr
|
27
|
-
import paddle.nn as nn
|
28
|
-
import paddle.nn.functional as F
|
29
|
-
|
30
|
-
|
31
|
-
class SAREncoder(nn.Layer):
|
32
|
-
"""
|
33
|
-
Args:
|
34
|
-
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
|
35
|
-
enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
|
36
|
-
enc_gru (bool): If True, use GRU, else LSTM in encoder.
|
37
|
-
d_model (int): Dim of channels from backbone.
|
38
|
-
d_enc (int): Dim of encoder RNN layer.
|
39
|
-
mask (bool): If True, mask padding in RNN sequence.
|
40
|
-
"""
|
41
|
-
|
42
|
-
def __init__(self,
|
43
|
-
enc_bi_rnn=False,
|
44
|
-
enc_drop_rnn=0.1,
|
45
|
-
enc_gru=False,
|
46
|
-
d_model=512,
|
47
|
-
d_enc=512,
|
48
|
-
mask=True,
|
49
|
-
**kwargs):
|
50
|
-
super().__init__()
|
51
|
-
assert isinstance(enc_bi_rnn, bool)
|
52
|
-
assert isinstance(enc_drop_rnn, (int, float))
|
53
|
-
assert 0 <= enc_drop_rnn < 1.0
|
54
|
-
assert isinstance(enc_gru, bool)
|
55
|
-
assert isinstance(d_model, int)
|
56
|
-
assert isinstance(d_enc, int)
|
57
|
-
assert isinstance(mask, bool)
|
58
|
-
|
59
|
-
self.enc_bi_rnn = enc_bi_rnn
|
60
|
-
self.enc_drop_rnn = enc_drop_rnn
|
61
|
-
self.mask = mask
|
62
|
-
|
63
|
-
# LSTM Encoder
|
64
|
-
if enc_bi_rnn:
|
65
|
-
direction = 'bidirectional'
|
66
|
-
else:
|
67
|
-
direction = 'forward'
|
68
|
-
kwargs = dict(
|
69
|
-
input_size=d_model,
|
70
|
-
hidden_size=d_enc,
|
71
|
-
num_layers=2,
|
72
|
-
time_major=False,
|
73
|
-
dropout=enc_drop_rnn,
|
74
|
-
direction=direction)
|
75
|
-
if enc_gru:
|
76
|
-
self.rnn_encoder = nn.GRU(**kwargs)
|
77
|
-
else:
|
78
|
-
self.rnn_encoder = nn.LSTM(**kwargs)
|
79
|
-
|
80
|
-
# global feature transformation
|
81
|
-
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
|
82
|
-
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
|
83
|
-
|
84
|
-
def forward(self, feat, img_metas=None):
|
85
|
-
if img_metas is not None:
|
86
|
-
assert len(img_metas[0]) == feat.shape[0]
|
87
|
-
|
88
|
-
valid_ratios = None
|
89
|
-
if img_metas is not None and self.mask:
|
90
|
-
valid_ratios = img_metas[-1]
|
91
|
-
|
92
|
-
h_feat = feat.shape[2] # bsz c h w
|
93
|
-
feat_v = F.max_pool2d(
|
94
|
-
feat, kernel_size=(h_feat, 1), stride=1, padding=0)
|
95
|
-
feat_v = feat_v.squeeze(2) # bsz * C * W
|
96
|
-
feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C
|
97
|
-
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
|
98
|
-
|
99
|
-
if valid_ratios is not None:
|
100
|
-
valid_hf = []
|
101
|
-
T = holistic_feat.shape[1]
|
102
|
-
for i, valid_ratio in enumerate(valid_ratios):
|
103
|
-
valid_step = min(T, math.ceil(T * valid_ratio)) - 1
|
104
|
-
valid_hf.append(holistic_feat[i, valid_step, :])
|
105
|
-
valid_hf = paddle.stack(valid_hf, axis=0)
|
106
|
-
else:
|
107
|
-
valid_hf = holistic_feat[:, -1, :] # bsz * C
|
108
|
-
holistic_feat = self.linear(valid_hf) # bsz * C
|
109
|
-
|
110
|
-
return holistic_feat
|
111
|
-
|
112
|
-
|
113
|
-
class BaseDecoder(nn.Layer):
|
114
|
-
def __init__(self, **kwargs):
|
115
|
-
super().__init__()
|
116
|
-
|
117
|
-
def forward_train(self, feat, out_enc, targets, img_metas):
|
118
|
-
raise NotImplementedError
|
119
|
-
|
120
|
-
def forward_test(self, feat, out_enc, img_metas):
|
121
|
-
raise NotImplementedError
|
122
|
-
|
123
|
-
def forward(self,
|
124
|
-
feat,
|
125
|
-
out_enc,
|
126
|
-
label=None,
|
127
|
-
img_metas=None,
|
128
|
-
train_mode=True):
|
129
|
-
self.train_mode = train_mode
|
130
|
-
|
131
|
-
if train_mode:
|
132
|
-
return self.forward_train(feat, out_enc, label, img_metas)
|
133
|
-
return self.forward_test(feat, out_enc, img_metas)
|
134
|
-
|
135
|
-
|
136
|
-
class ParallelSARDecoder(BaseDecoder):
|
137
|
-
"""
|
138
|
-
Args:
|
139
|
-
out_channels (int): Output class number.
|
140
|
-
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
|
141
|
-
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
|
142
|
-
dec_drop_rnn (float): Dropout of RNN layer in decoder.
|
143
|
-
dec_gru (bool): If True, use GRU, else LSTM in decoder.
|
144
|
-
d_model (int): Dim of channels from backbone.
|
145
|
-
d_enc (int): Dim of encoder RNN layer.
|
146
|
-
d_k (int): Dim of channels of attention module.
|
147
|
-
pred_dropout (float): Dropout probability of prediction layer.
|
148
|
-
max_seq_len (int): Maximum sequence length for decoding.
|
149
|
-
mask (bool): If True, mask padding in feature map.
|
150
|
-
start_idx (int): Index of start token.
|
151
|
-
padding_idx (int): Index of padding token.
|
152
|
-
pred_concat (bool): If True, concat glimpse feature from
|
153
|
-
attention with holistic feature and hidden state.
|
154
|
-
"""
|
155
|
-
|
156
|
-
def __init__(
|
157
|
-
self,
|
158
|
-
out_channels, # 90 + unknown + start + padding
|
159
|
-
enc_bi_rnn=False,
|
160
|
-
dec_bi_rnn=False,
|
161
|
-
dec_drop_rnn=0.0,
|
162
|
-
dec_gru=False,
|
163
|
-
d_model=512,
|
164
|
-
d_enc=512,
|
165
|
-
d_k=64,
|
166
|
-
pred_dropout=0.1,
|
167
|
-
max_text_length=30,
|
168
|
-
mask=True,
|
169
|
-
pred_concat=True,
|
170
|
-
**kwargs):
|
171
|
-
super().__init__()
|
172
|
-
|
173
|
-
self.num_classes = out_channels
|
174
|
-
self.enc_bi_rnn = enc_bi_rnn
|
175
|
-
self.d_k = d_k
|
176
|
-
self.start_idx = out_channels - 2
|
177
|
-
self.padding_idx = out_channels - 1
|
178
|
-
self.max_seq_len = max_text_length
|
179
|
-
self.mask = mask
|
180
|
-
self.pred_concat = pred_concat
|
181
|
-
|
182
|
-
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
|
183
|
-
decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
|
184
|
-
|
185
|
-
# 2D attention layer
|
186
|
-
self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
|
187
|
-
self.conv3x3_1 = nn.Conv2D(
|
188
|
-
d_model, d_k, kernel_size=3, stride=1, padding=1)
|
189
|
-
self.conv1x1_2 = nn.Linear(d_k, 1)
|
190
|
-
|
191
|
-
# Decoder RNN layer
|
192
|
-
if dec_bi_rnn:
|
193
|
-
direction = 'bidirectional'
|
194
|
-
else:
|
195
|
-
direction = 'forward'
|
196
|
-
|
197
|
-
kwargs = dict(
|
198
|
-
input_size=encoder_rnn_out_size,
|
199
|
-
hidden_size=encoder_rnn_out_size,
|
200
|
-
num_layers=2,
|
201
|
-
time_major=False,
|
202
|
-
dropout=dec_drop_rnn,
|
203
|
-
direction=direction)
|
204
|
-
if dec_gru:
|
205
|
-
self.rnn_decoder = nn.GRU(**kwargs)
|
206
|
-
else:
|
207
|
-
self.rnn_decoder = nn.LSTM(**kwargs)
|
208
|
-
|
209
|
-
# Decoder input embedding
|
210
|
-
self.embedding = nn.Embedding(
|
211
|
-
self.num_classes,
|
212
|
-
encoder_rnn_out_size,
|
213
|
-
padding_idx=self.padding_idx)
|
214
|
-
|
215
|
-
# Prediction layer
|
216
|
-
self.pred_dropout = nn.Dropout(pred_dropout)
|
217
|
-
pred_num_classes = self.num_classes - 1
|
218
|
-
if pred_concat:
|
219
|
-
fc_in_channel = decoder_rnn_out_size + d_model + d_enc
|
220
|
-
else:
|
221
|
-
fc_in_channel = d_model
|
222
|
-
self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
|
223
|
-
|
224
|
-
def _2d_attention(self,
|
225
|
-
decoder_input,
|
226
|
-
feat,
|
227
|
-
holistic_feat,
|
228
|
-
valid_ratios=None):
|
229
|
-
|
230
|
-
y = self.rnn_decoder(decoder_input)[0]
|
231
|
-
# y: bsz * (seq_len + 1) * hidden_size
|
232
|
-
|
233
|
-
attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
|
234
|
-
bsz, seq_len, attn_size = attn_query.shape
|
235
|
-
attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
|
236
|
-
# (bsz, seq_len + 1, attn_size, 1, 1)
|
237
|
-
|
238
|
-
attn_key = self.conv3x3_1(feat)
|
239
|
-
# bsz * attn_size * h * w
|
240
|
-
attn_key = attn_key.unsqueeze(1)
|
241
|
-
# bsz * 1 * attn_size * h * w
|
242
|
-
|
243
|
-
attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
|
244
|
-
|
245
|
-
# bsz * (seq_len + 1) * attn_size * h * w
|
246
|
-
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
|
247
|
-
# bsz * (seq_len + 1) * h * w * attn_size
|
248
|
-
attn_weight = self.conv1x1_2(attn_weight)
|
249
|
-
# bsz * (seq_len + 1) * h * w * 1
|
250
|
-
bsz, T, h, w, c = attn_weight.shape
|
251
|
-
assert c == 1
|
252
|
-
|
253
|
-
if valid_ratios is not None:
|
254
|
-
# cal mask of attention weight
|
255
|
-
for i, valid_ratio in enumerate(valid_ratios):
|
256
|
-
valid_width = min(w, math.ceil(w * valid_ratio))
|
257
|
-
if valid_width < w:
|
258
|
-
attn_weight[i, :, :, valid_width:, :] = float('-inf')
|
259
|
-
|
260
|
-
attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
|
261
|
-
attn_weight = F.softmax(attn_weight, axis=-1)
|
262
|
-
|
263
|
-
attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
|
264
|
-
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
|
265
|
-
# attn_weight: bsz * T * c * h * w
|
266
|
-
# feat: bsz * c * h * w
|
267
|
-
attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
|
268
|
-
(3, 4),
|
269
|
-
keepdim=False)
|
270
|
-
# bsz * (seq_len + 1) * C
|
271
|
-
|
272
|
-
# Linear transformation
|
273
|
-
if self.pred_concat:
|
274
|
-
hf_c = holistic_feat.shape[-1]
|
275
|
-
holistic_feat = paddle.expand(
|
276
|
-
holistic_feat, shape=[bsz, seq_len, hf_c])
|
277
|
-
y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
|
278
|
-
else:
|
279
|
-
y = self.prediction(attn_feat)
|
280
|
-
# bsz * (seq_len + 1) * num_classes
|
281
|
-
if self.train_mode:
|
282
|
-
y = self.pred_dropout(y)
|
283
|
-
|
284
|
-
return y
|
285
|
-
|
286
|
-
def forward_train(self, feat, out_enc, label, img_metas):
|
287
|
-
'''
|
288
|
-
img_metas: [label, valid_ratio]
|
289
|
-
'''
|
290
|
-
if img_metas is not None:
|
291
|
-
assert len(img_metas[0]) == feat.shape[0]
|
292
|
-
|
293
|
-
valid_ratios = None
|
294
|
-
if img_metas is not None and self.mask:
|
295
|
-
valid_ratios = img_metas[-1]
|
296
|
-
|
297
|
-
lab_embedding = self.embedding(label)
|
298
|
-
# bsz * seq_len * emb_dim
|
299
|
-
out_enc = out_enc.unsqueeze(1)
|
300
|
-
# bsz * 1 * emb_dim
|
301
|
-
in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
|
302
|
-
# bsz * (seq_len + 1) * C
|
303
|
-
out_dec = self._2d_attention(
|
304
|
-
in_dec, feat, out_enc, valid_ratios=valid_ratios)
|
305
|
-
# bsz * (seq_len + 1) * num_classes
|
306
|
-
|
307
|
-
return out_dec[:, 1:, :] # bsz * seq_len * num_classes
|
308
|
-
|
309
|
-
def forward_test(self, feat, out_enc, img_metas):
|
310
|
-
if img_metas is not None:
|
311
|
-
assert len(img_metas[0]) == feat.shape[0]
|
312
|
-
|
313
|
-
valid_ratios = None
|
314
|
-
if img_metas is not None and self.mask:
|
315
|
-
valid_ratios = img_metas[-1]
|
316
|
-
|
317
|
-
seq_len = self.max_seq_len
|
318
|
-
bsz = feat.shape[0]
|
319
|
-
start_token = paddle.full(
|
320
|
-
(bsz, ), fill_value=self.start_idx, dtype='int64')
|
321
|
-
# bsz
|
322
|
-
start_token = self.embedding(start_token)
|
323
|
-
# bsz * emb_dim
|
324
|
-
emb_dim = start_token.shape[1]
|
325
|
-
start_token = start_token.unsqueeze(1)
|
326
|
-
start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim])
|
327
|
-
# bsz * seq_len * emb_dim
|
328
|
-
out_enc = out_enc.unsqueeze(1)
|
329
|
-
# bsz * 1 * emb_dim
|
330
|
-
decoder_input = paddle.concat((out_enc, start_token), axis=1)
|
331
|
-
# bsz * (seq_len + 1) * emb_dim
|
332
|
-
|
333
|
-
outputs = []
|
334
|
-
for i in range(1, seq_len + 1):
|
335
|
-
decoder_output = self._2d_attention(
|
336
|
-
decoder_input, feat, out_enc, valid_ratios=valid_ratios)
|
337
|
-
char_output = decoder_output[:, i, :] # bsz * num_classes
|
338
|
-
char_output = F.softmax(char_output, -1)
|
339
|
-
outputs.append(char_output)
|
340
|
-
max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
|
341
|
-
char_embedding = self.embedding(max_idx) # bsz * emb_dim
|
342
|
-
if i < seq_len:
|
343
|
-
decoder_input[:, i + 1, :] = char_embedding
|
344
|
-
|
345
|
-
outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes
|
346
|
-
|
347
|
-
return outputs
|
348
|
-
|
349
|
-
|
350
|
-
class SARHead(nn.Layer):
|
351
|
-
def __init__(self,
|
352
|
-
out_channels,
|
353
|
-
enc_bi_rnn=False,
|
354
|
-
enc_drop_rnn=0.1,
|
355
|
-
enc_gru=False,
|
356
|
-
dec_bi_rnn=False,
|
357
|
-
dec_drop_rnn=0.0,
|
358
|
-
dec_gru=False,
|
359
|
-
d_k=512,
|
360
|
-
pred_dropout=0.1,
|
361
|
-
max_text_length=30,
|
362
|
-
pred_concat=True,
|
363
|
-
**kwargs):
|
364
|
-
super(SARHead, self).__init__()
|
365
|
-
|
366
|
-
# encoder module
|
367
|
-
self.encoder = SAREncoder(
|
368
|
-
enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru)
|
369
|
-
|
370
|
-
# decoder module
|
371
|
-
self.decoder = ParallelSARDecoder(
|
372
|
-
out_channels=out_channels,
|
373
|
-
enc_bi_rnn=enc_bi_rnn,
|
374
|
-
dec_bi_rnn=dec_bi_rnn,
|
375
|
-
dec_drop_rnn=dec_drop_rnn,
|
376
|
-
dec_gru=dec_gru,
|
377
|
-
d_k=d_k,
|
378
|
-
pred_dropout=pred_dropout,
|
379
|
-
max_text_length=max_text_length,
|
380
|
-
pred_concat=pred_concat)
|
381
|
-
|
382
|
-
def forward(self, feat, targets=None):
|
383
|
-
'''
|
384
|
-
img_metas: [label, valid_ratio]
|
385
|
-
'''
|
386
|
-
holistic_feat = self.encoder(feat, targets) # bsz c
|
387
|
-
|
388
|
-
if self.training:
|
389
|
-
label = targets[0] # label
|
390
|
-
label = paddle.to_tensor(label, dtype='int64')
|
391
|
-
final_out = self.decoder(
|
392
|
-
feat, holistic_feat, label, img_metas=targets)
|
393
|
-
if not self.training:
|
394
|
-
final_out = self.decoder(
|
395
|
-
feat,
|
396
|
-
holistic_feat,
|
397
|
-
label=None,
|
398
|
-
img_metas=targets,
|
399
|
-
train_mode=False)
|
400
|
-
# (bsz, seq_len, num_classes)
|
401
|
-
|
402
|
-
return final_out
|
@@ -1,280 +0,0 @@
|
|
1
|
-
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
|
19
|
-
import math
|
20
|
-
import paddle
|
21
|
-
from paddle import nn, ParamAttr
|
22
|
-
from paddle.nn import functional as F
|
23
|
-
import paddle.fluid as fluid
|
24
|
-
import numpy as np
|
25
|
-
from .self_attention import WrapEncoderForFeature
|
26
|
-
from .self_attention import WrapEncoder
|
27
|
-
from paddle.static import Program
|
28
|
-
from pyxlpr.ppocr.modeling.backbones.rec_resnet_fpn import ResNetFPN
|
29
|
-
import paddle.fluid.framework as framework
|
30
|
-
|
31
|
-
from collections import OrderedDict
|
32
|
-
gradient_clip = 10
|
33
|
-
|
34
|
-
|
35
|
-
class PVAM(nn.Layer):
|
36
|
-
def __init__(self, in_channels, char_num, max_text_length, num_heads,
|
37
|
-
num_encoder_tus, hidden_dims):
|
38
|
-
super(PVAM, self).__init__()
|
39
|
-
self.char_num = char_num
|
40
|
-
self.max_length = max_text_length
|
41
|
-
self.num_heads = num_heads
|
42
|
-
self.num_encoder_TUs = num_encoder_tus
|
43
|
-
self.hidden_dims = hidden_dims
|
44
|
-
# Transformer encoder
|
45
|
-
t = 256
|
46
|
-
c = 512
|
47
|
-
self.wrap_encoder_for_feature = WrapEncoderForFeature(
|
48
|
-
src_vocab_size=1,
|
49
|
-
max_length=t,
|
50
|
-
n_layer=self.num_encoder_TUs,
|
51
|
-
n_head=self.num_heads,
|
52
|
-
d_key=int(self.hidden_dims / self.num_heads),
|
53
|
-
d_value=int(self.hidden_dims / self.num_heads),
|
54
|
-
d_model=self.hidden_dims,
|
55
|
-
d_inner_hid=self.hidden_dims,
|
56
|
-
prepostprocess_dropout=0.1,
|
57
|
-
attention_dropout=0.1,
|
58
|
-
relu_dropout=0.1,
|
59
|
-
preprocess_cmd="n",
|
60
|
-
postprocess_cmd="da",
|
61
|
-
weight_sharing=True)
|
62
|
-
|
63
|
-
# PVAM
|
64
|
-
self.flatten0 = paddle.nn.Flatten(start_axis=0, stop_axis=1)
|
65
|
-
self.fc0 = paddle.nn.Linear(
|
66
|
-
in_features=in_channels,
|
67
|
-
out_features=in_channels, )
|
68
|
-
self.emb = paddle.nn.Embedding(
|
69
|
-
num_embeddings=self.max_length, embedding_dim=in_channels)
|
70
|
-
self.flatten1 = paddle.nn.Flatten(start_axis=0, stop_axis=2)
|
71
|
-
self.fc1 = paddle.nn.Linear(
|
72
|
-
in_features=in_channels, out_features=1, bias_attr=False)
|
73
|
-
|
74
|
-
def forward(self, inputs, encoder_word_pos, gsrm_word_pos):
|
75
|
-
b, c, h, w = inputs.shape
|
76
|
-
conv_features = paddle.reshape(inputs, shape=[-1, c, h * w])
|
77
|
-
conv_features = paddle.transpose(conv_features, perm=[0, 2, 1])
|
78
|
-
# transformer encoder
|
79
|
-
b, t, c = conv_features.shape
|
80
|
-
|
81
|
-
enc_inputs = [conv_features, encoder_word_pos, None]
|
82
|
-
word_features = self.wrap_encoder_for_feature(enc_inputs)
|
83
|
-
|
84
|
-
# pvam
|
85
|
-
b, t, c = word_features.shape
|
86
|
-
word_features = self.fc0(word_features)
|
87
|
-
word_features_ = paddle.reshape(word_features, [-1, 1, t, c])
|
88
|
-
word_features_ = paddle.tile(word_features_, [1, self.max_length, 1, 1])
|
89
|
-
word_pos_feature = self.emb(gsrm_word_pos)
|
90
|
-
word_pos_feature_ = paddle.reshape(word_pos_feature,
|
91
|
-
[-1, self.max_length, 1, c])
|
92
|
-
word_pos_feature_ = paddle.tile(word_pos_feature_, [1, 1, t, 1])
|
93
|
-
y = word_pos_feature_ + word_features_
|
94
|
-
y = F.tanh(y)
|
95
|
-
attention_weight = self.fc1(y)
|
96
|
-
attention_weight = paddle.reshape(
|
97
|
-
attention_weight, shape=[-1, self.max_length, t])
|
98
|
-
attention_weight = F.softmax(attention_weight, axis=-1)
|
99
|
-
pvam_features = paddle.matmul(attention_weight,
|
100
|
-
word_features) #[b, max_length, c]
|
101
|
-
return pvam_features
|
102
|
-
|
103
|
-
|
104
|
-
class GSRM(nn.Layer):
|
105
|
-
def __init__(self, in_channels, char_num, max_text_length, num_heads,
|
106
|
-
num_encoder_tus, num_decoder_tus, hidden_dims):
|
107
|
-
super(GSRM, self).__init__()
|
108
|
-
self.char_num = char_num
|
109
|
-
self.max_length = max_text_length
|
110
|
-
self.num_heads = num_heads
|
111
|
-
self.num_encoder_TUs = num_encoder_tus
|
112
|
-
self.num_decoder_TUs = num_decoder_tus
|
113
|
-
self.hidden_dims = hidden_dims
|
114
|
-
|
115
|
-
self.fc0 = paddle.nn.Linear(
|
116
|
-
in_features=in_channels, out_features=self.char_num)
|
117
|
-
self.wrap_encoder0 = WrapEncoder(
|
118
|
-
src_vocab_size=self.char_num + 1,
|
119
|
-
max_length=self.max_length,
|
120
|
-
n_layer=self.num_decoder_TUs,
|
121
|
-
n_head=self.num_heads,
|
122
|
-
d_key=int(self.hidden_dims / self.num_heads),
|
123
|
-
d_value=int(self.hidden_dims / self.num_heads),
|
124
|
-
d_model=self.hidden_dims,
|
125
|
-
d_inner_hid=self.hidden_dims,
|
126
|
-
prepostprocess_dropout=0.1,
|
127
|
-
attention_dropout=0.1,
|
128
|
-
relu_dropout=0.1,
|
129
|
-
preprocess_cmd="n",
|
130
|
-
postprocess_cmd="da",
|
131
|
-
weight_sharing=True)
|
132
|
-
|
133
|
-
self.wrap_encoder1 = WrapEncoder(
|
134
|
-
src_vocab_size=self.char_num + 1,
|
135
|
-
max_length=self.max_length,
|
136
|
-
n_layer=self.num_decoder_TUs,
|
137
|
-
n_head=self.num_heads,
|
138
|
-
d_key=int(self.hidden_dims / self.num_heads),
|
139
|
-
d_value=int(self.hidden_dims / self.num_heads),
|
140
|
-
d_model=self.hidden_dims,
|
141
|
-
d_inner_hid=self.hidden_dims,
|
142
|
-
prepostprocess_dropout=0.1,
|
143
|
-
attention_dropout=0.1,
|
144
|
-
relu_dropout=0.1,
|
145
|
-
preprocess_cmd="n",
|
146
|
-
postprocess_cmd="da",
|
147
|
-
weight_sharing=True)
|
148
|
-
|
149
|
-
self.mul = lambda x: paddle.matmul(x=x,
|
150
|
-
y=self.wrap_encoder0.prepare_decoder.emb0.weight,
|
151
|
-
transpose_y=True)
|
152
|
-
|
153
|
-
def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1,
|
154
|
-
gsrm_slf_attn_bias2):
|
155
|
-
# ===== GSRM Visual-to-semantic embedding block =====
|
156
|
-
b, t, c = inputs.shape
|
157
|
-
pvam_features = paddle.reshape(inputs, [-1, c])
|
158
|
-
word_out = self.fc0(pvam_features)
|
159
|
-
word_ids = paddle.argmax(F.softmax(word_out), axis=1)
|
160
|
-
word_ids = paddle.reshape(x=word_ids, shape=[-1, t, 1])
|
161
|
-
|
162
|
-
#===== GSRM Semantic reasoning block =====
|
163
|
-
"""
|
164
|
-
This module is achieved through bi-transformers,
|
165
|
-
ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
|
166
|
-
"""
|
167
|
-
pad_idx = self.char_num
|
168
|
-
|
169
|
-
word1 = paddle.cast(word_ids, "float32")
|
170
|
-
word1 = F.pad(word1, [1, 0], value=1.0 * pad_idx, data_format="NLC")
|
171
|
-
word1 = paddle.cast(word1, "int64")
|
172
|
-
word1 = word1[:, :-1, :]
|
173
|
-
word2 = word_ids
|
174
|
-
|
175
|
-
enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1]
|
176
|
-
enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2]
|
177
|
-
|
178
|
-
gsrm_feature1 = self.wrap_encoder0(enc_inputs_1)
|
179
|
-
gsrm_feature2 = self.wrap_encoder1(enc_inputs_2)
|
180
|
-
|
181
|
-
gsrm_feature2 = F.pad(gsrm_feature2, [0, 1],
|
182
|
-
value=0.,
|
183
|
-
data_format="NLC")
|
184
|
-
gsrm_feature2 = gsrm_feature2[:, 1:, ]
|
185
|
-
gsrm_features = gsrm_feature1 + gsrm_feature2
|
186
|
-
|
187
|
-
gsrm_out = self.mul(gsrm_features)
|
188
|
-
|
189
|
-
b, t, c = gsrm_out.shape
|
190
|
-
gsrm_out = paddle.reshape(gsrm_out, [-1, c])
|
191
|
-
|
192
|
-
return gsrm_features, word_out, gsrm_out
|
193
|
-
|
194
|
-
|
195
|
-
class VSFD(nn.Layer):
|
196
|
-
def __init__(self, in_channels=512, pvam_ch=512, char_num=38):
|
197
|
-
super(VSFD, self).__init__()
|
198
|
-
self.char_num = char_num
|
199
|
-
self.fc0 = paddle.nn.Linear(
|
200
|
-
in_features=in_channels * 2, out_features=pvam_ch)
|
201
|
-
self.fc1 = paddle.nn.Linear(
|
202
|
-
in_features=pvam_ch, out_features=self.char_num)
|
203
|
-
|
204
|
-
def forward(self, pvam_feature, gsrm_feature):
|
205
|
-
b, t, c1 = pvam_feature.shape
|
206
|
-
b, t, c2 = gsrm_feature.shape
|
207
|
-
combine_feature_ = paddle.concat([pvam_feature, gsrm_feature], axis=2)
|
208
|
-
img_comb_feature_ = paddle.reshape(
|
209
|
-
combine_feature_, shape=[-1, c1 + c2])
|
210
|
-
img_comb_feature_map = self.fc0(img_comb_feature_)
|
211
|
-
img_comb_feature_map = F.sigmoid(img_comb_feature_map)
|
212
|
-
img_comb_feature_map = paddle.reshape(
|
213
|
-
img_comb_feature_map, shape=[-1, t, c1])
|
214
|
-
combine_feature = img_comb_feature_map * pvam_feature + (
|
215
|
-
1.0 - img_comb_feature_map) * gsrm_feature
|
216
|
-
img_comb_feature = paddle.reshape(combine_feature, shape=[-1, c1])
|
217
|
-
|
218
|
-
out = self.fc1(img_comb_feature)
|
219
|
-
return out
|
220
|
-
|
221
|
-
|
222
|
-
class SRNHead(nn.Layer):
|
223
|
-
def __init__(self, in_channels, out_channels, max_text_length, num_heads,
|
224
|
-
num_encoder_TUs, num_decoder_TUs, hidden_dims, **kwargs):
|
225
|
-
super(SRNHead, self).__init__()
|
226
|
-
self.char_num = out_channels
|
227
|
-
self.max_length = max_text_length
|
228
|
-
self.num_heads = num_heads
|
229
|
-
self.num_encoder_TUs = num_encoder_TUs
|
230
|
-
self.num_decoder_TUs = num_decoder_TUs
|
231
|
-
self.hidden_dims = hidden_dims
|
232
|
-
|
233
|
-
self.pvam = PVAM(
|
234
|
-
in_channels=in_channels,
|
235
|
-
char_num=self.char_num,
|
236
|
-
max_text_length=self.max_length,
|
237
|
-
num_heads=self.num_heads,
|
238
|
-
num_encoder_tus=self.num_encoder_TUs,
|
239
|
-
hidden_dims=self.hidden_dims)
|
240
|
-
|
241
|
-
self.gsrm = GSRM(
|
242
|
-
in_channels=in_channels,
|
243
|
-
char_num=self.char_num,
|
244
|
-
max_text_length=self.max_length,
|
245
|
-
num_heads=self.num_heads,
|
246
|
-
num_encoder_tus=self.num_encoder_TUs,
|
247
|
-
num_decoder_tus=self.num_decoder_TUs,
|
248
|
-
hidden_dims=self.hidden_dims)
|
249
|
-
self.vsfd = VSFD(in_channels=in_channels, char_num=self.char_num)
|
250
|
-
|
251
|
-
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
|
252
|
-
|
253
|
-
def forward(self, inputs, targets=None):
|
254
|
-
others = targets[-4:]
|
255
|
-
encoder_word_pos = others[0]
|
256
|
-
gsrm_word_pos = others[1]
|
257
|
-
gsrm_slf_attn_bias1 = others[2]
|
258
|
-
gsrm_slf_attn_bias2 = others[3]
|
259
|
-
|
260
|
-
pvam_feature = self.pvam(inputs, encoder_word_pos, gsrm_word_pos)
|
261
|
-
|
262
|
-
gsrm_feature, word_out, gsrm_out = self.gsrm(
|
263
|
-
pvam_feature, gsrm_word_pos, gsrm_slf_attn_bias1,
|
264
|
-
gsrm_slf_attn_bias2)
|
265
|
-
|
266
|
-
final_out = self.vsfd(pvam_feature, gsrm_feature)
|
267
|
-
if not self.training:
|
268
|
-
final_out = F.softmax(final_out, axis=1)
|
269
|
-
|
270
|
-
_, decoded_out = paddle.topk(final_out, k=1)
|
271
|
-
|
272
|
-
predicts = OrderedDict([
|
273
|
-
('predict', final_out),
|
274
|
-
('pvam_feature', pvam_feature),
|
275
|
-
('decoded_out', decoded_out),
|
276
|
-
('word_out', word_out),
|
277
|
-
('gsrm_out', gsrm_out),
|
278
|
-
])
|
279
|
-
|
280
|
-
return predicts
|