python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. doctr/datasets/__init__.py +2 -0
  2. doctr/datasets/cord.py +6 -4
  3. doctr/datasets/datasets/base.py +3 -2
  4. doctr/datasets/datasets/pytorch.py +4 -2
  5. doctr/datasets/datasets/tensorflow.py +4 -2
  6. doctr/datasets/detection.py +6 -3
  7. doctr/datasets/doc_artefacts.py +2 -1
  8. doctr/datasets/funsd.py +7 -8
  9. doctr/datasets/generator/base.py +3 -2
  10. doctr/datasets/generator/pytorch.py +3 -1
  11. doctr/datasets/generator/tensorflow.py +3 -1
  12. doctr/datasets/ic03.py +3 -2
  13. doctr/datasets/ic13.py +2 -1
  14. doctr/datasets/iiit5k.py +6 -4
  15. doctr/datasets/iiithws.py +2 -1
  16. doctr/datasets/imgur5k.py +3 -2
  17. doctr/datasets/loader.py +4 -2
  18. doctr/datasets/mjsynth.py +2 -1
  19. doctr/datasets/ocr.py +2 -1
  20. doctr/datasets/orientation.py +40 -0
  21. doctr/datasets/recognition.py +3 -2
  22. doctr/datasets/sroie.py +2 -1
  23. doctr/datasets/svhn.py +2 -1
  24. doctr/datasets/svt.py +3 -2
  25. doctr/datasets/synthtext.py +2 -1
  26. doctr/datasets/utils.py +27 -11
  27. doctr/datasets/vocabs.py +26 -1
  28. doctr/datasets/wildreceipt.py +111 -0
  29. doctr/file_utils.py +3 -1
  30. doctr/io/elements.py +52 -35
  31. doctr/io/html.py +5 -3
  32. doctr/io/image/base.py +5 -4
  33. doctr/io/image/pytorch.py +12 -7
  34. doctr/io/image/tensorflow.py +11 -6
  35. doctr/io/pdf.py +5 -4
  36. doctr/io/reader.py +13 -5
  37. doctr/models/_utils.py +30 -53
  38. doctr/models/artefacts/barcode.py +4 -3
  39. doctr/models/artefacts/face.py +4 -2
  40. doctr/models/builder.py +58 -43
  41. doctr/models/classification/__init__.py +1 -0
  42. doctr/models/classification/magc_resnet/pytorch.py +5 -2
  43. doctr/models/classification/magc_resnet/tensorflow.py +5 -2
  44. doctr/models/classification/mobilenet/pytorch.py +16 -4
  45. doctr/models/classification/mobilenet/tensorflow.py +29 -20
  46. doctr/models/classification/predictor/pytorch.py +3 -2
  47. doctr/models/classification/predictor/tensorflow.py +2 -1
  48. doctr/models/classification/resnet/pytorch.py +23 -13
  49. doctr/models/classification/resnet/tensorflow.py +33 -26
  50. doctr/models/classification/textnet/__init__.py +6 -0
  51. doctr/models/classification/textnet/pytorch.py +275 -0
  52. doctr/models/classification/textnet/tensorflow.py +267 -0
  53. doctr/models/classification/vgg/pytorch.py +4 -2
  54. doctr/models/classification/vgg/tensorflow.py +5 -2
  55. doctr/models/classification/vit/pytorch.py +9 -3
  56. doctr/models/classification/vit/tensorflow.py +9 -3
  57. doctr/models/classification/zoo.py +7 -2
  58. doctr/models/core.py +1 -1
  59. doctr/models/detection/__init__.py +1 -0
  60. doctr/models/detection/_utils/pytorch.py +7 -1
  61. doctr/models/detection/_utils/tensorflow.py +7 -3
  62. doctr/models/detection/core.py +9 -3
  63. doctr/models/detection/differentiable_binarization/base.py +37 -25
  64. doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
  65. doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
  66. doctr/models/detection/fast/__init__.py +6 -0
  67. doctr/models/detection/fast/base.py +256 -0
  68. doctr/models/detection/fast/pytorch.py +442 -0
  69. doctr/models/detection/fast/tensorflow.py +428 -0
  70. doctr/models/detection/linknet/base.py +12 -5
  71. doctr/models/detection/linknet/pytorch.py +28 -15
  72. doctr/models/detection/linknet/tensorflow.py +68 -88
  73. doctr/models/detection/predictor/pytorch.py +16 -6
  74. doctr/models/detection/predictor/tensorflow.py +13 -5
  75. doctr/models/detection/zoo.py +19 -16
  76. doctr/models/factory/hub.py +20 -10
  77. doctr/models/kie_predictor/base.py +2 -1
  78. doctr/models/kie_predictor/pytorch.py +28 -36
  79. doctr/models/kie_predictor/tensorflow.py +27 -27
  80. doctr/models/modules/__init__.py +1 -0
  81. doctr/models/modules/layers/__init__.py +6 -0
  82. doctr/models/modules/layers/pytorch.py +166 -0
  83. doctr/models/modules/layers/tensorflow.py +175 -0
  84. doctr/models/modules/transformer/pytorch.py +24 -22
  85. doctr/models/modules/transformer/tensorflow.py +6 -4
  86. doctr/models/modules/vision_transformer/pytorch.py +2 -4
  87. doctr/models/modules/vision_transformer/tensorflow.py +2 -4
  88. doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
  89. doctr/models/predictor/base.py +14 -3
  90. doctr/models/predictor/pytorch.py +26 -29
  91. doctr/models/predictor/tensorflow.py +25 -22
  92. doctr/models/preprocessor/pytorch.py +14 -9
  93. doctr/models/preprocessor/tensorflow.py +10 -5
  94. doctr/models/recognition/core.py +4 -1
  95. doctr/models/recognition/crnn/pytorch.py +23 -16
  96. doctr/models/recognition/crnn/tensorflow.py +25 -17
  97. doctr/models/recognition/master/base.py +4 -1
  98. doctr/models/recognition/master/pytorch.py +20 -9
  99. doctr/models/recognition/master/tensorflow.py +20 -8
  100. doctr/models/recognition/parseq/base.py +4 -1
  101. doctr/models/recognition/parseq/pytorch.py +28 -22
  102. doctr/models/recognition/parseq/tensorflow.py +22 -11
  103. doctr/models/recognition/predictor/_utils.py +3 -2
  104. doctr/models/recognition/predictor/pytorch.py +3 -2
  105. doctr/models/recognition/predictor/tensorflow.py +2 -1
  106. doctr/models/recognition/sar/pytorch.py +14 -7
  107. doctr/models/recognition/sar/tensorflow.py +23 -14
  108. doctr/models/recognition/utils.py +5 -1
  109. doctr/models/recognition/vitstr/base.py +4 -1
  110. doctr/models/recognition/vitstr/pytorch.py +22 -13
  111. doctr/models/recognition/vitstr/tensorflow.py +21 -10
  112. doctr/models/recognition/zoo.py +4 -2
  113. doctr/models/utils/pytorch.py +24 -6
  114. doctr/models/utils/tensorflow.py +22 -3
  115. doctr/models/zoo.py +21 -3
  116. doctr/transforms/functional/base.py +8 -3
  117. doctr/transforms/functional/pytorch.py +23 -6
  118. doctr/transforms/functional/tensorflow.py +25 -5
  119. doctr/transforms/modules/base.py +12 -5
  120. doctr/transforms/modules/pytorch.py +10 -12
  121. doctr/transforms/modules/tensorflow.py +17 -9
  122. doctr/utils/common_types.py +1 -1
  123. doctr/utils/data.py +4 -2
  124. doctr/utils/fonts.py +3 -2
  125. doctr/utils/geometry.py +95 -26
  126. doctr/utils/metrics.py +36 -22
  127. doctr/utils/multithreading.py +5 -3
  128. doctr/utils/repr.py +3 -1
  129. doctr/utils/visualization.py +31 -8
  130. doctr/version.py +1 -1
  131. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
  132. python_doctr-0.8.1.dist-info/RECORD +173 -0
  133. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
  134. python_doctr-0.7.0.dist-info/RECORD +0 -161
  135. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
  136. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
  137. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -18,7 +18,7 @@ from doctr.datasets import VOCABS
