python-doctr 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. doctr/__init__.py +1 -1
  2. doctr/contrib/__init__.py +0 -0
  3. doctr/contrib/artefacts.py +131 -0
  4. doctr/contrib/base.py +105 -0
  5. doctr/datasets/cord.py +10 -1
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +11 -1
  8. doctr/datasets/generator/base.py +6 -5
  9. doctr/datasets/ic03.py +11 -1
  10. doctr/datasets/ic13.py +10 -1
  11. doctr/datasets/iiit5k.py +26 -16
  12. doctr/datasets/imgur5k.py +11 -2
  13. doctr/datasets/loader.py +1 -6
  14. doctr/datasets/sroie.py +11 -1
  15. doctr/datasets/svhn.py +11 -1
  16. doctr/datasets/svt.py +11 -1
  17. doctr/datasets/synthtext.py +11 -1
  18. doctr/datasets/utils.py +9 -3
  19. doctr/datasets/vocabs.py +15 -4
  20. doctr/datasets/wildreceipt.py +12 -1
  21. doctr/file_utils.py +45 -12
  22. doctr/io/elements.py +52 -10
  23. doctr/io/html.py +2 -2
  24. doctr/io/image/pytorch.py +6 -8
  25. doctr/io/image/tensorflow.py +1 -1
  26. doctr/io/pdf.py +5 -2
  27. doctr/io/reader.py +6 -0
  28. doctr/models/__init__.py +0 -1
  29. doctr/models/_utils.py +57 -20
  30. doctr/models/builder.py +73 -15
  31. doctr/models/classification/magc_resnet/tensorflow.py +13 -6
  32. doctr/models/classification/mobilenet/pytorch.py +47 -9
  33. doctr/models/classification/mobilenet/tensorflow.py +51 -14
  34. doctr/models/classification/predictor/pytorch.py +28 -17
  35. doctr/models/classification/predictor/tensorflow.py +26 -16
  36. doctr/models/classification/resnet/tensorflow.py +21 -8
  37. doctr/models/classification/textnet/pytorch.py +3 -3
  38. doctr/models/classification/textnet/tensorflow.py +11 -5
  39. doctr/models/classification/vgg/tensorflow.py +9 -3
  40. doctr/models/classification/vit/tensorflow.py +10 -4
  41. doctr/models/classification/zoo.py +55 -19
  42. doctr/models/detection/_utils/__init__.py +1 -0
  43. doctr/models/detection/_utils/base.py +66 -0
  44. doctr/models/detection/differentiable_binarization/base.py +4 -3
  45. doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
  46. doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
  47. doctr/models/detection/fast/base.py +6 -5
  48. doctr/models/detection/fast/pytorch.py +4 -4
  49. doctr/models/detection/fast/tensorflow.py +15 -12
  50. doctr/models/detection/linknet/base.py +4 -3
  51. doctr/models/detection/linknet/tensorflow.py +23 -11
  52. doctr/models/detection/predictor/pytorch.py +15 -1
  53. doctr/models/detection/predictor/tensorflow.py +17 -3
  54. doctr/models/detection/zoo.py +7 -2
  55. doctr/models/factory/hub.py +8 -18
  56. doctr/models/kie_predictor/base.py +13 -3
  57. doctr/models/kie_predictor/pytorch.py +45 -20
  58. doctr/models/kie_predictor/tensorflow.py +44 -17
  59. doctr/models/modules/layers/pytorch.py +2 -3
  60. doctr/models/modules/layers/tensorflow.py +6 -8
  61. doctr/models/modules/transformer/pytorch.py +2 -2
  62. doctr/models/modules/transformer/tensorflow.py +0 -2
  63. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  64. doctr/models/modules/vision_transformer/tensorflow.py +1 -1
  65. doctr/models/predictor/base.py +97 -58
  66. doctr/models/predictor/pytorch.py +35 -20
  67. doctr/models/predictor/tensorflow.py +35 -18
  68. doctr/models/preprocessor/pytorch.py +4 -4
  69. doctr/models/preprocessor/tensorflow.py +3 -2
  70. doctr/models/recognition/crnn/tensorflow.py +8 -6
  71. doctr/models/recognition/master/pytorch.py +2 -2
  72. doctr/models/recognition/master/tensorflow.py +9 -4
  73. doctr/models/recognition/parseq/pytorch.py +4 -3
  74. doctr/models/recognition/parseq/tensorflow.py +14 -11
  75. doctr/models/recognition/sar/pytorch.py +7 -6
  76. doctr/models/recognition/sar/tensorflow.py +10 -12
  77. doctr/models/recognition/vitstr/pytorch.py +1 -1
  78. doctr/models/recognition/vitstr/tensorflow.py +9 -4
  79. doctr/models/recognition/zoo.py +1 -1
  80. doctr/models/utils/pytorch.py +1 -1
  81. doctr/models/utils/tensorflow.py +15 -15
  82. doctr/models/zoo.py +2 -2
  83. doctr/py.typed +0 -0
  84. doctr/transforms/functional/base.py +1 -1
  85. doctr/transforms/functional/pytorch.py +5 -5
  86. doctr/transforms/modules/base.py +37 -15
  87. doctr/transforms/modules/pytorch.py +73 -14
  88. doctr/transforms/modules/tensorflow.py +78 -19
  89. doctr/utils/fonts.py +7 -5
  90. doctr/utils/geometry.py +141 -31
  91. doctr/utils/metrics.py +34 -175
  92. doctr/utils/reconstitution.py +212 -0
  93. doctr/utils/visualization.py +5 -118
  94. doctr/version.py +1 -1
  95. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/METADATA +85 -81
  96. python_doctr-0.10.0.dist-info/RECORD +173 -0
  97. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
  98. doctr/models/artefacts/__init__.py +0 -2
  99. doctr/models/artefacts/barcode.py +0 -74
  100. doctr/models/artefacts/face.py +0 -63
  101. doctr/models/obj_detection/__init__.py +0 -1
  102. doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
  103. doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
  104. python_doctr-0.8.1.dist-info/RECORD +0 -173
  105. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
  106. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
  107. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
