pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyxllib/algo/geo.py +12 -0
- pyxllib/algo/intervals.py +1 -1
- pyxllib/algo/matcher.py +78 -0
- pyxllib/algo/pupil.py +187 -19
- pyxllib/algo/specialist.py +2 -1
- pyxllib/algo/stat.py +38 -2
- {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
- pyxllib/autogui/activewin.py +246 -0
- pyxllib/autogui/all.py +9 -0
- pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
- pyxllib/autogui/uiautolib.py +362 -0
- pyxllib/autogui/wechat.py +827 -0
- pyxllib/autogui/wechat_msg.py +421 -0
- pyxllib/autogui/wxautolib.py +84 -0
- pyxllib/cv/slidercaptcha.py +137 -0
- pyxllib/data/echarts.py +123 -12
- pyxllib/data/jsonlib.py +89 -0
- pyxllib/data/pglib.py +514 -30
- pyxllib/data/sqlite.py +231 -4
- pyxllib/ext/JLineViewer.py +14 -1
- pyxllib/ext/drissionlib.py +277 -0
- pyxllib/ext/kq5034lib.py +0 -1594
- pyxllib/ext/robustprocfile.py +497 -0
- pyxllib/ext/unixlib.py +6 -5
- pyxllib/ext/utools.py +108 -95
- pyxllib/ext/webhook.py +32 -14
- pyxllib/ext/wjxlib.py +88 -0
- pyxllib/ext/wpsapi.py +124 -0
- pyxllib/ext/xlwork.py +9 -0
- pyxllib/ext/yuquelib.py +1003 -71
- pyxllib/file/docxlib.py +1 -1
- pyxllib/file/libreoffice.py +165 -0
- pyxllib/file/movielib.py +9 -0
- pyxllib/file/packlib/__init__.py +112 -75
- pyxllib/file/pdflib.py +1 -1
- pyxllib/file/pupil.py +1 -1
- pyxllib/file/specialist/dirlib.py +1 -1
- pyxllib/file/specialist/download.py +10 -3
- pyxllib/file/specialist/filelib.py +266 -55
- pyxllib/file/xlsxlib.py +205 -50
- pyxllib/file/xlsyncfile.py +341 -0
- pyxllib/prog/cachetools.py +64 -0
- pyxllib/prog/filelock.py +42 -0
- pyxllib/prog/multiprogs.py +940 -0
- pyxllib/prog/newbie.py +9 -2
- pyxllib/prog/pupil.py +129 -60
- pyxllib/prog/specialist/__init__.py +176 -2
- pyxllib/prog/specialist/bc.py +5 -2
- pyxllib/prog/specialist/browser.py +11 -2
- pyxllib/prog/specialist/datetime.py +68 -0
- pyxllib/prog/specialist/tictoc.py +12 -13
- pyxllib/prog/specialist/xllog.py +5 -5
- pyxllib/prog/xlosenv.py +7 -0
- pyxllib/text/airscript.js +744 -0
- pyxllib/text/charclasslib.py +17 -5
- pyxllib/text/jiebalib.py +6 -3
- pyxllib/text/jinjalib.py +32 -0
- pyxllib/text/jsa_ai_prompt.md +271 -0
- pyxllib/text/jscode.py +159 -4
- pyxllib/text/nestenv.py +1 -1
- pyxllib/text/newbie.py +12 -0
- pyxllib/text/pupil/common.py +26 -0
- pyxllib/text/specialist/ptag.py +2 -2
- pyxllib/text/templates/echart_base.html +11 -0
- pyxllib/text/templates/highlight_code.html +17 -0
- pyxllib/text/templates/latex_editor.html +103 -0
- pyxllib/text/xmllib.py +76 -14
- pyxllib/xl.py +2 -1
- pyxllib-0.3.197.dist-info/METADATA +48 -0
- pyxllib-0.3.197.dist-info/RECORD +126 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
- pyxllib/ext/autogui/__init__.py +0 -8
- pyxllib-0.3.96.dist-info/METADATA +0 -51
- pyxllib-0.3.96.dist-info/RECORD +0 -333
- pyxllib-0.3.96.dist-info/top_level.txt +0 -2
- pyxlpr/ai/__init__.py +0 -5
- pyxlpr/ai/clientlib.py +0 -1281
- pyxlpr/ai/specialist.py +0 -286
- pyxlpr/ai/torch_app.py +0 -172
- pyxlpr/ai/xlpaddle.py +0 -655
- pyxlpr/ai/xltorch.py +0 -705
- pyxlpr/data/__init__.py +0 -11
- pyxlpr/data/coco.py +0 -1325
- pyxlpr/data/datacls.py +0 -365
- pyxlpr/data/datasets.py +0 -200
- pyxlpr/data/gptlib.py +0 -1291
- pyxlpr/data/icdar/__init__.py +0 -96
- pyxlpr/data/icdar/deteval.py +0 -377
- pyxlpr/data/icdar/icdar2013.py +0 -341
- pyxlpr/data/icdar/iou.py +0 -340
- pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
- pyxlpr/data/imtextline.py +0 -473
- pyxlpr/data/labelme.py +0 -866
- pyxlpr/data/removeline.py +0 -179
- pyxlpr/data/specialist.py +0 -57
- pyxlpr/eval/__init__.py +0 -85
- pyxlpr/paddleocr.py +0 -776
- pyxlpr/ppocr/__init__.py +0 -15
- pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
- pyxlpr/ppocr/data/__init__.py +0 -135
- pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
- pyxlpr/ppocr/data/imaug/__init__.py +0 -67
- pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
- pyxlpr/ppocr/data/imaug/east_process.py +0 -437
- pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
- pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
- pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
- pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
- pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
- pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
- pyxlpr/ppocr/data/imaug/operators.py +0 -433
- pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
- pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
- pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
- pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
- pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
- pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
- pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
- pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
- pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
- pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
- pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
- pyxlpr/ppocr/data/simple_dataset.py +0 -372
- pyxlpr/ppocr/losses/__init__.py +0 -61
- pyxlpr/ppocr/losses/ace_loss.py +0 -52
- pyxlpr/ppocr/losses/basic_loss.py +0 -135
- pyxlpr/ppocr/losses/center_loss.py +0 -88
- pyxlpr/ppocr/losses/cls_loss.py +0 -30
- pyxlpr/ppocr/losses/combined_loss.py +0 -67
- pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
- pyxlpr/ppocr/losses/det_db_loss.py +0 -80
- pyxlpr/ppocr/losses/det_east_loss.py +0 -63
- pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
- pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
- pyxlpr/ppocr/losses/distillation_loss.py +0 -272
- pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
- pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
- pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
- pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
- pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
- pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
- pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
- pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
- pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
- pyxlpr/ppocr/losses/table_att_loss.py +0 -109
- pyxlpr/ppocr/metrics/__init__.py +0 -44
- pyxlpr/ppocr/metrics/cls_metric.py +0 -45
- pyxlpr/ppocr/metrics/det_metric.py +0 -82
- pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
- pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
- pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
- pyxlpr/ppocr/metrics/kie_metric.py +0 -70
- pyxlpr/ppocr/metrics/rec_metric.py +0 -75
- pyxlpr/ppocr/metrics/table_metric.py +0 -50
- pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
- pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
- pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
- pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
- pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
- pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
- pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
- pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
- pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
- pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
- pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
- pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
- pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
- pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
- pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
- pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
- pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
- pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
- pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
- pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
- pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
- pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
- pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
- pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
- pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
- pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
- pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
- pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
- pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
- pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
- pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
- pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
- pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
- pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
- pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
- pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
- pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
- pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
- pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
- pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
- pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
- pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
- pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
- pyxlpr/ppocr/optimizer/__init__.py +0 -61
- pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
- pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
- pyxlpr/ppocr/optimizer/optimizer.py +0 -160
- pyxlpr/ppocr/optimizer/regularizer.py +0 -52
- pyxlpr/ppocr/postprocess/__init__.py +0 -55
- pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
- pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
- pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
- pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
- pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
- pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
- pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
- pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
- pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
- pyxlpr/ppocr/tools/__init__.py +0 -14
- pyxlpr/ppocr/tools/eval.py +0 -83
- pyxlpr/ppocr/tools/export_center.py +0 -77
- pyxlpr/ppocr/tools/export_model.py +0 -129
- pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
- pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
- pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
- pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
- pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
- pyxlpr/ppocr/tools/infer/utility.py +0 -629
- pyxlpr/ppocr/tools/infer_cls.py +0 -83
- pyxlpr/ppocr/tools/infer_det.py +0 -134
- pyxlpr/ppocr/tools/infer_e2e.py +0 -122
- pyxlpr/ppocr/tools/infer_kie.py +0 -153
- pyxlpr/ppocr/tools/infer_rec.py +0 -146
- pyxlpr/ppocr/tools/infer_table.py +0 -107
- pyxlpr/ppocr/tools/program.py +0 -596
- pyxlpr/ppocr/tools/test_hubserving.py +0 -117
- pyxlpr/ppocr/tools/train.py +0 -163
- pyxlpr/ppocr/tools/xlprog.py +0 -748
- pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
- pyxlpr/ppocr/utils/__init__.py +0 -24
- pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
- pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
- pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
- pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
- pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
- pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
- pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
- pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
- pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
- pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
- pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
- pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
- pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
- pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
- pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
- pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
- pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
- pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
- pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
- pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
- pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
- pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
- pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
- pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
- pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
- pyxlpr/ppocr/utils/dict90.txt +0 -90
- pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
- pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
- pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
- pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
- pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
- pyxlpr/ppocr/utils/en_dict.txt +0 -95
- pyxlpr/ppocr/utils/gen_label.py +0 -81
- pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
- pyxlpr/ppocr/utils/iou.py +0 -54
- pyxlpr/ppocr/utils/logging.py +0 -69
- pyxlpr/ppocr/utils/network.py +0 -84
- pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
- pyxlpr/ppocr/utils/profiler.py +0 -110
- pyxlpr/ppocr/utils/save_load.py +0 -150
- pyxlpr/ppocr/utils/stats.py +0 -72
- pyxlpr/ppocr/utils/utility.py +0 -80
- pyxlpr/ppstructure/__init__.py +0 -13
- pyxlpr/ppstructure/predict_system.py +0 -187
- pyxlpr/ppstructure/table/__init__.py +0 -13
- pyxlpr/ppstructure/table/eval_table.py +0 -72
- pyxlpr/ppstructure/table/matcher.py +0 -192
- pyxlpr/ppstructure/table/predict_structure.py +0 -136
- pyxlpr/ppstructure/table/predict_table.py +0 -221
- pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
- pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
- pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
- pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
- pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
- pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
- pyxlpr/ppstructure/utility.py +0 -71
- pyxlpr/xlai.py +0 -10
- /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
@@ -1,307 +0,0 @@
|
|
1
|
-
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
#Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
#you may not use this file except in compliance with the License.
|
5
|
-
#You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
#Unless required by applicable law or agreed to in writing, software
|
10
|
-
#distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
#See the License for the specific language governing permissions and
|
13
|
-
#limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
|
19
|
-
from paddle import nn, ParamAttr
|
20
|
-
from paddle.nn import functional as F
|
21
|
-
import paddle.fluid as fluid
|
22
|
-
import paddle
|
23
|
-
import numpy as np
|
24
|
-
|
25
|
-
__all__ = ["ResNetFPN"]
|
26
|
-
|
27
|
-
|
28
|
-
class ResNetFPN(nn.Layer):
|
29
|
-
def __init__(self, in_channels=1, layers=50, **kwargs):
|
30
|
-
super(ResNetFPN, self).__init__()
|
31
|
-
supported_layers = {
|
32
|
-
18: {
|
33
|
-
'depth': [2, 2, 2, 2],
|
34
|
-
'block_class': BasicBlock
|
35
|
-
},
|
36
|
-
34: {
|
37
|
-
'depth': [3, 4, 6, 3],
|
38
|
-
'block_class': BasicBlock
|
39
|
-
},
|
40
|
-
50: {
|
41
|
-
'depth': [3, 4, 6, 3],
|
42
|
-
'block_class': BottleneckBlock
|
43
|
-
},
|
44
|
-
101: {
|
45
|
-
'depth': [3, 4, 23, 3],
|
46
|
-
'block_class': BottleneckBlock
|
47
|
-
},
|
48
|
-
152: {
|
49
|
-
'depth': [3, 8, 36, 3],
|
50
|
-
'block_class': BottleneckBlock
|
51
|
-
}
|
52
|
-
}
|
53
|
-
stride_list = [(2, 2), (2, 2), (1, 1), (1, 1)]
|
54
|
-
num_filters = [64, 128, 256, 512]
|
55
|
-
self.depth = supported_layers[layers]['depth']
|
56
|
-
self.F = []
|
57
|
-
self.conv = ConvBNLayer(
|
58
|
-
in_channels=in_channels,
|
59
|
-
out_channels=64,
|
60
|
-
kernel_size=7,
|
61
|
-
stride=2,
|
62
|
-
act="relu",
|
63
|
-
name="conv1")
|
64
|
-
self.block_list = []
|
65
|
-
in_ch = 64
|
66
|
-
if layers >= 50:
|
67
|
-
for block in range(len(self.depth)):
|
68
|
-
for i in range(self.depth[block]):
|
69
|
-
if layers in [101, 152] and block == 2:
|
70
|
-
if i == 0:
|
71
|
-
conv_name = "res" + str(block + 2) + "a"
|
72
|
-
else:
|
73
|
-
conv_name = "res" + str(block + 2) + "b" + str(i)
|
74
|
-
else:
|
75
|
-
conv_name = "res" + str(block + 2) + chr(97 + i)
|
76
|
-
block_list = self.add_sublayer(
|
77
|
-
"bottleneckBlock_{}_{}".format(block, i),
|
78
|
-
BottleneckBlock(
|
79
|
-
in_channels=in_ch,
|
80
|
-
out_channels=num_filters[block],
|
81
|
-
stride=stride_list[block] if i == 0 else 1,
|
82
|
-
name=conv_name))
|
83
|
-
in_ch = num_filters[block] * 4
|
84
|
-
self.block_list.append(block_list)
|
85
|
-
self.F.append(block_list)
|
86
|
-
else:
|
87
|
-
for block in range(len(self.depth)):
|
88
|
-
for i in range(self.depth[block]):
|
89
|
-
conv_name = "res" + str(block + 2) + chr(97 + i)
|
90
|
-
if i == 0 and block != 0:
|
91
|
-
stride = (2, 1)
|
92
|
-
else:
|
93
|
-
stride = (1, 1)
|
94
|
-
basic_block = self.add_sublayer(
|
95
|
-
conv_name,
|
96
|
-
BasicBlock(
|
97
|
-
in_channels=in_ch,
|
98
|
-
out_channels=num_filters[block],
|
99
|
-
stride=stride_list[block] if i == 0 else 1,
|
100
|
-
is_first=block == i == 0,
|
101
|
-
name=conv_name))
|
102
|
-
in_ch = basic_block.out_channels
|
103
|
-
self.block_list.append(basic_block)
|
104
|
-
out_ch_list = [in_ch // 4, in_ch // 2, in_ch]
|
105
|
-
self.base_block = []
|
106
|
-
self.conv_trans = []
|
107
|
-
self.bn_block = []
|
108
|
-
for i in [-2, -3]:
|
109
|
-
in_channels = out_ch_list[i + 1] + out_ch_list[i]
|
110
|
-
|
111
|
-
self.base_block.append(
|
112
|
-
self.add_sublayer(
|
113
|
-
"F_{}_base_block_0".format(i),
|
114
|
-
nn.Conv2D(
|
115
|
-
in_channels=in_channels,
|
116
|
-
out_channels=out_ch_list[i],
|
117
|
-
kernel_size=1,
|
118
|
-
weight_attr=ParamAttr(trainable=True),
|
119
|
-
bias_attr=ParamAttr(trainable=True))))
|
120
|
-
self.base_block.append(
|
121
|
-
self.add_sublayer(
|
122
|
-
"F_{}_base_block_1".format(i),
|
123
|
-
nn.Conv2D(
|
124
|
-
in_channels=out_ch_list[i],
|
125
|
-
out_channels=out_ch_list[i],
|
126
|
-
kernel_size=3,
|
127
|
-
padding=1,
|
128
|
-
weight_attr=ParamAttr(trainable=True),
|
129
|
-
bias_attr=ParamAttr(trainable=True))))
|
130
|
-
self.base_block.append(
|
131
|
-
self.add_sublayer(
|
132
|
-
"F_{}_base_block_2".format(i),
|
133
|
-
nn.BatchNorm(
|
134
|
-
num_channels=out_ch_list[i],
|
135
|
-
act="relu",
|
136
|
-
param_attr=ParamAttr(trainable=True),
|
137
|
-
bias_attr=ParamAttr(trainable=True))))
|
138
|
-
self.base_block.append(
|
139
|
-
self.add_sublayer(
|
140
|
-
"F_{}_base_block_3".format(i),
|
141
|
-
nn.Conv2D(
|
142
|
-
in_channels=out_ch_list[i],
|
143
|
-
out_channels=512,
|
144
|
-
kernel_size=1,
|
145
|
-
bias_attr=ParamAttr(trainable=True),
|
146
|
-
weight_attr=ParamAttr(trainable=True))))
|
147
|
-
self.out_channels = 512
|
148
|
-
|
149
|
-
def __call__(self, x):
|
150
|
-
x = self.conv(x)
|
151
|
-
fpn_list = []
|
152
|
-
F = []
|
153
|
-
for i in range(len(self.depth)):
|
154
|
-
fpn_list.append(np.sum(self.depth[:i + 1]))
|
155
|
-
|
156
|
-
for i, block in enumerate(self.block_list):
|
157
|
-
x = block(x)
|
158
|
-
for number in fpn_list:
|
159
|
-
if i + 1 == number:
|
160
|
-
F.append(x)
|
161
|
-
base = F[-1]
|
162
|
-
|
163
|
-
j = 0
|
164
|
-
for i, block in enumerate(self.base_block):
|
165
|
-
if i % 3 == 0 and i < 6:
|
166
|
-
j = j + 1
|
167
|
-
b, c, w, h = F[-j - 1].shape
|
168
|
-
if [w, h] == list(base.shape[2:]):
|
169
|
-
base = base
|
170
|
-
else:
|
171
|
-
base = self.conv_trans[j - 1](base)
|
172
|
-
base = self.bn_block[j - 1](base)
|
173
|
-
base = paddle.concat([base, F[-j - 1]], axis=1)
|
174
|
-
base = block(base)
|
175
|
-
return base
|
176
|
-
|
177
|
-
|
178
|
-
class ConvBNLayer(nn.Layer):
|
179
|
-
def __init__(self,
|
180
|
-
in_channels,
|
181
|
-
out_channels,
|
182
|
-
kernel_size,
|
183
|
-
stride=1,
|
184
|
-
groups=1,
|
185
|
-
act=None,
|
186
|
-
name=None):
|
187
|
-
super(ConvBNLayer, self).__init__()
|
188
|
-
self.conv = nn.Conv2D(
|
189
|
-
in_channels=in_channels,
|
190
|
-
out_channels=out_channels,
|
191
|
-
kernel_size=2 if stride == (1, 1) else kernel_size,
|
192
|
-
dilation=2 if stride == (1, 1) else 1,
|
193
|
-
stride=stride,
|
194
|
-
padding=(kernel_size - 1) // 2,
|
195
|
-
groups=groups,
|
196
|
-
weight_attr=ParamAttr(name=name + '.conv2d.output.1.w_0'),
|
197
|
-
bias_attr=False, )
|
198
|
-
|
199
|
-
if name == "conv1":
|
200
|
-
bn_name = "bn_" + name
|
201
|
-
else:
|
202
|
-
bn_name = "bn" + name[3:]
|
203
|
-
self.bn = nn.BatchNorm(
|
204
|
-
num_channels=out_channels,
|
205
|
-
act=act,
|
206
|
-
param_attr=ParamAttr(name=name + '.output.1.w_0'),
|
207
|
-
bias_attr=ParamAttr(name=name + '.output.1.b_0'),
|
208
|
-
moving_mean_name=bn_name + "_mean",
|
209
|
-
moving_variance_name=bn_name + "_variance")
|
210
|
-
|
211
|
-
def __call__(self, x):
|
212
|
-
x = self.conv(x)
|
213
|
-
x = self.bn(x)
|
214
|
-
return x
|
215
|
-
|
216
|
-
|
217
|
-
class ShortCut(nn.Layer):
|
218
|
-
def __init__(self, in_channels, out_channels, stride, name, is_first=False):
|
219
|
-
super(ShortCut, self).__init__()
|
220
|
-
self.use_conv = True
|
221
|
-
|
222
|
-
if in_channels != out_channels or stride != 1 or is_first == True:
|
223
|
-
if stride == (1, 1):
|
224
|
-
self.conv = ConvBNLayer(
|
225
|
-
in_channels, out_channels, 1, 1, name=name)
|
226
|
-
else: # stride==(2,2)
|
227
|
-
self.conv = ConvBNLayer(
|
228
|
-
in_channels, out_channels, 1, stride, name=name)
|
229
|
-
else:
|
230
|
-
self.use_conv = False
|
231
|
-
|
232
|
-
def forward(self, x):
|
233
|
-
if self.use_conv:
|
234
|
-
x = self.conv(x)
|
235
|
-
return x
|
236
|
-
|
237
|
-
|
238
|
-
class BottleneckBlock(nn.Layer):
|
239
|
-
def __init__(self, in_channels, out_channels, stride, name):
|
240
|
-
super(BottleneckBlock, self).__init__()
|
241
|
-
self.conv0 = ConvBNLayer(
|
242
|
-
in_channels=in_channels,
|
243
|
-
out_channels=out_channels,
|
244
|
-
kernel_size=1,
|
245
|
-
act='relu',
|
246
|
-
name=name + "_branch2a")
|
247
|
-
self.conv1 = ConvBNLayer(
|
248
|
-
in_channels=out_channels,
|
249
|
-
out_channels=out_channels,
|
250
|
-
kernel_size=3,
|
251
|
-
stride=stride,
|
252
|
-
act='relu',
|
253
|
-
name=name + "_branch2b")
|
254
|
-
|
255
|
-
self.conv2 = ConvBNLayer(
|
256
|
-
in_channels=out_channels,
|
257
|
-
out_channels=out_channels * 4,
|
258
|
-
kernel_size=1,
|
259
|
-
act=None,
|
260
|
-
name=name + "_branch2c")
|
261
|
-
|
262
|
-
self.short = ShortCut(
|
263
|
-
in_channels=in_channels,
|
264
|
-
out_channels=out_channels * 4,
|
265
|
-
stride=stride,
|
266
|
-
is_first=False,
|
267
|
-
name=name + "_branch1")
|
268
|
-
self.out_channels = out_channels * 4
|
269
|
-
|
270
|
-
def forward(self, x):
|
271
|
-
y = self.conv0(x)
|
272
|
-
y = self.conv1(y)
|
273
|
-
y = self.conv2(y)
|
274
|
-
y = y + self.short(x)
|
275
|
-
y = F.relu(y)
|
276
|
-
return y
|
277
|
-
|
278
|
-
|
279
|
-
class BasicBlock(nn.Layer):
|
280
|
-
def __init__(self, in_channels, out_channels, stride, name, is_first):
|
281
|
-
super(BasicBlock, self).__init__()
|
282
|
-
self.conv0 = ConvBNLayer(
|
283
|
-
in_channels=in_channels,
|
284
|
-
out_channels=out_channels,
|
285
|
-
kernel_size=3,
|
286
|
-
act='relu',
|
287
|
-
stride=stride,
|
288
|
-
name=name + "_branch2a")
|
289
|
-
self.conv1 = ConvBNLayer(
|
290
|
-
in_channels=out_channels,
|
291
|
-
out_channels=out_channels,
|
292
|
-
kernel_size=3,
|
293
|
-
act=None,
|
294
|
-
name=name + "_branch2b")
|
295
|
-
self.short = ShortCut(
|
296
|
-
in_channels=in_channels,
|
297
|
-
out_channels=out_channels,
|
298
|
-
stride=stride,
|
299
|
-
is_first=is_first,
|
300
|
-
name=name + "_branch1")
|
301
|
-
self.out_channels = out_channels
|
302
|
-
|
303
|
-
def forward(self, x):
|
304
|
-
y = self.conv0(x)
|
305
|
-
y = self.conv1(y)
|
306
|
-
y = y + self.short(x)
|
307
|
-
return F.relu(y)
|
@@ -1,286 +0,0 @@
|
|
1
|
-
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
|
19
|
-
import paddle
|
20
|
-
from paddle import ParamAttr
|
21
|
-
import paddle.nn as nn
|
22
|
-
import paddle.nn.functional as F
|
23
|
-
|
24
|
-
__all__ = ["ResNet"]
|
25
|
-
|
26
|
-
|
27
|
-
class ConvBNLayer(nn.Layer):
|
28
|
-
def __init__(
|
29
|
-
self,
|
30
|
-
in_channels,
|
31
|
-
out_channels,
|
32
|
-
kernel_size,
|
33
|
-
stride=1,
|
34
|
-
groups=1,
|
35
|
-
is_vd_mode=False,
|
36
|
-
act=None,
|
37
|
-
name=None, ):
|
38
|
-
super(ConvBNLayer, self).__init__()
|
39
|
-
|
40
|
-
self.is_vd_mode = is_vd_mode
|
41
|
-
self._pool2d_avg = nn.AvgPool2D(
|
42
|
-
kernel_size=stride, stride=stride, padding=0, ceil_mode=True)
|
43
|
-
self._conv = nn.Conv2D(
|
44
|
-
in_channels=in_channels,
|
45
|
-
out_channels=out_channels,
|
46
|
-
kernel_size=kernel_size,
|
47
|
-
stride=1 if is_vd_mode else stride,
|
48
|
-
padding=(kernel_size - 1) // 2,
|
49
|
-
groups=groups,
|
50
|
-
weight_attr=ParamAttr(name=name + "_weights"),
|
51
|
-
bias_attr=False)
|
52
|
-
if name == "conv1":
|
53
|
-
bn_name = "bn_" + name
|
54
|
-
else:
|
55
|
-
bn_name = "bn" + name[3:]
|
56
|
-
self._batch_norm = nn.BatchNorm(
|
57
|
-
out_channels,
|
58
|
-
act=act,
|
59
|
-
param_attr=ParamAttr(name=bn_name + '_scale'),
|
60
|
-
bias_attr=ParamAttr(bn_name + '_offset'),
|
61
|
-
moving_mean_name=bn_name + '_mean',
|
62
|
-
moving_variance_name=bn_name + '_variance')
|
63
|
-
|
64
|
-
def forward(self, inputs):
|
65
|
-
if self.is_vd_mode:
|
66
|
-
inputs = self._pool2d_avg(inputs)
|
67
|
-
y = self._conv(inputs)
|
68
|
-
y = self._batch_norm(y)
|
69
|
-
return y
|
70
|
-
|
71
|
-
|
72
|
-
class BottleneckBlock(nn.Layer):
|
73
|
-
def __init__(self,
|
74
|
-
in_channels,
|
75
|
-
out_channels,
|
76
|
-
stride,
|
77
|
-
shortcut=True,
|
78
|
-
if_first=False,
|
79
|
-
name=None):
|
80
|
-
super(BottleneckBlock, self).__init__()
|
81
|
-
|
82
|
-
self.conv0 = ConvBNLayer(
|
83
|
-
in_channels=in_channels,
|
84
|
-
out_channels=out_channels,
|
85
|
-
kernel_size=1,
|
86
|
-
act='relu',
|
87
|
-
name=name + "_branch2a")
|
88
|
-
self.conv1 = ConvBNLayer(
|
89
|
-
in_channels=out_channels,
|
90
|
-
out_channels=out_channels,
|
91
|
-
kernel_size=3,
|
92
|
-
stride=stride,
|
93
|
-
act='relu',
|
94
|
-
name=name + "_branch2b")
|
95
|
-
self.conv2 = ConvBNLayer(
|
96
|
-
in_channels=out_channels,
|
97
|
-
out_channels=out_channels * 4,
|
98
|
-
kernel_size=1,
|
99
|
-
act=None,
|
100
|
-
name=name + "_branch2c")
|
101
|
-
|
102
|
-
if not shortcut:
|
103
|
-
self.short = ConvBNLayer(
|
104
|
-
in_channels=in_channels,
|
105
|
-
out_channels=out_channels * 4,
|
106
|
-
kernel_size=1,
|
107
|
-
stride=stride,
|
108
|
-
is_vd_mode=not if_first and stride[0] != 1,
|
109
|
-
name=name + "_branch1")
|
110
|
-
|
111
|
-
self.shortcut = shortcut
|
112
|
-
|
113
|
-
def forward(self, inputs):
|
114
|
-
y = self.conv0(inputs)
|
115
|
-
|
116
|
-
conv1 = self.conv1(y)
|
117
|
-
conv2 = self.conv2(conv1)
|
118
|
-
|
119
|
-
if self.shortcut:
|
120
|
-
short = inputs
|
121
|
-
else:
|
122
|
-
short = self.short(inputs)
|
123
|
-
y = paddle.add(x=short, y=conv2)
|
124
|
-
y = F.relu(y)
|
125
|
-
return y
|
126
|
-
|
127
|
-
|
128
|
-
class BasicBlock(nn.Layer):
|
129
|
-
def __init__(self,
|
130
|
-
in_channels,
|
131
|
-
out_channels,
|
132
|
-
stride,
|
133
|
-
shortcut=True,
|
134
|
-
if_first=False,
|
135
|
-
name=None):
|
136
|
-
super(BasicBlock, self).__init__()
|
137
|
-
self.stride = stride
|
138
|
-
self.conv0 = ConvBNLayer(
|
139
|
-
in_channels=in_channels,
|
140
|
-
out_channels=out_channels,
|
141
|
-
kernel_size=3,
|
142
|
-
stride=stride,
|
143
|
-
act='relu',
|
144
|
-
name=name + "_branch2a")
|
145
|
-
self.conv1 = ConvBNLayer(
|
146
|
-
in_channels=out_channels,
|
147
|
-
out_channels=out_channels,
|
148
|
-
kernel_size=3,
|
149
|
-
act=None,
|
150
|
-
name=name + "_branch2b")
|
151
|
-
|
152
|
-
if not shortcut:
|
153
|
-
self.short = ConvBNLayer(
|
154
|
-
in_channels=in_channels,
|
155
|
-
out_channels=out_channels,
|
156
|
-
kernel_size=1,
|
157
|
-
stride=stride,
|
158
|
-
is_vd_mode=not if_first and stride[0] != 1,
|
159
|
-
name=name + "_branch1")
|
160
|
-
|
161
|
-
self.shortcut = shortcut
|
162
|
-
|
163
|
-
def forward(self, inputs):
|
164
|
-
y = self.conv0(inputs)
|
165
|
-
conv1 = self.conv1(y)
|
166
|
-
|
167
|
-
if self.shortcut:
|
168
|
-
short = inputs
|
169
|
-
else:
|
170
|
-
short = self.short(inputs)
|
171
|
-
y = paddle.add(x=short, y=conv1)
|
172
|
-
y = F.relu(y)
|
173
|
-
return y
|
174
|
-
|
175
|
-
|
176
|
-
class ResNet(nn.Layer):
|
177
|
-
def __init__(self, in_channels=3, layers=50, **kwargs):
|
178
|
-
super(ResNet, self).__init__()
|
179
|
-
|
180
|
-
self.layers = layers
|
181
|
-
supported_layers = [18, 34, 50, 101, 152, 200]
|
182
|
-
assert layers in supported_layers, \
|
183
|
-
"supported layers are {} but input layer is {}".format(
|
184
|
-
supported_layers, layers)
|
185
|
-
|
186
|
-
if layers == 18:
|
187
|
-
depth = [2, 2, 2, 2]
|
188
|
-
elif layers == 34 or layers == 50:
|
189
|
-
depth = [3, 4, 6, 3]
|
190
|
-
elif layers == 101:
|
191
|
-
depth = [3, 4, 23, 3]
|
192
|
-
elif layers == 152:
|
193
|
-
depth = [3, 8, 36, 3]
|
194
|
-
elif layers == 200:
|
195
|
-
depth = [3, 12, 48, 3]
|
196
|
-
num_channels = [64, 256, 512,
|
197
|
-
1024] if layers >= 50 else [64, 64, 128, 256]
|
198
|
-
num_filters = [64, 128, 256, 512]
|
199
|
-
|
200
|
-
self.conv1_1 = ConvBNLayer(
|
201
|
-
in_channels=in_channels,
|
202
|
-
out_channels=32,
|
203
|
-
kernel_size=3,
|
204
|
-
stride=1,
|
205
|
-
act='relu',
|
206
|
-
name="conv1_1")
|
207
|
-
self.conv1_2 = ConvBNLayer(
|
208
|
-
in_channels=32,
|
209
|
-
out_channels=32,
|
210
|
-
kernel_size=3,
|
211
|
-
stride=1,
|
212
|
-
act='relu',
|
213
|
-
name="conv1_2")
|
214
|
-
self.conv1_3 = ConvBNLayer(
|
215
|
-
in_channels=32,
|
216
|
-
out_channels=64,
|
217
|
-
kernel_size=3,
|
218
|
-
stride=1,
|
219
|
-
act='relu',
|
220
|
-
name="conv1_3")
|
221
|
-
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
|
222
|
-
|
223
|
-
self.block_list = []
|
224
|
-
if layers >= 50:
|
225
|
-
for block in range(len(depth)):
|
226
|
-
shortcut = False
|
227
|
-
for i in range(depth[block]):
|
228
|
-
if layers in [101, 152, 200] and block == 2:
|
229
|
-
if i == 0:
|
230
|
-
conv_name = "res" + str(block + 2) + "a"
|
231
|
-
else:
|
232
|
-
conv_name = "res" + str(block + 2) + "b" + str(i)
|
233
|
-
else:
|
234
|
-
conv_name = "res" + str(block + 2) + chr(97 + i)
|
235
|
-
|
236
|
-
if i == 0 and block != 0:
|
237
|
-
stride = (2, 1)
|
238
|
-
else:
|
239
|
-
stride = (1, 1)
|
240
|
-
bottleneck_block = self.add_sublayer(
|
241
|
-
'bb_%d_%d' % (block, i),
|
242
|
-
BottleneckBlock(
|
243
|
-
in_channels=num_channels[block]
|
244
|
-
if i == 0 else num_filters[block] * 4,
|
245
|
-
out_channels=num_filters[block],
|
246
|
-
stride=stride,
|
247
|
-
shortcut=shortcut,
|
248
|
-
if_first=block == i == 0,
|
249
|
-
name=conv_name))
|
250
|
-
shortcut = True
|
251
|
-
self.block_list.append(bottleneck_block)
|
252
|
-
self.out_channels = num_filters[block] * 4
|
253
|
-
else:
|
254
|
-
for block in range(len(depth)):
|
255
|
-
shortcut = False
|
256
|
-
for i in range(depth[block]):
|
257
|
-
conv_name = "res" + str(block + 2) + chr(97 + i)
|
258
|
-
if i == 0 and block != 0:
|
259
|
-
stride = (2, 1)
|
260
|
-
else:
|
261
|
-
stride = (1, 1)
|
262
|
-
|
263
|
-
basic_block = self.add_sublayer(
|
264
|
-
'bb_%d_%d' % (block, i),
|
265
|
-
BasicBlock(
|
266
|
-
in_channels=num_channels[block]
|
267
|
-
if i == 0 else num_filters[block],
|
268
|
-
out_channels=num_filters[block],
|
269
|
-
stride=stride,
|
270
|
-
shortcut=shortcut,
|
271
|
-
if_first=block == i == 0,
|
272
|
-
name=conv_name))
|
273
|
-
shortcut = True
|
274
|
-
self.block_list.append(basic_block)
|
275
|
-
self.out_channels = num_filters[block]
|
276
|
-
self.out_pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
277
|
-
|
278
|
-
def forward(self, inputs):
|
279
|
-
y = self.conv1_1(inputs)
|
280
|
-
y = self.conv1_2(y)
|
281
|
-
y = self.conv1_3(y)
|
282
|
-
y = self.pool2d_max(y)
|
283
|
-
for block in self.block_list:
|
284
|
-
y = block(y)
|
285
|
-
y = self.out_pool(y)
|
286
|
-
return y
|
@@ -1,54 +0,0 @@
|
|
1
|
-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
__all__ = ['build_head']
|
16
|
-
|
17
|
-
|
18
|
-
def build_head(config):
|
19
|
-
# det head
|
20
|
-
from .det_db_head import DBHead
|
21
|
-
from .det_east_head import EASTHead
|
22
|
-
from .det_sast_head import SASTHead
|
23
|
-
from .det_pse_head import PSEHead
|
24
|
-
from .e2e_pg_head import PGHead
|
25
|
-
|
26
|
-
# rec head
|
27
|
-
from .rec_ctc_head import CTCHead
|
28
|
-
from .rec_att_head import AttentionHead
|
29
|
-
from .rec_srn_head import SRNHead
|
30
|
-
from .rec_nrtr_head import Transformer
|
31
|
-
from .rec_sar_head import SARHead
|
32
|
-
from .rec_aster_head import AsterHead
|
33
|
-
|
34
|
-
# cls head
|
35
|
-
from .cls_head import ClsHead
|
36
|
-
|
37
|
-
#kie head
|
38
|
-
from .kie_sdmgr_head import SDMGRHead
|
39
|
-
|
40
|
-
from .table_att_head import TableAttentionHead
|
41
|
-
|
42
|
-
support_dict = [
|
43
|
-
'DBHead', 'PSEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead',
|
44
|
-
'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
45
|
-
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead'
|
46
|
-
]
|
47
|
-
|
48
|
-
#table head
|
49
|
-
|
50
|
-
module_name = config.pop('name')
|
51
|
-
assert module_name in support_dict, Exception('head only support {}'.format(
|
52
|
-
support_dict))
|
53
|
-
module_class = eval(module_name)(**config)
|
54
|
-
return module_class
|
@@ -1,52 +0,0 @@
|
|
1
|
-
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
|
19
|
-
import math
|
20
|
-
import paddle
|
21
|
-
from paddle import nn, ParamAttr
|
22
|
-
import paddle.nn.functional as F
|
23
|
-
|
24
|
-
|
25
|
-
class ClsHead(nn.Layer):
|
26
|
-
"""
|
27
|
-
Class orientation
|
28
|
-
|
29
|
-
Args:
|
30
|
-
|
31
|
-
params(dict): super parameters for build Class network
|
32
|
-
"""
|
33
|
-
|
34
|
-
def __init__(self, in_channels, class_dim, **kwargs):
|
35
|
-
super(ClsHead, self).__init__()
|
36
|
-
self.pool = nn.AdaptiveAvgPool2D(1)
|
37
|
-
stdv = 1.0 / math.sqrt(in_channels * 1.0)
|
38
|
-
self.fc = nn.Linear(
|
39
|
-
in_channels,
|
40
|
-
class_dim,
|
41
|
-
weight_attr=ParamAttr(
|
42
|
-
name="fc_0.w_0",
|
43
|
-
initializer=nn.initializer.Uniform(-stdv, stdv)),
|
44
|
-
bias_attr=ParamAttr(name="fc_0.b_0"), )
|
45
|
-
|
46
|
-
def forward(self, x, targets=None):
|
47
|
-
x = self.pool(x)
|
48
|
-
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
|
49
|
-
x = self.fc(x)
|
50
|
-
if not self.training:
|
51
|
-
x = F.softmax(x, axis=1)
|
52
|
-
return x
|