18
18
  from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
19
19
 
20
20
  from ...classification import vit_s
21
- from ...utils.pytorch import load_pretrained_params
21
+ from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
22
22
  from .base import _PARSeq, _PARSeqPostProcessor
23
23
 
24
24
  __all__ = ["PARSeq", "parseq"]
@@ -29,7 +29,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
29
29
  "std": (0.299, 0.296, 0.301),
30
30
  "input_shape": (3, 32, 128),
31
31
  "vocab": VOCABS["french"],
32
- "url": None,
32
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/parseq-56125471.pt&src=0",
33
33
  },
34
34
  }
35
35
 
@@ -38,6 +38,7 @@ class CharEmbedding(nn.Module):
38
38
  """Implements the character embedding module
39
39
 
40
40
  Args:
41
+ ----
41
42
  vocab_size: size of the vocabulary
42
43
  d_model: dimension of the model
43
44
  """
@@ -55,6 +56,7 @@ class PARSeqDecoder(nn.Module):
55
56
  """Implements decoder module of the PARSeq model
56
57
 
57
58
  Args:
59
+ ----
58
60
  d_model: dimension of the model
59
61
  num_heads: number of attention heads
60
62
  ffd: dimension of the feed forward layer
@@ -110,6 +112,7 @@ class PARSeq(_PARSeq, nn.Module):
110
112
  Slightly modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
111
113
 
112
114
  Args:
115
+ ----
113
116
  feature_extractor: the backbone serving as feature extractor
114
117
  vocab: vocabulary used for encoding
115
118
  embedding_units: number of embedding units
@@ -197,11 +200,11 @@ class PARSeq(_PARSeq, nn.Module):
197
200
  final_perms = torch.stack(perms)
198
201
  if len(perm_pool):
199
202
  i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(final_perms), replace=False)
200
- final_perms = torch.cat([final_perms, perm_pool[i]]) # type: ignore[index]
203
+ final_perms = torch.cat([final_perms, perm_pool[i]])
201
204
  else:
202
- perms.extend(
203
- [torch.randperm(max_num_chars, device=seqlen.device) for _ in range(num_gen_perms - len(perms))]
204
- )
205
+ perms.extend([
206
+ torch.randperm(max_num_chars, device=seqlen.device) for _ in range(num_gen_perms - len(perms))
207
+ ])
205
208
  final_perms = torch.stack(perms)
206
209
 
207
210
  comp = final_perms.flip(-1)
@@ -209,7 +212,7 @@ class PARSeq(_PARSeq, nn.Module):
209
212
 
210
213
  sos_idx = torch.zeros(len(final_perms), 1, device=seqlen.device)
211
214
  eos_idx = torch.full((len(final_perms), 1), max_num_chars + 1, device=seqlen.device)
212
- combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
215
+ combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int() # type: ignore
213
216
  if len(combined) > 1:
214
217
  combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
215
218
  return combined
@@ -237,7 +240,6 @@ class PARSeq(_PARSeq, nn.Module):
237
240
  target_query: Optional[torch.Tensor] = None,
238
241
  ) -> torch.Tensor:
239
242
  """Add positional information to the target sequence and pass it through the decoder."""
240
-
241
243
  batch_size, sequence_length = target.shape
242
244
  # apply positional information to the target sequence excluding the SOS token
243
245
  null_ctx = self.embed(target[:, :1])
@@ -280,7 +282,7 @@ class PARSeq(_PARSeq, nn.Module):
280
282
  ys[:, i + 1] = pos_prob.squeeze().argmax(-1)
281
283
 
282
284
  # Stop decoding if all sequences have reached the EOS token
283
- if max_len is None and (ys == self.vocab_size).any(dim=-1).all():
285
+ if max_len is None and (ys == self.vocab_size).any(dim=-1).all(): # type: ignore[attr-defined]
284
286
  break
285
287
 
286
288
  logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
@@ -295,7 +297,7 @@ class PARSeq(_PARSeq, nn.Module):
295
297
 
296
298
  # Create padding mask for refined target input maskes all behind EOS token as False
297
299
  # (N, 1, 1, max_length)
298
- target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
300
+ target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
299
301
  mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
300
302
  logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
301
303
 
@@ -329,11 +331,9 @@ class PARSeq(_PARSeq, nn.Module):
329
331
  gt_out = gt[:, 1:] # remove SOS token
330
332
  # Create padding mask for target input
331
333
  # [True, True, True, ..., False, False, False] -> False is masked
332
- padding_mask = (
333
- ~(((gt_in == self.vocab_size + 2) | (gt_in == self.vocab_size)).int().cumsum(-1) > 0)
334
- .unsqueeze(1)
335
- .unsqueeze(1)
336
- ) # (N, 1, 1, seq_len)
334
+ padding_mask = ~(
335
+ ((gt_in == self.vocab_size + 2) | (gt_in == self.vocab_size)).int().cumsum(-1) > 0
336
+ ).unsqueeze(1).unsqueeze(1) # (N, 1, 1, seq_len)
337
337
 
338
338
  loss = torch.tensor(0.0, device=features.device)
339
339
  loss_numel: Union[int, float] = 0
@@ -362,6 +362,8 @@ class PARSeq(_PARSeq, nn.Module):
362
362
  else:
363
363
  logits = self.decode_autoregressive(features)
364
364
 
365
+ logits = _bf16_to_float32(logits)
366
+
365
367
  out: Dict[str, Any] = {}
366
368
  if self.exportable:
367
369
  out["logits"] = logits
@@ -384,6 +386,7 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
384
386
  """Post processor for PARSeq architecture
