openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
@@ -15,22 +15,27 @@ name_to_module = {
15
15
  'MGPLoss': '.mgp_loss',
16
16
  'PARSeqLoss': '.parseq_loss',
17
17
  'RobustScannerLoss': '.robustscanner_loss',
18
+ 'SEEDLoss': '.seed_loss',
19
+ 'SMTRLoss': '.smtr_loss',
18
20
  'SRNLoss': '.srn_loss',
19
21
  'VisionLANLoss': '.visionlan_loss',
20
22
  'CAMLoss': '.cam_loss',
21
- 'SEEDLoss': '.seed_loss',
23
+ 'MDiffLoss': '.mdiff_loss',
24
+ 'UniRecLoss': '.unirec_loss',
25
+ 'CMERLoss': '.cmer_loss',
22
26
  }
23
27
 
24
28
 
25
29
  def build_loss(config):
26
30
  config = copy.deepcopy(config)
27
31
  module_name = config.pop('name')
28
- assert module_name in name_to_module, Exception(
29
- 'loss only support {}'.format(list(name_to_module.keys())))
30
32
 
31
33
  if module_name in globals():
32
34
  module_class = globals()[module_name]
33
35
  else:
36
+ assert module_name in name_to_module, Exception(
37
+ '{} is not supported. The losses in {} are supportes'.format(
38
+ module_name, list(name_to_module.keys())))
34
39
  module_path = name_to_module[module_name]
35
40
  module = import_module(module_path, package=__package__)
36
41
  module_class = getattr(module, module_name)
@@ -0,0 +1,12 @@
1
+ from torch import nn
2
+
3
+
4
+ class CMERLoss(nn.Module):
5
+
6
+ def __init__(self, label_smoothing=0.1, **kwargs):
7
+ super(CMERLoss, self).__init__()
8
+
9
+ def forward(self, pred, batch):
10
+ # loss, vision_loss, text_loss = pred.loss
11
+ loss = {'loss': pred.loss}
12
+ return loss
@@ -0,0 +1,11 @@
1
+ from torch import nn
2
+
3
+
4
+ class MDiffLoss(nn.Module):
5
+
6
+ def __init__(self, **kwargs):
7
+ super(MDiffLoss, self).__init__()
8
+
9
+ def forward(self, predicts, batch):
10
+
11
+ return {'loss': predicts}
@@ -0,0 +1,12 @@
1
+ from torch import nn
2
+
3
+
4
+ class UniRecLoss(nn.Module):
5
+
6
+ def __init__(self, label_smoothing=0.1, **kwargs):
7
+ super(UniRecLoss, self).__init__()
8
+
9
+ def forward(self, pred, batch):
10
+ # loss, vision_loss, text_loss = pred.loss
11
+ loss = {'loss': pred.loss}
12
+ return loss
@@ -6,8 +6,11 @@ from .rec_metric import RecMetric
6
6
  from .rec_metric_gtc import RecGTCMetric
7
7
  from .rec_metric_long import RecMetricLong
8
8
  from .rec_metric_mgp import RecMPGMetric
9
+ from .rec_metric_cmer import CMERMetric
9
10
 
10
- support_dict = ['RecMetric', 'RecMetricLong', 'RecGTCMetric', 'RecMPGMetric']
11
+ support_dict = [
12
+ 'RecMetric', 'RecMetricLong', 'RecGTCMetric', 'RecMPGMetric', 'CMERMetric'
13
+ ]
11
14
 
12
15
 
13
16
  def build_metric(config):
