openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +35 -1
- openocr/configs/dataset/rec/evaluation.yaml +41 -0
- openocr/configs/dataset/rec/ltb.yaml +9 -0
- openocr/configs/dataset/rec/mjsynth.yaml +11 -0
- openocr/configs/dataset/rec/openvino.yaml +25 -0
- openocr/configs/dataset/rec/ost.yaml +17 -0
- openocr/configs/dataset/rec/synthtext.yaml +7 -0
- openocr/configs/dataset/rec/test.yaml +77 -0
- openocr/configs/dataset/rec/textocr.yaml +13 -0
- openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
- openocr/configs/dataset/rec/union14m_b.yaml +47 -0
- openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
- openocr/configs/rec/cmer/cmer.yml +127 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
- openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
- openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
- openocr/demo_gradio.py +28 -8
- openocr/demo_opendoc.py +572 -0
- openocr/demo_unirec.py +392 -0
- openocr/opendet/losses/__init__.py +5 -7
- openocr/opendet/preprocess/crop_resize.py +2 -1
- openocr/openocr.py +685 -0
- openocr/openrec/losses/__init__.py +8 -3
- openocr/openrec/losses/cmer_loss.py +12 -0
- openocr/openrec/losses/mdiff_loss.py +11 -0
- openocr/openrec/losses/unirec_loss.py +12 -0
- openocr/openrec/metrics/__init__.py +4 -1
- openocr/openrec/metrics/rec_metric_cmer.py +328 -0
- openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
- openocr/openrec/modeling/decoders/__init__.py +1 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
- openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
- openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
- openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
- openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
- openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
- openocr/openrec/optimizer/__init__.py +4 -3
- openocr/openrec/optimizer/lr.py +49 -0
- openocr/openrec/postprocess/__init__.py +2 -0
- openocr/openrec/postprocess/abinet_postprocess.py +1 -1
- openocr/openrec/postprocess/ar_postprocess.py +1 -1
- openocr/openrec/postprocess/cmer_postprocess.py +86 -0
- openocr/openrec/postprocess/cppd_postprocess.py +1 -1
- openocr/openrec/postprocess/igtr_postprocess.py +1 -1
- openocr/openrec/postprocess/lister_postprocess.py +1 -1
- openocr/openrec/postprocess/mgp_postprocess.py +1 -1
- openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
- openocr/openrec/postprocess/smtr_postprocess.py +1 -1
- openocr/openrec/postprocess/srn_postprocess.py +1 -1
- openocr/openrec/postprocess/unirec_postprocess.py +58 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
- openocr/openrec/preprocess/__init__.py +5 -0
- openocr/openrec/preprocess/ce_label_encode.py +1 -1
- openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
- openocr/openrec/preprocess/ctc_label_encode.py +1 -1
- openocr/openrec/preprocess/dptr_label_encode.py +177 -157
- openocr/openrec/preprocess/igtr_label_encode.py +4 -2
- openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
- openocr/openrec/preprocess/rec_aug.py +128 -2
- openocr/openrec/preprocess/resize.py +57 -0
- openocr/openrec/preprocess/unirec_label_encode.py +62 -0
- openocr/tools/data/__init__.py +78 -55
- openocr/tools/data/cmer_web_dataset.py +310 -0
- openocr/tools/data/native_size_dataset.py +753 -0
- openocr/tools/data/native_size_sampler.py +158 -0
- openocr/tools/data/ratio_dataset_tvresize.py +2 -0
- openocr/tools/data/ratio_sampler.py +2 -1
- openocr/tools/download/download_dataset.py +38 -0
- openocr/tools/download/utils.py +28 -0
- openocr/tools/download_example_images.py +236 -0
- openocr/tools/engine/trainer.py +155 -39
- openocr/tools/eval_rec_all_ch.py +2 -2
- openocr/tools/infer_det.py +20 -2
- openocr/tools/infer_doc.py +898 -0
- openocr/tools/infer_doc_onnx.py +1172 -0
- openocr/tools/infer_e2e.py +27 -10
- openocr/tools/infer_rec.py +64 -15
- openocr/tools/infer_unirec_onnx.py +730 -0
- openocr/tools/to_markdown.py +468 -0
- openocr/tools/utils/ckpt.py +17 -5
- openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
- openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
- openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
- openocr_python-0.0.9.dist-info/METADATA +0 -149
- /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -15,22 +15,27 @@ name_to_module = {
|
|
|
15
15
|
'MGPLoss': '.mgp_loss',
|
|
16
16
|
'PARSeqLoss': '.parseq_loss',
|
|
17
17
|
'RobustScannerLoss': '.robustscanner_loss',
|
|
18
|
+
'SEEDLoss': '.seed_loss',
|
|
19
|
+
'SMTRLoss': '.smtr_loss',
|
|
18
20
|
'SRNLoss': '.srn_loss',
|
|
19
21
|
'VisionLANLoss': '.visionlan_loss',
|
|
20
22
|
'CAMLoss': '.cam_loss',
|
|
21
|
-
'
|
|
23
|
+
'MDiffLoss': '.mdiff_loss',
|
|
24
|
+
'UniRecLoss': '.unirec_loss',
|
|
25
|
+
'CMERLoss': '.cmer_loss',
|
|
22
26
|
}
|
|
23
27
|
|
|
24
28
|
|
|
25
29
|
def build_loss(config):
|
|
26
30
|
config = copy.deepcopy(config)
|
|
27
31
|
module_name = config.pop('name')
|
|
28
|
-
assert module_name in name_to_module, Exception(
|
|
29
|
-
'loss only support {}'.format(list(name_to_module.keys())))
|
|
30
32
|
|
|
31
33
|
if module_name in globals():
|
|
32
34
|
module_class = globals()[module_name]
|
|
33
35
|
else:
|
|
36
|
+
assert module_name in name_to_module, Exception(
|
|
37
|
+
'{} is not supported. The losses in {} are supportes'.format(
|
|
38
|
+
module_name, list(name_to_module.keys())))
|
|
34
39
|
module_path = name_to_module[module_name]
|
|
35
40
|
module = import_module(module_path, package=__package__)
|
|
36
41
|
module_class = getattr(module, module_name)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from torch import nn
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class CMERLoss(nn.Module):
|
|
5
|
+
|
|
6
|
+
def __init__(self, label_smoothing=0.1, **kwargs):
|
|
7
|
+
super(CMERLoss, self).__init__()
|
|
8
|
+
|
|
9
|
+
def forward(self, pred, batch):
|
|
10
|
+
# loss, vision_loss, text_loss = pred.loss
|
|
11
|
+
loss = {'loss': pred.loss}
|
|
12
|
+
return loss
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from torch import nn
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class UniRecLoss(nn.Module):
|
|
5
|
+
|
|
6
|
+
def __init__(self, label_smoothing=0.1, **kwargs):
|
|
7
|
+
super(UniRecLoss, self).__init__()
|
|
8
|
+
|
|
9
|
+
def forward(self, pred, batch):
|
|
10
|
+
# loss, vision_loss, text_loss = pred.loss
|
|
11
|
+
loss = {'loss': pred.loss}
|
|
12
|
+
return loss
|
|
@@ -6,8 +6,11 @@ from .rec_metric import RecMetric
|
|
|
6
6
|
from .rec_metric_gtc import RecGTCMetric
|
|
7
7
|
from .rec_metric_long import RecMetricLong
|
|
8
8
|
from .rec_metric_mgp import RecMPGMetric
|
|
9
|
+
from .rec_metric_cmer import CMERMetric
|
|
9
10
|
|
|
10
|
-
support_dict = [
|
|
11
|
+
support_dict = [
|
|
12
|
+
'RecMetric', 'RecMetricLong', 'RecGTCMetric', 'RecMPGMetric', 'CMERMetric'
|
|
13
|
+
]
|
|
11
14
|
|
|
12
15
|
|
|
13
16
|
def build_metric(config):
|
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import math
|
|
3
|
+
import collections
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
import datasets
|
|
6
|
+
import evaluate
|
|
7
|
+
from rouge_score import rouge_scorer, scoring
|
|
8
|
+
from Levenshtein import distance as levenshtein_distance
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _get_ngrams(segment, max_order):
|
|
12
|
+
ngram_counts = collections.Counter()
|
|
13
|
+
for order in range(1, max_order + 1):
|
|
14
|
+
for i in range(0, len(segment) - order + 1):
|
|
15
|
+
ngram = tuple(segment[i:i + order])
|
|
16
|
+
ngram_counts[ngram] += 1
|
|
17
|
+
return ngram_counts
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def compute_bleu(reference_corpus,
|
|
21
|
+
translation_corpus,
|
|
22
|
+
max_order=4,
|
|
23
|
+
smooth=False):
|
|
24
|
+
matches_by_order = [0] * max_order
|
|
25
|
+
possible_matches_by_order = [0] * max_order
|
|
26
|
+
reference_length = 0
|
|
27
|
+
translation_length = 0
|
|
28
|
+
for (references, translation) in zip(reference_corpus, translation_corpus):
|
|
29
|
+
reference_length += min(len(r) for r in references)
|
|
30
|
+
translation_length += len(translation)
|
|
31
|
+
merged_ref_ngram_counts = collections.Counter()
|
|
32
|
+
for reference in references:
|
|
33
|
+
merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
|
|
34
|
+
translation_ngram_counts = _get_ngrams(translation, max_order)
|
|
35
|
+
overlap = translation_ngram_counts & merged_ref_ngram_counts
|
|
36
|
+
for ngram in overlap:
|
|
37
|
+
matches_by_order[len(ngram) - 1] += overlap[ngram]
|
|
38
|
+
for order in range(1, max_order + 1):
|
|
39
|
+
possible_matches = len(translation) - order + 1
|
|
40
|
+
if possible_matches > 0:
|
|
41
|
+
possible_matches_by_order[order - 1] += possible_matches
|
|
42
|
+
precisions = [0] * max_order
|
|
43
|
+
for i in range(0, max_order):
|
|
44
|
+
if smooth:
|
|
45
|
+
precisions[i] = ((matches_by_order[i] + 1.) /
|
|
46
|
+
(possible_matches_by_order[i] + 1.))
|
|
47
|
+
else:
|
|
48
|
+
if possible_matches_by_order[i] > 0:
|
|
49
|
+
precisions[i] = (float(matches_by_order[i]) /
|
|
50
|
+
possible_matches_by_order[i])
|
|
51
|
+
else:
|
|
52
|
+
precisions[i] = 0.0
|
|
53
|
+
if min(precisions) > 0:
|
|
54
|
+
p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
|
|
55
|
+
geo_mean = math.exp(p_log_sum)
|
|
56
|
+
else:
|
|
57
|
+
geo_mean = 0
|
|
58
|
+
|
|
59
|
+
if reference_length == 0:
|
|
60
|
+
ratio = 0.0
|
|
61
|
+
else:
|
|
62
|
+
ratio = float(translation_length) / reference_length
|
|
63
|
+
|
|
64
|
+
if ratio > 1.0:
|
|
65
|
+
bp = 1.
|
|
66
|
+
elif ratio <= 0:
|
|
67
|
+
bp = 0.0
|
|
68
|
+
else:
|
|
69
|
+
bp = math.exp(1 - 1. / ratio)
|
|
70
|
+
|
|
71
|
+
bleu = geo_mean * bp
|
|
72
|
+
return (bleu, precisions, bp, ratio, translation_length, reference_length)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class BaseTokenizer:
|
|
76
|
+
|
|
77
|
+
def signature(self):
|
|
78
|
+
return 'none'
|
|
79
|
+
|
|
80
|
+
def __call__(self, line):
|
|
81
|
+
return line
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class TokenizerRegexp(BaseTokenizer):
|
|
85
|
+
|
|
86
|
+
def signature(self):
|
|
87
|
+
return 're'
|
|
88
|
+
|
|
89
|
+
def __init__(self):
|
|
90
|
+
self._re = [
|
|
91
|
+
(re.compile(r'([\{-\~[-\` -\&\(-\+\:-\@\/])'), r' \1 '),
|
|
92
|
+
(re.compile(r'([^0-9])([\.,])'), r'\1 \2 '),
|
|
93
|
+
(re.compile(r'([\.,])([^0-9])'), r' \1 \2'),
|
|
94
|
+
(re.compile(r'([0-9])(-)'), r'\1 \2 '),
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
@lru_cache(maxsize=2**16)
|
|
98
|
+
def __call__(self, line):
|
|
99
|
+
for (_re, repl) in self._re:
|
|
100
|
+
line = _re.sub(repl, line)
|
|
101
|
+
return line.split()
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class Tokenizer13a(BaseTokenizer):
|
|
105
|
+
|
|
106
|
+
def signature(self):
|
|
107
|
+
return '13a'
|
|
108
|
+
|
|
109
|
+
def __init__(self):
|
|
110
|
+
self._post_tokenizer = TokenizerRegexp()
|
|
111
|
+
|
|
112
|
+
@lru_cache(maxsize=2**16)
|
|
113
|
+
def __call__(self, line):
|
|
114
|
+
line = line.replace('<skipped>', '')
|
|
115
|
+
line = line.replace('-\n', '')
|
|
116
|
+
line = line.replace('\n', ' ')
|
|
117
|
+
if '&' in line:
|
|
118
|
+
line = line.replace('"', '"')
|
|
119
|
+
line = line.replace('&', '&')
|
|
120
|
+
line = line.replace('<', '<')
|
|
121
|
+
line = line.replace('>', '>')
|
|
122
|
+
return self._post_tokenizer(f' {line} ')
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class CustomBleu(evaluate.Metric):
|
|
126
|
+
|
|
127
|
+
def _info(self):
|
|
128
|
+
return evaluate.MetricInfo(
|
|
129
|
+
description='Custom BLEU implementation',
|
|
130
|
+
citation='',
|
|
131
|
+
inputs_description='',
|
|
132
|
+
features=datasets.Features({
|
|
133
|
+
'predictions':
|
|
134
|
+
datasets.Value('string', id='sequence'),
|
|
135
|
+
'references':
|
|
136
|
+
datasets.Sequence(datasets.Value('string', id='sequence'),
|
|
137
|
+
id='references'),
|
|
138
|
+
}),
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def _compute(self,
|
|
142
|
+
predictions,
|
|
143
|
+
references,
|
|
144
|
+
tokenizer=None,
|
|
145
|
+
max_order=4,
|
|
146
|
+
smooth=False):
|
|
147
|
+
if tokenizer is None:
|
|
148
|
+
tokenizer = Tokenizer13a()
|
|
149
|
+
|
|
150
|
+
if isinstance(references[0], str):
|
|
151
|
+
references = [[ref] for ref in references]
|
|
152
|
+
references_tokenized = [[tokenizer(r) for r in ref]
|
|
153
|
+
for ref in references]
|
|
154
|
+
predictions_tokenized = [tokenizer(p) for p in predictions]
|
|
155
|
+
score = compute_bleu(reference_corpus=references_tokenized,
|
|
156
|
+
translation_corpus=predictions_tokenized,
|
|
157
|
+
max_order=max_order,
|
|
158
|
+
smooth=smooth)
|
|
159
|
+
(bleu, precisions, bp, ratio, translation_length,
|
|
160
|
+
reference_length) = score
|
|
161
|
+
return {
|
|
162
|
+
'bleu': bleu,
|
|
163
|
+
'precisions': precisions,
|
|
164
|
+
'brevity_penalty': bp,
|
|
165
|
+
'length_ratio': ratio,
|
|
166
|
+
'translation_length': translation_length,
|
|
167
|
+
'reference_length': reference_length,
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class CustomRougeTokenizer:
|
|
172
|
+
|
|
173
|
+
def __init__(self, tokenizer_func):
|
|
174
|
+
self.tokenizer_func = tokenizer_func
|
|
175
|
+
|
|
176
|
+
def tokenize(self, text):
|
|
177
|
+
return self.tokenizer_func(text)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class CustomRouge(evaluate.Metric):
|
|
181
|
+
|
|
182
|
+
def _info(self):
|
|
183
|
+
return evaluate.MetricInfo(
|
|
184
|
+
description='Custom ROUGE implementation',
|
|
185
|
+
citation='',
|
|
186
|
+
inputs_description='',
|
|
187
|
+
features=datasets.Features({
|
|
188
|
+
'predictions':
|
|
189
|
+
datasets.Value('string', id='sequence'),
|
|
190
|
+
'references':
|
|
191
|
+
datasets.Sequence(datasets.Value('string', id='sequence')),
|
|
192
|
+
}),
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def _compute(self,
|
|
196
|
+
predictions,
|
|
197
|
+
references,
|
|
198
|
+
rouge_types=None,
|
|
199
|
+
use_aggregator=True,
|
|
200
|
+
use_stemmer=False,
|
|
201
|
+
tokenizer=None):
|
|
202
|
+
if rouge_types is None:
|
|
203
|
+
rouge_types = ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']
|
|
204
|
+
|
|
205
|
+
multi_ref = isinstance(references[0], list)
|
|
206
|
+
|
|
207
|
+
if tokenizer is not None:
|
|
208
|
+
tokenizer = CustomRougeTokenizer(tokenizer)
|
|
209
|
+
|
|
210
|
+
scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types,
|
|
211
|
+
use_stemmer=use_stemmer,
|
|
212
|
+
tokenizer=tokenizer)
|
|
213
|
+
|
|
214
|
+
if use_aggregator:
|
|
215
|
+
aggregator = scoring.BootstrapAggregator()
|
|
216
|
+
else:
|
|
217
|
+
scores = []
|
|
218
|
+
|
|
219
|
+
for ref, pred in zip(references, predictions):
|
|
220
|
+
if multi_ref:
|
|
221
|
+
score = scorer.score_multi(ref, pred)
|
|
222
|
+
else:
|
|
223
|
+
score = scorer.score(ref, pred)
|
|
224
|
+
if use_aggregator:
|
|
225
|
+
aggregator.add_scores(score)
|
|
226
|
+
else:
|
|
227
|
+
scores.append(score)
|
|
228
|
+
|
|
229
|
+
if use_aggregator:
|
|
230
|
+
result = aggregator.aggregate()
|
|
231
|
+
for key in result:
|
|
232
|
+
result[key] = result[key].mid.fmeasure
|
|
233
|
+
else:
|
|
234
|
+
result = {}
|
|
235
|
+
first_score = scores[0]
|
|
236
|
+
for key in first_score:
|
|
237
|
+
result[key] = [s[key].fmeasure for s in scores]
|
|
238
|
+
return result
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class CMERMetric(object):
|
|
242
|
+
|
|
243
|
+
def __init__(self, main_indicator='bleu', **kwargs):
|
|
244
|
+
self.main_indicator = main_indicator
|
|
245
|
+
|
|
246
|
+
self.tokenizer = Tokenizer13a()
|
|
247
|
+
self.rouge_metric = CustomRouge()
|
|
248
|
+
self.bleu_metric = CustomBleu()
|
|
249
|
+
self.reset()
|
|
250
|
+
|
|
251
|
+
def reset(self):
|
|
252
|
+
self.preds_list = []
|
|
253
|
+
self.labels_list = []
|
|
254
|
+
|
|
255
|
+
def _compute_single_pair(self, pred, label):
|
|
256
|
+
preds = [pred]
|
|
257
|
+
refs_formatted = [[label]]
|
|
258
|
+
|
|
259
|
+
rouge_results = self.rouge_metric.compute(predictions=preds,
|
|
260
|
+
references=refs_formatted,
|
|
261
|
+
use_aggregator=True,
|
|
262
|
+
tokenizer=self.tokenizer)
|
|
263
|
+
|
|
264
|
+
bleu_results = self.bleu_metric.compute(predictions=preds,
|
|
265
|
+
references=refs_formatted,
|
|
266
|
+
tokenizer=self.tokenizer)
|
|
267
|
+
|
|
268
|
+
dist = levenshtein_distance(pred, label)
|
|
269
|
+
|
|
270
|
+
return {
|
|
271
|
+
'rouge1': rouge_results['rouge1'],
|
|
272
|
+
'rouge2': rouge_results['rouge2'],
|
|
273
|
+
'rougeL': rouge_results['rougeL'],
|
|
274
|
+
'bleu': bleu_results['bleu'],
|
|
275
|
+
'edit_distance': float(dist),
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
def __call__(self, preds, labels, **kwargs):
|
|
279
|
+
if isinstance(preds, str):
|
|
280
|
+
preds = [preds]
|
|
281
|
+
if isinstance(labels, str):
|
|
282
|
+
labels = [labels]
|
|
283
|
+
self.preds_list.extend(preds)
|
|
284
|
+
self.labels_list.extend(labels)
|
|
285
|
+
|
|
286
|
+
def compute_single(self, preds, labels):
|
|
287
|
+
if len(preds) == 0:
|
|
288
|
+
return {
|
|
289
|
+
'rouge1': 0.0,
|
|
290
|
+
'rouge2': 0.0,
|
|
291
|
+
'rougeL': 0.0,
|
|
292
|
+
'bleu': 0.0,
|
|
293
|
+
'edit_distance': 0.0,
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
total_metrics = collections.defaultdict(float)
|
|
297
|
+
count = 0
|
|
298
|
+
|
|
299
|
+
for p, l in zip(preds, labels):
|
|
300
|
+
single_res = self._compute_single_pair(p, l)
|
|
301
|
+
for k, v in single_res.items():
|
|
302
|
+
total_metrics[k] += v
|
|
303
|
+
count += 1
|
|
304
|
+
|
|
305
|
+
return {k: v / count for k, v in total_metrics.items()}
|
|
306
|
+
|
|
307
|
+
def get_metric(self):
|
|
308
|
+
if len(self.preds_list) == 0:
|
|
309
|
+
return {
|
|
310
|
+
'rouge1': 0.0,
|
|
311
|
+
'rouge2': 0.0,
|
|
312
|
+
'rougeL': 0.0,
|
|
313
|
+
'bleu': 0.0,
|
|
314
|
+
'edit_distance': 0.0,
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
total_metrics = collections.defaultdict(float)
|
|
318
|
+
count = len(self.preds_list)
|
|
319
|
+
|
|
320
|
+
for p, l in zip(self.preds_list, self.labels_list):
|
|
321
|
+
single_res = self._compute_single_pair(p, l)
|
|
322
|
+
for k, v in single_res.items():
|
|
323
|
+
total_metrics[k] += v
|
|
324
|
+
|
|
325
|
+
avg_metrics = {k: v / count for k, v in total_metrics.items()}
|
|
326
|
+
|
|
327
|
+
self.reset()
|
|
328
|
+
return avg_metrics
|