python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. doctr/datasets/__init__.py +2 -0
  2. doctr/datasets/cord.py +6 -4
  3. doctr/datasets/datasets/base.py +3 -2
  4. doctr/datasets/datasets/pytorch.py +4 -2
  5. doctr/datasets/datasets/tensorflow.py +4 -2
  6. doctr/datasets/detection.py +6 -3
  7. doctr/datasets/doc_artefacts.py +2 -1
  8. doctr/datasets/funsd.py +7 -8
  9. doctr/datasets/generator/base.py +3 -2
  10. doctr/datasets/generator/pytorch.py +3 -1
  11. doctr/datasets/generator/tensorflow.py +3 -1
  12. doctr/datasets/ic03.py +3 -2
  13. doctr/datasets/ic13.py +2 -1
  14. doctr/datasets/iiit5k.py +6 -4
  15. doctr/datasets/iiithws.py +2 -1
  16. doctr/datasets/imgur5k.py +3 -2
  17. doctr/datasets/loader.py +4 -2
  18. doctr/datasets/mjsynth.py +2 -1
  19. doctr/datasets/ocr.py +2 -1
  20. doctr/datasets/orientation.py +40 -0
  21. doctr/datasets/recognition.py +3 -2
  22. doctr/datasets/sroie.py +2 -1
  23. doctr/datasets/svhn.py +2 -1
  24. doctr/datasets/svt.py +3 -2
  25. doctr/datasets/synthtext.py +2 -1
  26. doctr/datasets/utils.py +27 -11
  27. doctr/datasets/vocabs.py +26 -1
  28. doctr/datasets/wildreceipt.py +111 -0
  29. doctr/file_utils.py +3 -1
  30. doctr/io/elements.py +52 -35
  31. doctr/io/html.py +5 -3
  32. doctr/io/image/base.py +5 -4
  33. doctr/io/image/pytorch.py +12 -7
  34. doctr/io/image/tensorflow.py +11 -6
  35. doctr/io/pdf.py +5 -4
  36. doctr/io/reader.py +13 -5
  37. doctr/models/_utils.py +30 -53
  38. doctr/models/artefacts/barcode.py +4 -3
  39. doctr/models/artefacts/face.py +4 -2
  40. doctr/models/builder.py +58 -43
  41. doctr/models/classification/__init__.py +1 -0
  42. doctr/models/classification/magc_resnet/pytorch.py +5 -2
  43. doctr/models/classification/magc_resnet/tensorflow.py +5 -2
  44. doctr/models/classification/mobilenet/pytorch.py +16 -4
  45. doctr/models/classification/mobilenet/tensorflow.py +29 -20
  46. doctr/models/classification/predictor/pytorch.py +3 -2
  47. doctr/models/classification/predictor/tensorflow.py +2 -1
  48. doctr/models/classification/resnet/pytorch.py +23 -13
  49. doctr/models/classification/resnet/tensorflow.py +33 -26
  50. doctr/models/classification/textnet/__init__.py +6 -0
  51. doctr/models/classification/textnet/pytorch.py +275 -0
  52. doctr/models/classification/textnet/tensorflow.py +267 -0
  53. doctr/models/classification/vgg/pytorch.py +4 -2
  54. doctr/models/classification/vgg/tensorflow.py +5 -2
  55. doctr/models/classification/vit/pytorch.py +9 -3
  56. doctr/models/classification/vit/tensorflow.py +9 -3
  57. doctr/models/classification/zoo.py +7 -2
  58. doctr/models/core.py +1 -1
  59. doctr/models/detection/__init__.py +1 -0
  60. doctr/models/detection/_utils/pytorch.py +7 -1
  61. doctr/models/detection/_utils/tensorflow.py +7 -3
  62. doctr/models/detection/core.py +9 -3
  63. doctr/models/detection/differentiable_binarization/base.py +37 -25
  64. doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
  65. doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
  66. doctr/models/detection/fast/__init__.py +6 -0
  67. doctr/models/detection/fast/base.py +256 -0
  68. doctr/models/detection/fast/pytorch.py +442 -0
  69. doctr/models/detection/fast/tensorflow.py +428 -0
  70. doctr/models/detection/linknet/base.py +12 -5
  71. doctr/models/detection/linknet/pytorch.py +28 -15
  72. doctr/models/detection/linknet/tensorflow.py +68 -88
  73. doctr/models/detection/predictor/pytorch.py +16 -6
  74. doctr/models/detection/predictor/tensorflow.py +13 -5
  75. doctr/models/detection/zoo.py +19 -16
  76. doctr/models/factory/hub.py +20 -10
  77. doctr/models/kie_predictor/base.py +2 -1
  78. doctr/models/kie_predictor/pytorch.py +28 -36
  79. doctr/models/kie_predictor/tensorflow.py +27 -27
  80. doctr/models/modules/__init__.py +1 -0
  81. doctr/models/modules/layers/__init__.py +6 -0
  82. doctr/models/modules/layers/pytorch.py +166 -0
  83. doctr/models/modules/layers/tensorflow.py +175 -0
  84. doctr/models/modules/transformer/pytorch.py +24 -22
  85. doctr/models/modules/transformer/tensorflow.py +6 -4
  86. doctr/models/modules/vision_transformer/pytorch.py +2 -4
  87. doctr/models/modules/vision_transformer/tensorflow.py +2 -4
  88. doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
  89. doctr/models/predictor/base.py +14 -3
  90. doctr/models/predictor/pytorch.py +26 -29
  91. doctr/models/predictor/tensorflow.py +25 -22
  92. doctr/models/preprocessor/pytorch.py +14 -9
  93. doctr/models/preprocessor/tensorflow.py +10 -5
  94. doctr/models/recognition/core.py +4 -1
  95. doctr/models/recognition/crnn/pytorch.py +23 -16
  96. doctr/models/recognition/crnn/tensorflow.py +25 -17
  97. doctr/models/recognition/master/base.py +4 -1
  98. doctr/models/recognition/master/pytorch.py +20 -9
  99. doctr/models/recognition/master/tensorflow.py +20 -8
  100. doctr/models/recognition/parseq/base.py +4 -1
  101. doctr/models/recognition/parseq/pytorch.py +28 -22
  102. doctr/models/recognition/parseq/tensorflow.py +22 -11
  103. doctr/models/recognition/predictor/_utils.py +3 -2
  104. doctr/models/recognition/predictor/pytorch.py +3 -2
  105. doctr/models/recognition/predictor/tensorflow.py +2 -1
  106. doctr/models/recognition/sar/pytorch.py +14 -7
  107. doctr/models/recognition/sar/tensorflow.py +23 -14
  108. doctr/models/recognition/utils.py +5 -1
  109. doctr/models/recognition/vitstr/base.py +4 -1
  110. doctr/models/recognition/vitstr/pytorch.py +22 -13
  111. doctr/models/recognition/vitstr/tensorflow.py +21 -10
  112. doctr/models/recognition/zoo.py +4 -2
  113. doctr/models/utils/pytorch.py +24 -6
  114. doctr/models/utils/tensorflow.py +22 -3
  115. doctr/models/zoo.py +21 -3
  116. doctr/transforms/functional/base.py +8 -3
  117. doctr/transforms/functional/pytorch.py +23 -6
  118. doctr/transforms/functional/tensorflow.py +25 -5
  119. doctr/transforms/modules/base.py +12 -5
  120. doctr/transforms/modules/pytorch.py +10 -12
  121. doctr/transforms/modules/tensorflow.py +17 -9
  122. doctr/utils/common_types.py +1 -1
  123. doctr/utils/data.py +4 -2
  124. doctr/utils/fonts.py +3 -2
  125. doctr/utils/geometry.py +95 -26
  126. doctr/utils/metrics.py +36 -22
  127. doctr/utils/multithreading.py +5 -3
  128. doctr/utils/repr.py +3 -1
  129. doctr/utils/visualization.py +31 -8
  130. doctr/version.py +1 -1
  131. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
  132. python_doctr-0.8.1.dist-info/RECORD +173 -0
  133. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
  134. python_doctr-0.7.0.dist-info/RECORD +0 -161
  135. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
  136. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
  137. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -45,10 +45,10 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