@@ -3,7 +3,7 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Union
6
+ from typing import List, Optional, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -12,12 +12,12 @@ from torch import nn
12
12
  from doctr.models.preprocessor import PreProcessor
13
13
  from doctr.models.utils import set_device_and_dtype
14
14
 
15
- __all__ = ["CropOrientationPredictor"]
15
+ __all__ = ["OrientationPredictor"]
16
16
 
17
17
 
18
- class CropOrientationPredictor(nn.Module):
19
- """Implements an object able to detect the reading direction of a text box.
20
- 4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.
18
+ class OrientationPredictor(nn.Module):
19
+ """Implements an object able to detect the reading direction of a text box or a page.
20
+ 4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
21
21
 
22
22
  Args:
23
23
  ----
@@ -27,30 +27,41 @@ class CropOrientationPredictor(nn.Module):
27
27
 
28
28
  def __init__(
29
29
  self,
30
- pre_processor: PreProcessor,
31
- model: nn.Module,
30
+ pre_processor: Optional[PreProcessor],
31
+ model: Optional[nn.Module],
32
32
  ) -> None:
33
33
  super().__init__()
34
- self.pre_processor = pre_processor
35
- self.model = model.eval()
34
+ self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
35
+ self.model = model.eval() if isinstance(model, nn.Module) else None
36
36
 
37
37
  @torch.inference_mode()
38
38
  def forward(
39
39
  self,
40
- crops: List[Union[np.ndarray, torch.Tensor]],
41
- ) -> List[int]:
40
+ inputs: List[Union[np.ndarray, torch.Tensor]],
41
+ ) -> List[Union[List[int], List[float]]]:
42
42
  # Dimension check
43
- if any(crop.ndim != 3 for crop in crops):
44
- raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
43
+ if any(input.ndim != 3 for input in inputs):
44
+ raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
45
45
 
46
- processed_batches = self.pre_processor(crops)
46
+ if self.model is None or self.pre_processor is None:
47
+ # predictor is disabled
48
+ return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)]
49
+
50
+ processed_batches = self.pre_processor(inputs)
47
51
  _params = next(self.model.parameters())
48
52
  self.model, processed_batches = set_device_and_dtype(
49
53
  self.model, processed_batches, _params.device, _params.dtype
50
54
  )
51
- predicted_batches = [self.model(batch) for batch in processed_batches]
52
-
55
+ predicted_batches = [self.model(batch) for batch in processed_batches] # type: ignore[misc]
56
+ # confidence
57
+ probs = [
58
+ torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches
59
+ ]
53
60
  # Postprocess predictions
54
61
  predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches]
55
62
 
