pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyxllib/algo/geo.py +12 -0
- pyxllib/algo/intervals.py +1 -1
- pyxllib/algo/matcher.py +78 -0
- pyxllib/algo/pupil.py +187 -19
- pyxllib/algo/specialist.py +2 -1
- pyxllib/algo/stat.py +38 -2
- {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
- pyxllib/autogui/activewin.py +246 -0
- pyxllib/autogui/all.py +9 -0
- pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
- pyxllib/autogui/uiautolib.py +362 -0
- pyxllib/autogui/wechat.py +827 -0
- pyxllib/autogui/wechat_msg.py +421 -0
- pyxllib/autogui/wxautolib.py +84 -0
- pyxllib/cv/slidercaptcha.py +137 -0
- pyxllib/data/echarts.py +123 -12
- pyxllib/data/jsonlib.py +89 -0
- pyxllib/data/pglib.py +514 -30
- pyxllib/data/sqlite.py +231 -4
- pyxllib/ext/JLineViewer.py +14 -1
- pyxllib/ext/drissionlib.py +277 -0
- pyxllib/ext/kq5034lib.py +0 -1594
- pyxllib/ext/robustprocfile.py +497 -0
- pyxllib/ext/unixlib.py +6 -5
- pyxllib/ext/utools.py +108 -95
- pyxllib/ext/webhook.py +32 -14
- pyxllib/ext/wjxlib.py +88 -0
- pyxllib/ext/wpsapi.py +124 -0
- pyxllib/ext/xlwork.py +9 -0
- pyxllib/ext/yuquelib.py +1003 -71
- pyxllib/file/docxlib.py +1 -1
- pyxllib/file/libreoffice.py +165 -0
- pyxllib/file/movielib.py +9 -0
- pyxllib/file/packlib/__init__.py +112 -75
- pyxllib/file/pdflib.py +1 -1
- pyxllib/file/pupil.py +1 -1
- pyxllib/file/specialist/dirlib.py +1 -1
- pyxllib/file/specialist/download.py +10 -3
- pyxllib/file/specialist/filelib.py +266 -55
- pyxllib/file/xlsxlib.py +205 -50
- pyxllib/file/xlsyncfile.py +341 -0
- pyxllib/prog/cachetools.py +64 -0
- pyxllib/prog/filelock.py +42 -0
- pyxllib/prog/multiprogs.py +940 -0
- pyxllib/prog/newbie.py +9 -2
- pyxllib/prog/pupil.py +129 -60
- pyxllib/prog/specialist/__init__.py +176 -2
- pyxllib/prog/specialist/bc.py +5 -2
- pyxllib/prog/specialist/browser.py +11 -2
- pyxllib/prog/specialist/datetime.py +68 -0
- pyxllib/prog/specialist/tictoc.py +12 -13
- pyxllib/prog/specialist/xllog.py +5 -5
- pyxllib/prog/xlosenv.py +7 -0
- pyxllib/text/airscript.js +744 -0
- pyxllib/text/charclasslib.py +17 -5
- pyxllib/text/jiebalib.py +6 -3
- pyxllib/text/jinjalib.py +32 -0
- pyxllib/text/jsa_ai_prompt.md +271 -0
- pyxllib/text/jscode.py +159 -4
- pyxllib/text/nestenv.py +1 -1
- pyxllib/text/newbie.py +12 -0
- pyxllib/text/pupil/common.py +26 -0
- pyxllib/text/specialist/ptag.py +2 -2
- pyxllib/text/templates/echart_base.html +11 -0
- pyxllib/text/templates/highlight_code.html +17 -0
- pyxllib/text/templates/latex_editor.html +103 -0
- pyxllib/text/xmllib.py +76 -14
- pyxllib/xl.py +2 -1
- pyxllib-0.3.197.dist-info/METADATA +48 -0
- pyxllib-0.3.197.dist-info/RECORD +126 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
- pyxllib/ext/autogui/__init__.py +0 -8
- pyxllib-0.3.96.dist-info/METADATA +0 -51
- pyxllib-0.3.96.dist-info/RECORD +0 -333
- pyxllib-0.3.96.dist-info/top_level.txt +0 -2
- pyxlpr/ai/__init__.py +0 -5
- pyxlpr/ai/clientlib.py +0 -1281
- pyxlpr/ai/specialist.py +0 -286
- pyxlpr/ai/torch_app.py +0 -172
- pyxlpr/ai/xlpaddle.py +0 -655
- pyxlpr/ai/xltorch.py +0 -705
- pyxlpr/data/__init__.py +0 -11
- pyxlpr/data/coco.py +0 -1325
- pyxlpr/data/datacls.py +0 -365
- pyxlpr/data/datasets.py +0 -200
- pyxlpr/data/gptlib.py +0 -1291
- pyxlpr/data/icdar/__init__.py +0 -96
- pyxlpr/data/icdar/deteval.py +0 -377
- pyxlpr/data/icdar/icdar2013.py +0 -341
- pyxlpr/data/icdar/iou.py +0 -340
- pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
- pyxlpr/data/imtextline.py +0 -473
- pyxlpr/data/labelme.py +0 -866
- pyxlpr/data/removeline.py +0 -179
- pyxlpr/data/specialist.py +0 -57
- pyxlpr/eval/__init__.py +0 -85
- pyxlpr/paddleocr.py +0 -776
- pyxlpr/ppocr/__init__.py +0 -15
- pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
- pyxlpr/ppocr/data/__init__.py +0 -135
- pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
- pyxlpr/ppocr/data/imaug/__init__.py +0 -67
- pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
- pyxlpr/ppocr/data/imaug/east_process.py +0 -437
- pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
- pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
- pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
- pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
- pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
- pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
- pyxlpr/ppocr/data/imaug/operators.py +0 -433
- pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
- pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
- pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
- pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
- pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
- pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
- pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
- pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
- pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
- pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
- pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
- pyxlpr/ppocr/data/simple_dataset.py +0 -372
- pyxlpr/ppocr/losses/__init__.py +0 -61
- pyxlpr/ppocr/losses/ace_loss.py +0 -52
- pyxlpr/ppocr/losses/basic_loss.py +0 -135
- pyxlpr/ppocr/losses/center_loss.py +0 -88
- pyxlpr/ppocr/losses/cls_loss.py +0 -30
- pyxlpr/ppocr/losses/combined_loss.py +0 -67
- pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
- pyxlpr/ppocr/losses/det_db_loss.py +0 -80
- pyxlpr/ppocr/losses/det_east_loss.py +0 -63
- pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
- pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
- pyxlpr/ppocr/losses/distillation_loss.py +0 -272
- pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
- pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
- pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
- pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
- pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
- pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
- pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
- pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
- pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
- pyxlpr/ppocr/losses/table_att_loss.py +0 -109
- pyxlpr/ppocr/metrics/__init__.py +0 -44
- pyxlpr/ppocr/metrics/cls_metric.py +0 -45
- pyxlpr/ppocr/metrics/det_metric.py +0 -82
- pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
- pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
- pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
- pyxlpr/ppocr/metrics/kie_metric.py +0 -70
- pyxlpr/ppocr/metrics/rec_metric.py +0 -75
- pyxlpr/ppocr/metrics/table_metric.py +0 -50
- pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
- pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
- pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
- pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
- pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
- pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
- pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
- pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
- pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
- pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
- pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
- pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
- pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
- pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
- pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
- pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
- pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
- pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
- pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
- pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
- pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
- pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
- pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
- pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
- pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
- pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
- pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
- pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
- pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
- pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
- pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
- pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
- pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
- pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
- pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
- pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
- pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
- pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
- pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
- pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
- pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
- pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
- pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
- pyxlpr/ppocr/optimizer/__init__.py +0 -61
- pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
- pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
- pyxlpr/ppocr/optimizer/optimizer.py +0 -160
- pyxlpr/ppocr/optimizer/regularizer.py +0 -52
- pyxlpr/ppocr/postprocess/__init__.py +0 -55
- pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
- pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
- pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
- pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
- pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
- pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
- pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
- pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
- pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
- pyxlpr/ppocr/tools/__init__.py +0 -14
- pyxlpr/ppocr/tools/eval.py +0 -83
- pyxlpr/ppocr/tools/export_center.py +0 -77
- pyxlpr/ppocr/tools/export_model.py +0 -129
- pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
- pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
- pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
- pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
- pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
- pyxlpr/ppocr/tools/infer/utility.py +0 -629
- pyxlpr/ppocr/tools/infer_cls.py +0 -83
- pyxlpr/ppocr/tools/infer_det.py +0 -134
- pyxlpr/ppocr/tools/infer_e2e.py +0 -122
- pyxlpr/ppocr/tools/infer_kie.py +0 -153
- pyxlpr/ppocr/tools/infer_rec.py +0 -146
- pyxlpr/ppocr/tools/infer_table.py +0 -107
- pyxlpr/ppocr/tools/program.py +0 -596
- pyxlpr/ppocr/tools/test_hubserving.py +0 -117
- pyxlpr/ppocr/tools/train.py +0 -163
- pyxlpr/ppocr/tools/xlprog.py +0 -748
- pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
- pyxlpr/ppocr/utils/__init__.py +0 -24
- pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
- pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
- pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
- pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
- pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
- pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
- pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
- pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
- pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
- pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
- pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
- pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
- pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
- pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
- pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
- pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
- pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
- pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
- pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
- pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
- pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
- pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
- pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
- pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
- pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
- pyxlpr/ppocr/utils/dict90.txt +0 -90
- pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
- pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
- pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
- pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
- pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
- pyxlpr/ppocr/utils/en_dict.txt +0 -95
- pyxlpr/ppocr/utils/gen_label.py +0 -81
- pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
- pyxlpr/ppocr/utils/iou.py +0 -54
- pyxlpr/ppocr/utils/logging.py +0 -69
- pyxlpr/ppocr/utils/network.py +0 -84
- pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
- pyxlpr/ppocr/utils/profiler.py +0 -110
- pyxlpr/ppocr/utils/save_load.py +0 -150
- pyxlpr/ppocr/utils/stats.py +0 -72
- pyxlpr/ppocr/utils/utility.py +0 -80
- pyxlpr/ppstructure/__init__.py +0 -13
- pyxlpr/ppstructure/predict_system.py +0 -187
- pyxlpr/ppstructure/table/__init__.py +0 -13
- pyxlpr/ppstructure/table/eval_table.py +0 -72
- pyxlpr/ppstructure/table/matcher.py +0 -192
- pyxlpr/ppstructure/table/predict_structure.py +0 -136
- pyxlpr/ppstructure/table/predict_table.py +0 -221
- pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
- pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
- pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
- pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
- pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
- pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
- pyxlpr/ppstructure/utility.py +0 -71
- pyxlpr/xlai.py +0 -10
- /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
@@ -1,110 +0,0 @@
|
|
1
|
-
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
|
19
|
-
import paddle
|
20
|
-
from paddle import nn
|
21
|
-
import paddle.nn.functional as F
|
22
|
-
from paddle import ParamAttr
|
23
|
-
|
24
|
-
|
25
|
-
class TableFPN(nn.Layer):
|
26
|
-
def __init__(self, in_channels, out_channels, **kwargs):
|
27
|
-
super(TableFPN, self).__init__()
|
28
|
-
self.out_channels = 512
|
29
|
-
weight_attr = paddle.nn.initializer.KaimingUniform()
|
30
|
-
self.in2_conv = nn.Conv2D(
|
31
|
-
in_channels=in_channels[0],
|
32
|
-
out_channels=self.out_channels,
|
33
|
-
kernel_size=1,
|
34
|
-
weight_attr=ParamAttr(initializer=weight_attr),
|
35
|
-
bias_attr=False)
|
36
|
-
self.in3_conv = nn.Conv2D(
|
37
|
-
in_channels=in_channels[1],
|
38
|
-
out_channels=self.out_channels,
|
39
|
-
kernel_size=1,
|
40
|
-
stride = 1,
|
41
|
-
weight_attr=ParamAttr(initializer=weight_attr),
|
42
|
-
bias_attr=False)
|
43
|
-
self.in4_conv = nn.Conv2D(
|
44
|
-
in_channels=in_channels[2],
|
45
|
-
out_channels=self.out_channels,
|
46
|
-
kernel_size=1,
|
47
|
-
weight_attr=ParamAttr(initializer=weight_attr),
|
48
|
-
bias_attr=False)
|
49
|
-
self.in5_conv = nn.Conv2D(
|
50
|
-
in_channels=in_channels[3],
|
51
|
-
out_channels=self.out_channels,
|
52
|
-
kernel_size=1,
|
53
|
-
weight_attr=ParamAttr(initializer=weight_attr),
|
54
|
-
bias_attr=False)
|
55
|
-
self.p5_conv = nn.Conv2D(
|
56
|
-
in_channels=self.out_channels,
|
57
|
-
out_channels=self.out_channels // 4,
|
58
|
-
kernel_size=3,
|
59
|
-
padding=1,
|
60
|
-
weight_attr=ParamAttr(initializer=weight_attr),
|
61
|
-
bias_attr=False)
|
62
|
-
self.p4_conv = nn.Conv2D(
|
63
|
-
in_channels=self.out_channels,
|
64
|
-
out_channels=self.out_channels // 4,
|
65
|
-
kernel_size=3,
|
66
|
-
padding=1,
|
67
|
-
weight_attr=ParamAttr(initializer=weight_attr),
|
68
|
-
bias_attr=False)
|
69
|
-
self.p3_conv = nn.Conv2D(
|
70
|
-
in_channels=self.out_channels,
|
71
|
-
out_channels=self.out_channels // 4,
|
72
|
-
kernel_size=3,
|
73
|
-
padding=1,
|
74
|
-
weight_attr=ParamAttr(initializer=weight_attr),
|
75
|
-
bias_attr=False)
|
76
|
-
self.p2_conv = nn.Conv2D(
|
77
|
-
in_channels=self.out_channels,
|
78
|
-
out_channels=self.out_channels // 4,
|
79
|
-
kernel_size=3,
|
80
|
-
padding=1,
|
81
|
-
weight_attr=ParamAttr(initializer=weight_attr),
|
82
|
-
bias_attr=False)
|
83
|
-
self.fuse_conv = nn.Conv2D(
|
84
|
-
in_channels=self.out_channels * 4,
|
85
|
-
out_channels=512,
|
86
|
-
kernel_size=3,
|
87
|
-
padding=1,
|
88
|
-
weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
|
89
|
-
|
90
|
-
def forward(self, x):
|
91
|
-
c2, c3, c4, c5 = x
|
92
|
-
|
93
|
-
in5 = self.in5_conv(c5)
|
94
|
-
in4 = self.in4_conv(c4)
|
95
|
-
in3 = self.in3_conv(c3)
|
96
|
-
in2 = self.in2_conv(c2)
|
97
|
-
|
98
|
-
out4 = in4 + F.upsample(
|
99
|
-
in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16
|
100
|
-
out3 = in3 + F.upsample(
|
101
|
-
out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8
|
102
|
-
out2 = in2 + F.upsample(
|
103
|
-
out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4
|
104
|
-
|
105
|
-
p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1)
|
106
|
-
p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1)
|
107
|
-
p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1)
|
108
|
-
fuse = paddle.concat([in5, p4, p3, p2], axis=1)
|
109
|
-
fuse_conv = self.fuse_conv(fuse) * 0.005
|
110
|
-
return [c5 + fuse_conv]
|
@@ -1,28 +0,0 @@
|
|
1
|
-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
__all__ = ['build_transform']
|
16
|
-
|
17
|
-
|
18
|
-
def build_transform(config):
|
19
|
-
from .tps import TPS
|
20
|
-
from .stn import STN_ON
|
21
|
-
|
22
|
-
support_dict = ['TPS', 'STN_ON']
|
23
|
-
|
24
|
-
module_name = config.pop('name')
|
25
|
-
assert module_name in support_dict, Exception(
|
26
|
-
'transform only support {}'.format(support_dict))
|
27
|
-
module_class = eval(module_name)(**config)
|
28
|
-
return module_class
|
@@ -1,135 +0,0 @@
|
|
1
|
-
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
"""
|
15
|
-
This code is refer from:
|
16
|
-
https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/stn_head.py
|
17
|
-
"""
|
18
|
-
from __future__ import absolute_import
|
19
|
-
from __future__ import division
|
20
|
-
from __future__ import print_function
|
21
|
-
|
22
|
-
import math
|
23
|
-
import paddle
|
24
|
-
from paddle import nn, ParamAttr
|
25
|
-
from paddle.nn import functional as F
|
26
|
-
import numpy as np
|
27
|
-
|
28
|
-
from .tps_spatial_transformer import TPSSpatialTransformer
|
29
|
-
|
30
|
-
|
31
|
-
def conv3x3_block(in_channels, out_channels, stride=1):
|
32
|
-
n = 3 * 3 * out_channels
|
33
|
-
w = math.sqrt(2. / n)
|
34
|
-
conv_layer = nn.Conv2D(
|
35
|
-
in_channels,
|
36
|
-
out_channels,
|
37
|
-
kernel_size=3,
|
38
|
-
stride=stride,
|
39
|
-
padding=1,
|
40
|
-
weight_attr=nn.initializer.Normal(
|
41
|
-
mean=0.0, std=w),
|
42
|
-
bias_attr=nn.initializer.Constant(0))
|
43
|
-
block = nn.Sequential(conv_layer, nn.BatchNorm2D(out_channels), nn.ReLU())
|
44
|
-
return block
|
45
|
-
|
46
|
-
|
47
|
-
class STN(nn.Layer):
|
48
|
-
def __init__(self, in_channels, num_ctrlpoints, activation='none'):
|
49
|
-
super(STN, self).__init__()
|
50
|
-
self.in_channels = in_channels
|
51
|
-
self.num_ctrlpoints = num_ctrlpoints
|
52
|
-
self.activation = activation
|
53
|
-
self.stn_convnet = nn.Sequential(
|
54
|
-
conv3x3_block(in_channels, 32), #32x64
|
55
|
-
nn.MaxPool2D(
|
56
|
-
kernel_size=2, stride=2),
|
57
|
-
conv3x3_block(32, 64), #16x32
|
58
|
-
nn.MaxPool2D(
|
59
|
-
kernel_size=2, stride=2),
|
60
|
-
conv3x3_block(64, 128), # 8*16
|
61
|
-
nn.MaxPool2D(
|
62
|
-
kernel_size=2, stride=2),
|
63
|
-
conv3x3_block(128, 256), # 4*8
|
64
|
-
nn.MaxPool2D(
|
65
|
-
kernel_size=2, stride=2),
|
66
|
-
conv3x3_block(256, 256), # 2*4,
|
67
|
-
nn.MaxPool2D(
|
68
|
-
kernel_size=2, stride=2),
|
69
|
-
conv3x3_block(256, 256)) # 1*2
|
70
|
-
self.stn_fc1 = nn.Sequential(
|
71
|
-
nn.Linear(
|
72
|
-
2 * 256,
|
73
|
-
512,
|
74
|
-
weight_attr=nn.initializer.Normal(0, 0.001),
|
75
|
-
bias_attr=nn.initializer.Constant(0)),
|
76
|
-
nn.BatchNorm1D(512),
|
77
|
-
nn.ReLU())
|
78
|
-
fc2_bias = self.init_stn()
|
79
|
-
self.stn_fc2 = nn.Linear(
|
80
|
-
512,
|
81
|
-
num_ctrlpoints * 2,
|
82
|
-
weight_attr=nn.initializer.Constant(0.0),
|
83
|
-
bias_attr=nn.initializer.Assign(fc2_bias))
|
84
|
-
|
85
|
-
def init_stn(self):
|
86
|
-
margin = 0.01
|
87
|
-
sampling_num_per_side = int(self.num_ctrlpoints / 2)
|
88
|
-
ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side)
|
89
|
-
ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
|
90
|
-
ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin)
|
91
|
-
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
92
|
-
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
93
|
-
ctrl_points = np.concatenate(
|
94
|
-
[ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32)
|
95
|
-
if self.activation == 'none':
|
96
|
-
pass
|
97
|
-
elif self.activation == 'sigmoid':
|
98
|
-
ctrl_points = -np.log(1. / ctrl_points - 1.)
|
99
|
-
ctrl_points = paddle.to_tensor(ctrl_points)
|
100
|
-
fc2_bias = paddle.reshape(
|
101
|
-
ctrl_points, shape=[ctrl_points.shape[0] * ctrl_points.shape[1]])
|
102
|
-
return fc2_bias
|
103
|
-
|
104
|
-
def forward(self, x):
|
105
|
-
x = self.stn_convnet(x)
|
106
|
-
batch_size, _, h, w = x.shape
|
107
|
-
x = paddle.reshape(x, shape=(batch_size, -1))
|
108
|
-
img_feat = self.stn_fc1(x)
|
109
|
-
x = self.stn_fc2(0.1 * img_feat)
|
110
|
-
if self.activation == 'sigmoid':
|
111
|
-
x = F.sigmoid(x)
|
112
|
-
x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2])
|
113
|
-
return img_feat, x
|
114
|
-
|
115
|
-
|
116
|
-
class STN_ON(nn.Layer):
|
117
|
-
def __init__(self, in_channels, tps_inputsize, tps_outputsize,
|
118
|
-
num_control_points, tps_margins, stn_activation):
|
119
|
-
super(STN_ON, self).__init__()
|
120
|
-
self.tps = TPSSpatialTransformer(
|
121
|
-
output_image_size=tuple(tps_outputsize),
|
122
|
-
num_control_points=num_control_points,
|
123
|
-
margins=tuple(tps_margins))
|
124
|
-
self.stn_head = STN(in_channels=in_channels,
|
125
|
-
num_ctrlpoints=num_control_points,
|
126
|
-
activation=stn_activation)
|
127
|
-
self.tps_inputsize = tps_inputsize
|
128
|
-
self.out_channels = in_channels
|
129
|
-
|
130
|
-
def forward(self, image):
|
131
|
-
stn_input = paddle.nn.functional.interpolate(
|
132
|
-
image, self.tps_inputsize, mode="bilinear", align_corners=True)
|
133
|
-
stn_img_feat, ctrl_points = self.stn_head(stn_input)
|
134
|
-
x, _ = self.tps(image, ctrl_points)
|
135
|
-
return x
|
@@ -1,308 +0,0 @@
|
|
1
|
-
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
"""
|
15
|
-
This code is refer from:
|
16
|
-
https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/transformation.py
|
17
|
-
"""
|
18
|
-
|
19
|
-
from __future__ import absolute_import
|
20
|
-
from __future__ import division
|
21
|
-
from __future__ import print_function
|
22
|
-
|
23
|
-
import math
|
24
|
-
import paddle
|
25
|
-
from paddle import nn, ParamAttr
|
26
|
-
from paddle.nn import functional as F
|
27
|
-
import numpy as np
|
28
|
-
|
29
|
-
|
30
|
-
class ConvBNLayer(nn.Layer):
|
31
|
-
def __init__(self,
|
32
|
-
in_channels,
|
33
|
-
out_channels,
|
34
|
-
kernel_size,
|
35
|
-
stride=1,
|
36
|
-
groups=1,
|
37
|
-
act=None,
|
38
|
-
name=None):
|
39
|
-
super(ConvBNLayer, self).__init__()
|
40
|
-
self.conv = nn.Conv2D(
|
41
|
-
in_channels=in_channels,
|
42
|
-
out_channels=out_channels,
|
43
|
-
kernel_size=kernel_size,
|
44
|
-
stride=stride,
|
45
|
-
padding=(kernel_size - 1) // 2,
|
46
|
-
groups=groups,
|
47
|
-
weight_attr=ParamAttr(name=name + "_weights"),
|
48
|
-
bias_attr=False)
|
49
|
-
bn_name = "bn_" + name
|
50
|
-
self.bn = nn.BatchNorm(
|
51
|
-
out_channels,
|
52
|
-
act=act,
|
53
|
-
param_attr=ParamAttr(name=bn_name + '_scale'),
|
54
|
-
bias_attr=ParamAttr(bn_name + '_offset'),
|
55
|
-
moving_mean_name=bn_name + '_mean',
|
56
|
-
moving_variance_name=bn_name + '_variance')
|
57
|
-
|
58
|
-
def forward(self, x):
|
59
|
-
x = self.conv(x)
|
60
|
-
x = self.bn(x)
|
61
|
-
return x
|
62
|
-
|
63
|
-
|
64
|
-
class LocalizationNetwork(nn.Layer):
|
65
|
-
def __init__(self, in_channels, num_fiducial, loc_lr, model_name):
|
66
|
-
super(LocalizationNetwork, self).__init__()
|
67
|
-
self.F = num_fiducial
|
68
|
-
F = num_fiducial
|
69
|
-
if model_name == "large":
|
70
|
-
num_filters_list = [64, 128, 256, 512]
|
71
|
-
fc_dim = 256
|
72
|
-
else:
|
73
|
-
num_filters_list = [16, 32, 64, 128]
|
74
|
-
fc_dim = 64
|
75
|
-
|
76
|
-
self.block_list = []
|
77
|
-
for fno in range(0, len(num_filters_list)):
|
78
|
-
num_filters = num_filters_list[fno]
|
79
|
-
name = "loc_conv%d" % fno
|
80
|
-
conv = self.add_sublayer(
|
81
|
-
name,
|
82
|
-
ConvBNLayer(
|
83
|
-
in_channels=in_channels,
|
84
|
-
out_channels=num_filters,
|
85
|
-
kernel_size=3,
|
86
|
-
act='relu',
|
87
|
-
name=name))
|
88
|
-
self.block_list.append(conv)
|
89
|
-
if fno == len(num_filters_list) - 1:
|
90
|
-
pool = nn.AdaptiveAvgPool2D(1)
|
91
|
-
else:
|
92
|
-
pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
93
|
-
in_channels = num_filters
|
94
|
-
self.block_list.append(pool)
|
95
|
-
name = "loc_fc1"
|
96
|
-
stdv = 1.0 / math.sqrt(num_filters_list[-1] * 1.0)
|
97
|
-
self.fc1 = nn.Linear(
|
98
|
-
in_channels,
|
99
|
-
fc_dim,
|
100
|
-
weight_attr=ParamAttr(
|
101
|
-
learning_rate=loc_lr,
|
102
|
-
name=name + "_w",
|
103
|
-
initializer=nn.initializer.Uniform(-stdv, stdv)),
|
104
|
-
bias_attr=ParamAttr(name=name + '.b_0'),
|
105
|
-
name=name)
|
106
|
-
|
107
|
-
# Init fc2 in LocalizationNetwork
|
108
|
-
initial_bias = self.get_initial_fiducials()
|
109
|
-
initial_bias = initial_bias.reshape(-1)
|
110
|
-
name = "loc_fc2"
|
111
|
-
param_attr = ParamAttr(
|
112
|
-
learning_rate=loc_lr,
|
113
|
-
initializer=nn.initializer.Assign(np.zeros([fc_dim, F * 2])),
|
114
|
-
name=name + "_w")
|
115
|
-
bias_attr = ParamAttr(
|
116
|
-
learning_rate=loc_lr,
|
117
|
-
initializer=nn.initializer.Assign(initial_bias),
|
118
|
-
name=name + "_b")
|
119
|
-
self.fc2 = nn.Linear(
|
120
|
-
fc_dim,
|
121
|
-
F * 2,
|
122
|
-
weight_attr=param_attr,
|
123
|
-
bias_attr=bias_attr,
|
124
|
-
name=name)
|
125
|
-
self.out_channels = F * 2
|
126
|
-
|
127
|
-
def forward(self, x):
|
128
|
-
"""
|
129
|
-
Estimating parameters of geometric transformation
|
130
|
-
Args:
|
131
|
-
image: input
|
132
|
-
Return:
|
133
|
-
batch_C_prime: the matrix of the geometric transformation
|
134
|
-
"""
|
135
|
-
B = x.shape[0]
|
136
|
-
i = 0
|
137
|
-
for block in self.block_list:
|
138
|
-
x = block(x)
|
139
|
-
x = x.squeeze(axis=2).squeeze(axis=2)
|
140
|
-
x = self.fc1(x)
|
141
|
-
|
142
|
-
x = F.relu(x)
|
143
|
-
x = self.fc2(x)
|
144
|
-
x = x.reshape(shape=[-1, self.F, 2])
|
145
|
-
return x
|
146
|
-
|
147
|
-
def get_initial_fiducials(self):
|
148
|
-
""" see RARE paper Fig. 6 (a) """
|
149
|
-
F = self.F
|
150
|
-
ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
|
151
|
-
ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
|
152
|
-
ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
|
153
|
-
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
154
|
-
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
155
|
-
initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
|
156
|
-
return initial_bias
|
157
|
-
|
158
|
-
|
159
|
-
class GridGenerator(nn.Layer):
|
160
|
-
def __init__(self, in_channels, num_fiducial):
|
161
|
-
super(GridGenerator, self).__init__()
|
162
|
-
self.eps = 1e-6
|
163
|
-
self.F = num_fiducial
|
164
|
-
|
165
|
-
name = "ex_fc"
|
166
|
-
initializer = nn.initializer.Constant(value=0.0)
|
167
|
-
param_attr = ParamAttr(
|
168
|
-
learning_rate=0.0, initializer=initializer, name=name + "_w")
|
169
|
-
bias_attr = ParamAttr(
|
170
|
-
learning_rate=0.0, initializer=initializer, name=name + "_b")
|
171
|
-
self.fc = nn.Linear(
|
172
|
-
in_channels,
|
173
|
-
6,
|
174
|
-
weight_attr=param_attr,
|
175
|
-
bias_attr=bias_attr,
|
176
|
-
name=name)
|
177
|
-
|
178
|
-
def forward(self, batch_C_prime, I_r_size):
|
179
|
-
"""
|
180
|
-
Generate the grid for the grid_sampler.
|
181
|
-
Args:
|
182
|
-
batch_C_prime: the matrix of the geometric transformation
|
183
|
-
I_r_size: the shape of the input image
|
184
|
-
Return:
|
185
|
-
batch_P_prime: the grid for the grid_sampler
|
186
|
-
"""
|
187
|
-
C = self.build_C_paddle()
|
188
|
-
P = self.build_P_paddle(I_r_size)
|
189
|
-
|
190
|
-
inv_delta_C_tensor = self.build_inv_delta_C_paddle(C).astype('float32')
|
191
|
-
P_hat_tensor = self.build_P_hat_paddle(
|
192
|
-
C, paddle.to_tensor(P)).astype('float32')
|
193
|
-
|
194
|
-
inv_delta_C_tensor.stop_gradient = True
|
195
|
-
P_hat_tensor.stop_gradient = True
|
196
|
-
|
197
|
-
batch_C_ex_part_tensor = self.get_expand_tensor(batch_C_prime)
|
198
|
-
|
199
|
-
batch_C_ex_part_tensor.stop_gradient = True
|
200
|
-
|
201
|
-
batch_C_prime_with_zeros = paddle.concat(
|
202
|
-
[batch_C_prime, batch_C_ex_part_tensor], axis=1)
|
203
|
-
batch_T = paddle.matmul(inv_delta_C_tensor, batch_C_prime_with_zeros)
|
204
|
-
batch_P_prime = paddle.matmul(P_hat_tensor, batch_T)
|
205
|
-
return batch_P_prime
|
206
|
-
|
207
|
-
def build_C_paddle(self):
|
208
|
-
""" Return coordinates of fiducial points in I_r; C """
|
209
|
-
F = self.F
|
210
|
-
ctrl_pts_x = paddle.linspace(-1.0, 1.0, int(F / 2), dtype='float64')
|
211
|
-
ctrl_pts_y_top = -1 * paddle.ones([int(F / 2)], dtype='float64')
|
212
|
-
ctrl_pts_y_bottom = paddle.ones([int(F / 2)], dtype='float64')
|
213
|
-
ctrl_pts_top = paddle.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
214
|
-
ctrl_pts_bottom = paddle.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
215
|
-
C = paddle.concat([ctrl_pts_top, ctrl_pts_bottom], axis=0)
|
216
|
-
return C # F x 2
|
217
|
-
|
218
|
-
def build_P_paddle(self, I_r_size):
|
219
|
-
I_r_height, I_r_width = I_r_size
|
220
|
-
I_r_grid_x = (paddle.arange(
|
221
|
-
-I_r_width, I_r_width, 2, dtype='float64') + 1.0
|
222
|
-
) / paddle.to_tensor(np.array([I_r_width]))
|
223
|
-
|
224
|
-
I_r_grid_y = (paddle.arange(
|
225
|
-
-I_r_height, I_r_height, 2, dtype='float64') + 1.0
|
226
|
-
) / paddle.to_tensor(np.array([I_r_height]))
|
227
|
-
|
228
|
-
# P: self.I_r_width x self.I_r_height x 2
|
229
|
-
P = paddle.stack(paddle.meshgrid(I_r_grid_x, I_r_grid_y), axis=2)
|
230
|
-
P = paddle.transpose(P, perm=[1, 0, 2])
|
231
|
-
# n (= self.I_r_width x self.I_r_height) x 2
|
232
|
-
return P.reshape([-1, 2])
|
233
|
-
|
234
|
-
def build_inv_delta_C_paddle(self, C):
|
235
|
-
""" Return inv_delta_C which is needed to calculate T """
|
236
|
-
F = self.F
|
237
|
-
hat_eye = paddle.eye(F, dtype='float64') # F x F
|
238
|
-
hat_C = paddle.norm(
|
239
|
-
C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
|
240
|
-
hat_C = (hat_C**2) * paddle.log(hat_C)
|
241
|
-
delta_C = paddle.concat( # F+3 x F+3
|
242
|
-
[
|
243
|
-
paddle.concat(
|
244
|
-
[paddle.ones(
|
245
|
-
(F, 1), dtype='float64'), C, hat_C], axis=1), # F x F+3
|
246
|
-
paddle.concat(
|
247
|
-
[
|
248
|
-
paddle.zeros(
|
249
|
-
(2, 3), dtype='float64'), paddle.transpose(
|
250
|
-
C, perm=[1, 0])
|
251
|
-
],
|
252
|
-
axis=1), # 2 x F+3
|
253
|
-
paddle.concat(
|
254
|
-
[
|
255
|
-
paddle.zeros(
|
256
|
-
(1, 3), dtype='float64'), paddle.ones(
|
257
|
-
(1, F), dtype='float64')
|
258
|
-
],
|
259
|
-
axis=1) # 1 x F+3
|
260
|
-
],
|
261
|
-
axis=0)
|
262
|
-
inv_delta_C = paddle.inverse(delta_C)
|
263
|
-
return inv_delta_C # F+3 x F+3
|
264
|
-
|
265
|
-
def build_P_hat_paddle(self, C, P):
|
266
|
-
F = self.F
|
267
|
-
eps = self.eps
|
268
|
-
n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
|
269
|
-
# P_tile: n x 2 -> n x 1 x 2 -> n x F x 2
|
270
|
-
P_tile = paddle.tile(paddle.unsqueeze(P, axis=1), (1, F, 1))
|
271
|
-
C_tile = paddle.unsqueeze(C, axis=0) # 1 x F x 2
|
272
|
-
P_diff = P_tile - C_tile # n x F x 2
|
273
|
-
# rbf_norm: n x F
|
274
|
-
rbf_norm = paddle.norm(P_diff, p=2, axis=2, keepdim=False)
|
275
|
-
|
276
|
-
# rbf: n x F
|
277
|
-
rbf = paddle.multiply(
|
278
|
-
paddle.square(rbf_norm), paddle.log(rbf_norm + eps))
|
279
|
-
P_hat = paddle.concat(
|
280
|
-
[paddle.ones(
|
281
|
-
(n, 1), dtype='float64'), P, rbf], axis=1)
|
282
|
-
return P_hat # n x F+3
|
283
|
-
|
284
|
-
def get_expand_tensor(self, batch_C_prime):
|
285
|
-
B, H, C = batch_C_prime.shape
|
286
|
-
batch_C_prime = batch_C_prime.reshape([B, H * C])
|
287
|
-
batch_C_ex_part_tensor = self.fc(batch_C_prime)
|
288
|
-
batch_C_ex_part_tensor = batch_C_ex_part_tensor.reshape([-1, 3, 2])
|
289
|
-
return batch_C_ex_part_tensor
|
290
|
-
|
291
|
-
|
292
|
-
class TPS(nn.Layer):
|
293
|
-
def __init__(self, in_channels, num_fiducial, loc_lr, model_name):
|
294
|
-
super(TPS, self).__init__()
|
295
|
-
self.loc_net = LocalizationNetwork(in_channels, num_fiducial, loc_lr,
|
296
|
-
model_name)
|
297
|
-
self.grid_generator = GridGenerator(self.loc_net.out_channels,
|
298
|
-
num_fiducial)
|
299
|
-
self.out_channels = in_channels
|
300
|
-
|
301
|
-
def forward(self, image):
|
302
|
-
image.stop_gradient = False
|
303
|
-
batch_C_prime = self.loc_net(image)
|
304
|
-
batch_P_prime = self.grid_generator(batch_C_prime, image.shape[2:])
|
305
|
-
batch_P_prime = batch_P_prime.reshape(
|
306
|
-
[-1, image.shape[2], image.shape[3], 2])
|
307
|
-
batch_I_r = F.grid_sample(x=image, grid=batch_P_prime)
|
308
|
-
return batch_I_r
|