45
45
 
46
46
 
47
47
  class CTCPostProcessor(RecognitionPostProcessor):
48
- """
49
- Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
48
+ """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
50
49
 
51
50
  Args:
51
+ ----
52
52
  vocab: string containing the ordered sequence of supported characters
53
53
  """
54
54
 
@@ -62,14 +62,15 @@ class CTCPostProcessor(RecognitionPostProcessor):
62
62
  <https://github.com/githubharald/CTCDecoder>`_.
63
63
 
64
64
  Args:
65
+ ----
65
66
  logits: model output, shape: N x T x C
66
67
  vocab: vocabulary to use
67
68
  blank: index of blank label
68
69
 
69
70
  Returns:
71
+ -------
70
72
  A list of tuples: (word, confidence)
71
73
  """
72
-
73
74
  # Gather the most confident characters, and assign the smallest conf among those to the sequence prob
74
75
  probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values
75
76
 
@@ -82,14 +83,15 @@ class CTCPostProcessor(RecognitionPostProcessor):
82
83
  return list(zip(words, probs.tolist()))
83
84
 
84
85
  def __call__(self, logits: torch.Tensor) -> List[Tuple[str, float]]:
85
- """
86
- Performs decoding of raw output with CTC and decoding of CTC predictions
86
+ """Performs decoding of raw output with CTC and decoding of CTC predictions
87
87
  with label_to_idx mapping dictionnary