385
387
 
386
388
  Args:
389
+ ----
387
390
  vocab: string containing the ordered sequence of supported characters
388
391
  """
389
392
 
@@ -393,18 +396,19 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
393
396
  ) -> List[Tuple[str, float]]:
394
397
  # compute pred with argmax for attention models
395
398
  out_idxs = logits.argmax(-1)
396
- # N x L
397
- probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1)
398
- # Take the minimum confidence of the sequence
399
- probs = probs.min(dim=1).values.detach().cpu()
399
+ preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]
400
400
 
401
401
  # Manual decoding
402
402
  word_values = [
403
403
  "".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0]
404
404
  for encoded_seq in out_idxs.cpu().numpy()
405
405
  ]
406
+ # compute probabilties for each word up to the EOS token
407
+ probs = [
408
+ preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values)
409
+ ]
406
410
 
407
- return list(zip(word_values, probs.numpy().tolist()))
411
+ return list(zip(word_values, probs))
408
412
 
409
413
 
410
414
  def _parseq(
@@ -457,12 +461,14 @@ def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
457
461
  >>> out = model(input_tensor)
458
462
 
459
463
  Args:
464
+ ----
460
465
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
466
+ **kwargs: keyword arguments of the PARSeq architecture
461
467
 
462
468
  Returns:
469
+ -------
463
470
  text recognition architecture
464
471
  """
