pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyxllib/algo/geo.py +12 -0
- pyxllib/algo/intervals.py +1 -1
- pyxllib/algo/matcher.py +78 -0
- pyxllib/algo/pupil.py +187 -19
- pyxllib/algo/specialist.py +2 -1
- pyxllib/algo/stat.py +38 -2
- {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
- pyxllib/autogui/activewin.py +246 -0
- pyxllib/autogui/all.py +9 -0
- pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
- pyxllib/autogui/uiautolib.py +362 -0
- pyxllib/autogui/wechat.py +827 -0
- pyxllib/autogui/wechat_msg.py +421 -0
- pyxllib/autogui/wxautolib.py +84 -0
- pyxllib/cv/slidercaptcha.py +137 -0
- pyxllib/data/echarts.py +123 -12
- pyxllib/data/jsonlib.py +89 -0
- pyxllib/data/pglib.py +514 -30
- pyxllib/data/sqlite.py +231 -4
- pyxllib/ext/JLineViewer.py +14 -1
- pyxllib/ext/drissionlib.py +277 -0
- pyxllib/ext/kq5034lib.py +0 -1594
- pyxllib/ext/robustprocfile.py +497 -0
- pyxllib/ext/unixlib.py +6 -5
- pyxllib/ext/utools.py +108 -95
- pyxllib/ext/webhook.py +32 -14
- pyxllib/ext/wjxlib.py +88 -0
- pyxllib/ext/wpsapi.py +124 -0
- pyxllib/ext/xlwork.py +9 -0
- pyxllib/ext/yuquelib.py +1003 -71
- pyxllib/file/docxlib.py +1 -1
- pyxllib/file/libreoffice.py +165 -0
- pyxllib/file/movielib.py +9 -0
- pyxllib/file/packlib/__init__.py +112 -75
- pyxllib/file/pdflib.py +1 -1
- pyxllib/file/pupil.py +1 -1
- pyxllib/file/specialist/dirlib.py +1 -1
- pyxllib/file/specialist/download.py +10 -3
- pyxllib/file/specialist/filelib.py +266 -55
- pyxllib/file/xlsxlib.py +205 -50
- pyxllib/file/xlsyncfile.py +341 -0
- pyxllib/prog/cachetools.py +64 -0
- pyxllib/prog/filelock.py +42 -0
- pyxllib/prog/multiprogs.py +940 -0
- pyxllib/prog/newbie.py +9 -2
- pyxllib/prog/pupil.py +129 -60
- pyxllib/prog/specialist/__init__.py +176 -2
- pyxllib/prog/specialist/bc.py +5 -2
- pyxllib/prog/specialist/browser.py +11 -2
- pyxllib/prog/specialist/datetime.py +68 -0
- pyxllib/prog/specialist/tictoc.py +12 -13
- pyxllib/prog/specialist/xllog.py +5 -5
- pyxllib/prog/xlosenv.py +7 -0
- pyxllib/text/airscript.js +744 -0
- pyxllib/text/charclasslib.py +17 -5
- pyxllib/text/jiebalib.py +6 -3
- pyxllib/text/jinjalib.py +32 -0
- pyxllib/text/jsa_ai_prompt.md +271 -0
- pyxllib/text/jscode.py +159 -4
- pyxllib/text/nestenv.py +1 -1
- pyxllib/text/newbie.py +12 -0
- pyxllib/text/pupil/common.py +26 -0
- pyxllib/text/specialist/ptag.py +2 -2
- pyxllib/text/templates/echart_base.html +11 -0
- pyxllib/text/templates/highlight_code.html +17 -0
- pyxllib/text/templates/latex_editor.html +103 -0
- pyxllib/text/xmllib.py +76 -14
- pyxllib/xl.py +2 -1
- pyxllib-0.3.197.dist-info/METADATA +48 -0
- pyxllib-0.3.197.dist-info/RECORD +126 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
- pyxllib/ext/autogui/__init__.py +0 -8
- pyxllib-0.3.96.dist-info/METADATA +0 -51
- pyxllib-0.3.96.dist-info/RECORD +0 -333
- pyxllib-0.3.96.dist-info/top_level.txt +0 -2
- pyxlpr/ai/__init__.py +0 -5
- pyxlpr/ai/clientlib.py +0 -1281
- pyxlpr/ai/specialist.py +0 -286
- pyxlpr/ai/torch_app.py +0 -172
- pyxlpr/ai/xlpaddle.py +0 -655
- pyxlpr/ai/xltorch.py +0 -705
- pyxlpr/data/__init__.py +0 -11
- pyxlpr/data/coco.py +0 -1325
- pyxlpr/data/datacls.py +0 -365
- pyxlpr/data/datasets.py +0 -200
- pyxlpr/data/gptlib.py +0 -1291
- pyxlpr/data/icdar/__init__.py +0 -96
- pyxlpr/data/icdar/deteval.py +0 -377
- pyxlpr/data/icdar/icdar2013.py +0 -341
- pyxlpr/data/icdar/iou.py +0 -340
- pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
- pyxlpr/data/imtextline.py +0 -473
- pyxlpr/data/labelme.py +0 -866
- pyxlpr/data/removeline.py +0 -179
- pyxlpr/data/specialist.py +0 -57
- pyxlpr/eval/__init__.py +0 -85
- pyxlpr/paddleocr.py +0 -776
- pyxlpr/ppocr/__init__.py +0 -15
- pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
- pyxlpr/ppocr/data/__init__.py +0 -135
- pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
- pyxlpr/ppocr/data/imaug/__init__.py +0 -67
- pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
- pyxlpr/ppocr/data/imaug/east_process.py +0 -437
- pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
- pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
- pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
- pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
- pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
- pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
- pyxlpr/ppocr/data/imaug/operators.py +0 -433
- pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
- pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
- pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
- pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
- pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
- pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
- pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
- pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
- pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
- pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
- pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
- pyxlpr/ppocr/data/simple_dataset.py +0 -372
- pyxlpr/ppocr/losses/__init__.py +0 -61
- pyxlpr/ppocr/losses/ace_loss.py +0 -52
- pyxlpr/ppocr/losses/basic_loss.py +0 -135
- pyxlpr/ppocr/losses/center_loss.py +0 -88
- pyxlpr/ppocr/losses/cls_loss.py +0 -30
- pyxlpr/ppocr/losses/combined_loss.py +0 -67
- pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
- pyxlpr/ppocr/losses/det_db_loss.py +0 -80
- pyxlpr/ppocr/losses/det_east_loss.py +0 -63
- pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
- pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
- pyxlpr/ppocr/losses/distillation_loss.py +0 -272
- pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
- pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
- pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
- pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
- pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
- pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
- pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
- pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
- pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
- pyxlpr/ppocr/losses/table_att_loss.py +0 -109
- pyxlpr/ppocr/metrics/__init__.py +0 -44
- pyxlpr/ppocr/metrics/cls_metric.py +0 -45
- pyxlpr/ppocr/metrics/det_metric.py +0 -82
- pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
- pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
- pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
- pyxlpr/ppocr/metrics/kie_metric.py +0 -70
- pyxlpr/ppocr/metrics/rec_metric.py +0 -75
- pyxlpr/ppocr/metrics/table_metric.py +0 -50
- pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
- pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
- pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
- pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
- pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
- pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
- pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
- pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
- pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
- pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
- pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
- pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
- pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
- pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
- pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
- pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
- pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
- pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
- pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
- pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
- pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
- pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
- pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
- pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
- pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
- pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
- pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
- pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
- pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
- pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
- pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
- pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
- pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
- pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
- pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
- pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
- pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
- pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
- pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
- pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
- pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
- pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
- pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
- pyxlpr/ppocr/optimizer/__init__.py +0 -61
- pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
- pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
- pyxlpr/ppocr/optimizer/optimizer.py +0 -160
- pyxlpr/ppocr/optimizer/regularizer.py +0 -52
- pyxlpr/ppocr/postprocess/__init__.py +0 -55
- pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
- pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
- pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
- pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
- pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
- pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
- pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
- pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
- pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
- pyxlpr/ppocr/tools/__init__.py +0 -14
- pyxlpr/ppocr/tools/eval.py +0 -83
- pyxlpr/ppocr/tools/export_center.py +0 -77
- pyxlpr/ppocr/tools/export_model.py +0 -129
- pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
- pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
- pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
- pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
- pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
- pyxlpr/ppocr/tools/infer/utility.py +0 -629
- pyxlpr/ppocr/tools/infer_cls.py +0 -83
- pyxlpr/ppocr/tools/infer_det.py +0 -134
- pyxlpr/ppocr/tools/infer_e2e.py +0 -122
- pyxlpr/ppocr/tools/infer_kie.py +0 -153
- pyxlpr/ppocr/tools/infer_rec.py +0 -146
- pyxlpr/ppocr/tools/infer_table.py +0 -107
- pyxlpr/ppocr/tools/program.py +0 -596
- pyxlpr/ppocr/tools/test_hubserving.py +0 -117
- pyxlpr/ppocr/tools/train.py +0 -163
- pyxlpr/ppocr/tools/xlprog.py +0 -748
- pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
- pyxlpr/ppocr/utils/__init__.py +0 -24
- pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
- pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
- pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
- pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
- pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
- pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
- pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
- pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
- pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
- pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
- pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
- pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
- pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
- pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
- pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
- pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
- pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
- pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
- pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
- pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
- pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
- pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
- pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
- pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
- pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
- pyxlpr/ppocr/utils/dict90.txt +0 -90
- pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
- pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
- pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
- pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
- pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
- pyxlpr/ppocr/utils/en_dict.txt +0 -95
- pyxlpr/ppocr/utils/gen_label.py +0 -81
- pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
- pyxlpr/ppocr/utils/iou.py +0 -54
- pyxlpr/ppocr/utils/logging.py +0 -69
- pyxlpr/ppocr/utils/network.py +0 -84
- pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
- pyxlpr/ppocr/utils/profiler.py +0 -110
- pyxlpr/ppocr/utils/save_load.py +0 -150
- pyxlpr/ppocr/utils/stats.py +0 -72
- pyxlpr/ppocr/utils/utility.py +0 -80
- pyxlpr/ppstructure/__init__.py +0 -13
- pyxlpr/ppstructure/predict_system.py +0 -187
- pyxlpr/ppstructure/table/__init__.py +0 -13
- pyxlpr/ppstructure/table/eval_table.py +0 -72
- pyxlpr/ppstructure/table/matcher.py +0 -192
- pyxlpr/ppstructure/table/predict_structure.py +0 -136
- pyxlpr/ppstructure/table/predict_table.py +0 -221
- pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
- pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
- pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
- pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
- pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
- pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
- pyxlpr/ppstructure/utility.py +0 -71
- pyxlpr/xlai.py +0 -10
- /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
@@ -1,88 +0,0 @@
|
|
1
|
-
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
#Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
#you may not use this file except in compliance with the License.
|
5
|
-
#You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
#Unless required by applicable law or agreed to in writing, software
|
10
|
-
#distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
#See the License for the specific language governing permissions and
|
13
|
-
#limitations under the License.
|
14
|
-
|
15
|
-
# This code is refer from: https://github.com/KaiyangZhou/pytorch-center-loss
|
16
|
-
|
17
|
-
from __future__ import absolute_import
|
18
|
-
from __future__ import division
|
19
|
-
from __future__ import print_function
|
20
|
-
import os
|
21
|
-
import pickle
|
22
|
-
|
23
|
-
import paddle
|
24
|
-
import paddle.nn as nn
|
25
|
-
import paddle.nn.functional as F
|
26
|
-
|
27
|
-
|
28
|
-
class CenterLoss(nn.Layer):
|
29
|
-
"""
|
30
|
-
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
|
31
|
-
"""
|
32
|
-
|
33
|
-
def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None):
|
34
|
-
super().__init__()
|
35
|
-
self.num_classes = num_classes
|
36
|
-
self.feat_dim = feat_dim
|
37
|
-
self.centers = paddle.randn(
|
38
|
-
shape=[self.num_classes, self.feat_dim]).astype("float64")
|
39
|
-
|
40
|
-
if center_file_path is not None:
|
41
|
-
assert os.path.exists(
|
42
|
-
center_file_path
|
43
|
-
), f"center path({center_file_path}) must exist when it is not None."
|
44
|
-
with open(center_file_path, 'rb') as f:
|
45
|
-
char_dict = pickle.load(f)
|
46
|
-
for key in char_dict.keys():
|
47
|
-
self.centers[key] = paddle.to_tensor(char_dict[key])
|
48
|
-
|
49
|
-
def __call__(self, predicts, batch):
|
50
|
-
assert isinstance(predicts, (list, tuple))
|
51
|
-
features, predicts = predicts
|
52
|
-
|
53
|
-
feats_reshape = paddle.reshape(
|
54
|
-
features, [-1, features.shape[-1]]).astype("float64")
|
55
|
-
label = paddle.argmax(predicts, axis=2)
|
56
|
-
label = paddle.reshape(label, [label.shape[0] * label.shape[1]])
|
57
|
-
|
58
|
-
batch_size = feats_reshape.shape[0]
|
59
|
-
|
60
|
-
#calc l2 distance between feats and centers
|
61
|
-
square_feat = paddle.sum(paddle.square(feats_reshape),
|
62
|
-
axis=1,
|
63
|
-
keepdim=True)
|
64
|
-
square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
|
65
|
-
|
66
|
-
square_center = paddle.sum(paddle.square(self.centers),
|
67
|
-
axis=1,
|
68
|
-
keepdim=True)
|
69
|
-
square_center = paddle.expand(
|
70
|
-
square_center, [self.num_classes, batch_size]).astype("float64")
|
71
|
-
square_center = paddle.transpose(square_center, [1, 0])
|
72
|
-
|
73
|
-
distmat = paddle.add(square_feat, square_center)
|
74
|
-
feat_dot_center = paddle.matmul(feats_reshape,
|
75
|
-
paddle.transpose(self.centers, [1, 0]))
|
76
|
-
distmat = distmat - 2.0 * feat_dot_center
|
77
|
-
|
78
|
-
#generate the mask
|
79
|
-
classes = paddle.arange(self.num_classes).astype("int64")
|
80
|
-
label = paddle.expand(
|
81
|
-
paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
|
82
|
-
mask = paddle.equal(
|
83
|
-
paddle.expand(classes, [batch_size, self.num_classes]),
|
84
|
-
label).astype("float64")
|
85
|
-
dist = paddle.multiply(distmat, mask)
|
86
|
-
|
87
|
-
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
|
88
|
-
return {'loss_center': loss}
|
pyxlpr/ppocr/losses/cls_loss.py
DELETED
@@ -1,30 +0,0 @@
|
|
1
|
-
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
|
19
|
-
from paddle import nn
|
20
|
-
|
21
|
-
|
22
|
-
class ClsLoss(nn.Layer):
|
23
|
-
def __init__(self, **kwargs):
|
24
|
-
super(ClsLoss, self).__init__()
|
25
|
-
self.loss_func = nn.CrossEntropyLoss(reduction='mean')
|
26
|
-
|
27
|
-
def forward(self, predicts, batch):
|
28
|
-
label = batch[1].astype("int64")
|
29
|
-
loss = self.loss_func(input=predicts, label=label)
|
30
|
-
return {'loss': loss}
|
@@ -1,67 +0,0 @@
|
|
1
|
-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
import paddle
|
16
|
-
import paddle.nn as nn
|
17
|
-
|
18
|
-
from .rec_ctc_loss import CTCLoss
|
19
|
-
from .center_loss import CenterLoss
|
20
|
-
from .ace_loss import ACELoss
|
21
|
-
|
22
|
-
from .distillation_loss import DistillationCTCLoss
|
23
|
-
from .distillation_loss import DistillationDMLLoss
|
24
|
-
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
|
25
|
-
|
26
|
-
|
27
|
-
class CombinedLoss(nn.Layer):
|
28
|
-
"""
|
29
|
-
CombinedLoss:
|
30
|
-
a combionation of loss function
|
31
|
-
"""
|
32
|
-
|
33
|
-
def __init__(self, loss_config_list=None):
|
34
|
-
super().__init__()
|
35
|
-
self.loss_func = []
|
36
|
-
self.loss_weight = []
|
37
|
-
assert isinstance(loss_config_list, list), (
|
38
|
-
'operator config should be a list')
|
39
|
-
for config in loss_config_list:
|
40
|
-
assert isinstance(config,
|
41
|
-
dict) and len(config) == 1, "yaml format error"
|
42
|
-
name = list(config)[0]
|
43
|
-
param = config[name]
|
44
|
-
assert "weight" in param, "weight must be in param, but param just contains {}".format(
|
45
|
-
param.keys())
|
46
|
-
self.loss_weight.append(param.pop("weight"))
|
47
|
-
self.loss_func.append(eval(name)(**param))
|
48
|
-
|
49
|
-
def forward(self, input, batch, **kargs):
|
50
|
-
loss_dict = {}
|
51
|
-
loss_all = 0.
|
52
|
-
for idx, loss_func in enumerate(self.loss_func):
|
53
|
-
loss = loss_func(input, batch, **kargs)
|
54
|
-
if isinstance(loss, paddle.Tensor):
|
55
|
-
loss = {"loss_{}_{}".format(str(loss), idx): loss}
|
56
|
-
|
57
|
-
weight = self.loss_weight[idx]
|
58
|
-
|
59
|
-
loss = {key: loss[key] * weight for key in loss}
|
60
|
-
|
61
|
-
if "loss" in loss:
|
62
|
-
loss_all += loss["loss"]
|
63
|
-
else:
|
64
|
-
loss_all += paddle.add_n(list(loss.values()))
|
65
|
-
loss_dict.update(loss)
|
66
|
-
loss_dict["loss"] = loss_all
|
67
|
-
return loss_dict
|
@@ -1,208 +0,0 @@
|
|
1
|
-
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
"""
|
15
|
-
This code is refer from:
|
16
|
-
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/basic_loss.py
|
17
|
-
"""
|
18
|
-
from __future__ import absolute_import
|
19
|
-
from __future__ import division
|
20
|
-
from __future__ import print_function
|
21
|
-
|
22
|
-
import numpy as np
|
23
|
-
|
24
|
-
import paddle
|
25
|
-
from paddle import nn
|
26
|
-
import paddle.nn.functional as F
|
27
|
-
|
28
|
-
|
29
|
-
class BalanceLoss(nn.Layer):
|
30
|
-
def __init__(self,
|
31
|
-
balance_loss=True,
|
32
|
-
main_loss_type='DiceLoss',
|
33
|
-
negative_ratio=3,
|
34
|
-
return_origin=False,
|
35
|
-
eps=1e-6,
|
36
|
-
**kwargs):
|
37
|
-
"""
|
38
|
-
The BalanceLoss for Differentiable Binarization text detection
|
39
|
-
args:
|
40
|
-
balance_loss (bool): whether balance loss or not, default is True
|
41
|
-
main_loss_type (str): can only be one of ['CrossEntropy','DiceLoss',
|
42
|
-
'Euclidean','BCELoss', 'MaskL1Loss'], default is 'DiceLoss'.
|
43
|
-
negative_ratio (int|float): float, default is 3.
|
44
|
-
return_origin (bool): whether return unbalanced loss or not, default is False.
|
45
|
-
eps (float): default is 1e-6.
|
46
|
-
"""
|
47
|
-
super(BalanceLoss, self).__init__()
|
48
|
-
self.balance_loss = balance_loss
|
49
|
-
self.main_loss_type = main_loss_type
|
50
|
-
self.negative_ratio = negative_ratio
|
51
|
-
self.return_origin = return_origin
|
52
|
-
self.eps = eps
|
53
|
-
|
54
|
-
if self.main_loss_type == "CrossEntropy":
|
55
|
-
self.loss = nn.CrossEntropyLoss()
|
56
|
-
elif self.main_loss_type == "Euclidean":
|
57
|
-
self.loss = nn.MSELoss()
|
58
|
-
elif self.main_loss_type == "DiceLoss":
|
59
|
-
self.loss = DiceLoss(self.eps)
|
60
|
-
elif self.main_loss_type == "BCELoss":
|
61
|
-
self.loss = BCELoss(reduction='none')
|
62
|
-
elif self.main_loss_type == "MaskL1Loss":
|
63
|
-
self.loss = MaskL1Loss(self.eps)
|
64
|
-
else:
|
65
|
-
loss_type = [
|
66
|
-
'CrossEntropy', 'DiceLoss', 'Euclidean', 'BCELoss', 'MaskL1Loss'
|
67
|
-
]
|
68
|
-
raise Exception(
|
69
|
-
"main_loss_type in BalanceLoss() can only be one of {}".format(
|
70
|
-
loss_type))
|
71
|
-
|
72
|
-
def forward(self, pred, gt, mask=None):
|
73
|
-
"""
|
74
|
-
The BalanceLoss for Differentiable Binarization text detection
|
75
|
-
args:
|
76
|
-
pred (variable): predicted feature maps.
|
77
|
-
gt (variable): ground truth feature maps.
|
78
|
-
mask (variable): masked maps.
|
79
|
-
return: (variable) balanced loss
|
80
|
-
"""
|
81
|
-
# if self.main_loss_type in ['DiceLoss']:
|
82
|
-
# # For the loss that returns to scalar value, perform ohem on the mask
|
83
|
-
# mask = ohem_batch(pred, gt, mask, self.negative_ratio)
|
84
|
-
# loss = self.loss(pred, gt, mask)
|
85
|
-
# return loss
|
86
|
-
|
87
|
-
positive = gt * mask
|
88
|
-
negative = (1 - gt) * mask
|
89
|
-
|
90
|
-
positive_count = int(positive.sum())
|
91
|
-
negative_count = int(
|
92
|
-
min(negative.sum(), positive_count * self.negative_ratio))
|
93
|
-
loss = self.loss(pred, gt, mask=mask)
|
94
|
-
|
95
|
-
if not self.balance_loss:
|
96
|
-
return loss
|
97
|
-
|
98
|
-
positive_loss = positive * loss
|
99
|
-
negative_loss = negative * loss
|
100
|
-
negative_loss = paddle.reshape(negative_loss, shape=[-1])
|
101
|
-
if negative_count > 0:
|
102
|
-
sort_loss = negative_loss.sort(descending=True)
|
103
|
-
negative_loss = sort_loss[:negative_count]
|
104
|
-
# negative_loss, _ = paddle.topk(negative_loss, k=negative_count_int)
|
105
|
-
balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
|
106
|
-
positive_count + negative_count + self.eps)
|
107
|
-
else:
|
108
|
-
balance_loss = positive_loss.sum() / (positive_count + self.eps)
|
109
|
-
if self.return_origin:
|
110
|
-
return balance_loss, loss
|
111
|
-
|
112
|
-
return balance_loss
|
113
|
-
|
114
|
-
|
115
|
-
class DiceLoss(nn.Layer):
|
116
|
-
def __init__(self, eps=1e-6):
|
117
|
-
super(DiceLoss, self).__init__()
|
118
|
-
self.eps = eps
|
119
|
-
|
120
|
-
def forward(self, pred, gt, mask, weights=None):
|
121
|
-
"""
|
122
|
-
DiceLoss function.
|
123
|
-
"""
|
124
|
-
|
125
|
-
assert pred.shape == gt.shape
|
126
|
-
assert pred.shape == mask.shape
|
127
|
-
if weights is not None:
|
128
|
-
assert weights.shape == mask.shape
|
129
|
-
mask = weights * mask
|
130
|
-
intersection = paddle.sum(pred * gt * mask)
|
131
|
-
|
132
|
-
union = paddle.sum(pred * mask) + paddle.sum(gt * mask) + self.eps
|
133
|
-
loss = 1 - 2.0 * intersection / union
|
134
|
-
assert loss <= 1
|
135
|
-
return loss
|
136
|
-
|
137
|
-
|
138
|
-
class MaskL1Loss(nn.Layer):
|
139
|
-
def __init__(self, eps=1e-6):
|
140
|
-
super(MaskL1Loss, self).__init__()
|
141
|
-
self.eps = eps
|
142
|
-
|
143
|
-
def forward(self, pred, gt, mask):
|
144
|
-
"""
|
145
|
-
Mask L1 Loss
|
146
|
-
"""
|
147
|
-
loss = (paddle.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
|
148
|
-
loss = paddle.mean(loss)
|
149
|
-
return loss
|
150
|
-
|
151
|
-
|
152
|
-
class BCELoss(nn.Layer):
|
153
|
-
def __init__(self, reduction='mean'):
|
154
|
-
super(BCELoss, self).__init__()
|
155
|
-
self.reduction = reduction
|
156
|
-
|
157
|
-
def forward(self, input, label, mask=None, weight=None, name=None):
|
158
|
-
loss = F.binary_cross_entropy(input, label, reduction=self.reduction)
|
159
|
-
return loss
|
160
|
-
|
161
|
-
|
162
|
-
def ohem_single(score, gt_text, training_mask, ohem_ratio):
|
163
|
-
pos_num = (int)(np.sum(gt_text > 0.5)) - (
|
164
|
-
int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))
|
165
|
-
|
166
|
-
if pos_num == 0:
|
167
|
-
# selected_mask = gt_text.copy() * 0 # may be not good
|
168
|
-
selected_mask = training_mask
|
169
|
-
selected_mask = selected_mask.reshape(
|
170
|
-
1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
|
171
|
-
return selected_mask
|
172
|
-
|
173
|
-
neg_num = (int)(np.sum(gt_text <= 0.5))
|
174
|
-
neg_num = (int)(min(pos_num * ohem_ratio, neg_num))
|
175
|
-
|
176
|
-
if neg_num == 0:
|
177
|
-
selected_mask = training_mask
|
178
|
-
selected_mask = selected_mask.reshape(
|
179
|
-
1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
|
180
|
-
return selected_mask
|
181
|
-
|
182
|
-
neg_score = score[gt_text <= 0.5]
|
183
|
-
# 将负样本得分从高到低排序
|
184
|
-
neg_score_sorted = np.sort(-neg_score)
|
185
|
-
threshold = -neg_score_sorted[neg_num - 1]
|
186
|
-
# 选出 得分高的 负样本 和正样本 的 mask
|
187
|
-
selected_mask = ((score >= threshold) |
|
188
|
-
(gt_text > 0.5)) & (training_mask > 0.5)
|
189
|
-
selected_mask = selected_mask.reshape(
|
190
|
-
1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
|
191
|
-
return selected_mask
|
192
|
-
|
193
|
-
|
194
|
-
def ohem_batch(scores, gt_texts, training_masks, ohem_ratio):
|
195
|
-
scores = scores.numpy()
|
196
|
-
gt_texts = gt_texts.numpy()
|
197
|
-
training_masks = training_masks.numpy()
|
198
|
-
|
199
|
-
selected_masks = []
|
200
|
-
for i in range(scores.shape[0]):
|
201
|
-
selected_masks.append(
|
202
|
-
ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[
|
203
|
-
i, :, :], ohem_ratio))
|
204
|
-
|
205
|
-
selected_masks = np.concatenate(selected_masks, 0)
|
206
|
-
selected_masks = paddle.to_tensor(selected_masks)
|
207
|
-
|
208
|
-
return selected_masks
|
@@ -1,80 +0,0 @@
|
|
1
|
-
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
"""
|
15
|
-
This code is refer from:
|
16
|
-
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/DB_loss.py
|
17
|
-
"""
|
18
|
-
|
19
|
-
from __future__ import absolute_import
|
20
|
-
from __future__ import division
|
21
|
-
from __future__ import print_function
|
22
|
-
|
23
|
-
from paddle import nn
|
24
|
-
|
25
|
-
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
26
|
-
|
27
|
-
|
28
|
-
class DBLoss(nn.Layer):
|
29
|
-
"""
|
30
|
-
Differentiable Binarization (DB) Loss Function
|
31
|
-
args:
|
32
|
-
param (dict): the super paramter for DB Loss
|
33
|
-
"""
|
34
|
-
|
35
|
-
def __init__(self,
|
36
|
-
balance_loss=True,
|
37
|
-
main_loss_type='DiceLoss',
|
38
|
-
alpha=5,
|
39
|
-
beta=10,
|
40
|
-
ohem_ratio=3,
|
41
|
-
eps=1e-6,
|
42
|
-
**kwargs):
|
43
|
-
super(DBLoss, self).__init__()
|
44
|
-
self.alpha = alpha
|
45
|
-
self.beta = beta
|
46
|
-
# 声明不同的损失函数
|
47
|
-
self.dice_loss = DiceLoss(eps=eps)
|
48
|
-
self.l1_loss = MaskL1Loss(eps=eps)
|
49
|
-
self.bce_loss = BalanceLoss(
|
50
|
-
balance_loss=balance_loss,
|
51
|
-
main_loss_type=main_loss_type,
|
52
|
-
negative_ratio=ohem_ratio)
|
53
|
-
|
54
|
-
def forward(self, predicts, labels):
|
55
|
-
predict_maps = predicts['maps']
|
56
|
-
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[
|
57
|
-
1:]
|
58
|
-
shrink_maps = predict_maps[:, 0, :, :]
|
59
|
-
threshold_maps = predict_maps[:, 1, :, :]
|
60
|
-
binary_maps = predict_maps[:, 2, :, :]
|
61
|
-
# 1. 针对文本预测概率图,使用二值交叉熵损失函数
|
62
|
-
loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,
|
63
|
-
label_shrink_mask)
|
64
|
-
# 2. 针对文本预测阈值图使用L1距离损失函数
|
65
|
-
loss_threshold_maps = self.l1_loss(threshold_maps, label_threshold_map,
|
66
|
-
label_threshold_mask)
|
67
|
-
# 3. 针对文本预测二值图,使用dice loss损失函数
|
68
|
-
loss_binary_maps = self.dice_loss(binary_maps, label_shrink_map,
|
69
|
-
label_shrink_mask)
|
70
|
-
# 4. 不同的损失函数乘上不同的权重
|
71
|
-
loss_shrink_maps = self.alpha * loss_shrink_maps
|
72
|
-
loss_threshold_maps = self.beta * loss_threshold_maps
|
73
|
-
|
74
|
-
loss_all = loss_shrink_maps + loss_threshold_maps \
|
75
|
-
+ loss_binary_maps
|
76
|
-
losses = {'loss': loss_all, \
|
77
|
-
"loss_shrink_maps": loss_shrink_maps, \
|
78
|
-
"loss_threshold_maps": loss_threshold_maps, \
|
79
|
-
"loss_binary_maps": loss_binary_maps}
|
80
|
-
return losses
|
@@ -1,63 +0,0 @@
|
|
1
|
-
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
|
19
|
-
import paddle
|
20
|
-
from paddle import nn
|
21
|
-
from .det_basic_loss import DiceLoss
|
22
|
-
|
23
|
-
|
24
|
-
class EASTLoss(nn.Layer):
|
25
|
-
"""
|
26
|
-
"""
|
27
|
-
|
28
|
-
def __init__(self,
|
29
|
-
eps=1e-6,
|
30
|
-
**kwargs):
|
31
|
-
super(EASTLoss, self).__init__()
|
32
|
-
self.dice_loss = DiceLoss(eps=eps)
|
33
|
-
|
34
|
-
def forward(self, predicts, labels):
|
35
|
-
l_score, l_geo, l_mask = labels[1:]
|
36
|
-
f_score = predicts['f_score']
|
37
|
-
f_geo = predicts['f_geo']
|
38
|
-
|
39
|
-
dice_loss = self.dice_loss(f_score, l_score, l_mask)
|
40
|
-
|
41
|
-
#smoooth_l1_loss
|
42
|
-
channels = 8
|
43
|
-
l_geo_split = paddle.split(
|
44
|
-
l_geo, num_or_sections=channels + 1, axis=1)
|
45
|
-
f_geo_split = paddle.split(f_geo, num_or_sections=channels, axis=1)
|
46
|
-
smooth_l1 = 0
|
47
|
-
for i in range(0, channels):
|
48
|
-
geo_diff = l_geo_split[i] - f_geo_split[i]
|
49
|
-
abs_geo_diff = paddle.abs(geo_diff)
|
50
|
-
smooth_l1_sign = paddle.less_than(abs_geo_diff, l_score)
|
51
|
-
smooth_l1_sign = paddle.cast(smooth_l1_sign, dtype='float32')
|
52
|
-
in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + \
|
53
|
-
(abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign)
|
54
|
-
out_loss = l_geo_split[-1] / channels * in_loss * l_score
|
55
|
-
smooth_l1 += out_loss
|
56
|
-
smooth_l1_loss = paddle.mean(smooth_l1 * l_score)
|
57
|
-
|
58
|
-
dice_loss = dice_loss * 0.01
|
59
|
-
total_loss = dice_loss + smooth_l1_loss
|
60
|
-
losses = {"loss":total_loss, \
|
61
|
-
"dice_loss":dice_loss,\
|
62
|
-
"smooth_l1_loss":smooth_l1_loss}
|
63
|
-
return losses
|
@@ -1,149 +0,0 @@
|
|
1
|
-
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
"""
|
15
|
-
This code is refer from:
|
16
|
-
https://github.com/whai362/PSENet/blob/python3/models/head/psenet_head.py
|
17
|
-
"""
|
18
|
-
|
19
|
-
import paddle
|
20
|
-
from paddle import nn
|
21
|
-
from paddle.nn import functional as F
|
22
|
-
import numpy as np
|
23
|
-
from pyxlpr.ppocr.utils.iou import iou
|
24
|
-
|
25
|
-
|
26
|
-
class PSELoss(nn.Layer):
|
27
|
-
def __init__(self,
|
28
|
-
alpha,
|
29
|
-
ohem_ratio=3,
|
30
|
-
kernel_sample_mask='pred',
|
31
|
-
reduction='sum',
|
32
|
-
eps=1e-6,
|
33
|
-
**kwargs):
|
34
|
-
"""Implement PSE Loss.
|
35
|
-
"""
|
36
|
-
super(PSELoss, self).__init__()
|
37
|
-
assert reduction in ['sum', 'mean', 'none']
|
38
|
-
self.alpha = alpha
|
39
|
-
self.ohem_ratio = ohem_ratio
|
40
|
-
self.kernel_sample_mask = kernel_sample_mask
|
41
|
-
self.reduction = reduction
|
42
|
-
self.eps = eps
|
43
|
-
|
44
|
-
def forward(self, outputs, labels):
|
45
|
-
predicts = outputs['maps']
|
46
|
-
predicts = F.interpolate(predicts, scale_factor=4)
|
47
|
-
|
48
|
-
texts = predicts[:, 0, :, :]
|
49
|
-
kernels = predicts[:, 1:, :, :]
|
50
|
-
gt_texts, gt_kernels, training_masks = labels[1:]
|
51
|
-
|
52
|
-
# text loss
|
53
|
-
selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
|
54
|
-
|
55
|
-
loss_text = self.dice_loss(texts, gt_texts, selected_masks)
|
56
|
-
iou_text = iou((texts > 0).astype('int64'),
|
57
|
-
gt_texts,
|
58
|
-
training_masks,
|
59
|
-
reduce=False)
|
60
|
-
losses = dict(loss_text=loss_text, iou_text=iou_text)
|
61
|
-
|
62
|
-
# kernel loss
|
63
|
-
loss_kernels = []
|
64
|
-
if self.kernel_sample_mask == 'gt':
|
65
|
-
selected_masks = gt_texts * training_masks
|
66
|
-
elif self.kernel_sample_mask == 'pred':
|
67
|
-
selected_masks = (
|
68
|
-
F.sigmoid(texts) > 0.5).astype('float32') * training_masks
|
69
|
-
|
70
|
-
for i in range(kernels.shape[1]):
|
71
|
-
kernel_i = kernels[:, i, :, :]
|
72
|
-
gt_kernel_i = gt_kernels[:, i, :, :]
|
73
|
-
loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i,
|
74
|
-
selected_masks)
|
75
|
-
loss_kernels.append(loss_kernel_i)
|
76
|
-
loss_kernels = paddle.mean(paddle.stack(loss_kernels, axis=1), axis=1)
|
77
|
-
iou_kernel = iou((kernels[:, -1, :, :] > 0).astype('int64'),
|
78
|
-
gt_kernels[:, -1, :, :],
|
79
|
-
training_masks * gt_texts,
|
80
|
-
reduce=False)
|
81
|
-
losses.update(dict(loss_kernels=loss_kernels, iou_kernel=iou_kernel))
|
82
|
-
loss = self.alpha * loss_text + (1 - self.alpha) * loss_kernels
|
83
|
-
losses['loss'] = loss
|
84
|
-
if self.reduction == 'sum':
|
85
|
-
losses = {x: paddle.sum(v) for x, v in losses.items()}
|
86
|
-
elif self.reduction == 'mean':
|
87
|
-
losses = {x: paddle.mean(v) for x, v in losses.items()}
|
88
|
-
return losses
|
89
|
-
|
90
|
-
def dice_loss(self, input, target, mask):
|
91
|
-
input = F.sigmoid(input)
|
92
|
-
|
93
|
-
input = input.reshape([input.shape[0], -1])
|
94
|
-
target = target.reshape([target.shape[0], -1])
|
95
|
-
mask = mask.reshape([mask.shape[0], -1])
|
96
|
-
|
97
|
-
input = input * mask
|
98
|
-
target = target * mask
|
99
|
-
|
100
|
-
a = paddle.sum(input * target, 1)
|
101
|
-
b = paddle.sum(input * input, 1) + self.eps
|
102
|
-
c = paddle.sum(target * target, 1) + self.eps
|
103
|
-
d = (2 * a) / (b + c)
|
104
|
-
return 1 - d
|
105
|
-
|
106
|
-
def ohem_single(self, score, gt_text, training_mask, ohem_ratio=3):
|
107
|
-
pos_num = int(paddle.sum((gt_text > 0.5).astype('float32'))) - int(
|
108
|
-
paddle.sum(
|
109
|
-
paddle.logical_and((gt_text > 0.5), (training_mask <= 0.5))
|
110
|
-
.astype('float32')))
|
111
|
-
|
112
|
-
if pos_num == 0:
|
113
|
-
selected_mask = training_mask
|
114
|
-
selected_mask = selected_mask.reshape(
|
115
|
-
[1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
|
116
|
-
'float32')
|
117
|
-
return selected_mask
|
118
|
-
|
119
|
-
neg_num = int(paddle.sum((gt_text <= 0.5).astype('float32')))
|
120
|
-
neg_num = int(min(pos_num * ohem_ratio, neg_num))
|
121
|
-
|
122
|
-
if neg_num == 0:
|
123
|
-
selected_mask = training_mask
|
124
|
-
selected_mask = selected_mask.view(
|
125
|
-
1, selected_mask.shape[0],
|
126
|
-
selected_mask.shape[1]).astype('float32')
|
127
|
-
return selected_mask
|
128
|
-
|
129
|
-
neg_score = paddle.masked_select(score, gt_text <= 0.5)
|
130
|
-
neg_score_sorted = paddle.sort(-neg_score)
|
131
|
-
threshold = -neg_score_sorted[neg_num - 1]
|
132
|
-
|
133
|
-
selected_mask = paddle.logical_and(
|
134
|
-
paddle.logical_or((score >= threshold), (gt_text > 0.5)),
|
135
|
-
(training_mask > 0.5))
|
136
|
-
selected_mask = selected_mask.reshape(
|
137
|
-
[1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
|
138
|
-
'float32')
|
139
|
-
return selected_mask
|
140
|
-
|
141
|
-
def ohem_batch(self, scores, gt_texts, training_masks, ohem_ratio=3):
|
142
|
-
selected_masks = []
|
143
|
-
for i in range(scores.shape[0]):
|
144
|
-
selected_masks.append(
|
145
|
-
self.ohem_single(scores[i, :, :], gt_texts[i, :, :],
|
146
|
-
training_masks[i, :, :], ohem_ratio))
|
147
|
-
|
148
|
-
selected_masks = paddle.concat(selected_masks, 0).astype('float32')
|
149
|
-
return selected_masks
|