56
- return [int(pred) for batch in predicted_batches for pred in batch]
63
+ class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
64
+ classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] # type: ignore[union-attr]
65
+ confs = [round(float(p), 2) for prob in probs for p in prob]
66
+
67
+ return [class_idxs, classes, confs]
@@ -3,21 +3,21 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Union
6
+ from typing import List, Optional, Union
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
10
- from tensorflow import keras
10
+ from tensorflow.keras import Model
11
11
 
12
12
  from doctr.models.preprocessor import PreProcessor
13
13
  from doctr.utils.repr import NestedObject
14
14
 
15
- __all__ = ["CropOrientationPredictor"]
15
+ __all__ = ["OrientationPredictor"]
16
16
 
17
17
 
18
- class CropOrientationPredictor(NestedObject):
19
- """Implements an object able to detect the reading direction of a text box.
20
- 4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.
18
+ class OrientationPredictor(NestedObject):
19
+ """Implements an object able to detect the reading direction of a text box or a page.
20
+ 4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
21
21
 
22
22
  Args:
23
23
  ----
@@ -29,24 +29,34 @@ class CropOrientationPredictor(NestedObject):
29
29
 
30
30
  def __init__(
31
31
  self,
32
- pre_processor: PreProcessor,
33
- model: keras.Model,
32
+ pre_processor: Optional[PreProcessor],
33
+ model: Optional[Model],
34
34
  ) -> None:
35
- self.pre_processor = pre_processor
36
- self.model = model
35
+ self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
36
+ self.model = model if isinstance(model, Model) else None
37
37
 
38
38
  def __call__(
39
39
  self,
40
- crops: List[Union[np.ndarray, tf.Tensor]],
41
- ) -> List[int]:
40
+ inputs: List[Union[np.ndarray, tf.Tensor]],
41
+ ) -> List[Union[List[int], List[float]]]:
42
42
  # Dimension check
43
- if any(crop.ndim != 3 for crop in crops):
44
- raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.")
43
+ if any(input.ndim != 3 for input in inputs):
44
+ raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
45
45
 
46
- processed_batches = self.pre_processor(crops)
46
+ if self.model is None or self.pre_processor is None:
47
+ # predictor is disabled
48
+ return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)]
49
+
50
+ processed_batches = self.pre_processor(inputs)
47
51
  predicted_batches = [self.model(batch, training=False) for batch in processed_batches]
48
52
 
53
+ # confidence
54
+ probs = [tf.math.reduce_max(tf.nn.softmax(batch, axis=1), axis=1).numpy() for batch in predicted_batches]
49
55
  # Postprocess predictions
50
56
  predicted_batches = [out_batch.numpy().argmax(1) for out_batch in predicted_batches]
51
57
 
52
- return [int(pred) for batch in predicted_batches for pred in batch]
58
+ class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
59
+ classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs]
60
+ confs = [round(float(p), 2) for prob in probs for p in prob]
61
+
62
+ return [class_idxs, classes, confs]
@@ -13,7 +13,7 @@ from tensorflow.keras.models import Sequential
13
13
 
14
14
  from doctr.datasets import VOCABS
15
15
 
16
- from ...utils import conv_sequence, load_pretrained_params
16
+ from ...utils import _build_model, conv_sequence, load_pretrained_params
17
17
 
18
18
  __all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"]
19
19
 
@@ -24,35 +24,35 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
24
24
  "std": (0.299, 0.296, 0.301),
25
25
  "input_shape": (32, 32, 3),
26
26
  "classes": list(VOCABS["french"]),
27
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/resnet18-d4634669.zip&src=0",
27
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0",
28
28
  },
29
29
  "resnet31": {
30
30
  "mean": (0.694, 0.695, 0.693),
31
31
  "std": (0.299, 0.296, 0.301),
32
32
  "input_shape": (32, 32, 3),
33
33
  "classes": list(VOCABS["french"]),
34
- "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet31-5a47a60b.zip&src=0",
34
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0",
35
35
  },
36
36
  "resnet34": {
37
37
  "mean": (0.694, 0.695, 0.693),
38
38
  "std": (0.299, 0.296, 0.301),
39
39
  "input_shape": (32, 32, 3),
40
40
  "classes": list(VOCABS["french"]),
41
- "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34-5dcc97ca.zip&src=0",
41
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0",
42
42
  },