88
88
 
89
89
  Args:
90
+ ----
90
91
  logits: raw output of the model, shape (N, C + 1, seq_len)
91
92
 
92
93
  Returns:
94
+ -------
93
95
  A tuple of 2 lists: a list of str (words) and a list of float (probs)
94
96
 
95
97
  """
@@ -102,6 +104,7 @@ class CRNN(RecognitionModel, nn.Module):
102
104
  Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
103
105
 
104
106
  Args:
107
+ ----
105
108
  feature_extractor: the backbone serving as feature extractor
106
109
  vocab: vocabulary used for encoding
107
110
  rnn_units: number of units in the LSTM layers
@@ -128,12 +131,9 @@ class CRNN(RecognitionModel, nn.Module):
128
131
  self.feat_extractor = feature_extractor
129
132
 
130
133
  # Resolve the input_size of the LSTM
131
- self.feat_extractor.eval()
132
- with torch.no_grad():
134
+ with torch.inference_mode():
133
135
  out_shape = self.feat_extractor(torch.zeros((1, *input_shape))).shape
134
136
  lstm_in = out_shape[1] * out_shape[2]
135
- # Switch back to original mode
136
- self.feat_extractor.train()
137
137
 
138
138
  self.decoder = nn.LSTM(
139
139
  input_size=lstm_in,
@@ -168,11 +168,12 @@ class CRNN(RecognitionModel, nn.Module):
168
168
  """Compute CTC loss for the model.
169
169
 
170
170
  Args:
171
- gt: the encoded tensor with gt labels
171
+ ----
172
172
  model_output: predicted logits of the model
173
- seq_len: lengths of each gt word inside the batch
173
+ target: list of target strings
174
174
 
175
175
  Returns:
176
+ -------
176
177
  The loss of the model on the batch
177
178
  """
178
179
  gt, seq_len = self.build_target(target)
@@ -249,7 +250,7 @@ def _crnn(
249
250
  _cfg["input_shape"] = kwargs["input_shape"]
250
251
 
251
252
  # Build the model
252
- model = CRNN(feat_extractor, cfg=_cfg, **kwargs) # type: ignore[arg-type]
253
+ model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
253
254
  # Load pretrained parameters
254
255
  if pretrained:
255
256
  # The number of classes is not the same as the number of classes in the pretrained model =>
@@ -271,12 +272,14 @@ def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
271
272
  >>> out = model(input_tensor)
272
273
 
273
274
  Args:
275
+ ----
274
276
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
277
+ **kwargs: keyword arguments of the CRNN architecture
275
278
 
276
279
  Returns:
280
+ -------
277
281
  text recognition architecture
278
282
  """
279
-
280
283
  return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, ignore_keys=["linear.weight", "linear.bias"], **kwargs)
