pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyxllib/algo/geo.py +12 -0
- pyxllib/algo/intervals.py +1 -1
- pyxllib/algo/matcher.py +78 -0
- pyxllib/algo/pupil.py +187 -19
- pyxllib/algo/specialist.py +2 -1
- pyxllib/algo/stat.py +38 -2
- {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
- pyxllib/autogui/activewin.py +246 -0
- pyxllib/autogui/all.py +9 -0
- pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
- pyxllib/autogui/uiautolib.py +362 -0
- pyxllib/autogui/wechat.py +827 -0
- pyxllib/autogui/wechat_msg.py +421 -0
- pyxllib/autogui/wxautolib.py +84 -0
- pyxllib/cv/slidercaptcha.py +137 -0
- pyxllib/data/echarts.py +123 -12
- pyxllib/data/jsonlib.py +89 -0
- pyxllib/data/pglib.py +514 -30
- pyxllib/data/sqlite.py +231 -4
- pyxllib/ext/JLineViewer.py +14 -1
- pyxllib/ext/drissionlib.py +277 -0
- pyxllib/ext/kq5034lib.py +0 -1594
- pyxllib/ext/robustprocfile.py +497 -0
- pyxllib/ext/unixlib.py +6 -5
- pyxllib/ext/utools.py +108 -95
- pyxllib/ext/webhook.py +32 -14
- pyxllib/ext/wjxlib.py +88 -0
- pyxllib/ext/wpsapi.py +124 -0
- pyxllib/ext/xlwork.py +9 -0
- pyxllib/ext/yuquelib.py +1003 -71
- pyxllib/file/docxlib.py +1 -1
- pyxllib/file/libreoffice.py +165 -0
- pyxllib/file/movielib.py +9 -0
- pyxllib/file/packlib/__init__.py +112 -75
- pyxllib/file/pdflib.py +1 -1
- pyxllib/file/pupil.py +1 -1
- pyxllib/file/specialist/dirlib.py +1 -1
- pyxllib/file/specialist/download.py +10 -3
- pyxllib/file/specialist/filelib.py +266 -55
- pyxllib/file/xlsxlib.py +205 -50
- pyxllib/file/xlsyncfile.py +341 -0
- pyxllib/prog/cachetools.py +64 -0
- pyxllib/prog/filelock.py +42 -0
- pyxllib/prog/multiprogs.py +940 -0
- pyxllib/prog/newbie.py +9 -2
- pyxllib/prog/pupil.py +129 -60
- pyxllib/prog/specialist/__init__.py +176 -2
- pyxllib/prog/specialist/bc.py +5 -2
- pyxllib/prog/specialist/browser.py +11 -2
- pyxllib/prog/specialist/datetime.py +68 -0
- pyxllib/prog/specialist/tictoc.py +12 -13
- pyxllib/prog/specialist/xllog.py +5 -5
- pyxllib/prog/xlosenv.py +7 -0
- pyxllib/text/airscript.js +744 -0
- pyxllib/text/charclasslib.py +17 -5
- pyxllib/text/jiebalib.py +6 -3
- pyxllib/text/jinjalib.py +32 -0
- pyxllib/text/jsa_ai_prompt.md +271 -0
- pyxllib/text/jscode.py +159 -4
- pyxllib/text/nestenv.py +1 -1
- pyxllib/text/newbie.py +12 -0
- pyxllib/text/pupil/common.py +26 -0
- pyxllib/text/specialist/ptag.py +2 -2
- pyxllib/text/templates/echart_base.html +11 -0
- pyxllib/text/templates/highlight_code.html +17 -0
- pyxllib/text/templates/latex_editor.html +103 -0
- pyxllib/text/xmllib.py +76 -14
- pyxllib/xl.py +2 -1
- pyxllib-0.3.197.dist-info/METADATA +48 -0
- pyxllib-0.3.197.dist-info/RECORD +126 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
- pyxllib/ext/autogui/__init__.py +0 -8
- pyxllib-0.3.96.dist-info/METADATA +0 -51
- pyxllib-0.3.96.dist-info/RECORD +0 -333
- pyxllib-0.3.96.dist-info/top_level.txt +0 -2
- pyxlpr/ai/__init__.py +0 -5
- pyxlpr/ai/clientlib.py +0 -1281
- pyxlpr/ai/specialist.py +0 -286
- pyxlpr/ai/torch_app.py +0 -172
- pyxlpr/ai/xlpaddle.py +0 -655
- pyxlpr/ai/xltorch.py +0 -705
- pyxlpr/data/__init__.py +0 -11
- pyxlpr/data/coco.py +0 -1325
- pyxlpr/data/datacls.py +0 -365
- pyxlpr/data/datasets.py +0 -200
- pyxlpr/data/gptlib.py +0 -1291
- pyxlpr/data/icdar/__init__.py +0 -96
- pyxlpr/data/icdar/deteval.py +0 -377
- pyxlpr/data/icdar/icdar2013.py +0 -341
- pyxlpr/data/icdar/iou.py +0 -340
- pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
- pyxlpr/data/imtextline.py +0 -473
- pyxlpr/data/labelme.py +0 -866
- pyxlpr/data/removeline.py +0 -179
- pyxlpr/data/specialist.py +0 -57
- pyxlpr/eval/__init__.py +0 -85
- pyxlpr/paddleocr.py +0 -776
- pyxlpr/ppocr/__init__.py +0 -15
- pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
- pyxlpr/ppocr/data/__init__.py +0 -135
- pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
- pyxlpr/ppocr/data/imaug/__init__.py +0 -67
- pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
- pyxlpr/ppocr/data/imaug/east_process.py +0 -437
- pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
- pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
- pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
- pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
- pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
- pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
- pyxlpr/ppocr/data/imaug/operators.py +0 -433
- pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
- pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
- pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
- pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
- pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
- pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
- pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
- pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
- pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
- pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
- pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
- pyxlpr/ppocr/data/simple_dataset.py +0 -372
- pyxlpr/ppocr/losses/__init__.py +0 -61
- pyxlpr/ppocr/losses/ace_loss.py +0 -52
- pyxlpr/ppocr/losses/basic_loss.py +0 -135
- pyxlpr/ppocr/losses/center_loss.py +0 -88
- pyxlpr/ppocr/losses/cls_loss.py +0 -30
- pyxlpr/ppocr/losses/combined_loss.py +0 -67
- pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
- pyxlpr/ppocr/losses/det_db_loss.py +0 -80
- pyxlpr/ppocr/losses/det_east_loss.py +0 -63
- pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
- pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
- pyxlpr/ppocr/losses/distillation_loss.py +0 -272
- pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
- pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
- pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
- pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
- pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
- pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
- pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
- pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
- pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
- pyxlpr/ppocr/losses/table_att_loss.py +0 -109
- pyxlpr/ppocr/metrics/__init__.py +0 -44
- pyxlpr/ppocr/metrics/cls_metric.py +0 -45
- pyxlpr/ppocr/metrics/det_metric.py +0 -82
- pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
- pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
- pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
- pyxlpr/ppocr/metrics/kie_metric.py +0 -70
- pyxlpr/ppocr/metrics/rec_metric.py +0 -75
- pyxlpr/ppocr/metrics/table_metric.py +0 -50
- pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
- pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
- pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
- pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
- pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
- pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
- pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
- pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
- pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
- pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
- pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
- pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
- pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
- pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
- pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
- pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
- pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
- pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
- pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
- pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
- pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
- pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
- pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
- pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
- pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
- pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
- pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
- pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
- pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
- pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
- pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
- pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
- pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
- pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
- pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
- pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
- pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
- pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
- pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
- pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
- pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
- pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
- pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
- pyxlpr/ppocr/optimizer/__init__.py +0 -61
- pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
- pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
- pyxlpr/ppocr/optimizer/optimizer.py +0 -160
- pyxlpr/ppocr/optimizer/regularizer.py +0 -52
- pyxlpr/ppocr/postprocess/__init__.py +0 -55
- pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
- pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
- pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
- pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
- pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
- pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
- pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
- pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
- pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
- pyxlpr/ppocr/tools/__init__.py +0 -14
- pyxlpr/ppocr/tools/eval.py +0 -83
- pyxlpr/ppocr/tools/export_center.py +0 -77
- pyxlpr/ppocr/tools/export_model.py +0 -129
- pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
- pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
- pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
- pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
- pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
- pyxlpr/ppocr/tools/infer/utility.py +0 -629
- pyxlpr/ppocr/tools/infer_cls.py +0 -83
- pyxlpr/ppocr/tools/infer_det.py +0 -134
- pyxlpr/ppocr/tools/infer_e2e.py +0 -122
- pyxlpr/ppocr/tools/infer_kie.py +0 -153
- pyxlpr/ppocr/tools/infer_rec.py +0 -146
- pyxlpr/ppocr/tools/infer_table.py +0 -107
- pyxlpr/ppocr/tools/program.py +0 -596
- pyxlpr/ppocr/tools/test_hubserving.py +0 -117
- pyxlpr/ppocr/tools/train.py +0 -163
- pyxlpr/ppocr/tools/xlprog.py +0 -748
- pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
- pyxlpr/ppocr/utils/__init__.py +0 -24
- pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
- pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
- pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
- pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
- pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
- pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
- pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
- pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
- pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
- pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
- pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
- pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
- pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
- pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
- pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
- pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
- pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
- pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
- pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
- pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
- pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
- pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
- pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
- pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
- pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
- pyxlpr/ppocr/utils/dict90.txt +0 -90
- pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
- pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
- pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
- pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
- pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
- pyxlpr/ppocr/utils/en_dict.txt +0 -95
- pyxlpr/ppocr/utils/gen_label.py +0 -81
- pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
- pyxlpr/ppocr/utils/iou.py +0 -54
- pyxlpr/ppocr/utils/logging.py +0 -69
- pyxlpr/ppocr/utils/network.py +0 -84
- pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
- pyxlpr/ppocr/utils/profiler.py +0 -110
- pyxlpr/ppocr/utils/save_load.py +0 -150
- pyxlpr/ppocr/utils/stats.py +0 -72
- pyxlpr/ppocr/utils/utility.py +0 -80
- pyxlpr/ppstructure/__init__.py +0 -13
- pyxlpr/ppstructure/predict_system.py +0 -187
- pyxlpr/ppstructure/table/__init__.py +0 -13
- pyxlpr/ppstructure/table/eval_table.py +0 -72
- pyxlpr/ppstructure/table/matcher.py +0 -192
- pyxlpr/ppstructure/table/predict_structure.py +0 -136
- pyxlpr/ppstructure/table/predict_table.py +0 -221
- pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
- pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
- pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
- pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
- pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
- pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
- pyxlpr/ppstructure/utility.py +0 -71
- pyxlpr/xlai.py +0 -10
- /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
@@ -1,206 +0,0 @@
|
|
1
|
-
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
|
19
|
-
import math
|
20
|
-
import paddle
|
21
|
-
from paddle import nn
|
22
|
-
import paddle.nn.functional as F
|
23
|
-
from paddle import ParamAttr
|
24
|
-
|
25
|
-
|
26
|
-
class SDMGRHead(nn.Layer):
|
27
|
-
def __init__(self,
|
28
|
-
in_channels,
|
29
|
-
num_chars=92,
|
30
|
-
visual_dim=16,
|
31
|
-
fusion_dim=1024,
|
32
|
-
node_input=32,
|
33
|
-
node_embed=256,
|
34
|
-
edge_input=5,
|
35
|
-
edge_embed=256,
|
36
|
-
num_gnn=2,
|
37
|
-
num_classes=26,
|
38
|
-
bidirectional=False):
|
39
|
-
super().__init__()
|
40
|
-
|
41
|
-
self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)
|
42
|
-
self.node_embed = nn.Embedding(num_chars, node_input, 0)
|
43
|
-
hidden = node_embed // 2 if bidirectional else node_embed
|
44
|
-
self.rnn = nn.LSTM(
|
45
|
-
input_size=node_input, hidden_size=hidden, num_layers=1)
|
46
|
-
self.edge_embed = nn.Linear(edge_input, edge_embed)
|
47
|
-
self.gnn_layers = nn.LayerList(
|
48
|
-
[GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
|
49
|
-
self.node_cls = nn.Linear(node_embed, num_classes)
|
50
|
-
self.edge_cls = nn.Linear(edge_embed, 2)
|
51
|
-
|
52
|
-
def forward(self, input, targets):
|
53
|
-
relations, texts, x = input
|
54
|
-
node_nums, char_nums = [], []
|
55
|
-
for text in texts:
|
56
|
-
node_nums.append(text.shape[0])
|
57
|
-
char_nums.append(paddle.sum((text > -1).astype(int), axis=-1))
|
58
|
-
|
59
|
-
max_num = max([char_num.max() for char_num in char_nums])
|
60
|
-
all_nodes = paddle.concat([
|
61
|
-
paddle.concat(
|
62
|
-
[text, paddle.zeros(
|
63
|
-
(text.shape[0], max_num - text.shape[1]))], -1)
|
64
|
-
for text in texts
|
65
|
-
])
|
66
|
-
temp = paddle.clip(all_nodes, min=0).astype(int)
|
67
|
-
embed_nodes = self.node_embed(temp)
|
68
|
-
rnn_nodes, _ = self.rnn(embed_nodes)
|
69
|
-
|
70
|
-
b, h, w = rnn_nodes.shape
|
71
|
-
nodes = paddle.zeros([b, w])
|
72
|
-
all_nums = paddle.concat(char_nums)
|
73
|
-
valid = paddle.nonzero((all_nums > 0).astype(int))
|
74
|
-
temp_all_nums = (
|
75
|
-
paddle.gather(all_nums, valid) - 1).unsqueeze(-1).unsqueeze(-1)
|
76
|
-
temp_all_nums = paddle.expand(temp_all_nums, [
|
77
|
-
temp_all_nums.shape[0], temp_all_nums.shape[1], rnn_nodes.shape[-1]
|
78
|
-
])
|
79
|
-
temp_all_nodes = paddle.gather(rnn_nodes, valid)
|
80
|
-
N, C, A = temp_all_nodes.shape
|
81
|
-
one_hot = F.one_hot(
|
82
|
-
temp_all_nums[:, 0, :], num_classes=C).transpose([0, 2, 1])
|
83
|
-
one_hot = paddle.multiply(
|
84
|
-
temp_all_nodes, one_hot.astype("float32")).sum(axis=1, keepdim=True)
|
85
|
-
t = one_hot.expand([N, 1, A]).squeeze(1)
|
86
|
-
nodes = paddle.scatter(nodes, valid.squeeze(1), t)
|
87
|
-
|
88
|
-
if x is not None:
|
89
|
-
nodes = self.fusion([x, nodes])
|
90
|
-
|
91
|
-
all_edges = paddle.concat(
|
92
|
-
[rel.reshape([-1, rel.shape[-1]]) for rel in relations])
|
93
|
-
embed_edges = self.edge_embed(all_edges.astype('float32'))
|
94
|
-
embed_edges = F.normalize(embed_edges)
|
95
|
-
|
96
|
-
for gnn_layer in self.gnn_layers:
|
97
|
-
nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)
|
98
|
-
|
99
|
-
node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)
|
100
|
-
return node_cls, edge_cls
|
101
|
-
|
102
|
-
|
103
|
-
class GNNLayer(nn.Layer):
|
104
|
-
def __init__(self, node_dim=256, edge_dim=256):
|
105
|
-
super().__init__()
|
106
|
-
self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim)
|
107
|
-
self.coef_fc = nn.Linear(node_dim, 1)
|
108
|
-
self.out_fc = nn.Linear(node_dim, node_dim)
|
109
|
-
self.relu = nn.ReLU()
|
110
|
-
|
111
|
-
def forward(self, nodes, edges, nums):
|
112
|
-
start, cat_nodes = 0, []
|
113
|
-
for num in nums:
|
114
|
-
sample_nodes = nodes[start:start + num]
|
115
|
-
cat_nodes.append(
|
116
|
-
paddle.concat([
|
117
|
-
paddle.expand(sample_nodes.unsqueeze(1), [-1, num, -1]),
|
118
|
-
paddle.expand(sample_nodes.unsqueeze(0), [num, -1, -1])
|
119
|
-
], -1).reshape([num**2, -1]))
|
120
|
-
start += num
|
121
|
-
cat_nodes = paddle.concat([paddle.concat(cat_nodes), edges], -1)
|
122
|
-
cat_nodes = self.relu(self.in_fc(cat_nodes))
|
123
|
-
coefs = self.coef_fc(cat_nodes)
|
124
|
-
|
125
|
-
start, residuals = 0, []
|
126
|
-
for num in nums:
|
127
|
-
residual = F.softmax(
|
128
|
-
-paddle.eye(num).unsqueeze(-1) * 1e9 +
|
129
|
-
coefs[start:start + num**2].reshape([num, num, -1]), 1)
|
130
|
-
residuals.append((residual * cat_nodes[start:start + num**2]
|
131
|
-
.reshape([num, num, -1])).sum(1))
|
132
|
-
start += num**2
|
133
|
-
|
134
|
-
nodes += self.relu(self.out_fc(paddle.concat(residuals)))
|
135
|
-
return [nodes, cat_nodes]
|
136
|
-
|
137
|
-
|
138
|
-
class Block(nn.Layer):
|
139
|
-
def __init__(self,
|
140
|
-
input_dims,
|
141
|
-
output_dim,
|
142
|
-
mm_dim=1600,
|
143
|
-
chunks=20,
|
144
|
-
rank=15,
|
145
|
-
shared=False,
|
146
|
-
dropout_input=0.,
|
147
|
-
dropout_pre_lin=0.,
|
148
|
-
dropout_output=0.,
|
149
|
-
pos_norm='before_cat'):
|
150
|
-
super().__init__()
|
151
|
-
self.rank = rank
|
152
|
-
self.dropout_input = dropout_input
|
153
|
-
self.dropout_pre_lin = dropout_pre_lin
|
154
|
-
self.dropout_output = dropout_output
|
155
|
-
assert (pos_norm in ['before_cat', 'after_cat'])
|
156
|
-
self.pos_norm = pos_norm
|
157
|
-
# Modules
|
158
|
-
self.linear0 = nn.Linear(input_dims[0], mm_dim)
|
159
|
-
self.linear1 = (self.linear0
|
160
|
-
if shared else nn.Linear(input_dims[1], mm_dim))
|
161
|
-
self.merge_linears0 = nn.LayerList()
|
162
|
-
self.merge_linears1 = nn.LayerList()
|
163
|
-
self.chunks = self.chunk_sizes(mm_dim, chunks)
|
164
|
-
for size in self.chunks:
|
165
|
-
ml0 = nn.Linear(size, size * rank)
|
166
|
-
self.merge_linears0.append(ml0)
|
167
|
-
ml1 = ml0 if shared else nn.Linear(size, size * rank)
|
168
|
-
self.merge_linears1.append(ml1)
|
169
|
-
self.linear_out = nn.Linear(mm_dim, output_dim)
|
170
|
-
|
171
|
-
def forward(self, x):
|
172
|
-
x0 = self.linear0(x[0])
|
173
|
-
x1 = self.linear1(x[1])
|
174
|
-
bs = x1.shape[0]
|
175
|
-
if self.dropout_input > 0:
|
176
|
-
x0 = F.dropout(x0, p=self.dropout_input, training=self.training)
|
177
|
-
x1 = F.dropout(x1, p=self.dropout_input, training=self.training)
|
178
|
-
x0_chunks = paddle.split(x0, self.chunks, -1)
|
179
|
-
x1_chunks = paddle.split(x1, self.chunks, -1)
|
180
|
-
zs = []
|
181
|
-
for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks, self.merge_linears0,
|
182
|
-
self.merge_linears1):
|
183
|
-
m = m0(x0_c) * m1(x1_c) # bs x split_size*rank
|
184
|
-
m = m.reshape([bs, self.rank, -1])
|
185
|
-
z = paddle.sum(m, 1)
|
186
|
-
if self.pos_norm == 'before_cat':
|
187
|
-
z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z))
|
188
|
-
z = F.normalize(z)
|
189
|
-
zs.append(z)
|
190
|
-
z = paddle.concat(zs, 1)
|
191
|
-
if self.pos_norm == 'after_cat':
|
192
|
-
z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z))
|
193
|
-
z = F.normalize(z)
|
194
|
-
|
195
|
-
if self.dropout_pre_lin > 0:
|
196
|
-
z = F.dropout(z, p=self.dropout_pre_lin, training=self.training)
|
197
|
-
z = self.linear_out(z)
|
198
|
-
if self.dropout_output > 0:
|
199
|
-
z = F.dropout(z, p=self.dropout_output, training=self.training)
|
200
|
-
return z
|
201
|
-
|
202
|
-
def chunk_sizes(self, dim, chunks):
|
203
|
-
split_size = (dim + chunks - 1) // chunks
|
204
|
-
sizes_list = [split_size] * chunks
|
205
|
-
sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim)
|
206
|
-
return sizes_list
|
@@ -1,163 +0,0 @@
|
|
1
|
-
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
import paddle
|
16
|
-
from paddle import nn
|
17
|
-
import paddle.nn.functional as F
|
18
|
-
from paddle.nn import Linear
|
19
|
-
from paddle.nn.initializer import XavierUniform as xavier_uniform_
|
20
|
-
from paddle.nn.initializer import Constant as constant_
|
21
|
-
from paddle.nn.initializer import XavierNormal as xavier_normal_
|
22
|
-
|
23
|
-
zeros_ = constant_(value=0.)
|
24
|
-
ones_ = constant_(value=1.)
|
25
|
-
|
26
|
-
|
27
|
-
class MultiheadAttention(nn.Layer):
|
28
|
-
"""Allows the model to jointly attend to information
|
29
|
-
from different representation subspaces.
|
30
|
-
See reference: Attention Is All You Need
|
31
|
-
|
32
|
-
.. math::
|
33
|
-
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
34
|
-
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
35
|
-
|
36
|
-
Args:
|
37
|
-
embed_dim: total dimension of the model
|
38
|
-
num_heads: parallel attention layers, or heads
|
39
|
-
|
40
|
-
"""
|
41
|
-
|
42
|
-
def __init__(self,
|
43
|
-
embed_dim,
|
44
|
-
num_heads,
|
45
|
-
dropout=0.,
|
46
|
-
bias=True,
|
47
|
-
add_bias_kv=False,
|
48
|
-
add_zero_attn=False):
|
49
|
-
super(MultiheadAttention, self).__init__()
|
50
|
-
self.embed_dim = embed_dim
|
51
|
-
self.num_heads = num_heads
|
52
|
-
self.dropout = dropout
|
53
|
-
self.head_dim = embed_dim // num_heads
|
54
|
-
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
55
|
-
self.scaling = self.head_dim**-0.5
|
56
|
-
self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
|
57
|
-
self._reset_parameters()
|
58
|
-
self.conv1 = paddle.nn.Conv2D(
|
59
|
-
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
60
|
-
self.conv2 = paddle.nn.Conv2D(
|
61
|
-
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
62
|
-
self.conv3 = paddle.nn.Conv2D(
|
63
|
-
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
64
|
-
|
65
|
-
def _reset_parameters(self):
|
66
|
-
xavier_uniform_(self.out_proj.weight)
|
67
|
-
|
68
|
-
def forward(self,
|
69
|
-
query,
|
70
|
-
key,
|
71
|
-
value,
|
72
|
-
key_padding_mask=None,
|
73
|
-
incremental_state=None,
|
74
|
-
attn_mask=None):
|
75
|
-
"""
|
76
|
-
Inputs of forward function
|
77
|
-
query: [target length, batch size, embed dim]
|
78
|
-
key: [sequence length, batch size, embed dim]
|
79
|
-
value: [sequence length, batch size, embed dim]
|
80
|
-
key_padding_mask: if True, mask padding based on batch size
|
81
|
-
incremental_state: if provided, previous time steps are cashed
|
82
|
-
need_weights: output attn_output_weights
|
83
|
-
static_kv: key and value are static
|
84
|
-
|
85
|
-
Outputs of forward function
|
86
|
-
attn_output: [target length, batch size, embed dim]
|
87
|
-
attn_output_weights: [batch size, target length, sequence length]
|
88
|
-
"""
|
89
|
-
q_shape = paddle.shape(query)
|
90
|
-
src_shape = paddle.shape(key)
|
91
|
-
q = self._in_proj_q(query)
|
92
|
-
k = self._in_proj_k(key)
|
93
|
-
v = self._in_proj_v(value)
|
94
|
-
q *= self.scaling
|
95
|
-
q = paddle.transpose(
|
96
|
-
paddle.reshape(
|
97
|
-
q, [q_shape[0], q_shape[1], self.num_heads, self.head_dim]),
|
98
|
-
[1, 2, 0, 3])
|
99
|
-
k = paddle.transpose(
|
100
|
-
paddle.reshape(
|
101
|
-
k, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
|
102
|
-
[1, 2, 0, 3])
|
103
|
-
v = paddle.transpose(
|
104
|
-
paddle.reshape(
|
105
|
-
v, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
|
106
|
-
[1, 2, 0, 3])
|
107
|
-
if key_padding_mask is not None:
|
108
|
-
assert key_padding_mask.shape[0] == q_shape[1]
|
109
|
-
assert key_padding_mask.shape[1] == src_shape[0]
|
110
|
-
attn_output_weights = paddle.matmul(q,
|
111
|
-
paddle.transpose(k, [0, 1, 3, 2]))
|
112
|
-
if attn_mask is not None:
|
113
|
-
attn_mask = paddle.unsqueeze(paddle.unsqueeze(attn_mask, 0), 0)
|
114
|
-
attn_output_weights += attn_mask
|
115
|
-
if key_padding_mask is not None:
|
116
|
-
attn_output_weights = paddle.reshape(
|
117
|
-
attn_output_weights,
|
118
|
-
[q_shape[1], self.num_heads, q_shape[0], src_shape[0]])
|
119
|
-
key = paddle.unsqueeze(paddle.unsqueeze(key_padding_mask, 1), 2)
|
120
|
-
key = paddle.cast(key, 'float32')
|
121
|
-
y = paddle.full(
|
122
|
-
shape=paddle.shape(key), dtype='float32', fill_value='-inf')
|
123
|
-
y = paddle.where(key == 0., key, y)
|
124
|
-
attn_output_weights += y
|
125
|
-
attn_output_weights = F.softmax(
|
126
|
-
attn_output_weights.astype('float32'),
|
127
|
-
axis=-1,
|
128
|
-
dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16
|
129
|
-
else attn_output_weights.dtype)
|
130
|
-
attn_output_weights = F.dropout(
|
131
|
-
attn_output_weights, p=self.dropout, training=self.training)
|
132
|
-
|
133
|
-
attn_output = paddle.matmul(attn_output_weights, v)
|
134
|
-
attn_output = paddle.reshape(
|
135
|
-
paddle.transpose(attn_output, [2, 0, 1, 3]),
|
136
|
-
[q_shape[0], q_shape[1], self.embed_dim])
|
137
|
-
attn_output = self.out_proj(attn_output)
|
138
|
-
|
139
|
-
return attn_output
|
140
|
-
|
141
|
-
def _in_proj_q(self, query):
|
142
|
-
query = paddle.transpose(query, [1, 2, 0])
|
143
|
-
query = paddle.unsqueeze(query, axis=2)
|
144
|
-
res = self.conv1(query)
|
145
|
-
res = paddle.squeeze(res, axis=2)
|
146
|
-
res = paddle.transpose(res, [2, 0, 1])
|
147
|
-
return res
|
148
|
-
|
149
|
-
def _in_proj_k(self, key):
|
150
|
-
key = paddle.transpose(key, [1, 2, 0])
|
151
|
-
key = paddle.unsqueeze(key, axis=2)
|
152
|
-
res = self.conv2(key)
|
153
|
-
res = paddle.squeeze(res, axis=2)
|
154
|
-
res = paddle.transpose(res, [2, 0, 1])
|
155
|
-
return res
|
156
|
-
|
157
|
-
def _in_proj_v(self, value):
|
158
|
-
value = paddle.transpose(value, [1, 2, 0]) #(1, 2, 0)
|
159
|
-
value = paddle.unsqueeze(value, axis=2)
|
160
|
-
res = self.conv3(value)
|
161
|
-
res = paddle.squeeze(res, axis=2)
|
162
|
-
res = paddle.transpose(res, [2, 0, 1])
|
163
|
-
return res
|