465
-
466
472
  return _parseq(
467
473
  "parseq",
468
474
  pretrained,
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -16,7 +16,7 @@ from doctr.datasets import VOCABS
16
16
  from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
17
17
 
18
18
  from ...classification import vit_s
19
- from ...utils.tensorflow import load_pretrained_params
19
+ from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
20
20
  from .base import _PARSeq, _PARSeqPostProcessor
21
21
 
22
22
  __all__ = ["PARSeq", "parseq"]
@@ -36,6 +36,7 @@ class CharEmbedding(layers.Layer):
36
36
  """Implements the character embedding module
37
37
 
38
38
  Args:
39
+ ----
39
40
  vocab_size: size of the vocabulary
40
41
  d_model: dimension of the model
41
42
  """
@@ -53,6 +54,7 @@ class PARSeqDecoder(layers.Layer):
53
54
  """Implements decoder module of the PARSeq model
54
55
 
55
56
  Args:
57
+ ----
56
58
  d_model: dimension of the model
57
59
  num_heads: number of attention heads
58
60
  ffd: dimension of the feed forward layer
@@ -113,6 +115,7 @@ class PARSeq(_PARSeq, Model):
113
115
  Modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
114
116
 
115
117
  Args:
118
+ ----
116
119
  feature_extractor: the backbone serving as feature extractor
117
120
  vocab: vocabulary used for encoding
118
121
  embedding_units: number of embedding units
@@ -191,9 +194,9 @@ class PARSeq(_PARSeq, Model):
191
194
  i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(final_perms), replace=False)
192
195
  final_perms = tf.concat([final_perms, perm_pool[i[0] : i[1]]], axis=0)
193
196
  else:
194
- perms.extend(
195
- [tf.random.shuffle(tf.range(max_num_chars, dtype=tf.int32)) for _ in range(num_gen_perms - len(perms))]
196
- )
197
+ perms.extend([
198
+ tf.random.shuffle(tf.range(max_num_chars, dtype=tf.int32)) for _ in range(num_gen_perms - len(perms))
199
+ ])
197
200
  final_perms = tf.stack(perms)
198
201
 
199
202
  comp = tf.reverse(final_perms, axis=[-1])
@@ -390,6 +393,8 @@ class PARSeq(_PARSeq, Model):
390
393
  else:
391
394
  logits = self.decode_autoregressive(features, **kwargs)
392
395
 
396
+ logits = _bf16_to_float32(logits)
397
+
393
398
  out: Dict[str, tf.Tensor] = {}
394
399
  if self.exportable:
395
400
  out["logits"] = logits
@@ -412,6 +417,7 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
412
417
  """Post processor for PARSeq architecture
413
418
 
414
419
  Args:
420
+ ----
415
421
  vocab: string containing the ordered sequence of supported characters
416
422
  """
417
423
 
@@ -421,10 +427,7 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
421
427
  ) -> List[Tuple[str, float]]:
422
428
  # compute pred with argmax for attention models
423
429
  out_idxs = tf.math.argmax(logits, axis=2)
424
- # N x L
425
- probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2)
426
- # Take the minimum confidence of the sequence
427
- probs = tf.math.reduce_min(probs, axis=1)
430
+ preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
428
431
 
429
432
  # decode raw output of the model with tf_label_to_idx
430
433
  out_idxs = tf.cast(out_idxs, dtype="int32")
@@ -434,7 +437,13 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
434
437
  decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
435
438
  word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
436
439
 
437
- return list(zip(word_values, probs.numpy().tolist()))
440
+ # compute probabilties for each word up to the EOS token
441
+ probs = [
442
+ preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
443
+ for i, word in enumerate(word_values)
444
+ ]
445
+
446
+ return list(zip(word_values, probs))
438
447
 
439
448
 
440
449
  def _parseq(
@@ -484,12 +493,14 @@ def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
484
493
  >>> out = model(input_tensor)
485
494
 
486
495
  Args:
496
+ ----
487
497
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
498
+ **kwargs: keyword arguments of the PARSeq architecture
488
499
 
489
500
  Returns:
501
+ -------
490
502
  text recognition architecture
491
503
  """
492
-
493
504
  return _parseq(
494
505
  "parseq",
495
506
  pretrained,
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -22,6 +22,7 @@ def split_crops(
22
22
  """Chunk crops horizontally to match a given aspect ratio
23
23
 
24
24
  Args:
25
+ ----
25
26
  crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise
26
27
  max_ratio: the maximum aspect ratio that won't trigger the chunk
27
28
  target_ratio: when crops are chunked, they will be chunked to match this aspect ratio
@@ -29,9 +30,9 @@ def split_crops(
29
30
  channels_last: whether the numpy array has dimensions in channels last order
30
31
 
31
32
  Returns:
33
+ -------
32
34
  a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required
33
35
  """
34
-
35
36
  _remap_required = False
36
37
  crop_map: List[Union[int, Tuple[int, int]]] = []
37
38
  new_crops: List[np.ndarray] = []
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -21,6 +21,7 @@ class RecognitionPredictor(nn.Module):
21
21
  """Implements an object able to identify character sequences in images
22
22
 
23
23
  Args:
24
+ ----
24
25
  pre_processor: transform inputs for easier batched model inference
25
26
  model: core detection architecture
26
27
  split_wide_crops: wether to use crop splitting for high aspect ratio crops
@@ -40,7 +41,7 @@ class RecognitionPredictor(nn.Module):
40
41
  self.dil_factor = 1.4 # Dilation factor to overlap the crops
41
42
  self.target_ar = 6 # Target aspect ratio
42
43
 
43
- @torch.no_grad()
44
+ @torch.inference_mode()
44
45
  def forward(
45
46
  self,
46
47
  crops: Sequence[Union[np.ndarray, torch.Tensor]],
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -21,6 +21,7 @@ class RecognitionPredictor(NestedObject):
21
21
  """Implements an object able to identify character sequences in images
22
22
 
23
23
  Args:
24
+ ----
24
25
  pre_processor: transform inputs for easier batched model inference
25
26
  model: core detection architecture
26
27
  split_wide_crops: wether to use crop splitting for high aspect ratio crops
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -14,7 +14,7 @@ from torchvision.models._utils import IntermediateLayerGetter
14
14
  from doctr.datasets import VOCABS
15
15
 
16
16
  from ...classification import resnet31
17
- from ...utils.pytorch import load_pretrained_params
17
+ from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
18
18
  from ..core import RecognitionModel, RecognitionPostProcessor
19
19
 
20
20
  __all__ = ["SAR", "sar_resnet31"]
@@ -25,7 +25,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
25
25
  "std": (0.299, 0.296, 0.301),
26
26
  "input_shape": (3, 32, 128),
27
27
  "vocab": VOCABS["french"],
28
- "url": None,
28
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/sar_resnet31-9a1deedf.pt&src=0",
29
29
  },
30
30
  }
31
31
 
@@ -80,6 +80,7 @@ class SARDecoder(nn.Module):
80
80
  """Implements decoder module of the SAR model
81
81
 
82
82
  Args:
83
+ ----
83
84
  rnn_units: number of hidden units in recurrent cells
84
85
  max_length: maximum length of a sequence
85
86
  vocab_size: number of classes in the model alphabet
@@ -164,6 +165,7 @@ class SAR(nn.Module, RecognitionModel):
164
165
  Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
165
166
 
166
167
  Args:
168
+ ----
167
169
  feature_extractor: the backbone serving as feature extractor
168
170
  vocab: vocabulary used for encoding
169
171
  rnn_units: number of hidden units in both encoder and decoder LSTM
@@ -249,7 +251,7 @@ class SAR(nn.Module, RecognitionModel):
249
251
  if self.training and target is None:
250
252
  raise ValueError("Need to provide labels during training for teacher forcing")
251
253
 
252
- decoded_features = self.decoder(features, encoded, gt=None if target is None else gt)
254
+ decoded_features = _bf16_to_float32(self.decoder(features, encoded, gt=None if target is None else gt))
253
255
 
254
256
  out: Dict[str, Any] = {}
255
257
  if self.exportable:
@@ -278,17 +280,19 @@ class SAR(nn.Module, RecognitionModel):
278
280
  Sequences are masked after the EOS character.
279
281
 
280
282
  Args:
283
+ ----
281
284
  model_output: predicted logits of the model
282
285
  gt: the encoded tensor with gt labels
283
286
  seq_len: lengths of each gt word inside the batch
284
287
 
285
288
  Returns:
289
+ -------
286
290
  The loss of the model on the batch
287
291
  """
288
292
  # Input length : number of timesteps
289
293
  input_len = model_output.shape[1]
290
294
  # Add one for additional <eos> token
291
- seq_len = seq_len + 1
295
+ seq_len = seq_len + 1 # type: ignore[assignment]
292
296
  # Compute loss
293
297
  # (N, L, vocab_size + 1)
294
298
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
@@ -303,6 +307,7 @@ class SARPostProcessor(RecognitionPostProcessor):
303
307
  """Post processor for SAR architectures
304
308
 
305
309
  Args:
310
+ ----
306
311
  vocab: string containing the ordered sequence of supported characters
307
312
  """
308
313
 
@@ -323,7 +328,7 @@ class SARPostProcessor(RecognitionPostProcessor):
323
328
  for encoded_seq in out_idxs.detach().cpu().numpy()
324
329
  ]