43
43
  "resnet50": {
44
44
  "mean": (0.694, 0.695, 0.693),
45
45
  "std": (0.299, 0.296, 0.301),
46
46
  "input_shape": (32, 32, 3),
47
47
  "classes": list(VOCABS["french"]),
48
- "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet50-e75e4cdf.zip&src=0",
48
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0",
49
49
  },
50
50
  "resnet34_wide": {
51
51
  "mean": (0.694, 0.695, 0.693),
52
52
  "std": (0.299, 0.296, 0.301),
53
53
  "input_shape": (32, 32, 3),
54
54
  "classes": list(VOCABS["french"]),
55
- "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34_wide-c1271816.zip&src=0",
55
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0",
56
56
  },
57
57
  }
58
58
 
@@ -210,9 +210,15 @@ def _resnet(
210
210
  model = ResNet(
211
211
  num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, cfg=_cfg, **kwargs
212
212
  )
213
+ _build_model(model)
214
+
213
215
  # Load pretrained parameters
214
216
  if pretrained:
215
- load_pretrained_params(model, default_cfgs[arch]["url"])
217
+ # The number of classes is not the same as the number of classes in the pretrained model =>
218
+ # skip the mismatching layers for fine tuning
219
+ load_pretrained_params(
220
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
221
+ )
216
222
 
217
223
  return model
218
224
 
@@ -354,10 +360,17 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
354
360
  )
355
361
 
356
362
  model.cfg = _cfg
363
+ _build_model(model)
357
364
 
358
365
  # Load pretrained parameters
359
366
  if pretrained:
360
- load_pretrained_params(model, default_cfgs["resnet50"]["url"])
367
+ # The number of classes is not the same as the number of classes in the pretrained model =>
368
+ # skip the mismatching layers for fine tuning
369
+ load_pretrained_params(
370
+ model,
371
+ default_cfgs["resnet50"]["url"],
372
+ skip_mismatch=kwargs["num_classes"] != len(default_cfgs["resnet50"]["classes"]),
373
+ )
361
374
 
362
375
  return model
363
376
 
@@ -22,21 +22,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
22
22
  "std": (0.299, 0.296, 0.301),
23
23
  "input_shape": (3, 32, 32),
24
24
  "classes": list(VOCABS["french"]),
25
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_tiny-c5970fe0.pt&src=0",
25
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-27288d12.pt&src=0",
26
26
  },
27
27
  "textnet_small": {
28
28
  "mean": (0.694, 0.695, 0.693),
29
29
  "std": (0.299, 0.296, 0.301),
30
30
  "input_shape": (3, 32, 32),
31
31
  "classes": list(VOCABS["french"]),
32
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_small-6e8ab0ce.pt&src=0",
32
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-43166ee6.pt&src=0",
33
33
  },
34
34
  "textnet_base": {
35
35
  "mean": (0.694, 0.695, 0.693),
36
36
  "std": (0.299, 0.296, 0.301),
37
37
  "input_shape": (3, 32, 32),
38
38
  "classes": list(VOCABS["french"]),
39
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_base-8295dc85.pt&src=0",
39
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-7f68d7e0.pt&src=0",
40
40
  },
41
41
  }
42
42
 
@@ -12,7 +12,7 @@ from tensorflow.keras import Sequential, layers
12
12
  from doctr.datasets import VOCABS
13
13
 
14
14
  from ...modules.layers.tensorflow import FASTConvLayer
15
- from ...utils import conv_sequence, load_pretrained_params
15
+ from ...utils import _build_model, conv_sequence, load_pretrained_params
16
16
 
17
17
  __all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
18
18
 
@@ -22,21 +22,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
22
22
  "std": (0.299, 0.296, 0.301),
23
23
  "input_shape": (32, 32, 3),
24
24
  "classes": list(VOCABS["french"]),
25
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_tiny-9e605bd8.zip&src=0",
25
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0",
26
26
  },
27
27
  "textnet_small": {
28
28
  "mean": (0.694, 0.695, 0.693),
29
29
  "std": (0.299, 0.296, 0.301),
30
30
  "input_shape": (32, 32, 3),
31
31
  "classes": list(VOCABS["french"]),
32
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_small-4784b292.zip&src=0",
32
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0",
33
33
  },