281
284
 
282
285
 
@@ -291,12 +294,14 @@ def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
291
294
  >>> out = model(input_tensor)
292
295
 
293
296
  Args:
297
+ ----
294
298
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
299
+ **kwargs: keyword arguments of the CRNN architecture
295
300
 
296
301
  Returns:
302
+ -------
297
303
  text recognition architecture
298
304
  """
299
-
300
305
  return _crnn(
301
306
  "crnn_mobilenet_v3_small",
302
307
  pretrained,
@@ -317,12 +322,14 @@ def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
317
322
  >>> out = model(input_tensor)
318
323
 
319
324
  Args:
325
+ ----
320
326
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
327
+ **kwargs: keyword arguments of the CRNN architecture
321
328
 
322
329
  Returns:
330
+ -------
323
331
  text recognition architecture
324
332
  """
325
-
326
333
  return _crnn(
327
334
  "crnn_mobilenet_v3_large",
328
335
  pretrained,
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -13,7 +13,7 @@ from tensorflow.keras.models import Model, Sequential
13
13
  from doctr.datasets import VOCABS
14
14
 
15
15
  from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
16
- from ...utils.tensorflow import load_pretrained_params
16
+ from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
17
17
  from ..core import RecognitionModel, RecognitionPostProcessor
18
18
 
19
19
  __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
@@ -44,10 +44,10 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
44
44
 
45
45
 
46
46
  class CTCPostProcessor(RecognitionPostProcessor):
47
- """
48
- Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
47
+ """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
49
48
 
50
49
  Args:
50
+ ----
51
51
  vocab: string containing the ordered sequence of supported characters
52
52
  ignore_case: if True, ignore case of letters
53
53
  ignore_accents: if True, ignore accents of letters
@@ -59,16 +59,17 @@ class CTCPostProcessor(RecognitionPostProcessor):
59
59
  beam_width: int = 1,
60
60
  top_paths: int = 1,
61
61
  ) -> Union[List[Tuple[str, float]], List[Tuple[List[str], List[float]]]]:
62
- """
63
- Performs decoding of raw output with CTC and decoding of CTC predictions
62
+ """Performs decoding of raw output with CTC and decoding of CTC predictions
64
63
  with label_to_idx mapping dictionnary
65
64
 
66
65
  Args:
66
+ ----
67
67
  logits: raw output of the model, shape BATCH_SIZE X SEQ_LEN X NUM_CLASSES + 1
68
68
  beam_width: An int scalar >= 0 (beam search beam width).
69
69
  top_paths: An int scalar >= 0, <= beam_width (controls output size).
70
70
 
71
71
  Returns:
72
+ -------
72
73
  A list of decoded words of length BATCH_SIZE
73
74
 
74
75
 
@@ -113,6 +114,7 @@ class CRNN(RecognitionModel, Model):
113
114
  Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
114
115
 
115
116
  Args:
117
+ ----
116
118
  feature_extractor: the backbone serving as feature extractor
117
119
  vocab: vocabulary used for encoding
118
120
  rnn_units: number of units in the LSTM layers
@@ -144,13 +146,11 @@ class CRNN(RecognitionModel, Model):
144
146
  self.exportable = exportable
145
147
  self.feat_extractor = feature_extractor
146
148
 
147
- self.decoder = Sequential(
148
- [
149
- layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)),
150
- layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)),
151
- layers.Dense(units=len(vocab) + 1),
152
- ]
153
- )
149
+ self.decoder = Sequential([
150
+ layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)),
151
+ layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)),
152
+ layers.Dense(units=len(vocab) + 1),
153
+ ])
154
154
  self.decoder.build(input_shape=(None, w, h * c))
155
155
 
156
156
  self.postprocessor = CTCPostProcessor(vocab=vocab)
@@ -166,10 +166,12 @@ class CRNN(RecognitionModel, Model):
166
166
  """Compute CTC loss for the model.
167
167
 
168
168
  Args:
169
+ ----
169
170
  model_output: predicted logits of the model
170
171
  target: lengths of each gt word inside the batch
171
172
 
172
173
  Returns:
174
+ -------
173
175
  The loss of the model on the batch
174
176
  """
