pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyxllib/algo/geo.py +12 -0
- pyxllib/algo/intervals.py +1 -1
- pyxllib/algo/matcher.py +78 -0
- pyxllib/algo/pupil.py +187 -19
- pyxllib/algo/specialist.py +2 -1
- pyxllib/algo/stat.py +38 -2
- {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
- pyxllib/autogui/activewin.py +246 -0
- pyxllib/autogui/all.py +9 -0
- pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
- pyxllib/autogui/uiautolib.py +362 -0
- pyxllib/autogui/wechat.py +827 -0
- pyxllib/autogui/wechat_msg.py +421 -0
- pyxllib/autogui/wxautolib.py +84 -0
- pyxllib/cv/slidercaptcha.py +137 -0
- pyxllib/data/echarts.py +123 -12
- pyxllib/data/jsonlib.py +89 -0
- pyxllib/data/pglib.py +514 -30
- pyxllib/data/sqlite.py +231 -4
- pyxllib/ext/JLineViewer.py +14 -1
- pyxllib/ext/drissionlib.py +277 -0
- pyxllib/ext/kq5034lib.py +0 -1594
- pyxllib/ext/robustprocfile.py +497 -0
- pyxllib/ext/unixlib.py +6 -5
- pyxllib/ext/utools.py +108 -95
- pyxllib/ext/webhook.py +32 -14
- pyxllib/ext/wjxlib.py +88 -0
- pyxllib/ext/wpsapi.py +124 -0
- pyxllib/ext/xlwork.py +9 -0
- pyxllib/ext/yuquelib.py +1003 -71
- pyxllib/file/docxlib.py +1 -1
- pyxllib/file/libreoffice.py +165 -0
- pyxllib/file/movielib.py +9 -0
- pyxllib/file/packlib/__init__.py +112 -75
- pyxllib/file/pdflib.py +1 -1
- pyxllib/file/pupil.py +1 -1
- pyxllib/file/specialist/dirlib.py +1 -1
- pyxllib/file/specialist/download.py +10 -3
- pyxllib/file/specialist/filelib.py +266 -55
- pyxllib/file/xlsxlib.py +205 -50
- pyxllib/file/xlsyncfile.py +341 -0
- pyxllib/prog/cachetools.py +64 -0
- pyxllib/prog/filelock.py +42 -0
- pyxllib/prog/multiprogs.py +940 -0
- pyxllib/prog/newbie.py +9 -2
- pyxllib/prog/pupil.py +129 -60
- pyxllib/prog/specialist/__init__.py +176 -2
- pyxllib/prog/specialist/bc.py +5 -2
- pyxllib/prog/specialist/browser.py +11 -2
- pyxllib/prog/specialist/datetime.py +68 -0
- pyxllib/prog/specialist/tictoc.py +12 -13
- pyxllib/prog/specialist/xllog.py +5 -5
- pyxllib/prog/xlosenv.py +7 -0
- pyxllib/text/airscript.js +744 -0
- pyxllib/text/charclasslib.py +17 -5
- pyxllib/text/jiebalib.py +6 -3
- pyxllib/text/jinjalib.py +32 -0
- pyxllib/text/jsa_ai_prompt.md +271 -0
- pyxllib/text/jscode.py +159 -4
- pyxllib/text/nestenv.py +1 -1
- pyxllib/text/newbie.py +12 -0
- pyxllib/text/pupil/common.py +26 -0
- pyxllib/text/specialist/ptag.py +2 -2
- pyxllib/text/templates/echart_base.html +11 -0
- pyxllib/text/templates/highlight_code.html +17 -0
- pyxllib/text/templates/latex_editor.html +103 -0
- pyxllib/text/xmllib.py +76 -14
- pyxllib/xl.py +2 -1
- pyxllib-0.3.197.dist-info/METADATA +48 -0
- pyxllib-0.3.197.dist-info/RECORD +126 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
- pyxllib/ext/autogui/__init__.py +0 -8
- pyxllib-0.3.96.dist-info/METADATA +0 -51
- pyxllib-0.3.96.dist-info/RECORD +0 -333
- pyxllib-0.3.96.dist-info/top_level.txt +0 -2
- pyxlpr/ai/__init__.py +0 -5
- pyxlpr/ai/clientlib.py +0 -1281
- pyxlpr/ai/specialist.py +0 -286
- pyxlpr/ai/torch_app.py +0 -172
- pyxlpr/ai/xlpaddle.py +0 -655
- pyxlpr/ai/xltorch.py +0 -705
- pyxlpr/data/__init__.py +0 -11
- pyxlpr/data/coco.py +0 -1325
- pyxlpr/data/datacls.py +0 -365
- pyxlpr/data/datasets.py +0 -200
- pyxlpr/data/gptlib.py +0 -1291
- pyxlpr/data/icdar/__init__.py +0 -96
- pyxlpr/data/icdar/deteval.py +0 -377
- pyxlpr/data/icdar/icdar2013.py +0 -341
- pyxlpr/data/icdar/iou.py +0 -340
- pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
- pyxlpr/data/imtextline.py +0 -473
- pyxlpr/data/labelme.py +0 -866
- pyxlpr/data/removeline.py +0 -179
- pyxlpr/data/specialist.py +0 -57
- pyxlpr/eval/__init__.py +0 -85
- pyxlpr/paddleocr.py +0 -776
- pyxlpr/ppocr/__init__.py +0 -15
- pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
- pyxlpr/ppocr/data/__init__.py +0 -135
- pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
- pyxlpr/ppocr/data/imaug/__init__.py +0 -67
- pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
- pyxlpr/ppocr/data/imaug/east_process.py +0 -437
- pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
- pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
- pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
- pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
- pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
- pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
- pyxlpr/ppocr/data/imaug/operators.py +0 -433
- pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
- pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
- pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
- pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
- pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
- pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
- pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
- pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
- pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
- pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
- pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
- pyxlpr/ppocr/data/simple_dataset.py +0 -372
- pyxlpr/ppocr/losses/__init__.py +0 -61
- pyxlpr/ppocr/losses/ace_loss.py +0 -52
- pyxlpr/ppocr/losses/basic_loss.py +0 -135
- pyxlpr/ppocr/losses/center_loss.py +0 -88
- pyxlpr/ppocr/losses/cls_loss.py +0 -30
- pyxlpr/ppocr/losses/combined_loss.py +0 -67
- pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
- pyxlpr/ppocr/losses/det_db_loss.py +0 -80
- pyxlpr/ppocr/losses/det_east_loss.py +0 -63
- pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
- pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
- pyxlpr/ppocr/losses/distillation_loss.py +0 -272
- pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
- pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
- pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
- pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
- pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
- pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
- pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
- pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
- pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
- pyxlpr/ppocr/losses/table_att_loss.py +0 -109
- pyxlpr/ppocr/metrics/__init__.py +0 -44
- pyxlpr/ppocr/metrics/cls_metric.py +0 -45
- pyxlpr/ppocr/metrics/det_metric.py +0 -82
- pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
- pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
- pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
- pyxlpr/ppocr/metrics/kie_metric.py +0 -70
- pyxlpr/ppocr/metrics/rec_metric.py +0 -75
- pyxlpr/ppocr/metrics/table_metric.py +0 -50
- pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
- pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
- pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
- pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
- pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
- pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
- pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
- pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
- pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
- pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
- pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
- pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
- pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
- pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
- pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
- pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
- pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
- pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
- pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
- pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
- pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
- pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
- pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
- pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
- pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
- pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
- pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
- pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
- pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
- pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
- pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
- pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
- pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
- pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
- pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
- pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
- pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
- pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
- pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
- pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
- pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
- pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
- pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
- pyxlpr/ppocr/optimizer/__init__.py +0 -61
- pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
- pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
- pyxlpr/ppocr/optimizer/optimizer.py +0 -160
- pyxlpr/ppocr/optimizer/regularizer.py +0 -52
- pyxlpr/ppocr/postprocess/__init__.py +0 -55
- pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
- pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
- pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
- pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
- pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
- pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
- pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
- pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
- pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
- pyxlpr/ppocr/tools/__init__.py +0 -14
- pyxlpr/ppocr/tools/eval.py +0 -83
- pyxlpr/ppocr/tools/export_center.py +0 -77
- pyxlpr/ppocr/tools/export_model.py +0 -129
- pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
- pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
- pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
- pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
- pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
- pyxlpr/ppocr/tools/infer/utility.py +0 -629
- pyxlpr/ppocr/tools/infer_cls.py +0 -83
- pyxlpr/ppocr/tools/infer_det.py +0 -134
- pyxlpr/ppocr/tools/infer_e2e.py +0 -122
- pyxlpr/ppocr/tools/infer_kie.py +0 -153
- pyxlpr/ppocr/tools/infer_rec.py +0 -146
- pyxlpr/ppocr/tools/infer_table.py +0 -107
- pyxlpr/ppocr/tools/program.py +0 -596
- pyxlpr/ppocr/tools/test_hubserving.py +0 -117
- pyxlpr/ppocr/tools/train.py +0 -163
- pyxlpr/ppocr/tools/xlprog.py +0 -748
- pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
- pyxlpr/ppocr/utils/__init__.py +0 -24
- pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
- pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
- pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
- pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
- pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
- pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
- pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
- pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
- pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
- pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
- pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
- pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
- pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
- pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
- pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
- pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
- pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
- pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
- pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
- pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
- pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
- pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
- pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
- pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
- pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
- pyxlpr/ppocr/utils/dict90.txt +0 -90
- pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
- pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
- pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
- pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
- pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
- pyxlpr/ppocr/utils/en_dict.txt +0 -95
- pyxlpr/ppocr/utils/gen_label.py +0 -81
- pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
- pyxlpr/ppocr/utils/iou.py +0 -54
- pyxlpr/ppocr/utils/logging.py +0 -69
- pyxlpr/ppocr/utils/network.py +0 -84
- pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
- pyxlpr/ppocr/utils/profiler.py +0 -110
- pyxlpr/ppocr/utils/save_load.py +0 -150
- pyxlpr/ppocr/utils/stats.py +0 -72
- pyxlpr/ppocr/utils/utility.py +0 -80
- pyxlpr/ppstructure/__init__.py +0 -13
- pyxlpr/ppstructure/predict_system.py +0 -187
- pyxlpr/ppstructure/table/__init__.py +0 -13
- pyxlpr/ppstructure/table/eval_table.py +0 -72
- pyxlpr/ppstructure/table/matcher.py +0 -192
- pyxlpr/ppstructure/table/predict_structure.py +0 -136
- pyxlpr/ppstructure/table/predict_table.py +0 -221
- pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
- pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
- pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
- pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
- pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
- pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
- pyxlpr/ppstructure/utility.py +0 -71
- pyxlpr/xlai.py +0 -10
- /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
@@ -1,393 +0,0 @@
|
|
1
|
-
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
"""
|
15
|
-
This code is refer from:
|
16
|
-
https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/attention_recognition_head.py
|
17
|
-
"""
|
18
|
-
from __future__ import absolute_import
|
19
|
-
from __future__ import division
|
20
|
-
from __future__ import print_function
|
21
|
-
|
22
|
-
import sys
|
23
|
-
|
24
|
-
import paddle
|
25
|
-
from paddle import nn
|
26
|
-
from paddle.nn import functional as F
|
27
|
-
|
28
|
-
|
29
|
-
class AsterHead(nn.Layer):
|
30
|
-
def __init__(self,
|
31
|
-
in_channels,
|
32
|
-
out_channels,
|
33
|
-
sDim,
|
34
|
-
attDim,
|
35
|
-
max_len_labels,
|
36
|
-
time_step=25,
|
37
|
-
beam_width=5,
|
38
|
-
**kwargs):
|
39
|
-
super(AsterHead, self).__init__()
|
40
|
-
self.num_classes = out_channels
|
41
|
-
self.in_planes = in_channels
|
42
|
-
self.sDim = sDim
|
43
|
-
self.attDim = attDim
|
44
|
-
self.max_len_labels = max_len_labels
|
45
|
-
self.decoder = AttentionRecognitionHead(in_channels, out_channels, sDim,
|
46
|
-
attDim, max_len_labels)
|
47
|
-
self.time_step = time_step
|
48
|
-
self.embeder = Embedding(self.time_step, in_channels)
|
49
|
-
self.beam_width = beam_width
|
50
|
-
self.eos = self.num_classes - 3
|
51
|
-
|
52
|
-
def forward(self, x, targets=None, embed=None):
|
53
|
-
return_dict = {}
|
54
|
-
embedding_vectors = self.embeder(x)
|
55
|
-
|
56
|
-
if self.training:
|
57
|
-
rec_targets, rec_lengths, _ = targets
|
58
|
-
rec_pred = self.decoder([x, rec_targets, rec_lengths],
|
59
|
-
embedding_vectors)
|
60
|
-
return_dict['rec_pred'] = rec_pred
|
61
|
-
return_dict['embedding_vectors'] = embedding_vectors
|
62
|
-
else:
|
63
|
-
rec_pred, rec_pred_scores = self.decoder.beam_search(
|
64
|
-
x, self.beam_width, self.eos, embedding_vectors)
|
65
|
-
return_dict['rec_pred'] = rec_pred
|
66
|
-
return_dict['rec_pred_scores'] = rec_pred_scores
|
67
|
-
return_dict['embedding_vectors'] = embedding_vectors
|
68
|
-
|
69
|
-
return return_dict
|
70
|
-
|
71
|
-
|
72
|
-
class Embedding(nn.Layer):
|
73
|
-
def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300):
|
74
|
-
super(Embedding, self).__init__()
|
75
|
-
self.in_timestep = in_timestep
|
76
|
-
self.in_planes = in_planes
|
77
|
-
self.embed_dim = embed_dim
|
78
|
-
self.mid_dim = mid_dim
|
79
|
-
self.eEmbed = nn.Linear(
|
80
|
-
in_timestep * in_planes,
|
81
|
-
self.embed_dim) # Embed encoder output to a word-embedding like
|
82
|
-
|
83
|
-
def forward(self, x):
|
84
|
-
x = paddle.reshape(x, [paddle.shape(x)[0], -1])
|
85
|
-
x = self.eEmbed(x)
|
86
|
-
return x
|
87
|
-
|
88
|
-
|
89
|
-
class AttentionRecognitionHead(nn.Layer):
|
90
|
-
"""
|
91
|
-
input: [b x 16 x 64 x in_planes]
|
92
|
-
output: probability sequence: [b x T x num_classes]
|
93
|
-
"""
|
94
|
-
|
95
|
-
def __init__(self, in_channels, out_channels, sDim, attDim, max_len_labels):
|
96
|
-
super(AttentionRecognitionHead, self).__init__()
|
97
|
-
self.num_classes = out_channels # this is the output classes. So it includes the <EOS>.
|
98
|
-
self.in_planes = in_channels
|
99
|
-
self.sDim = sDim
|
100
|
-
self.attDim = attDim
|
101
|
-
self.max_len_labels = max_len_labels
|
102
|
-
|
103
|
-
self.decoder = DecoderUnit(
|
104
|
-
sDim=sDim, xDim=in_channels, yDim=self.num_classes, attDim=attDim)
|
105
|
-
|
106
|
-
def forward(self, x, embed):
|
107
|
-
x, targets, lengths = x
|
108
|
-
batch_size = paddle.shape(x)[0]
|
109
|
-
# Decoder
|
110
|
-
state = self.decoder.get_initial_state(embed)
|
111
|
-
outputs = []
|
112
|
-
for i in range(max(lengths)):
|
113
|
-
if i == 0:
|
114
|
-
y_prev = paddle.full(
|
115
|
-
shape=[batch_size], fill_value=self.num_classes)
|
116
|
-
else:
|
117
|
-
y_prev = targets[:, i - 1]
|
118
|
-
output, state = self.decoder(x, state, y_prev)
|
119
|
-
outputs.append(output)
|
120
|
-
outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1)
|
121
|
-
return outputs
|
122
|
-
|
123
|
-
# inference stage.
|
124
|
-
def sample(self, x):
|
125
|
-
x, _, _ = x
|
126
|
-
batch_size = x.size(0)
|
127
|
-
# Decoder
|
128
|
-
state = paddle.zeros([1, batch_size, self.sDim])
|
129
|
-
|
130
|
-
predicted_ids, predicted_scores = [], []
|
131
|
-
for i in range(self.max_len_labels):
|
132
|
-
if i == 0:
|
133
|
-
y_prev = paddle.full(
|
134
|
-
shape=[batch_size], fill_value=self.num_classes)
|
135
|
-
else:
|
136
|
-
y_prev = predicted
|
137
|
-
|
138
|
-
output, state = self.decoder(x, state, y_prev)
|
139
|
-
output = F.softmax(output, axis=1)
|
140
|
-
score, predicted = output.max(1)
|
141
|
-
predicted_ids.append(predicted.unsqueeze(1))
|
142
|
-
predicted_scores.append(score.unsqueeze(1))
|
143
|
-
predicted_ids = paddle.concat([predicted_ids, 1])
|
144
|
-
predicted_scores = paddle.concat([predicted_scores, 1])
|
145
|
-
# return predicted_ids.squeeze(), predicted_scores.squeeze()
|
146
|
-
return predicted_ids, predicted_scores
|
147
|
-
|
148
|
-
def beam_search(self, x, beam_width, eos, embed):
|
149
|
-
def _inflate(tensor, times, dim):
|
150
|
-
repeat_dims = [1] * tensor.dim()
|
151
|
-
repeat_dims[dim] = times
|
152
|
-
output = paddle.tile(tensor, repeat_dims)
|
153
|
-
return output
|
154
|
-
|
155
|
-
# https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py
|
156
|
-
batch_size, l, d = x.shape
|
157
|
-
x = paddle.tile(
|
158
|
-
paddle.transpose(
|
159
|
-
x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1])
|
160
|
-
inflated_encoder_feats = paddle.reshape(
|
161
|
-
paddle.transpose(
|
162
|
-
x, perm=[1, 0, 2, 3]), [-1, l, d])
|
163
|
-
|
164
|
-
# Initialize the decoder
|
165
|
-
state = self.decoder.get_initial_state(embed, tile_times=beam_width)
|
166
|
-
|
167
|
-
pos_index = paddle.reshape(
|
168
|
-
paddle.arange(batch_size) * beam_width, shape=[-1, 1])
|
169
|
-
|
170
|
-
# Initialize the scores
|
171
|
-
sequence_scores = paddle.full(
|
172
|
-
shape=[batch_size * beam_width, 1], fill_value=-float('Inf'))
|
173
|
-
index = [i * beam_width for i in range(0, batch_size)]
|
174
|
-
sequence_scores[index] = 0.0
|
175
|
-
|
176
|
-
# Initialize the input vector
|
177
|
-
y_prev = paddle.full(
|
178
|
-
shape=[batch_size * beam_width], fill_value=self.num_classes)
|
179
|
-
|
180
|
-
# Store decisions for backtracking
|
181
|
-
stored_scores = list()
|
182
|
-
stored_predecessors = list()
|
183
|
-
stored_emitted_symbols = list()
|
184
|
-
|
185
|
-
for i in range(self.max_len_labels):
|
186
|
-
output, state = self.decoder(inflated_encoder_feats, state, y_prev)
|
187
|
-
state = paddle.unsqueeze(state, axis=0)
|
188
|
-
log_softmax_output = paddle.nn.functional.log_softmax(
|
189
|
-
output, axis=1)
|
190
|
-
|
191
|
-
sequence_scores = _inflate(sequence_scores, self.num_classes, 1)
|
192
|
-
sequence_scores += log_softmax_output
|
193
|
-
scores, candidates = paddle.topk(
|
194
|
-
paddle.reshape(sequence_scores, [batch_size, -1]),
|
195
|
-
beam_width,
|
196
|
-
axis=1)
|
197
|
-
|
198
|
-
# Reshape input = (bk, 1) and sequence_scores = (bk, 1)
|
199
|
-
y_prev = paddle.reshape(
|
200
|
-
candidates % self.num_classes, shape=[batch_size * beam_width])
|
201
|
-
sequence_scores = paddle.reshape(
|
202
|
-
scores, shape=[batch_size * beam_width, 1])
|
203
|
-
|
204
|
-
# Update fields for next timestep
|
205
|
-
pos_index = paddle.expand_as(pos_index, candidates)
|
206
|
-
predecessors = paddle.cast(
|
207
|
-
candidates / self.num_classes + pos_index, dtype='int64')
|
208
|
-
predecessors = paddle.reshape(
|
209
|
-
predecessors, shape=[batch_size * beam_width, 1])
|
210
|
-
state = paddle.index_select(
|
211
|
-
state, index=predecessors.squeeze(), axis=1)
|
212
|
-
|
213
|
-
# Update sequence socres and erase scores for <eos> symbol so that they aren't expanded
|
214
|
-
stored_scores.append(sequence_scores.clone())
|
215
|
-
y_prev = paddle.reshape(y_prev, shape=[-1, 1])
|
216
|
-
eos_prev = paddle.full_like(y_prev, fill_value=eos)
|
217
|
-
mask = eos_prev == y_prev
|
218
|
-
mask = paddle.nonzero(mask)
|
219
|
-
if mask.dim() > 0:
|
220
|
-
sequence_scores = sequence_scores.numpy()
|
221
|
-
mask = mask.numpy()
|
222
|
-
sequence_scores[mask] = -float('inf')
|
223
|
-
sequence_scores = paddle.to_tensor(sequence_scores)
|
224
|
-
|
225
|
-
# Cache results for backtracking
|
226
|
-
stored_predecessors.append(predecessors)
|
227
|
-
y_prev = paddle.squeeze(y_prev)
|
228
|
-
stored_emitted_symbols.append(y_prev)
|
229
|
-
|
230
|
-
# Do backtracking to return the optimal values
|
231
|
-
#====== backtrak ======#
|
232
|
-
# Initialize return variables given different types
|
233
|
-
p = list()
|
234
|
-
l = [[self.max_len_labels] * beam_width for _ in range(batch_size)
|
235
|
-
] # Placeholder for lengths of top-k sequences
|
236
|
-
|
237
|
-
# the last step output of the beams are not sorted
|
238
|
-
# thus they are sorted here
|
239
|
-
sorted_score, sorted_idx = paddle.topk(
|
240
|
-
paddle.reshape(
|
241
|
-
stored_scores[-1], shape=[batch_size, beam_width]),
|
242
|
-
beam_width)
|
243
|
-
|
244
|
-
# initialize the sequence scores with the sorted last step beam scores
|
245
|
-
s = sorted_score.clone()
|
246
|
-
|
247
|
-
batch_eos_found = [0] * batch_size # the number of EOS found
|
248
|
-
# in the backward loop below for each batch
|
249
|
-
t = self.max_len_labels - 1
|
250
|
-
# initialize the back pointer with the sorted order of the last step beams.
|
251
|
-
# add pos_index for indexing variable with b*k as the first dimension.
|
252
|
-
t_predecessors = paddle.reshape(
|
253
|
-
sorted_idx + pos_index.expand_as(sorted_idx),
|
254
|
-
shape=[batch_size * beam_width])
|
255
|
-
while t >= 0:
|
256
|
-
# Re-order the variables with the back pointer
|
257
|
-
current_symbol = paddle.index_select(
|
258
|
-
stored_emitted_symbols[t], index=t_predecessors, axis=0)
|
259
|
-
t_predecessors = paddle.index_select(
|
260
|
-
stored_predecessors[t].squeeze(), index=t_predecessors, axis=0)
|
261
|
-
eos_indices = stored_emitted_symbols[t] == eos
|
262
|
-
eos_indices = paddle.nonzero(eos_indices)
|
263
|
-
|
264
|
-
if eos_indices.dim() > 0:
|
265
|
-
for i in range(eos_indices.shape[0] - 1, -1, -1):
|
266
|
-
# Indices of the EOS symbol for both variables
|
267
|
-
# with b*k as the first dimension, and b, k for
|
268
|
-
# the first two dimensions
|
269
|
-
idx = eos_indices[i]
|
270
|
-
b_idx = int(idx[0] / beam_width)
|
271
|
-
# The indices of the replacing position
|
272
|
-
# according to the replacement strategy noted above
|
273
|
-
res_k_idx = beam_width - (batch_eos_found[b_idx] %
|
274
|
-
beam_width) - 1
|
275
|
-
batch_eos_found[b_idx] += 1
|
276
|
-
res_idx = b_idx * beam_width + res_k_idx
|
277
|
-
|
278
|
-
# Replace the old information in return variables
|
279
|
-
# with the new ended sequence information
|
280
|
-
t_predecessors[res_idx] = stored_predecessors[t][idx[0]]
|
281
|
-
current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]]
|
282
|
-
s[b_idx, res_k_idx] = stored_scores[t][idx[0], 0]
|
283
|
-
l[b_idx][res_k_idx] = t + 1
|
284
|
-
|
285
|
-
# record the back tracked results
|
286
|
-
p.append(current_symbol)
|
287
|
-
t -= 1
|
288
|
-
|
289
|
-
# Sort and re-order again as the added ended sequences may change
|
290
|
-
# the order (very unlikely)
|
291
|
-
s, re_sorted_idx = s.topk(beam_width)
|
292
|
-
for b_idx in range(batch_size):
|
293
|
-
l[b_idx] = [
|
294
|
-
l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :]
|
295
|
-
]
|
296
|
-
|
297
|
-
re_sorted_idx = paddle.reshape(
|
298
|
-
re_sorted_idx + pos_index.expand_as(re_sorted_idx),
|
299
|
-
[batch_size * beam_width])
|
300
|
-
|
301
|
-
# Reverse the sequences and re-order at the same time
|
302
|
-
# It is reversed because the backtracking happens in reverse time order
|
303
|
-
p = [
|
304
|
-
paddle.reshape(
|
305
|
-
paddle.index_select(step, re_sorted_idx, 0),
|
306
|
-
shape=[batch_size, beam_width, -1]) for step in reversed(p)
|
307
|
-
]
|
308
|
-
p = paddle.concat(p, -1)[:, 0, :]
|
309
|
-
return p, paddle.ones_like(p)
|
310
|
-
|
311
|
-
|
312
|
-
class AttentionUnit(nn.Layer):
|
313
|
-
def __init__(self, sDim, xDim, attDim):
|
314
|
-
super(AttentionUnit, self).__init__()
|
315
|
-
|
316
|
-
self.sDim = sDim
|
317
|
-
self.xDim = xDim
|
318
|
-
self.attDim = attDim
|
319
|
-
|
320
|
-
self.sEmbed = nn.Linear(sDim, attDim)
|
321
|
-
self.xEmbed = nn.Linear(xDim, attDim)
|
322
|
-
self.wEmbed = nn.Linear(attDim, 1)
|
323
|
-
|
324
|
-
def forward(self, x, sPrev):
|
325
|
-
batch_size, T, _ = x.shape # [b x T x xDim]
|
326
|
-
x = paddle.reshape(x, [-1, self.xDim]) # [(b x T) x xDim]
|
327
|
-
xProj = self.xEmbed(x) # [(b x T) x attDim]
|
328
|
-
xProj = paddle.reshape(xProj, [batch_size, T, -1]) # [b x T x attDim]
|
329
|
-
|
330
|
-
sPrev = sPrev.squeeze(0)
|
331
|
-
sProj = self.sEmbed(sPrev) # [b x attDim]
|
332
|
-
sProj = paddle.unsqueeze(sProj, 1) # [b x 1 x attDim]
|
333
|
-
sProj = paddle.expand(sProj,
|
334
|
-
[batch_size, T, self.attDim]) # [b x T x attDim]
|
335
|
-
|
336
|
-
sumTanh = paddle.tanh(sProj + xProj)
|
337
|
-
sumTanh = paddle.reshape(sumTanh, [-1, self.attDim])
|
338
|
-
|
339
|
-
vProj = self.wEmbed(sumTanh) # [(b x T) x 1]
|
340
|
-
vProj = paddle.reshape(vProj, [batch_size, T])
|
341
|
-
alpha = F.softmax(
|
342
|
-
vProj, axis=1) # attention weights for each sample in the minibatch
|
343
|
-
return alpha
|
344
|
-
|
345
|
-
|
346
|
-
class DecoderUnit(nn.Layer):
|
347
|
-
def __init__(self, sDim, xDim, yDim, attDim):
|
348
|
-
super(DecoderUnit, self).__init__()
|
349
|
-
self.sDim = sDim
|
350
|
-
self.xDim = xDim
|
351
|
-
self.yDim = yDim
|
352
|
-
self.attDim = attDim
|
353
|
-
self.emdDim = attDim
|
354
|
-
|
355
|
-
self.attention_unit = AttentionUnit(sDim, xDim, attDim)
|
356
|
-
self.tgt_embedding = nn.Embedding(
|
357
|
-
yDim + 1, self.emdDim, weight_attr=nn.initializer.Normal(
|
358
|
-
std=0.01)) # the last is used for <BOS>
|
359
|
-
self.gru = nn.GRUCell(input_size=xDim + self.emdDim, hidden_size=sDim)
|
360
|
-
self.fc = nn.Linear(
|
361
|
-
sDim,
|
362
|
-
yDim,
|
363
|
-
weight_attr=nn.initializer.Normal(std=0.01),
|
364
|
-
bias_attr=nn.initializer.Constant(value=0))
|
365
|
-
self.embed_fc = nn.Linear(300, self.sDim)
|
366
|
-
|
367
|
-
def get_initial_state(self, embed, tile_times=1):
|
368
|
-
assert embed.shape[1] == 300
|
369
|
-
state = self.embed_fc(embed) # N * sDim
|
370
|
-
if tile_times != 1:
|
371
|
-
state = state.unsqueeze(1)
|
372
|
-
trans_state = paddle.transpose(state, perm=[1, 0, 2])
|
373
|
-
state = paddle.tile(trans_state, repeat_times=[tile_times, 1, 1])
|
374
|
-
trans_state = paddle.transpose(state, perm=[1, 0, 2])
|
375
|
-
state = paddle.reshape(trans_state, shape=[-1, self.sDim])
|
376
|
-
state = state.unsqueeze(0) # 1 * N * sDim
|
377
|
-
return state
|
378
|
-
|
379
|
-
def forward(self, x, sPrev, yPrev):
|
380
|
-
# x: feature sequence from the image decoder.
|
381
|
-
batch_size, T, _ = x.shape
|
382
|
-
alpha = self.attention_unit(x, sPrev)
|
383
|
-
context = paddle.squeeze(paddle.matmul(alpha.unsqueeze(1), x), axis=1)
|
384
|
-
yPrev = paddle.cast(yPrev, dtype="int64")
|
385
|
-
yProj = self.tgt_embedding(yPrev)
|
386
|
-
|
387
|
-
concat_context = paddle.concat([yProj, context], 1)
|
388
|
-
concat_context = paddle.squeeze(concat_context, 1)
|
389
|
-
sPrev = paddle.squeeze(sPrev, 0)
|
390
|
-
output, state = self.gru(concat_context, sPrev)
|
391
|
-
output = paddle.squeeze(output, axis=1)
|
392
|
-
output = self.fc(output)
|
393
|
-
return output, state
|
@@ -1,202 +0,0 @@
|
|
1
|
-
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
|
19
|
-
import paddle
|
20
|
-
import paddle.nn as nn
|
21
|
-
import paddle.nn.functional as F
|
22
|
-
import numpy as np
|
23
|
-
|
24
|
-
|
25
|
-
class AttentionHead(nn.Layer):
|
26
|
-
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
|
27
|
-
super(AttentionHead, self).__init__()
|
28
|
-
self.input_size = in_channels
|
29
|
-
self.hidden_size = hidden_size
|
30
|
-
self.num_classes = out_channels
|
31
|
-
|
32
|
-
self.attention_cell = AttentionGRUCell(
|
33
|
-
in_channels, hidden_size, out_channels, use_gru=False)
|
34
|
-
self.generator = nn.Linear(hidden_size, out_channels)
|
35
|
-
|
36
|
-
def _char_to_onehot(self, input_char, onehot_dim):
|
37
|
-
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
38
|
-
return input_ont_hot
|
39
|
-
|
40
|
-
def forward(self, inputs, targets=None, batch_max_length=25):
|
41
|
-
batch_size = paddle.shape(inputs)[0]
|
42
|
-
num_steps = batch_max_length
|
43
|
-
|
44
|
-
hidden = paddle.zeros((batch_size, self.hidden_size))
|
45
|
-
output_hiddens = []
|
46
|
-
|
47
|
-
if targets is not None:
|
48
|
-
for i in range(num_steps):
|
49
|
-
char_onehots = self._char_to_onehot(
|
50
|
-
targets[:, i], onehot_dim=self.num_classes)
|
51
|
-
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
|
52
|
-
char_onehots)
|
53
|
-
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
54
|
-
output = paddle.concat(output_hiddens, axis=1)
|
55
|
-
probs = self.generator(output)
|
56
|
-
else:
|
57
|
-
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
58
|
-
probs = None
|
59
|
-
char_onehots = None
|
60
|
-
outputs = None
|
61
|
-
alpha = None
|
62
|
-
|
63
|
-
for i in range(num_steps):
|
64
|
-
char_onehots = self._char_to_onehot(
|
65
|
-
targets, onehot_dim=self.num_classes)
|
66
|
-
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
|
67
|
-
char_onehots)
|
68
|
-
probs_step = self.generator(outputs)
|
69
|
-
if probs is None:
|
70
|
-
probs = paddle.unsqueeze(probs_step, axis=1)
|
71
|
-
else:
|
72
|
-
probs = paddle.concat(
|
73
|
-
[probs, paddle.unsqueeze(
|
74
|
-
probs_step, axis=1)], axis=1)
|
75
|
-
next_input = probs_step.argmax(axis=1)
|
76
|
-
targets = next_input
|
77
|
-
if not self.training:
|
78
|
-
probs = paddle.nn.functional.softmax(probs, axis=2)
|
79
|
-
return probs
|
80
|
-
|
81
|
-
|
82
|
-
class AttentionGRUCell(nn.Layer):
|
83
|
-
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
84
|
-
super(AttentionGRUCell, self).__init__()
|
85
|
-
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
|
86
|
-
self.h2h = nn.Linear(hidden_size, hidden_size)
|
87
|
-
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
|
88
|
-
|
89
|
-
self.rnn = nn.GRUCell(
|
90
|
-
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
91
|
-
|
92
|
-
self.hidden_size = hidden_size
|
93
|
-
|
94
|
-
def forward(self, prev_hidden, batch_H, char_onehots):
|
95
|
-
|
96
|
-
batch_H_proj = self.i2h(batch_H)
|
97
|
-
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
|
98
|
-
|
99
|
-
res = paddle.add(batch_H_proj, prev_hidden_proj)
|
100
|
-
res = paddle.tanh(res)
|
101
|
-
e = self.score(res)
|
102
|
-
|
103
|
-
alpha = F.softmax(e, axis=1)
|
104
|
-
alpha = paddle.transpose(alpha, [0, 2, 1])
|
105
|
-
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
106
|
-
concat_context = paddle.concat([context, char_onehots], 1)
|
107
|
-
|
108
|
-
cur_hidden = self.rnn(concat_context, prev_hidden)
|
109
|
-
|
110
|
-
return cur_hidden, alpha
|
111
|
-
|
112
|
-
|
113
|
-
class AttentionLSTM(nn.Layer):
|
114
|
-
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
|
115
|
-
super(AttentionLSTM, self).__init__()
|
116
|
-
self.input_size = in_channels
|
117
|
-
self.hidden_size = hidden_size
|
118
|
-
self.num_classes = out_channels
|
119
|
-
|
120
|
-
self.attention_cell = AttentionLSTMCell(
|
121
|
-
in_channels, hidden_size, out_channels, use_gru=False)
|
122
|
-
self.generator = nn.Linear(hidden_size, out_channels)
|
123
|
-
|
124
|
-
def _char_to_onehot(self, input_char, onehot_dim):
|
125
|
-
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
126
|
-
return input_ont_hot
|
127
|
-
|
128
|
-
def forward(self, inputs, targets=None, batch_max_length=25):
|
129
|
-
batch_size = inputs.shape[0]
|
130
|
-
num_steps = batch_max_length
|
131
|
-
|
132
|
-
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
|
133
|
-
(batch_size, self.hidden_size)))
|
134
|
-
output_hiddens = []
|
135
|
-
|
136
|
-
if targets is not None:
|
137
|
-
for i in range(num_steps):
|
138
|
-
# one-hot vectors for a i-th char
|
139
|
-
char_onehots = self._char_to_onehot(
|
140
|
-
targets[:, i], onehot_dim=self.num_classes)
|
141
|
-
hidden, alpha = self.attention_cell(hidden, inputs,
|
142
|
-
char_onehots)
|
143
|
-
|
144
|
-
hidden = (hidden[1][0], hidden[1][1])
|
145
|
-
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
|
146
|
-
output = paddle.concat(output_hiddens, axis=1)
|
147
|
-
probs = self.generator(output)
|
148
|
-
|
149
|
-
else:
|
150
|
-
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
151
|
-
probs = None
|
152
|
-
|
153
|
-
for i in range(num_steps):
|
154
|
-
char_onehots = self._char_to_onehot(
|
155
|
-
targets, onehot_dim=self.num_classes)
|
156
|
-
hidden, alpha = self.attention_cell(hidden, inputs,
|
157
|
-
char_onehots)
|
158
|
-
probs_step = self.generator(hidden[0])
|
159
|
-
hidden = (hidden[1][0], hidden[1][1])
|
160
|
-
if probs is None:
|
161
|
-
probs = paddle.unsqueeze(probs_step, axis=1)
|
162
|
-
else:
|
163
|
-
probs = paddle.concat(
|
164
|
-
[probs, paddle.unsqueeze(
|
165
|
-
probs_step, axis=1)], axis=1)
|
166
|
-
|
167
|
-
next_input = probs_step.argmax(axis=1)
|
168
|
-
|
169
|
-
targets = next_input
|
170
|
-
|
171
|
-
return probs
|
172
|
-
|
173
|
-
|
174
|
-
class AttentionLSTMCell(nn.Layer):
|
175
|
-
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
176
|
-
super(AttentionLSTMCell, self).__init__()
|
177
|
-
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
|
178
|
-
self.h2h = nn.Linear(hidden_size, hidden_size)
|
179
|
-
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
|
180
|
-
if not use_gru:
|
181
|
-
self.rnn = nn.LSTMCell(
|
182
|
-
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
183
|
-
else:
|
184
|
-
self.rnn = nn.GRUCell(
|
185
|
-
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
186
|
-
|
187
|
-
self.hidden_size = hidden_size
|
188
|
-
|
189
|
-
def forward(self, prev_hidden, batch_H, char_onehots):
|
190
|
-
batch_H_proj = self.i2h(batch_H)
|
191
|
-
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
|
192
|
-
res = paddle.add(batch_H_proj, prev_hidden_proj)
|
193
|
-
res = paddle.tanh(res)
|
194
|
-
e = self.score(res)
|
195
|
-
|
196
|
-
alpha = F.softmax(e, axis=1)
|
197
|
-
alpha = paddle.transpose(alpha, [0, 2, 1])
|
198
|
-
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
199
|
-
concat_context = paddle.concat([context, char_onehots], 1)
|
200
|
-
cur_hidden = self.rnn(concat_context, prev_hidden)
|
201
|
-
|
202
|
-
return cur_hidden, alpha
|
@@ -1,88 +0,0 @@
|
|
1
|
-
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
|
19
|
-
import math
|
20
|
-
|
21
|
-
import paddle
|
22
|
-
from paddle import ParamAttr, nn
|
23
|
-
from paddle.nn import functional as F
|
24
|
-
|
25
|
-
|
26
|
-
def get_para_bias_attr(l2_decay, k):
|
27
|
-
regularizer = paddle.regularizer.L2Decay(l2_decay)
|
28
|
-
stdv = 1.0 / math.sqrt(k * 1.0)
|
29
|
-
initializer = nn.initializer.Uniform(-stdv, stdv)
|
30
|
-
weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
|
31
|
-
bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
|
32
|
-
return [weight_attr, bias_attr]
|
33
|
-
|
34
|
-
|
35
|
-
class CTCHead(nn.Layer):
|
36
|
-
def __init__(self,
|
37
|
-
in_channels,
|
38
|
-
out_channels,
|
39
|
-
fc_decay=0.0004,
|
40
|
-
mid_channels=None,
|
41
|
-
return_feats=False,
|
42
|
-
**kwargs):
|
43
|
-
super(CTCHead, self).__init__()
|
44
|
-
if mid_channels is None:
|
45
|
-
weight_attr, bias_attr = get_para_bias_attr(
|
46
|
-
l2_decay=fc_decay, k=in_channels)
|
47
|
-
self.fc = nn.Linear(
|
48
|
-
in_channels,
|
49
|
-
out_channels,
|
50
|
-
weight_attr=weight_attr,
|
51
|
-
bias_attr=bias_attr)
|
52
|
-
else:
|
53
|
-
weight_attr1, bias_attr1 = get_para_bias_attr(
|
54
|
-
l2_decay=fc_decay, k=in_channels)
|
55
|
-
self.fc1 = nn.Linear(
|
56
|
-
in_channels,
|
57
|
-
mid_channels,
|
58
|
-
weight_attr=weight_attr1,
|
59
|
-
bias_attr=bias_attr1)
|
60
|
-
|
61
|
-
weight_attr2, bias_attr2 = get_para_bias_attr(
|
62
|
-
l2_decay=fc_decay, k=mid_channels)
|
63
|
-
self.fc2 = nn.Linear(
|
64
|
-
mid_channels,
|
65
|
-
out_channels,
|
66
|
-
weight_attr=weight_attr2,
|
67
|
-
bias_attr=bias_attr2)
|
68
|
-
self.out_channels = out_channels
|
69
|
-
self.mid_channels = mid_channels
|
70
|
-
self.return_feats = return_feats
|
71
|
-
|
72
|
-
def forward(self, x, targets=None):
|
73
|
-
if self.mid_channels is None:
|
74
|
-
predicts = self.fc(x)
|
75
|
-
else:
|
76
|
-
x = self.fc1(x)
|
77
|
-
predicts = self.fc2(x)
|
78
|
-
|
79
|
-
if self.return_feats:
|
80
|
-
result = (x, predicts)
|
81
|
-
else:
|
82
|
-
result = predicts
|
83
|
-
|
84
|
-
if not self.training:
|
85
|
-
predicts = F.softmax(predicts, axis=2)
|
86
|
-
result = predicts
|
87
|
-
|
88
|
-
return result
|