34
34
  "textnet_base": {
35
35
  "mean": (0.694, 0.695, 0.693),
36
36
  "std": (0.299, 0.296, 0.301),
37
37
  "input_shape": (32, 32, 3),
38
38
  "classes": list(VOCABS["french"]),
39
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_base-2c3f3265.zip&src=0",
39
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0",
40
40
  },
41
41
  }
42
42
 
@@ -111,9 +111,15 @@ def _textnet(
111
111
 
112
112
  # Build the model
113
113
  model = TextNet(cfg=_cfg, **kwargs)
114
+ _build_model(model)
115
+
114
116
  # Load pretrained parameters
115
117
  if pretrained:
116
- load_pretrained_params(model, default_cfgs[arch]["url"])
118
+ # The number of classes is not the same as the number of classes in the pretrained model =>
119
+ # skip the mismatching layers for fine tuning
120
+ load_pretrained_params(
121
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
122
+ )
117
123
 
118
124
  return model
119
125
 
@@ -11,7 +11,7 @@ from tensorflow.keras.models import Sequential
11
11
 
12
12
  from doctr.datasets import VOCABS
13
13
 
14
- from ...utils import conv_sequence, load_pretrained_params
14
+ from ...utils import _build_model, conv_sequence, load_pretrained_params
15
15
 
16
16
  __all__ = ["VGG", "vgg16_bn_r"]
17
17
 
@@ -22,7 +22,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
22
22
  "std": (1.0, 1.0, 1.0),
23
23
  "input_shape": (32, 32, 3),
24
24
  "classes": list(VOCABS["french"]),
25
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/vgg16_bn_r-c5836cea.zip&src=0",
25
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0",
26
26
  },
27
27
  }
28
28
 
@@ -81,9 +81,15 @@ def _vgg(
81
81
 
82
82
  # Build the model
83
83
  model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs)
84
+ _build_model(model)
85
+
84
86
  # Load pretrained parameters
85
87
  if pretrained:
86
- load_pretrained_params(model, default_cfgs[arch]["url"])
88
+ # The number of classes is not the same as the number of classes in the pretrained model =>
89
+ # skip the mismatching layers for fine tuning
90
+ load_pretrained_params(
91
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
92
+ )
87
93
 
88
94
  return model
89
95
 
@@ -14,7 +14,7 @@ from doctr.models.modules.transformer import EncoderBlock
14
14
  from doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding
15
15
  from doctr.utils.repr import NestedObject
16
16
 
17
- from ...utils import load_pretrained_params
17
+ from ...utils import _build_model, load_pretrained_params
18
18
 
19
19
  __all__ = ["vit_s", "vit_b"]
20
20
 
@@ -25,14 +25,14 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
25
25
  "std": (0.299, 0.296, 0.301),
26
26
  "input_shape": (3, 32, 32),
27
27
  "classes": list(VOCABS["french"]),
28
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_s-6300fcc9.zip&src=0",
28
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_s-69bc459e.weights.h5&src=0",
29
29
  },
30
30
  "vit_b": {
31
31
  "mean": (0.694, 0.695, 0.693),
32
32
  "std": (0.299, 0.296, 0.301),
33
33
  "input_shape": (32, 32, 3),
34
34
  "classes": list(VOCABS["french"]),
35
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_b-57158446.zip&src=0",
35
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_b-c64705bd.weights.h5&src=0",
36
36
  },
37
37
  }
38
38
 