175
177
  gt, seq_len = self.build_target(target)
@@ -199,7 +201,7 @@ class CRNN(RecognitionModel, Model):
199
201
  w, h, c = transposed_feat.get_shape().as_list()[1:]
200
202
  # B x W x H x C --> B x W x H * C
201
203
  features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c))
202
- logits = self.decoder(features_seq, **kwargs)
204
+ logits = _bf16_to_float32(self.decoder(features_seq, **kwargs))
203
205
 
204
206
  out: Dict[str, tf.Tensor] = {}
205
207
  if self.exportable:
@@ -261,12 +263,14 @@ def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
261
263
  >>> out = model(input_tensor)
262
264
 
263
265
  Args:
266
+ ----
264
267
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
268
+ **kwargs: keyword arguments of the CRNN architecture
265
269
 
266
270
  Returns:
271
+ -------
267
272
  text recognition architecture
268
273
  """
269
-
270
274
  return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, **kwargs)
271
275
 
272
276
 
@@ -281,12 +285,14 @@ def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
281
285
  >>> out = model(input_tensor)
282
286
 
283
287
  Args:
288
+ ----
284
289
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
290
+ **kwargs: keyword arguments of the CRNN architecture
285
291
 
286
292
  Returns:
293
+ -------
287
294
  text recognition architecture
288
295
  """
289
-
290
296
  return _crnn("crnn_mobilenet_v3_small", pretrained, mobilenet_v3_small_r, **kwargs)
291
297
 
292
298
 
@@ -301,10 +307,12 @@ def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
301
307
  >>> out = model(input_tensor)
302
308
 
303
309
  Args:
310
+ ----
304
311
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
312
+ **kwargs: keyword arguments of the CRNN architecture
305
313
 
306
314
  Returns:
315
+ -------
307
316
  text recognition architecture
308
317
  """
309
-
310
318
  return _crnn("crnn_mobilenet_v3_large", pretrained, mobilenet_v3_large_r, **kwargs)
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -23,9 +23,11 @@ class _MASTER:
23
23
  sequence lengths.
24
24
 
25
25
  Args:
26
+ ----
26
27
  gts: list of ground-truth labels
27
28
 
28
29
  Returns:
30
+ -------
29
31
  A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
30
32
  """
31
33
  encoded = encode_sequences(
@@ -44,6 +46,7 @@ class _MASTERPostProcessor(RecognitionPostProcessor):
44
46
  """Abstract class to postprocess the raw output of the model
45
47
 
46
48
  Args:
49
+ ----
47
50
  vocab: string containing the ordered sequence of supported characters
48
51
  """
49
52
 
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -15,7 +15,7 @@ from doctr.datasets import VOCABS
15
15
  from doctr.models.classification import magc_resnet31
16
16
  from doctr.models.modules.transformer import Decoder, PositionalEncoding
17
17
 
18
- from ...utils.pytorch import load_pretrained_params
18
+ from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
19
19
  from .base import _MASTER, _MASTERPostProcessor
20
20
 
21
21
  __all__ = ["MASTER", "master"]
@@ -27,7 +27,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
27
27
  "std": (0.299, 0.296, 0.301),
28
28
  "input_shape": (3, 32, 128),
29
29
  "vocab": VOCABS["french"],
30
- "url": None,
30
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/master-fde31e4a.pt&src=0",
31
31
  },
32
32
  }
33
33
 
@@ -37,6 +37,7 @@ class MASTER(_MASTER, nn.Module):
37
37
  Implementation based on the official Pytorch implementation: <https://github.com/wenwenyu/MASTER-pytorch>`_.
38
38
 
39
39
  Args:
40
+ ----
40
41
  feature_extractor: the backbone serving as feature extractor
41
42
  vocab: vocabulary, (without EOS, SOS, PAD)
42
43
  d_model: d parameter for the transformer decoder
@@ -105,7 +106,8 @@ class MASTER(_MASTER, nn.Module):
105
106
  # borrowed and slightly modified from https://github.com/wenwenyu/MASTER-pytorch
106
107
  # NOTE: nn.TransformerDecoder takes the inverse from this implementation
