pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyxllib/algo/geo.py +12 -0
- pyxllib/algo/intervals.py +1 -1
- pyxllib/algo/matcher.py +78 -0
- pyxllib/algo/pupil.py +187 -19
- pyxllib/algo/specialist.py +2 -1
- pyxllib/algo/stat.py +38 -2
- {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
- pyxllib/autogui/activewin.py +246 -0
- pyxllib/autogui/all.py +9 -0
- pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
- pyxllib/autogui/uiautolib.py +362 -0
- pyxllib/autogui/wechat.py +827 -0
- pyxllib/autogui/wechat_msg.py +421 -0
- pyxllib/autogui/wxautolib.py +84 -0
- pyxllib/cv/slidercaptcha.py +137 -0
- pyxllib/data/echarts.py +123 -12
- pyxllib/data/jsonlib.py +89 -0
- pyxllib/data/pglib.py +514 -30
- pyxllib/data/sqlite.py +231 -4
- pyxllib/ext/JLineViewer.py +14 -1
- pyxllib/ext/drissionlib.py +277 -0
- pyxllib/ext/kq5034lib.py +0 -1594
- pyxllib/ext/robustprocfile.py +497 -0
- pyxllib/ext/unixlib.py +6 -5
- pyxllib/ext/utools.py +108 -95
- pyxllib/ext/webhook.py +32 -14
- pyxllib/ext/wjxlib.py +88 -0
- pyxllib/ext/wpsapi.py +124 -0
- pyxllib/ext/xlwork.py +9 -0
- pyxllib/ext/yuquelib.py +1003 -71
- pyxllib/file/docxlib.py +1 -1
- pyxllib/file/libreoffice.py +165 -0
- pyxllib/file/movielib.py +9 -0
- pyxllib/file/packlib/__init__.py +112 -75
- pyxllib/file/pdflib.py +1 -1
- pyxllib/file/pupil.py +1 -1
- pyxllib/file/specialist/dirlib.py +1 -1
- pyxllib/file/specialist/download.py +10 -3
- pyxllib/file/specialist/filelib.py +266 -55
- pyxllib/file/xlsxlib.py +205 -50
- pyxllib/file/xlsyncfile.py +341 -0
- pyxllib/prog/cachetools.py +64 -0
- pyxllib/prog/filelock.py +42 -0
- pyxllib/prog/multiprogs.py +940 -0
- pyxllib/prog/newbie.py +9 -2
- pyxllib/prog/pupil.py +129 -60
- pyxllib/prog/specialist/__init__.py +176 -2
- pyxllib/prog/specialist/bc.py +5 -2
- pyxllib/prog/specialist/browser.py +11 -2
- pyxllib/prog/specialist/datetime.py +68 -0
- pyxllib/prog/specialist/tictoc.py +12 -13
- pyxllib/prog/specialist/xllog.py +5 -5
- pyxllib/prog/xlosenv.py +7 -0
- pyxllib/text/airscript.js +744 -0
- pyxllib/text/charclasslib.py +17 -5
- pyxllib/text/jiebalib.py +6 -3
- pyxllib/text/jinjalib.py +32 -0
- pyxllib/text/jsa_ai_prompt.md +271 -0
- pyxllib/text/jscode.py +159 -4
- pyxllib/text/nestenv.py +1 -1
- pyxllib/text/newbie.py +12 -0
- pyxllib/text/pupil/common.py +26 -0
- pyxllib/text/specialist/ptag.py +2 -2
- pyxllib/text/templates/echart_base.html +11 -0
- pyxllib/text/templates/highlight_code.html +17 -0
- pyxllib/text/templates/latex_editor.html +103 -0
- pyxllib/text/xmllib.py +76 -14
- pyxllib/xl.py +2 -1
- pyxllib-0.3.197.dist-info/METADATA +48 -0
- pyxllib-0.3.197.dist-info/RECORD +126 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
- pyxllib/ext/autogui/__init__.py +0 -8
- pyxllib-0.3.96.dist-info/METADATA +0 -51
- pyxllib-0.3.96.dist-info/RECORD +0 -333
- pyxllib-0.3.96.dist-info/top_level.txt +0 -2
- pyxlpr/ai/__init__.py +0 -5
- pyxlpr/ai/clientlib.py +0 -1281
- pyxlpr/ai/specialist.py +0 -286
- pyxlpr/ai/torch_app.py +0 -172
- pyxlpr/ai/xlpaddle.py +0 -655
- pyxlpr/ai/xltorch.py +0 -705
- pyxlpr/data/__init__.py +0 -11
- pyxlpr/data/coco.py +0 -1325
- pyxlpr/data/datacls.py +0 -365
- pyxlpr/data/datasets.py +0 -200
- pyxlpr/data/gptlib.py +0 -1291
- pyxlpr/data/icdar/__init__.py +0 -96
- pyxlpr/data/icdar/deteval.py +0 -377
- pyxlpr/data/icdar/icdar2013.py +0 -341
- pyxlpr/data/icdar/iou.py +0 -340
- pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
- pyxlpr/data/imtextline.py +0 -473
- pyxlpr/data/labelme.py +0 -866
- pyxlpr/data/removeline.py +0 -179
- pyxlpr/data/specialist.py +0 -57
- pyxlpr/eval/__init__.py +0 -85
- pyxlpr/paddleocr.py +0 -776
- pyxlpr/ppocr/__init__.py +0 -15
- pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
- pyxlpr/ppocr/data/__init__.py +0 -135
- pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
- pyxlpr/ppocr/data/imaug/__init__.py +0 -67
- pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
- pyxlpr/ppocr/data/imaug/east_process.py +0 -437
- pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
- pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
- pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
- pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
- pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
- pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
- pyxlpr/ppocr/data/imaug/operators.py +0 -433
- pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
- pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
- pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
- pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
- pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
- pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
- pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
- pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
- pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
- pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
- pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
- pyxlpr/ppocr/data/simple_dataset.py +0 -372
- pyxlpr/ppocr/losses/__init__.py +0 -61
- pyxlpr/ppocr/losses/ace_loss.py +0 -52
- pyxlpr/ppocr/losses/basic_loss.py +0 -135
- pyxlpr/ppocr/losses/center_loss.py +0 -88
- pyxlpr/ppocr/losses/cls_loss.py +0 -30
- pyxlpr/ppocr/losses/combined_loss.py +0 -67
- pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
- pyxlpr/ppocr/losses/det_db_loss.py +0 -80
- pyxlpr/ppocr/losses/det_east_loss.py +0 -63
- pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
- pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
- pyxlpr/ppocr/losses/distillation_loss.py +0 -272
- pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
- pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
- pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
- pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
- pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
- pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
- pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
- pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
- pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
- pyxlpr/ppocr/losses/table_att_loss.py +0 -109
- pyxlpr/ppocr/metrics/__init__.py +0 -44
- pyxlpr/ppocr/metrics/cls_metric.py +0 -45
- pyxlpr/ppocr/metrics/det_metric.py +0 -82
- pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
- pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
- pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
- pyxlpr/ppocr/metrics/kie_metric.py +0 -70
- pyxlpr/ppocr/metrics/rec_metric.py +0 -75
- pyxlpr/ppocr/metrics/table_metric.py +0 -50
- pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
- pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
- pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
- pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
- pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
- pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
- pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
- pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
- pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
- pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
- pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
- pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
- pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
- pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
- pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
- pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
- pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
- pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
- pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
- pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
- pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
- pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
- pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
- pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
- pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
- pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
- pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
- pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
- pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
- pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
- pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
- pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
- pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
- pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
- pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
- pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
- pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
- pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
- pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
- pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
- pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
- pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
- pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
- pyxlpr/ppocr/optimizer/__init__.py +0 -61
- pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
- pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
- pyxlpr/ppocr/optimizer/optimizer.py +0 -160
- pyxlpr/ppocr/optimizer/regularizer.py +0 -52
- pyxlpr/ppocr/postprocess/__init__.py +0 -55
- pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
- pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
- pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
- pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
- pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
- pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
- pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
- pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
- pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
- pyxlpr/ppocr/tools/__init__.py +0 -14
- pyxlpr/ppocr/tools/eval.py +0 -83
- pyxlpr/ppocr/tools/export_center.py +0 -77
- pyxlpr/ppocr/tools/export_model.py +0 -129
- pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
- pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
- pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
- pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
- pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
- pyxlpr/ppocr/tools/infer/utility.py +0 -629
- pyxlpr/ppocr/tools/infer_cls.py +0 -83
- pyxlpr/ppocr/tools/infer_det.py +0 -134
- pyxlpr/ppocr/tools/infer_e2e.py +0 -122
- pyxlpr/ppocr/tools/infer_kie.py +0 -153
- pyxlpr/ppocr/tools/infer_rec.py +0 -146
- pyxlpr/ppocr/tools/infer_table.py +0 -107
- pyxlpr/ppocr/tools/program.py +0 -596
- pyxlpr/ppocr/tools/test_hubserving.py +0 -117
- pyxlpr/ppocr/tools/train.py +0 -163
- pyxlpr/ppocr/tools/xlprog.py +0 -748
- pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
- pyxlpr/ppocr/utils/__init__.py +0 -24
- pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
- pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
- pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
- pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
- pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
- pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
- pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
- pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
- pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
- pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
- pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
- pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
- pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
- pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
- pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
- pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
- pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
- pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
- pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
- pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
- pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
- pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
- pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
- pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
- pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
- pyxlpr/ppocr/utils/dict90.txt +0 -90
- pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
- pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
- pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
- pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
- pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
- pyxlpr/ppocr/utils/en_dict.txt +0 -95
- pyxlpr/ppocr/utils/gen_label.py +0 -81
- pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
- pyxlpr/ppocr/utils/iou.py +0 -54
- pyxlpr/ppocr/utils/logging.py +0 -69
- pyxlpr/ppocr/utils/network.py +0 -84
- pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
- pyxlpr/ppocr/utils/profiler.py +0 -110
- pyxlpr/ppocr/utils/save_load.py +0 -150
- pyxlpr/ppocr/utils/stats.py +0 -72
- pyxlpr/ppocr/utils/utility.py +0 -80
- pyxlpr/ppstructure/__init__.py +0 -13
- pyxlpr/ppstructure/predict_system.py +0 -187
- pyxlpr/ppstructure/table/__init__.py +0 -13
- pyxlpr/ppstructure/table/eval_table.py +0 -72
- pyxlpr/ppstructure/table/matcher.py +0 -192
- pyxlpr/ppstructure/table/predict_structure.py +0 -136
- pyxlpr/ppstructure/table/predict_table.py +0 -221
- pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
- pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
- pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
- pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
- pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
- pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
- pyxlpr/ppstructure/utility.py +0 -71
- pyxlpr/xlai.py +0 -10
- /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
@@ -1,121 +0,0 @@
|
|
1
|
-
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
|
19
|
-
import paddle
|
20
|
-
from paddle import nn
|
21
|
-
from .det_basic_loss import DiceLoss
|
22
|
-
import numpy as np
|
23
|
-
|
24
|
-
|
25
|
-
class SASTLoss(nn.Layer):
|
26
|
-
"""
|
27
|
-
"""
|
28
|
-
|
29
|
-
def __init__(self, eps=1e-6, **kwargs):
|
30
|
-
super(SASTLoss, self).__init__()
|
31
|
-
self.dice_loss = DiceLoss(eps=eps)
|
32
|
-
|
33
|
-
def forward(self, predicts, labels):
|
34
|
-
"""
|
35
|
-
tcl_pos: N x 128 x 3
|
36
|
-
tcl_mask: N x 128 x 1
|
37
|
-
tcl_label: N x X list or LoDTensor
|
38
|
-
"""
|
39
|
-
|
40
|
-
f_score = predicts['f_score']
|
41
|
-
f_border = predicts['f_border']
|
42
|
-
f_tvo = predicts['f_tvo']
|
43
|
-
f_tco = predicts['f_tco']
|
44
|
-
|
45
|
-
l_score, l_border, l_mask, l_tvo, l_tco = labels[1:]
|
46
|
-
|
47
|
-
#score_loss
|
48
|
-
intersection = paddle.sum(f_score * l_score * l_mask)
|
49
|
-
union = paddle.sum(f_score * l_mask) + paddle.sum(l_score * l_mask)
|
50
|
-
score_loss = 1.0 - 2 * intersection / (union + 1e-5)
|
51
|
-
|
52
|
-
#border loss
|
53
|
-
l_border_split, l_border_norm = paddle.split(
|
54
|
-
l_border, num_or_sections=[4, 1], axis=1)
|
55
|
-
f_border_split = f_border
|
56
|
-
border_ex_shape = l_border_norm.shape * np.array([1, 4, 1, 1])
|
57
|
-
l_border_norm_split = paddle.expand(
|
58
|
-
x=l_border_norm, shape=border_ex_shape)
|
59
|
-
l_border_score = paddle.expand(x=l_score, shape=border_ex_shape)
|
60
|
-
l_border_mask = paddle.expand(x=l_mask, shape=border_ex_shape)
|
61
|
-
|
62
|
-
border_diff = l_border_split - f_border_split
|
63
|
-
abs_border_diff = paddle.abs(border_diff)
|
64
|
-
border_sign = abs_border_diff < 1.0
|
65
|
-
border_sign = paddle.cast(border_sign, dtype='float32')
|
66
|
-
border_sign.stop_gradient = True
|
67
|
-
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
|
68
|
-
(abs_border_diff - 0.5) * (1.0 - border_sign)
|
69
|
-
border_out_loss = l_border_norm_split * border_in_loss
|
70
|
-
border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
|
71
|
-
(paddle.sum(l_border_score * l_border_mask) + 1e-5)
|
72
|
-
|
73
|
-
#tvo_loss
|
74
|
-
l_tvo_split, l_tvo_norm = paddle.split(
|
75
|
-
l_tvo, num_or_sections=[8, 1], axis=1)
|
76
|
-
f_tvo_split = f_tvo
|
77
|
-
tvo_ex_shape = l_tvo_norm.shape * np.array([1, 8, 1, 1])
|
78
|
-
l_tvo_norm_split = paddle.expand(x=l_tvo_norm, shape=tvo_ex_shape)
|
79
|
-
l_tvo_score = paddle.expand(x=l_score, shape=tvo_ex_shape)
|
80
|
-
l_tvo_mask = paddle.expand(x=l_mask, shape=tvo_ex_shape)
|
81
|
-
#
|
82
|
-
tvo_geo_diff = l_tvo_split - f_tvo_split
|
83
|
-
abs_tvo_geo_diff = paddle.abs(tvo_geo_diff)
|
84
|
-
tvo_sign = abs_tvo_geo_diff < 1.0
|
85
|
-
tvo_sign = paddle.cast(tvo_sign, dtype='float32')
|
86
|
-
tvo_sign.stop_gradient = True
|
87
|
-
tvo_in_loss = 0.5 * abs_tvo_geo_diff * abs_tvo_geo_diff * tvo_sign + \
|
88
|
-
(abs_tvo_geo_diff - 0.5) * (1.0 - tvo_sign)
|
89
|
-
tvo_out_loss = l_tvo_norm_split * tvo_in_loss
|
90
|
-
tvo_loss = paddle.sum(tvo_out_loss * l_tvo_score * l_tvo_mask) / \
|
91
|
-
(paddle.sum(l_tvo_score * l_tvo_mask) + 1e-5)
|
92
|
-
|
93
|
-
#tco_loss
|
94
|
-
l_tco_split, l_tco_norm = paddle.split(
|
95
|
-
l_tco, num_or_sections=[2, 1], axis=1)
|
96
|
-
f_tco_split = f_tco
|
97
|
-
tco_ex_shape = l_tco_norm.shape * np.array([1, 2, 1, 1])
|
98
|
-
l_tco_norm_split = paddle.expand(x=l_tco_norm, shape=tco_ex_shape)
|
99
|
-
l_tco_score = paddle.expand(x=l_score, shape=tco_ex_shape)
|
100
|
-
l_tco_mask = paddle.expand(x=l_mask, shape=tco_ex_shape)
|
101
|
-
|
102
|
-
tco_geo_diff = l_tco_split - f_tco_split
|
103
|
-
abs_tco_geo_diff = paddle.abs(tco_geo_diff)
|
104
|
-
tco_sign = abs_tco_geo_diff < 1.0
|
105
|
-
tco_sign = paddle.cast(tco_sign, dtype='float32')
|
106
|
-
tco_sign.stop_gradient = True
|
107
|
-
tco_in_loss = 0.5 * abs_tco_geo_diff * abs_tco_geo_diff * tco_sign + \
|
108
|
-
(abs_tco_geo_diff - 0.5) * (1.0 - tco_sign)
|
109
|
-
tco_out_loss = l_tco_norm_split * tco_in_loss
|
110
|
-
tco_loss = paddle.sum(tco_out_loss * l_tco_score * l_tco_mask) / \
|
111
|
-
(paddle.sum(l_tco_score * l_tco_mask) + 1e-5)
|
112
|
-
|
113
|
-
# total loss
|
114
|
-
tvo_lw, tco_lw = 1.5, 1.5
|
115
|
-
score_lw, border_lw = 1.0, 1.0
|
116
|
-
total_loss = score_loss * score_lw + border_loss * border_lw + \
|
117
|
-
tvo_loss * tvo_lw + tco_loss * tco_lw
|
118
|
-
|
119
|
-
losses = {'loss':total_loss, "score_loss":score_loss,\
|
120
|
-
"border_loss":border_loss, 'tvo_loss':tvo_loss, 'tco_loss':tco_loss}
|
121
|
-
return losses
|
@@ -1,272 +0,0 @@
|
|
1
|
-
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
#Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
#you may not use this file except in compliance with the License.
|
5
|
-
#You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
#Unless required by applicable law or agreed to in writing, software
|
10
|
-
#distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
#See the License for the specific language governing permissions and
|
13
|
-
#limitations under the License.
|
14
|
-
|
15
|
-
import paddle
|
16
|
-
import paddle.nn as nn
|
17
|
-
import numpy as np
|
18
|
-
import cv2
|
19
|
-
|
20
|
-
from .rec_ctc_loss import CTCLoss
|
21
|
-
from .basic_loss import DMLLoss
|
22
|
-
from .basic_loss import DistanceLoss
|
23
|
-
from .det_db_loss import DBLoss
|
24
|
-
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
25
|
-
|
26
|
-
|
27
|
-
def _sum_loss(loss_dict):
|
28
|
-
if "loss" in loss_dict.keys():
|
29
|
-
return loss_dict
|
30
|
-
else:
|
31
|
-
loss_dict["loss"] = 0.
|
32
|
-
for k, value in loss_dict.items():
|
33
|
-
if k == "loss":
|
34
|
-
continue
|
35
|
-
else:
|
36
|
-
loss_dict["loss"] += value
|
37
|
-
return loss_dict
|
38
|
-
|
39
|
-
|
40
|
-
class DistillationDMLLoss(DMLLoss):
|
41
|
-
"""
|
42
|
-
"""
|
43
|
-
|
44
|
-
def __init__(self,
|
45
|
-
model_name_pairs=[],
|
46
|
-
act=None,
|
47
|
-
use_log=False,
|
48
|
-
key=None,
|
49
|
-
maps_name=None,
|
50
|
-
name="dml"):
|
51
|
-
super().__init__(act=act, use_log=use_log)
|
52
|
-
assert isinstance(model_name_pairs, list)
|
53
|
-
self.key = key
|
54
|
-
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
55
|
-
self.name = name
|
56
|
-
self.maps_name = self._check_maps_name(maps_name)
|
57
|
-
|
58
|
-
def _check_model_name_pairs(self, model_name_pairs):
|
59
|
-
if not isinstance(model_name_pairs, list):
|
60
|
-
return []
|
61
|
-
elif isinstance(model_name_pairs[0], list) and isinstance(
|
62
|
-
model_name_pairs[0][0], str):
|
63
|
-
return model_name_pairs
|
64
|
-
else:
|
65
|
-
return [model_name_pairs]
|
66
|
-
|
67
|
-
def _check_maps_name(self, maps_name):
|
68
|
-
if maps_name is None:
|
69
|
-
return None
|
70
|
-
elif type(maps_name) == str:
|
71
|
-
return [maps_name]
|
72
|
-
elif type(maps_name) == list:
|
73
|
-
return [maps_name]
|
74
|
-
else:
|
75
|
-
return None
|
76
|
-
|
77
|
-
def _slice_out(self, outs):
|
78
|
-
new_outs = {}
|
79
|
-
for k in self.maps_name:
|
80
|
-
if k == "thrink_maps":
|
81
|
-
new_outs[k] = outs[:, 0, :, :]
|
82
|
-
elif k == "threshold_maps":
|
83
|
-
new_outs[k] = outs[:, 1, :, :]
|
84
|
-
elif k == "binary_maps":
|
85
|
-
new_outs[k] = outs[:, 2, :, :]
|
86
|
-
else:
|
87
|
-
continue
|
88
|
-
return new_outs
|
89
|
-
|
90
|
-
def forward(self, predicts, batch):
|
91
|
-
loss_dict = dict()
|
92
|
-
for idx, pair in enumerate(self.model_name_pairs):
|
93
|
-
out1 = predicts[pair[0]]
|
94
|
-
out2 = predicts[pair[1]]
|
95
|
-
if self.key is not None:
|
96
|
-
out1 = out1[self.key]
|
97
|
-
out2 = out2[self.key]
|
98
|
-
|
99
|
-
if self.maps_name is None:
|
100
|
-
loss = super().forward(out1, out2)
|
101
|
-
if isinstance(loss, dict):
|
102
|
-
for key in loss:
|
103
|
-
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
|
104
|
-
idx)] = loss[key]
|
105
|
-
else:
|
106
|
-
loss_dict["{}_{}".format(self.name, idx)] = loss
|
107
|
-
else:
|
108
|
-
outs1 = self._slice_out(out1)
|
109
|
-
outs2 = self._slice_out(out2)
|
110
|
-
for _c, k in enumerate(outs1.keys()):
|
111
|
-
loss = super().forward(outs1[k], outs2[k])
|
112
|
-
if isinstance(loss, dict):
|
113
|
-
for key in loss:
|
114
|
-
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
115
|
-
0], pair[1], self.maps_name, idx)] = loss[key]
|
116
|
-
else:
|
117
|
-
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
|
118
|
-
_c], idx)] = loss
|
119
|
-
|
120
|
-
loss_dict = _sum_loss(loss_dict)
|
121
|
-
|
122
|
-
return loss_dict
|
123
|
-
|
124
|
-
|
125
|
-
class DistillationCTCLoss(CTCLoss):
|
126
|
-
def __init__(self, model_name_list=[], key=None, name="loss_ctc"):
|
127
|
-
super().__init__()
|
128
|
-
self.model_name_list = model_name_list
|
129
|
-
self.key = key
|
130
|
-
self.name = name
|
131
|
-
|
132
|
-
def forward(self, predicts, batch):
|
133
|
-
loss_dict = dict()
|
134
|
-
for idx, model_name in enumerate(self.model_name_list):
|
135
|
-
out = predicts[model_name]
|
136
|
-
if self.key is not None:
|
137
|
-
out = out[self.key]
|
138
|
-
loss = super().forward(out, batch)
|
139
|
-
if isinstance(loss, dict):
|
140
|
-
for key in loss:
|
141
|
-
loss_dict["{}_{}_{}".format(self.name, model_name,
|
142
|
-
idx)] = loss[key]
|
143
|
-
else:
|
144
|
-
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
145
|
-
return loss_dict
|
146
|
-
|
147
|
-
|
148
|
-
class DistillationDBLoss(DBLoss):
|
149
|
-
def __init__(self,
|
150
|
-
model_name_list=[],
|
151
|
-
balance_loss=True,
|
152
|
-
main_loss_type='DiceLoss',
|
153
|
-
alpha=5,
|
154
|
-
beta=10,
|
155
|
-
ohem_ratio=3,
|
156
|
-
eps=1e-6,
|
157
|
-
name="db",
|
158
|
-
**kwargs):
|
159
|
-
super().__init__()
|
160
|
-
self.model_name_list = model_name_list
|
161
|
-
self.name = name
|
162
|
-
self.key = None
|
163
|
-
|
164
|
-
def forward(self, predicts, batch):
|
165
|
-
loss_dict = {}
|
166
|
-
for idx, model_name in enumerate(self.model_name_list):
|
167
|
-
out = predicts[model_name]
|
168
|
-
if self.key is not None:
|
169
|
-
out = out[self.key]
|
170
|
-
loss = super().forward(out, batch)
|
171
|
-
|
172
|
-
if isinstance(loss, dict):
|
173
|
-
for key in loss.keys():
|
174
|
-
if key == "loss":
|
175
|
-
continue
|
176
|
-
name = "{}_{}_{}".format(self.name, model_name, key)
|
177
|
-
loss_dict[name] = loss[key]
|
178
|
-
else:
|
179
|
-
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
180
|
-
|
181
|
-
loss_dict = _sum_loss(loss_dict)
|
182
|
-
return loss_dict
|
183
|
-
|
184
|
-
|
185
|
-
class DistillationDilaDBLoss(DBLoss):
|
186
|
-
def __init__(self,
|
187
|
-
model_name_pairs=[],
|
188
|
-
key=None,
|
189
|
-
balance_loss=True,
|
190
|
-
main_loss_type='DiceLoss',
|
191
|
-
alpha=5,
|
192
|
-
beta=10,
|
193
|
-
ohem_ratio=3,
|
194
|
-
eps=1e-6,
|
195
|
-
name="dila_dbloss"):
|
196
|
-
super().__init__()
|
197
|
-
self.model_name_pairs = model_name_pairs
|
198
|
-
self.name = name
|
199
|
-
self.key = key
|
200
|
-
|
201
|
-
def forward(self, predicts, batch):
|
202
|
-
loss_dict = dict()
|
203
|
-
for idx, pair in enumerate(self.model_name_pairs):
|
204
|
-
stu_outs = predicts[pair[0]]
|
205
|
-
tch_outs = predicts[pair[1]]
|
206
|
-
if self.key is not None:
|
207
|
-
stu_preds = stu_outs[self.key]
|
208
|
-
tch_preds = tch_outs[self.key]
|
209
|
-
|
210
|
-
stu_shrink_maps = stu_preds[:, 0, :, :]
|
211
|
-
stu_binary_maps = stu_preds[:, 2, :, :]
|
212
|
-
|
213
|
-
# dilation to teacher prediction
|
214
|
-
dilation_w = np.array([[1, 1], [1, 1]])
|
215
|
-
th_shrink_maps = tch_preds[:, 0, :, :]
|
216
|
-
th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
|
217
|
-
dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
|
218
|
-
for i in range(th_shrink_maps.shape[0]):
|
219
|
-
dilate_maps[i] = cv2.dilate(
|
220
|
-
th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
|
221
|
-
th_shrink_maps = paddle.to_tensor(dilate_maps)
|
222
|
-
|
223
|
-
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[
|
224
|
-
1:]
|
225
|
-
|
226
|
-
# calculate the shrink map loss
|
227
|
-
bce_loss = self.alpha * self.bce_loss(
|
228
|
-
stu_shrink_maps, th_shrink_maps, label_shrink_mask)
|
229
|
-
loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
|
230
|
-
label_shrink_mask)
|
231
|
-
|
232
|
-
# k = f"{self.name}_{pair[0]}_{pair[1]}"
|
233
|
-
k = "{}_{}_{}".format(self.name, pair[0], pair[1])
|
234
|
-
loss_dict[k] = bce_loss + loss_binary_maps
|
235
|
-
|
236
|
-
loss_dict = _sum_loss(loss_dict)
|
237
|
-
return loss_dict
|
238
|
-
|
239
|
-
|
240
|
-
class DistillationDistanceLoss(DistanceLoss):
|
241
|
-
"""
|
242
|
-
"""
|
243
|
-
|
244
|
-
def __init__(self,
|
245
|
-
mode="l2",
|
246
|
-
model_name_pairs=[],
|
247
|
-
key=None,
|
248
|
-
name="loss_distance",
|
249
|
-
**kargs):
|
250
|
-
super().__init__(mode=mode, **kargs)
|
251
|
-
assert isinstance(model_name_pairs, list)
|
252
|
-
self.key = key
|
253
|
-
self.model_name_pairs = model_name_pairs
|
254
|
-
self.name = name + "_l2"
|
255
|
-
|
256
|
-
def forward(self, predicts, batch):
|
257
|
-
loss_dict = dict()
|
258
|
-
for idx, pair in enumerate(self.model_name_pairs):
|
259
|
-
out1 = predicts[pair[0]]
|
260
|
-
out2 = predicts[pair[1]]
|
261
|
-
if self.key is not None:
|
262
|
-
out1 = out1[self.key]
|
263
|
-
out2 = out2[self.key]
|
264
|
-
loss = super().forward(out1, out2)
|
265
|
-
if isinstance(loss, dict):
|
266
|
-
for key in loss:
|
267
|
-
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
|
268
|
-
key]
|
269
|
-
else:
|
270
|
-
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
|
271
|
-
idx)] = loss
|
272
|
-
return loss_dict
|
@@ -1,140 +0,0 @@
|
|
1
|
-
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
|
19
|
-
from paddle import nn
|
20
|
-
import paddle
|
21
|
-
|
22
|
-
from .det_basic_loss import DiceLoss
|
23
|
-
from pyxlpr.ppocr.utils.e2e_utils.extract_batchsize import pre_process
|
24
|
-
|
25
|
-
|
26
|
-
class PGLoss(nn.Layer):
|
27
|
-
def __init__(self,
|
28
|
-
tcl_bs,
|
29
|
-
max_text_length,
|
30
|
-
max_text_nums,
|
31
|
-
pad_num,
|
32
|
-
eps=1e-6,
|
33
|
-
**kwargs):
|
34
|
-
super(PGLoss, self).__init__()
|
35
|
-
self.tcl_bs = tcl_bs
|
36
|
-
self.max_text_nums = max_text_nums
|
37
|
-
self.max_text_length = max_text_length
|
38
|
-
self.pad_num = pad_num
|
39
|
-
self.dice_loss = DiceLoss(eps=eps)
|
40
|
-
|
41
|
-
def border_loss(self, f_border, l_border, l_score, l_mask):
|
42
|
-
l_border_split, l_border_norm = paddle.tensor.split(
|
43
|
-
l_border, num_or_sections=[4, 1], axis=1)
|
44
|
-
f_border_split = f_border
|
45
|
-
b, c, h, w = l_border_norm.shape
|
46
|
-
l_border_norm_split = paddle.expand(
|
47
|
-
x=l_border_norm, shape=[b, 4 * c, h, w])
|
48
|
-
b, c, h, w = l_score.shape
|
49
|
-
l_border_score = paddle.expand(x=l_score, shape=[b, 4 * c, h, w])
|
50
|
-
b, c, h, w = l_mask.shape
|
51
|
-
l_border_mask = paddle.expand(x=l_mask, shape=[b, 4 * c, h, w])
|
52
|
-
border_diff = l_border_split - f_border_split
|
53
|
-
abs_border_diff = paddle.abs(border_diff)
|
54
|
-
border_sign = abs_border_diff < 1.0
|
55
|
-
border_sign = paddle.cast(border_sign, dtype='float32')
|
56
|
-
border_sign.stop_gradient = True
|
57
|
-
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
|
58
|
-
(abs_border_diff - 0.5) * (1.0 - border_sign)
|
59
|
-
border_out_loss = l_border_norm_split * border_in_loss
|
60
|
-
border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
|
61
|
-
(paddle.sum(l_border_score * l_border_mask) + 1e-5)
|
62
|
-
return border_loss
|
63
|
-
|
64
|
-
def direction_loss(self, f_direction, l_direction, l_score, l_mask):
|
65
|
-
l_direction_split, l_direction_norm = paddle.tensor.split(
|
66
|
-
l_direction, num_or_sections=[2, 1], axis=1)
|
67
|
-
f_direction_split = f_direction
|
68
|
-
b, c, h, w = l_direction_norm.shape
|
69
|
-
l_direction_norm_split = paddle.expand(
|
70
|
-
x=l_direction_norm, shape=[b, 2 * c, h, w])
|
71
|
-
b, c, h, w = l_score.shape
|
72
|
-
l_direction_score = paddle.expand(x=l_score, shape=[b, 2 * c, h, w])
|
73
|
-
b, c, h, w = l_mask.shape
|
74
|
-
l_direction_mask = paddle.expand(x=l_mask, shape=[b, 2 * c, h, w])
|
75
|
-
direction_diff = l_direction_split - f_direction_split
|
76
|
-
abs_direction_diff = paddle.abs(direction_diff)
|
77
|
-
direction_sign = abs_direction_diff < 1.0
|
78
|
-
direction_sign = paddle.cast(direction_sign, dtype='float32')
|
79
|
-
direction_sign.stop_gradient = True
|
80
|
-
direction_in_loss = 0.5 * abs_direction_diff * abs_direction_diff * direction_sign + \
|
81
|
-
(abs_direction_diff - 0.5) * (1.0 - direction_sign)
|
82
|
-
direction_out_loss = l_direction_norm_split * direction_in_loss
|
83
|
-
direction_loss = paddle.sum(direction_out_loss * l_direction_score * l_direction_mask) / \
|
84
|
-
(paddle.sum(l_direction_score * l_direction_mask) + 1e-5)
|
85
|
-
return direction_loss
|
86
|
-
|
87
|
-
def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t):
|
88
|
-
f_char = paddle.transpose(f_char, [0, 2, 3, 1])
|
89
|
-
tcl_pos = paddle.reshape(tcl_pos, [-1, 3])
|
90
|
-
tcl_pos = paddle.cast(tcl_pos, dtype=int)
|
91
|
-
f_tcl_char = paddle.gather_nd(f_char, tcl_pos)
|
92
|
-
f_tcl_char = paddle.reshape(f_tcl_char,
|
93
|
-
[-1, 64, 37]) # len(Lexicon_Table)+1
|
94
|
-
f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2)
|
95
|
-
f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0
|
96
|
-
b, c, l = tcl_mask.shape
|
97
|
-
tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l])
|
98
|
-
tcl_mask_fg.stop_gradient = True
|
99
|
-
f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (
|
100
|
-
-20.0)
|
101
|
-
f_tcl_char_mask = paddle.concat([f_tcl_char_fg, f_tcl_char_bg], axis=2)
|
102
|
-
f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2))
|
103
|
-
N, B, _ = f_tcl_char_ld.shape
|
104
|
-
input_lengths = paddle.to_tensor([N] * B, dtype='int64')
|
105
|
-
cost = paddle.nn.functional.ctc_loss(
|
106
|
-
log_probs=f_tcl_char_ld,
|
107
|
-
labels=tcl_label,
|
108
|
-
input_lengths=input_lengths,
|
109
|
-
label_lengths=label_t,
|
110
|
-
blank=self.pad_num,
|
111
|
-
reduction='none')
|
112
|
-
cost = cost.mean()
|
113
|
-
return cost
|
114
|
-
|
115
|
-
def forward(self, predicts, labels):
|
116
|
-
images, tcl_maps, tcl_label_maps, border_maps \
|
117
|
-
, direction_maps, training_masks, label_list, pos_list, pos_mask = labels
|
118
|
-
# for all the batch_size
|
119
|
-
pos_list, pos_mask, label_list, label_t = pre_process(
|
120
|
-
label_list, pos_list, pos_mask, self.max_text_length,
|
121
|
-
self.max_text_nums, self.pad_num, self.tcl_bs)
|
122
|
-
|
123
|
-
f_score, f_border, f_direction, f_char = predicts['f_score'], predicts['f_border'], predicts['f_direction'], \
|
124
|
-
predicts['f_char']
|
125
|
-
score_loss = self.dice_loss(f_score, tcl_maps, training_masks)
|
126
|
-
border_loss = self.border_loss(f_border, border_maps, tcl_maps,
|
127
|
-
training_masks)
|
128
|
-
direction_loss = self.direction_loss(f_direction, direction_maps,
|
129
|
-
tcl_maps, training_masks)
|
130
|
-
ctc_loss = self.ctcloss(f_char, pos_list, pos_mask, label_list, label_t)
|
131
|
-
loss_all = score_loss + border_loss + direction_loss + 5 * ctc_loss
|
132
|
-
|
133
|
-
losses = {
|
134
|
-
'loss': loss_all,
|
135
|
-
"score_loss": score_loss,
|
136
|
-
"border_loss": border_loss,
|
137
|
-
"direction_loss": direction_loss,
|
138
|
-
"ctc_loss": ctc_loss
|
139
|
-
}
|
140
|
-
return losses
|
@@ -1,113 +0,0 @@
|
|
1
|
-
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
|
19
|
-
from paddle import nn
|
20
|
-
import paddle
|
21
|
-
|
22
|
-
|
23
|
-
class SDMGRLoss(nn.Layer):
|
24
|
-
def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=0):
|
25
|
-
super().__init__()
|
26
|
-
self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore)
|
27
|
-
self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1)
|
28
|
-
self.node_weight = node_weight
|
29
|
-
self.edge_weight = edge_weight
|
30
|
-
self.ignore = ignore
|
31
|
-
|
32
|
-
def pre_process(self, gts, tag):
|
33
|
-
gts, tag = gts.numpy(), tag.numpy().tolist()
|
34
|
-
temp_gts = []
|
35
|
-
batch = len(tag)
|
36
|
-
for i in range(batch):
|
37
|
-
num, recoder_len = tag[i][0], tag[i][1]
|
38
|
-
temp_gts.append(
|
39
|
-
paddle.to_tensor(
|
40
|
-
gts[i, :num, :num + 1], dtype='int64'))
|
41
|
-
return temp_gts
|
42
|
-
|
43
|
-
def accuracy(self, pred, target, topk=1, thresh=None):
|
44
|
-
"""Calculate accuracy according to the prediction and target.
|
45
|
-
|
46
|
-
Args:
|
47
|
-
pred (torch.Tensor): The model prediction, shape (N, num_class)
|
48
|
-
target (torch.Tensor): The target of each prediction, shape (N, )
|
49
|
-
topk (int | tuple[int], optional): If the predictions in ``topk``
|
50
|
-
matches the target, the predictions will be regarded as
|
51
|
-
correct ones. Defaults to 1.
|
52
|
-
thresh (float, optional): If not None, predictions with scores under
|
53
|
-
this threshold are considered incorrect. Default to None.
|
54
|
-
|
55
|
-
Returns:
|
56
|
-
float | tuple[float]: If the input ``topk`` is a single integer,
|
57
|
-
the function will return a single float as accuracy. If
|
58
|
-
``topk`` is a tuple containing multiple integers, the
|
59
|
-
function will return a tuple containing accuracies of
|
60
|
-
each ``topk`` number.
|
61
|
-
"""
|
62
|
-
assert isinstance(topk, (int, tuple))
|
63
|
-
if isinstance(topk, int):
|
64
|
-
topk = (topk, )
|
65
|
-
return_single = True
|
66
|
-
else:
|
67
|
-
return_single = False
|
68
|
-
|
69
|
-
maxk = max(topk)
|
70
|
-
if pred.shape[0] == 0:
|
71
|
-
accu = [pred.new_tensor(0.) for i in range(len(topk))]
|
72
|
-
return accu[0] if return_single else accu
|
73
|
-
pred_value, pred_label = paddle.topk(pred, maxk, axis=1)
|
74
|
-
pred_label = pred_label.transpose(
|
75
|
-
[1, 0]) # transpose to shape (maxk, N)
|
76
|
-
correct = paddle.equal(pred_label,
|
77
|
-
(target.reshape([1, -1]).expand_as(pred_label)))
|
78
|
-
res = []
|
79
|
-
for k in topk:
|
80
|
-
correct_k = paddle.sum(correct[:k].reshape([-1]).astype('float32'),
|
81
|
-
axis=0,
|
82
|
-
keepdim=True)
|
83
|
-
res.append(
|
84
|
-
paddle.multiply(correct_k,
|
85
|
-
paddle.to_tensor(100.0 / pred.shape[0])))
|
86
|
-
return res[0] if return_single else res
|
87
|
-
|
88
|
-
def forward(self, pred, batch):
|
89
|
-
node_preds, edge_preds = pred
|
90
|
-
gts, tag = batch[4], batch[5]
|
91
|
-
gts = self.pre_process(gts, tag)
|
92
|
-
node_gts, edge_gts = [], []
|
93
|
-
for gt in gts:
|
94
|
-
node_gts.append(gt[:, 0])
|
95
|
-
edge_gts.append(gt[:, 1:].reshape([-1]))
|
96
|
-
node_gts = paddle.concat(node_gts)
|
97
|
-
edge_gts = paddle.concat(edge_gts)
|
98
|
-
|
99
|
-
node_valids = paddle.nonzero(node_gts != self.ignore).reshape([-1])
|
100
|
-
edge_valids = paddle.nonzero(edge_gts != -1).reshape([-1])
|
101
|
-
loss_node = self.loss_node(node_preds, node_gts)
|
102
|
-
loss_edge = self.loss_edge(edge_preds, edge_gts)
|
103
|
-
loss = self.node_weight * loss_node + self.edge_weight * loss_edge
|
104
|
-
return dict(
|
105
|
-
loss=loss,
|
106
|
-
loss_node=loss_node,
|
107
|
-
loss_edge=loss_edge,
|
108
|
-
acc_node=self.accuracy(
|
109
|
-
paddle.gather(node_preds, node_valids),
|
110
|
-
paddle.gather(node_gts, node_valids)),
|
111
|
-
acc_edge=self.accuracy(
|
112
|
-
paddle.gather(edge_preds, edge_valids),
|
113
|
-
paddle.gather(edge_gts, edge_valids)))
|