@@ -121,9 +121,15 @@ def _vit(
121
121
 
122
122
  # Build the model
123
123
  model = VisionTransformer(cfg=_cfg, **kwargs)
124
+ _build_model(model)
125
+
124
126
  # Load pretrained parameters
125
127
  if pretrained:
126
- load_pretrained_params(model, default_cfgs[arch]["url"])
128
+ # The number of classes is not the same as the number of classes in the pretrained model =>
129
+ # skip the mismatching layers for fine tuning
130
+ load_pretrained_params(
131
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
132
+ )
127
133
 
128
134
  return model
129
135
 
@@ -9,9 +9,9 @@ from doctr.file_utils import is_tf_available
9
9
 
10
10
  from .. import classification
11
11
  from ..preprocessor import PreProcessor
12
- from .predictor import CropOrientationPredictor
12
+ from .predictor import OrientationPredictor
13
13
 
14
- __all__ = ["crop_orientation_predictor"]
14
+ __all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
15
15
 
16
16
  ARCHS: List[str] = [
17
17
  "magc_resnet31",
@@ -31,44 +31,80 @@ ARCHS: List[str] = [
31
31
  "vit_s",
32
32
  "vit_b",
33
33
  ]
34
- ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_orientation"]
34
+ ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
35
35
 
36
36
 
37
- def _crop_orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> CropOrientationPredictor:
38
- if arch not in ORIENTATION_ARCHS:
39
- raise ValueError(f"unknown architecture '{arch}'")
37
+ def _orientation_predictor(
38
+ arch: Any, pretrained: bool, model_type: str, disabled: bool = False, **kwargs: Any
39
+ ) -> OrientationPredictor:
40
+ if disabled:
41
+ # Case where the orientation predictor is disabled
42
+ return OrientationPredictor(None, None)
43
+
44
+ if isinstance(arch, str):
45
+ if arch not in ORIENTATION_ARCHS:
46
+ raise ValueError(f"unknown architecture '{arch}'")
47
+
48
+ # Load directly classifier from backbone
49
+ _model = classification.__dict__[arch](pretrained=pretrained)
50
+ else:
51
+ if not isinstance(arch, classification.MobileNetV3):
52
+ raise ValueError(f"unknown architecture: {type(arch)}")
53
+ _model = arch
40
54
 
41
- # Load directly classifier from backbone
42
- _model = classification.__dict__[arch](pretrained=pretrained)
43
55
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
44
56
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
45
- kwargs["batch_size"] = kwargs.get("batch_size", 64)
57
+ kwargs["batch_size"] = kwargs.get("batch_size", 128 if model_type == "crop" else 4)
46
58
  input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]
47
- predictor = CropOrientationPredictor(
59
+ predictor = OrientationPredictor(
48
60
  PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
49
61
  )
50
62
  return predictor
51
63
 
52
64
 
53
65
  def crop_orientation_predictor(
54
- arch: str = "mobilenet_v3_small_orientation", pretrained: bool = False, **kwargs: Any
55
- ) -> CropOrientationPredictor:
56
- """Orientation classification architecture.
66
+ arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
67
+ ) -> OrientationPredictor:
68
+ """Crop orientation classification architecture.
57
69
 
58
70
  >>> import numpy as np
59
71
  >>> from doctr.models import crop_orientation_predictor
60
- >>> model = crop_orientation_predictor(arch='classif_mobilenet_v3_small', pretrained=True)
61
- >>> input_crop = (255 * np.random.rand(600, 800, 3)).astype(np.uint8)
72
+ >>> model = crop_orientation_predictor(arch='mobilenet_v3_small_crop_orientation', pretrained=True)
73
+ >>> input_crop = (255 * np.random.rand(256, 256, 3)).astype(np.uint8)
62
74
  >>> out = model([input_crop])
63
75
 
64
76
  Args:
65
77
  ----
66
- arch: name of the architecture to use (e.g. 'mobilenet_v3_small')
78
+ arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
79
+ pretrained: If True, returns a model pre-trained on our recognition crops dataset
80
+ **kwargs: keyword arguments to be passed to the OrientationPredictor
81
+
82
+ Returns:
83
+ -------
84
+ OrientationPredictor
85
+ """
86
+ return _orientation_predictor(arch, pretrained, model_type="crop", **kwargs)
87
+
88
+
89
+ def page_orientation_predictor(
90
+ arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
91
+ ) -> OrientationPredictor:
92
+ """Page orientation classification architecture.
93
+
94
+ >>> import numpy as np
95
+ >>> from doctr.models import page_orientation_predictor
96
+ >>> model = page_orientation_predictor(arch='mobilenet_v3_small_page_orientation', pretrained=True)
97
+ >>> input_page = (255 * np.random.rand(512, 512, 3)).astype(np.uint8)
98
+ >>> out = model([input_page])
99
+
100
+ Args:
101
+ ----
102
+ arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
67
103
  pretrained: If True, returns a model pre-trained on our recognition crops dataset
68
- **kwargs: keyword arguments to be passed to the CropOrientationPredictor
104
+ **kwargs: keyword arguments to be passed to the OrientationPredictor
69
105
 
70
106
  Returns:
71
107
  -------
72
- CropOrientationPredictor
108
+ OrientationPredictor
73
109
  """
74
- return _crop_orientation_predictor(arch, pretrained, **kwargs)
110
+ return _orientation_predictor(arch, pretrained, model_type="page", **kwargs)
@@ -1,4 +1,5 @@
1
1
  from doctr.file_utils import is_tf_available
2
+ from .base import *
2
3
 
3
4
  if is_tf_available():
4
5
  from .tensorflow import *
@@ -0,0 +1,66 @@
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Dict, List
7
+
8
+ import numpy as np
9
+
10
+ __all__ = ["_remove_padding"]
11
+
12
+
13
+ def _remove_padding(
14
+ pages: List[np.ndarray],
15
+ loc_preds: List[Dict[str, np.ndarray]],
16
+ preserve_aspect_ratio: bool,
17
+ symmetric_pad: bool,
18
+ assume_straight_pages: bool,
19
+ ) -> List[Dict[str, np.ndarray]]:
20
+ """Remove padding from the localization predictions
21
+
22
+ Args:
23
+ ----
24
+ pages: list of pages
25
+ loc_preds: list of localization predictions
26
+ preserve_aspect_ratio: whether the aspect ratio was preserved during padding
27
+ symmetric_pad: whether the padding was symmetric
28
+ assume_straight_pages: whether the pages are assumed to be straight
29
+
30
+ Returns:
31
+ -------
32
+ list of unpaded localization predictions
33
+ """
34
+ if preserve_aspect_ratio:
35
+ # Rectify loc_preds to remove padding
36
+ rectified_preds = []
37
+ for page, dict_loc_preds in zip(pages, loc_preds):
38
+ for k, loc_pred in dict_loc_preds.items():
39
+ h, w = page.shape[0], page.shape[1]
40
+ if h > w:
41
+ # y unchanged, dilate x coord
42
+ if symmetric_pad:
43
+ if assume_straight_pages:
44
+ loc_pred[:, [0, 2]] = (loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5
45
+ else:
46
+ loc_pred[:, :, 0] = (loc_pred[:, :, 0] - 0.5) * h / w + 0.5
47
+ else:
48
+ if assume_straight_pages:
49
+ loc_pred[:, [0, 2]] *= h / w
50
+ else:
51
+ loc_pred[:, :, 0] *= h / w
52
+ elif w > h:
53
+ # x unchanged, dilate y coord
54
+ if symmetric_pad:
55
+ if assume_straight_pages:
56
+ loc_pred[:, [1, 3]] = (loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5
57
+ else:
58
+ loc_pred[:, :, 1] = (loc_pred[:, :, 1] - 0.5) * w / h + 0.5
59
+ else:
60
+ if assume_straight_pages:
61
+ loc_pred[:, [1, 3]] *= w / h
62
+ else:
63
+ loc_pred[:, :, 1] *= w / h
64
+ rectified_preds.append({k: np.clip(loc_pred, 0, 1)})
65
+ return rectified_preds
66
+ return loc_preds
@@ -114,7 +114,7 @@ class DBPostProcessor(DetectionPostProcessor):
114
114
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
115
115
  for contour in contours:
116
116
  # Check whether smallest enclosing bounding box is not too small
117
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
117
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box): # type: ignore[index]
118
118
  continue
119
119
  # Compute objectness
120
120
  if self.assume_straight_pages:
@@ -150,10 +150,11 @@ class DBPostProcessor(DetectionPostProcessor):
150
150
  raise AssertionError("When assume straight pages is false a box is a (4, 2) array (polygon)")
151
151
  _box[:, 0] /= width
152
152
  _box[:, 1] /= height
153
- boxes.append(_box)
153
+ # Add score to box as (0, score)
154
+ boxes.append(np.vstack([_box, np.array([0.0, score])]))
154
155
 
155
156
  if not self.assume_straight_pages:
156
- return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
157
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5, 2), dtype=pred.dtype)
157
158
  else:
158
159
  return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
159
160
 
@@ -39,7 +39,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
39
39
  "input_shape": (3, 1024, 1024),
40
40
  "mean": (0.798, 0.785, 0.772),
41
41
  "std": (0.264, 0.2749, 0.287),
42
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-81e9b152.pt&src=0",
42
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/db_mobilenet_v3_large-21748dd0.pt&src=0",
43
43
  },
44
44
  }
45
45
 
@@ -273,7 +273,7 @@ class DBNet(_DBNet, nn.Module):
273
273
  dice_map = torch.softmax(out_map, dim=1)
274
274
  else:
275
275
  # compute binary map instead
276
- dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment]
276
+ dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
277
277
  # Class reduced
278
278
  inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
279
279
  cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))