107
108
  # [True, True, True, ..., False, False, False] -> False is masked
108
- target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1) # (N, 1, 1, max_length)
109
+ # (N, 1, 1, max_length)
110
+ target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
109
111
  target_length = target.size(1)
110
112
  # sub mask filled diagonal with True = see and False = masked (max_length, max_length)
111
113
  # NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14)
@@ -128,17 +130,19 @@ class MASTER(_MASTER, nn.Module):
128
130
  Sequences are masked after the EOS character.
129
131
 
130
132
  Args:
133
+ ----
131
134
  gt: the encoded tensor with gt labels
132
135
  model_output: predicted logits of the model
133
136
  seq_len: lengths of each gt word inside the batch
134
137
 
135
138
  Returns:
139
+ -------
136
140
  The loss of the model on the batch
137
141
  """
138
142
  # Input length : number of timesteps
139
143
  input_len = model_output.shape[1]
140
144
  # Add one for additional <eos> token (sos disappear in shift!)
141
- seq_len = seq_len + 1
145
+ seq_len = seq_len + 1 # type: ignore[assignment]
142
146
  # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
143
147
  # The "masked" first gt char is <sos>. Delete last logit of the model output.
144
148
  cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
@@ -159,15 +163,16 @@ class MASTER(_MASTER, nn.Module):
159
163
  """Call function for training
160
164
 
161
165
  Args:
166
+ ----
162
167
  x: images
163
168
  target: list of str labels
164
169
  return_model_output: if True, return logits
165
170
  return_preds: if True, decode logits
166
171
 
167
172
  Returns:
173
+ -------
168
174
  A dictionnary containing eventually loss, logits and predictions.
169
175
  """
170
-
171
176
  # Encode
172
177
  features = self.feat_extractor(x)["features"]
173
178
  b, c, h, w = features.shape
@@ -195,6 +200,8 @@ class MASTER(_MASTER, nn.Module):
195
200
  else:
196
201
  logits = self.decode(encoded)
197
202
 
203
+ logits = _bf16_to_float32(logits)
204
+
198
205
  if self.exportable:
199
206
  out["logits"] = logits
200
207
  return out
@@ -214,9 +221,11 @@ class MASTER(_MASTER, nn.Module):
214
221
  """Decode function for prediction
215
222
 
216
223
  Args:
224
+ ----
217
225
  encoded: input tensor
218
226
 
219
- Return:
227
+ Returns:
228
+ -------
220
229
  A Tuple of torch.Tensor: predictions, logits
221
230
  """
222
231
  b = encoded.size(0)
@@ -259,7 +268,7 @@ class MASTERPostProcessor(_MASTERPostProcessor):
259
268
  for encoded_seq in out_idxs.cpu().numpy()
260
269
  ]
261
270
 
262
- return list(zip(word_values, probs.numpy().tolist()))
271
+ return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
263
272
 
264
273
 
265
274
  def _master(
@@ -307,12 +316,14 @@ def master(pretrained: bool = False, **kwargs: Any) -> MASTER:
307
316
  >>> out = model(input_tensor)
308
317
 
309
318
  Args:
319
+ ----
310
320
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
321
+ **kwargs: keywoard arguments passed to the MASTER architecture
311
322
 
312
323
  Returns:
324
+ -------
313
325
  text recognition architecture
314
326
  """