325
330
 
326
- return list(zip(word_values, probs.numpy().tolist()))
331
+ return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
327
332
 
328
333
 
329
334
  def _sar(
@@ -373,12 +378,14 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
373
378
  >>> out = model(input_tensor)
374
379
 
375
380
  Args:
381
+ ----
376
382
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
383
+ **kwargs: keyword arguments of the SAR architecture
377
384
 
378
385
  Returns:
386
+ -------
379
387
  text recognition architecture
380
388
  """
381
-
382
389
  return _sar(
383
390
  "sar_resnet31",
384
391
  pretrained,
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -13,7 +13,7 @@ from doctr.datasets import VOCABS
13
13
  from doctr.utils.repr import NestedObject
14
14
 
15
15
  from ...classification import resnet31
16
- from ...utils.tensorflow import load_pretrained_params
16
+ from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
17
17
  from ..core import RecognitionModel, RecognitionPostProcessor
18
18
 
19
19
  __all__ = ["SAR", "sar_resnet31"]
@@ -33,18 +33,17 @@ class SAREncoder(layers.Layer, NestedObject):
33
33
  """Implements encoder module of the SAR model
34
34
 
35
35
  Args:
36
+ ----
36
37
  rnn_units: number of hidden rnn units
37
38
  dropout_prob: dropout probability
38
39
  """
39
40
 
40
41
  def __init__(self, rnn_units: int, dropout_prob: float = 0.0) -> None:
41
42
  super().__init__()
42
- self.rnn = Sequential(
43
- [
44
- layers.LSTM(units=rnn_units, return_sequences=True, recurrent_dropout=dropout_prob),
45
- layers.LSTM(units=rnn_units, return_sequences=False, recurrent_dropout=dropout_prob),
46
- ]
47
- )
43
+ self.rnn = Sequential([
44
+ layers.LSTM(units=rnn_units, return_sequences=True, recurrent_dropout=dropout_prob),
45
+ layers.LSTM(units=rnn_units, return_sequences=False, recurrent_dropout=dropout_prob),
46
+ ])
48
47
 
49
48
  def call(
50
49
  self,
@@ -59,6 +58,7 @@ class AttentionModule(layers.Layer, NestedObject):
59
58
  """Implements attention module of the SAR model
60
59
 
61
60
  Args:
61
+ ----
62
62
  attention_units: number of hidden attention units
63
63
 
64
64
  """
@@ -120,6 +120,7 @@ class SARDecoder(layers.Layer, NestedObject):
120
120
  """Implements decoder module of the SAR model
121
121
 
122
122
  Args:
123
+ ----
123
124
  rnn_units: number of hidden units in recurrent cells
124
125
  max_length: maximum length of a sequence
125
126
  vocab_size: number of classes in the model alphabet
@@ -147,9 +148,9 @@ class SARDecoder(layers.Layer, NestedObject):
147
148
  self.embed = layers.Dense(embedding_units, use_bias=False)
148
149
  self.embed_tgt = layers.Embedding(embedding_units, self.vocab_size + 1)
149
150
 
150
- self.lstm_cells = layers.StackedRNNCells(
151
- [layers.LSTMCell(rnn_units, implementation=1) for _ in range(num_decoder_cells)]
152
- )
151
+ self.lstm_cells = layers.StackedRNNCells([
152
+ layers.LSTMCell(rnn_units, implementation=1) for _ in range(num_decoder_cells)
153
+ ])
153
154
  self.attention_module = AttentionModule(attention_units)
154
155
  self.output_dense = layers.Dense(self.vocab_size + 1, use_bias=True)
155
156
  self.dropout = layers.Dropout(dropout_prob)
@@ -215,6 +216,7 @@ class SAR(Model, RecognitionModel):
215
216
  Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
216
217
 
217
218
  Args:
219
+ ----
218
220
  feature_extractor: the backbone serving as feature extractor
219
221
  vocab: vocabulary used for encoding
220
222
  rnn_units: number of hidden units in both encoder and decoder LSTM
@@ -273,11 +275,13 @@ class SAR(Model, RecognitionModel):
273
275
  Sequences are masked after the EOS character.
274
276
 
275
277
  Args:
278
+ ----
276
279
  gt: the encoded tensor with gt labels
277
280
  model_output: predicted logits of the model
278
281
  seq_len: lengths of each gt word inside the batch
279
282
 
280
283
  Returns:
284
+ -------
281
285
  The loss of the model on the batch
282
286
  """
283
287
  # Input length : number of timesteps
@@ -316,7 +320,9 @@ class SAR(Model, RecognitionModel):
316
320
  if kwargs.get("training", False) and target is None:
317
321
  raise ValueError("Need to provide labels during training for teacher forcing")
318
322
 
319
- decoded_features = self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
323
+ decoded_features = _bf16_to_float32(
324
+ self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
325
+ )
320
326
 
321
327
  out: Dict[str, tf.Tensor] = {}
322
328
  if self.exportable:
@@ -340,6 +346,7 @@ class SARPostProcessor(RecognitionPostProcessor):
340
346
  """Post processor for SAR architectures
341
347
 
342
348
  Args:
349
+ ----
343
350
  vocab: string containing the ordered sequence of supported characters
344
351
  """
345
352
 
@@ -362,7 +369,7 @@ class SARPostProcessor(RecognitionPostProcessor):
362
369
  decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
363
370
  word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
364
371
 
365
- return list(zip(word_values, probs.numpy().tolist()))
372
+ return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
366
373
 
367
374
 
368
375
  def _sar(
@@ -409,10 +416,12 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
409
416
  >>> out = model(input_tensor)
410
417
 
411
418
  Args:
419
+ ----
412
420
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
421
+ **kwargs: keyword arguments of the SAR architecture
413
422
 
414
423
  Returns:
424
+ -------
415
425
  text recognition architecture
416
426
  """
417
-
418
427
  return _sar("sar_resnet31", pretrained, resnet31, **kwargs)
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -14,12 +14,14 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str:
14
14
  """Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
15
15
 
16
16
  Args:
17
+ ----
17
18
  a: first char seq, suffix should be similar to b's prefix.
18
19
  b: second char seq, prefix should be similar to a's suffix.
19
20
  dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
20
21
  only used when the mother sequence is splitted on a character repetition
21
22
 
22
23
  Returns:
24
+ -------
23
25
  A merged character sequence.
24
26
 
25
27
  Example::
@@ -63,11 +65,13 @@ def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str:
63
65
  """Recursively merges consecutive string sequences with overlapping characters.
64
66
 
65
67
  Args:
68
+ ----
66
69
  seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
67
70
  dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
68
71
  only used when the mother sequence is splitted on a character repetition
69
72
 
70
73
  Returns:
74
+ -------
71
75
  A merged character sequence
72
76
 
73
77
  Example::
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -23,9 +23,11 @@ class _ViTSTR:
23
23
  sequence lengths.
24
24
 
25
25
  Args:
26
+ ----
26
27
  gts: list of ground-truth labels
27
28
 
28
29
  Returns:
30
+ -------
29
31
  A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
30
32
  """
31
33
  encoded = encode_sequences(
@@ -43,6 +45,7 @@ class _ViTSTRPostProcessor(RecognitionPostProcessor):
43
45
  """Abstract class to postprocess the raw output of the model
44
46
 
45
47
  Args:
48
+ ----
46
49
  vocab: string containing the ordered sequence of supported characters
47
50
  """
48
51