@@ -0,0 +1,328 @@
1
+ import re
2
+ import math
3
+ import collections
4
+ from functools import lru_cache
5
+ import datasets
6
+ import evaluate
7
+ from rouge_score import rouge_scorer, scoring
8
+ from Levenshtein import distance as levenshtein_distance
9
+
10
+
11
+ def _get_ngrams(segment, max_order):
12
+ ngram_counts = collections.Counter()
13
+ for order in range(1, max_order + 1):
14
+ for i in range(0, len(segment) - order + 1):
15
+ ngram = tuple(segment[i:i + order])
16
+ ngram_counts[ngram] += 1
17
+ return ngram_counts
18
+
19
+
20
+ def compute_bleu(reference_corpus,
21
+ translation_corpus,
22
+ max_order=4,
23
+ smooth=False):
24
+ matches_by_order = [0] * max_order
25
+ possible_matches_by_order = [0] * max_order
26
+ reference_length = 0
27
+ translation_length = 0
28
+ for (references, translation) in zip(reference_corpus, translation_corpus):
29
+ reference_length += min(len(r) for r in references)
30
+ translation_length += len(translation)
31
+ merged_ref_ngram_counts = collections.Counter()
32
+ for reference in references:
33
+ merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
34
+ translation_ngram_counts = _get_ngrams(translation, max_order)
35
+ overlap = translation_ngram_counts & merged_ref_ngram_counts
36
+ for ngram in overlap:
37
+ matches_by_order[len(ngram) - 1] += overlap[ngram]
38
+ for order in range(1, max_order + 1):
39
+ possible_matches = len(translation) - order + 1
40
+ if possible_matches > 0:
41
+ possible_matches_by_order[order - 1] += possible_matches
42
+ precisions = [0] * max_order
43
+ for i in range(0, max_order):
44
+ if smooth:
45
+ precisions[i] = ((matches_by_order[i] + 1.) /
46
+ (possible_matches_by_order[i] + 1.))
47
+ else:
48
+ if possible_matches_by_order[i] > 0:
49
+ precisions[i] = (float(matches_by_order[i]) /
50
+ possible_matches_by_order[i])
51
+ else:
52
+ precisions[i] = 0.0
53
+ if min(precisions) > 0:
54
+ p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
55
+ geo_mean = math.exp(p_log_sum)
56
+ else:
57
+ geo_mean = 0
58
+
59
+ if reference_length == 0:
60
+ ratio = 0.0
61
+ else:
62
+ ratio = float(translation_length) / reference_length
63
+
64
+ if ratio > 1.0:
65
+ bp = 1.
66
+ elif ratio <= 0:
67
+ bp = 0.0
68
+ else:
69
+ bp = math.exp(1 - 1. / ratio)
70
+
71
+ bleu = geo_mean * bp
72
+ return (bleu, precisions, bp, ratio, translation_length, reference_length)
73
+
74
+
75
+ class BaseTokenizer:
76
+
77
+ def signature(self):
78
+ return 'none'
79
+
80
+ def __call__(self, line):
81
+ return line
82
+
83
+
84
+ class TokenizerRegexp(BaseTokenizer):
85
+
86
+ def signature(self):
87
+ return 're'
88
+
89
+ def __init__(self):
90
+ self._re = [
91
+ (re.compile(r'([\{-\~[-\` -\&\(-\+\:-\@\/])'), r' \1 '),
92
+ (re.compile(r'([^0-9])([\.,])'), r'\1 \2 '),
93
+ (re.compile(r'([\.,])([^0-9])'), r' \1 \2'),
94
+ (re.compile(r'([0-9])(-)'), r'\1 \2 '),
95
+ ]
96
+
97
+ @lru_cache(maxsize=2**16)
98
+ def __call__(self, line):
99
+ for (_re, repl) in self._re:
100
+ line = _re.sub(repl, line)
101
+ return line.split()
102
+
103
+
104
+ class Tokenizer13a(BaseTokenizer):
105
+
106
+ def signature(self):
107
+ return '13a'
108
+
109
+ def __init__(self):
110
+ self._post_tokenizer = TokenizerRegexp()
111
+
112
+ @lru_cache(maxsize=2**16)
113
+ def __call__(self, line):
114
+ line = line.replace('<skipped>', '')
115
+ line = line.replace('-\n', '')
116
+ line = line.replace('\n', ' ')
117
+ if '&' in line:
118
+ line = line.replace('&quot;', '"')
119
+ line = line.replace('&amp;', '&')
120
+ line = line.replace('&lt;', '<')
121
+ line = line.replace('&gt;', '>')
122
+ return self._post_tokenizer(f' {line} ')
123
+
124
+
125
+ class CustomBleu(evaluate.Metric):
126
+
127
+ def _info(self):
128
+ return evaluate.MetricInfo(
129
+ description='Custom BLEU implementation',
130
+ citation='',
131
+ inputs_description='',
132
+ features=datasets.Features({
133
+ 'predictions':
134
+ datasets.Value('string', id='sequence'),
135
+ 'references':
136
+ datasets.Sequence(datasets.Value('string', id='sequence'),
137
+ id='references'),
138
+ }),
139
+ )
140
+
141
+ def _compute(self,
142
+ predictions,
143
+ references,
144
+ tokenizer=None,
145
+ max_order=4,
146
+ smooth=False):
147
+ if tokenizer is None:
148
+ tokenizer = Tokenizer13a()
149
+
150
+ if isinstance(references[0], str):
151
+ references = [[ref] for ref in references]
152
+ references_tokenized = [[tokenizer(r) for r in ref]
153
+ for ref in references]
154
+ predictions_tokenized = [tokenizer(p) for p in predictions]
155
+ score = compute_bleu(reference_corpus=references_tokenized,
156
+ translation_corpus=predictions_tokenized,
157
+ max_order=max_order,
158
+ smooth=smooth)
159
+ (bleu, precisions, bp, ratio, translation_length,
160
+ reference_length) = score
161
+ return {
162
+ 'bleu': bleu,
163
+ 'precisions': precisions,
164
+ 'brevity_penalty': bp,
165
+ 'length_ratio': ratio,
166
+ 'translation_length': translation_length,
167
+ 'reference_length': reference_length,
168
+ }
169
+
170
+
171
+ class CustomRougeTokenizer:
172
+
173
+ def __init__(self, tokenizer_func):
174
+ self.tokenizer_func = tokenizer_func
175
+
176
+ def tokenize(self, text):
177
+ return self.tokenizer_func(text)
178
+
179
+
180
+ class CustomRouge(evaluate.Metric):
181
+
182
+ def _info(self):
183
+ return evaluate.MetricInfo(
184
+ description='Custom ROUGE implementation',
185
+ citation='',
186
+ inputs_description='',
187
+ features=datasets.Features({
188
+ 'predictions':
189
+ datasets.Value('string', id='sequence'),
190
+ 'references':
191
+ datasets.Sequence(datasets.Value('string', id='sequence')),
192
+ }),
193
+ )
194
+
195
+ def _compute(self,
196
+ predictions,
197
+ references,
198
+ rouge_types=None,
199
+ use_aggregator=True,
200
+ use_stemmer=False,
201
+ tokenizer=None):
202
+ if rouge_types is None:
203
+ rouge_types = ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']
204
+
205
+ multi_ref = isinstance(references[0], list)
206
+
207
+ if tokenizer is not None:
208
+ tokenizer = CustomRougeTokenizer(tokenizer)
209
+
210
+ scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types,
211
+ use_stemmer=use_stemmer,
212
+ tokenizer=tokenizer)
213
+
214
+ if use_aggregator:
215
+ aggregator = scoring.BootstrapAggregator()
216
+ else:
217
+ scores = []
218
+
219
+ for ref, pred in zip(references, predictions):
220
+ if multi_ref:
221
+ score = scorer.score_multi(ref, pred)
222
+ else:
223
+ score = scorer.score(ref, pred)
224
+ if use_aggregator:
225
+ aggregator.add_scores(score)
226
+ else:
227
+ scores.append(score)
228
+
229
+ if use_aggregator:
230
+ result = aggregator.aggregate()
231
+ for key in result:
232
+ result[key] = result[key].mid.fmeasure
233
+ else:
234
+ result = {}
235
+ first_score = scores[0]
236
+ for key in first_score:
237
+ result[key] = [s[key].fmeasure for s in scores]
238
+ return result
239
+
240
+
241
+ class CMERMetric(object):
242
+
243
+ def __init__(self, main_indicator='bleu', **kwargs):
244
+ self.main_indicator = main_indicator
245
+
246
+ self.tokenizer = Tokenizer13a()
247
+ self.rouge_metric = CustomRouge()
248
+ self.bleu_metric = CustomBleu()
249
+ self.reset()
250
+
251
+ def reset(self):
252
+ self.preds_list = []
253
+ self.labels_list = []
254
+
255
+ def _compute_single_pair(self, pred, label):
256
+ preds = [pred]
257
+ refs_formatted = [[label]]
258
+
259
+ rouge_results = self.rouge_metric.compute(predictions=preds,
260
+ references=refs_formatted,
261
+ use_aggregator=True,
262
+ tokenizer=self.tokenizer)
263
+
264
+ bleu_results = self.bleu_metric.compute(predictions=preds,
265
+ references=refs_formatted,
266
+ tokenizer=self.tokenizer)
267
+
268
+ dist = levenshtein_distance(pred, label)
269
+
270
+ return {
271
+ 'rouge1': rouge_results['rouge1'],
272
+ 'rouge2': rouge_results['rouge2'],
273
+ 'rougeL': rouge_results['rougeL'],
274
+ 'bleu': bleu_results['bleu'],
275
+ 'edit_distance': float(dist),
276
+ }
277
+
278
+ def __call__(self, preds, labels, **kwargs):
279
+ if isinstance(preds, str):
280
+ preds = [preds]
281
+ if isinstance(labels, str):
282
+ labels = [labels]
283
+ self.preds_list.extend(preds)
284
+ self.labels_list.extend(labels)
285
+
286
+ def compute_single(self, preds, labels):
287
+ if len(preds) == 0:
288
+ return {
289
+ 'rouge1': 0.0,
290
+ 'rouge2': 0.0,
291
+ 'rougeL': 0.0,
292
+ 'bleu': 0.0,
293
+ 'edit_distance': 0.0,
294
+ }
295
+
296
+ total_metrics = collections.defaultdict(float)
297
+ count = 0
298
+
299
+ for p, l in zip(preds, labels):
300
+ single_res = self._compute_single_pair(p, l)
301
+ for k, v in single_res.items():
302
+ total_metrics[k] += v
303
+ count += 1
304
+
305
+ return {k: v / count for k, v in total_metrics.items()}
306
+
307
+ def get_metric(self):
308
+ if len(self.preds_list) == 0:
309
+ return {
310
+ 'rouge1': 0.0,
311
+ 'rouge2': 0.0,
312
+ 'rougeL': 0.0,
313
+ 'bleu': 0.0,
314
+ 'edit_distance': 0.0,
315
+ }
316
+
317
+ total_metrics = collections.defaultdict(float)
318
+ count = len(self.preds_list)
319
+
320
+ for p, l in zip(self.preds_list, self.labels_list):
321
+ single_res = self._compute_single_pair(p, l)
322
+ for k, v in single_res.items():
323
+ total_metrics[k] += v
324
+
325
+ avg_metrics = {k: v / count for k, v in total_metrics.items()}
326
+
327
+ self.reset()
328
+ return avg_metrics