315
-
316
327
  return _master(
317
328
  "master",
318
329
  pretrained,
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -13,7 +13,7 @@ from doctr.datasets import VOCABS
13
13
  from doctr.models.classification import magc_resnet31
14
14
  from doctr.models.modules.transformer import Decoder, PositionalEncoding
15
15
 
16
- from ...utils.tensorflow import load_pretrained_params
16
+ from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
17
17
  from .base import _MASTER, _MASTERPostProcessor
18
18
 
19
19
  __all__ = ["MASTER", "master"]
@@ -31,11 +31,11 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
31
31
 
32
32
 
33
33
  class MASTER(_MASTER, Model):
34
-
35
34
  """Implements MASTER as described in paper: <https://arxiv.org/pdf/1910.02562.pdf>`_.
36
35
  Implementation based on the official TF implementation: <https://github.com/jiangxiluning/MASTER-TF>`_.
37
36
 
38
37
  Args:
38
+ ----
39
39
  feature_extractor: the backbone serving as feature extractor
40
40
  vocab: vocabulary, (without EOS, SOS, PAD)
41
41
  d_model: d parameter for the transformer decoder
@@ -115,11 +115,13 @@ class MASTER(_MASTER, Model):
115
115
  Sequences are masked after the EOS character.
116
116
 
117
117
  Args:
118
+ ----
118
119
  gt: the encoded tensor with gt labels
119
120
  model_output: predicted logits of the model
120
121
  seq_len: lengths of each gt word inside the batch
121
122
 
122
123
  Returns:
124
+ -------
123
125
  The loss of the model on the batch
124
126
  """
125
127
  # Input length : number of timesteps
@@ -150,15 +152,17 @@ class MASTER(_MASTER, Model):
150
152
  """Call function for training
151
153
 
152
154
  Args:
155
+ ----
153
156
  x: images
154
157
  target: list of str labels
155
158
  return_model_output: if True, return logits
156
159
  return_preds: if True, decode logits
160
+ **kwargs: keyword arguments passed to the decoder
157
161
 
158
- Return:
162
+ Returns:
163
+ -------
159
164
  A dictionnary containing eventually loss, logits and predictions.
160
165
  """
161
-
162
166
  # Encode
163
167
  feature = self.feat_extractor(x, **kwargs)
164
168
  b, h, w, c = feature.get_shape()
@@ -183,6 +187,8 @@ class MASTER(_MASTER, Model):
183
187
  else:
184
188
  logits = self.decode(encoded, **kwargs)
185
189
 
190
+ logits = _bf16_to_float32(logits)
191
+
186
192
  if self.exportable:
187
193
  out["logits"] = logits
188
194
  return out
@@ -203,9 +209,12 @@ class MASTER(_MASTER, Model):
203
209
  """Decode function for prediction
204
210
 
205
211
  Args:
212
+ ----
206
213
  encoded: encoded features
214
+ **kwargs: keyword arguments passed to the decoder
207
215
 
208
- Return:
216
+ Returns:
217
+ -------
209
218
  A Tuple of tf.Tensor: predictions, logits
210
219
  """
211
220
  b = encoded.shape[0]
@@ -238,6 +247,7 @@ class MASTERPostProcessor(_MASTERPostProcessor):
238
247
  """Post processor for MASTER architectures
239
248
 
240
249
  Args:
250
+ ----
241
251
  vocab: string containing the ordered sequence of supported characters
242
252
  """
243
253
 
@@ -260,7 +270,7 @@ class MASTERPostProcessor(_MASTERPostProcessor):
260
270
  decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
261
271
  word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
262
272
 
263
- return list(zip(word_values, probs.numpy().tolist()))
273
+ return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
264
274
 
265
275
 
266
276
  def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool = True, **kwargs: Any) -> MASTER:
@@ -297,10 +307,12 @@ def master(pretrained: bool = False, **kwargs: Any) -> MASTER:
297
307
  >>> out = model(input_tensor)
298
308
 
299
309
  Args:
310
+ ----
300
311
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
312
+ **kwargs: keywoard arguments passed to the MASTER architecture
301
313
 
302
314
  Returns:
315
+ -------
303
316
  text recognition architecture
304
317
  """
305
-
306
318
  return _master("master", pretrained, magc_resnet31, **kwargs)
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -23,9 +23,11 @@ class _PARSeq:
23
23
  sequence lengths.
24
24
 
25
25
  Args:
26
+ ----
26
27
  gts: list of ground-truth labels
27
28
 
28
29
  Returns:
30
+ -------
29
31
  A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
30
32
  """
31
33
  encoded = encode_sequences(
@@ -44,6 +46,7 @@ class _PARSeqPostProcessor(RecognitionPostProcessor):
44
46
  """Abstract class to postprocess the raw output of the model
45
47
 
46
48
  Args:
49
+ ----
47
50
  vocab: string containing the ordered sequence of supported characters
48
51
  """
49
52