python-doctr 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. doctr/__init__.py +1 -1
  2. doctr/contrib/__init__.py +0 -0
  3. doctr/contrib/artefacts.py +131 -0
  4. doctr/contrib/base.py +105 -0
  5. doctr/datasets/cord.py +10 -1
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +11 -1
  8. doctr/datasets/generator/base.py +6 -5
  9. doctr/datasets/ic03.py +11 -1
  10. doctr/datasets/ic13.py +10 -1
  11. doctr/datasets/iiit5k.py +26 -16
  12. doctr/datasets/imgur5k.py +11 -2
  13. doctr/datasets/loader.py +1 -6
  14. doctr/datasets/sroie.py +11 -1
  15. doctr/datasets/svhn.py +11 -1
  16. doctr/datasets/svt.py +11 -1
  17. doctr/datasets/synthtext.py +11 -1
  18. doctr/datasets/utils.py +9 -3
  19. doctr/datasets/vocabs.py +15 -4
  20. doctr/datasets/wildreceipt.py +12 -1
  21. doctr/file_utils.py +45 -12
  22. doctr/io/elements.py +52 -10
  23. doctr/io/html.py +2 -2
  24. doctr/io/image/pytorch.py +6 -8
  25. doctr/io/image/tensorflow.py +1 -1
  26. doctr/io/pdf.py +5 -2
  27. doctr/io/reader.py +6 -0
  28. doctr/models/__init__.py +0 -1
  29. doctr/models/_utils.py +57 -20
  30. doctr/models/builder.py +73 -15
  31. doctr/models/classification/magc_resnet/tensorflow.py +13 -6
  32. doctr/models/classification/mobilenet/pytorch.py +47 -9
  33. doctr/models/classification/mobilenet/tensorflow.py +51 -14
  34. doctr/models/classification/predictor/pytorch.py +28 -17
  35. doctr/models/classification/predictor/tensorflow.py +26 -16
  36. doctr/models/classification/resnet/tensorflow.py +21 -8
  37. doctr/models/classification/textnet/pytorch.py +3 -3
  38. doctr/models/classification/textnet/tensorflow.py +11 -5
  39. doctr/models/classification/vgg/tensorflow.py +9 -3
  40. doctr/models/classification/vit/tensorflow.py +10 -4
  41. doctr/models/classification/zoo.py +55 -19
  42. doctr/models/detection/_utils/__init__.py +1 -0
  43. doctr/models/detection/_utils/base.py +66 -0
  44. doctr/models/detection/differentiable_binarization/base.py +4 -3
  45. doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
  46. doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
  47. doctr/models/detection/fast/base.py +6 -5
  48. doctr/models/detection/fast/pytorch.py +4 -4
  49. doctr/models/detection/fast/tensorflow.py +15 -12
  50. doctr/models/detection/linknet/base.py +4 -3
  51. doctr/models/detection/linknet/tensorflow.py +23 -11
  52. doctr/models/detection/predictor/pytorch.py +15 -1
  53. doctr/models/detection/predictor/tensorflow.py +17 -3
  54. doctr/models/detection/zoo.py +7 -2
  55. doctr/models/factory/hub.py +8 -18
  56. doctr/models/kie_predictor/base.py +13 -3
  57. doctr/models/kie_predictor/pytorch.py +45 -20
  58. doctr/models/kie_predictor/tensorflow.py +44 -17
  59. doctr/models/modules/layers/pytorch.py +2 -3
  60. doctr/models/modules/layers/tensorflow.py +6 -8
  61. doctr/models/modules/transformer/pytorch.py +2 -2
  62. doctr/models/modules/transformer/tensorflow.py +0 -2
  63. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  64. doctr/models/modules/vision_transformer/tensorflow.py +1 -1
  65. doctr/models/predictor/base.py +97 -58
  66. doctr/models/predictor/pytorch.py +35 -20
  67. doctr/models/predictor/tensorflow.py +35 -18
  68. doctr/models/preprocessor/pytorch.py +4 -4
  69. doctr/models/preprocessor/tensorflow.py +3 -2
  70. doctr/models/recognition/crnn/tensorflow.py +8 -6
  71. doctr/models/recognition/master/pytorch.py +2 -2
  72. doctr/models/recognition/master/tensorflow.py +9 -4
  73. doctr/models/recognition/parseq/pytorch.py +4 -3
  74. doctr/models/recognition/parseq/tensorflow.py +14 -11
  75. doctr/models/recognition/sar/pytorch.py +7 -6
  76. doctr/models/recognition/sar/tensorflow.py +10 -12
  77. doctr/models/recognition/vitstr/pytorch.py +1 -1
  78. doctr/models/recognition/vitstr/tensorflow.py +9 -4
  79. doctr/models/recognition/zoo.py +1 -1
  80. doctr/models/utils/pytorch.py +1 -1
  81. doctr/models/utils/tensorflow.py +15 -15
  82. doctr/models/zoo.py +2 -2
  83. doctr/py.typed +0 -0
  84. doctr/transforms/functional/base.py +1 -1
  85. doctr/transforms/functional/pytorch.py +5 -5
  86. doctr/transforms/modules/base.py +37 -15
  87. doctr/transforms/modules/pytorch.py +73 -14
  88. doctr/transforms/modules/tensorflow.py +78 -19
  89. doctr/utils/fonts.py +7 -5
  90. doctr/utils/geometry.py +141 -31
  91. doctr/utils/metrics.py +34 -175
  92. doctr/utils/reconstitution.py +212 -0
  93. doctr/utils/visualization.py +5 -118
  94. doctr/version.py +1 -1
  95. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/METADATA +85 -81
  96. python_doctr-0.10.0.dist-info/RECORD +173 -0
  97. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
  98. doctr/models/artefacts/__init__.py +0 -2
  99. doctr/models/artefacts/barcode.py +0 -74
  100. doctr/models/artefacts/face.py +0 -63
  101. doctr/models/obj_detection/__init__.py +0 -1
  102. doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
  103. doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
  104. python_doctr-0.8.1.dist-info/RECORD +0 -173
  105. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
  106. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
  107. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
@@ -10,12 +10,17 @@ from typing import Any, Dict, List, Optional, Tuple
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
13
- from tensorflow import keras
14
- from tensorflow.keras import layers
13
+ from tensorflow.keras import Model, Sequential, layers, losses
15
14
  from tensorflow.keras.applications import ResNet50
16
15
 
17
16
  from doctr.file_utils import CLASS_NAME
18
- from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params
17
+ from doctr.models.utils import (
18
+ IntermediateLayerGetter,
19
+ _bf16_to_float32,
20
+ _build_model,
21
+ conv_sequence,
22
+ load_pretrained_params,
23
+ )
19
24
  from doctr.utils.repr import NestedObject
20
25
 
21
26
  from ...classification import mobilenet_v3_large
@@ -29,13 +34,13 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
29
34
  "mean": (0.798, 0.785, 0.772),
30
35
  "std": (0.264, 0.2749, 0.287),
31
36
  "input_shape": (1024, 1024, 3),
32
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-84171458.zip&src=0",
37
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_resnet50-649fa22b.weights.h5&src=0",
33
38
  },
34
39
  "db_mobilenet_v3_large": {
35
40
  "mean": (0.798, 0.785, 0.772),
36
41
  "std": (0.264, 0.2749, 0.287),
37
42
  "input_shape": (1024, 1024, 3),
38
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-da524564.zip&src=0",
43
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_mobilenet_v3_large-ee2e1dbe.weights.h5&src=0",
39
44
  },
40
45
  }
41
46
 
@@ -81,7 +86,7 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
81
86
  if dilation_factor > 1:
82
87
  _layers.append(layers.UpSampling2D(size=(dilation_factor, dilation_factor), interpolation="nearest"))
83
88
 
84
- module = keras.Sequential(_layers)
89
+ module = Sequential(_layers)
85
90
 
86
91
  return module
87
92
 
@@ -104,7 +109,7 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
104
109
  return layers.concatenate(results)
105
110
 
106
111
 
107
- class DBNet(_DBNet, keras.Model, NestedObject):
112
+ class DBNet(_DBNet, Model, NestedObject):
108
113
  """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
109
114
  <https://arxiv.org/pdf/1911.08947.pdf>`_.
110
115
 
@@ -147,14 +152,14 @@ class DBNet(_DBNet, keras.Model, NestedObject):
147
152
  _inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape]
148
153
  output_shape = tuple(self.fpn(_inputs).shape)
149
154
 
150
- self.probability_head = keras.Sequential([
155
+ self.probability_head = Sequential([
151
156
  *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
152
157
  layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
153
158
  layers.BatchNormalization(),
154
159
  layers.Activation("relu"),
155
160
  layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
156
161
  ])
157
- self.threshold_head = keras.Sequential([
162
+ self.threshold_head = Sequential([
158
163
  *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
159
164
  layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
160
165
  layers.BatchNormalization(),
@@ -206,7 +211,7 @@ class DBNet(_DBNet, keras.Model, NestedObject):
206
211
 
207
212
  # Focal loss
208
213
  focal_scale = 10.0
209
- bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
214
+ bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
210
215
 
211
216
  # Convert logits to prob, compute gamma factor
212
217
  p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map))
@@ -305,9 +310,16 @@ def _db_resnet(
305
310
 
306
311
  # Build the model
307
312
  model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
313
+ _build_model(model)
314
+
308
315
  # Load pretrained parameters
309
316
  if pretrained:
310
- load_pretrained_params(model, _cfg["url"])
317
+ # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
318
+ load_pretrained_params(
319
+ model,
320
+ _cfg["url"],
321
+ skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
322
+ )
311
323
 
312
324
  return model
313
325
 
@@ -326,6 +338,10 @@ def _db_mobilenet(
326
338
  # Patch the config
327
339
  _cfg = deepcopy(default_cfgs[arch])
328
340
  _cfg["input_shape"] = input_shape or _cfg["input_shape"]
341
+ if not kwargs.get("class_names", None):
342
+ kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME])
343
+ else:
344
+ kwargs["class_names"] = sorted(kwargs["class_names"])
329
345
 
330
346
  # Feature extractor
331
347
  feat_extractor = IntermediateLayerGetter(
@@ -339,9 +355,15 @@ def _db_mobilenet(
339
355
 
340
356
  # Build the model
341
357
  model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
358
+ _build_model(model)
342
359
  # Load pretrained parameters
343
360
  if pretrained:
344
- load_pretrained_params(model, _cfg["url"])
361
+ # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
362
+ load_pretrained_params(
363
+ model,
364
+ _cfg["url"],
365
+ skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
366
+ )
345
367
 
346
368
  return model
347
369
 
@@ -31,7 +31,7 @@ class FASTPostProcessor(DetectionPostProcessor):
31
31
 
32
32
  def __init__(
33
33
  self,
34
- bin_thresh: float = 0.3,
34
+ bin_thresh: float = 0.1,
35
35
  box_thresh: float = 0.1,
36
36
  assume_straight_pages: bool = True,
37
37
  ) -> None:
@@ -111,7 +111,7 @@ class FASTPostProcessor(DetectionPostProcessor):
111
111
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
112
112
  for contour in contours:
113
113
  # Check whether smallest enclosing bounding box is not too small
114
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
114
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
115
115
  continue
116
116
  # Compute objectness
117
117
  if self.assume_straight_pages:
@@ -138,10 +138,11 @@ class FASTPostProcessor(DetectionPostProcessor):
138
138
  # compute relative box to get rid of img shape
139
139
  _box[:, 0] /= width
140
140
  _box[:, 1] /= height
141
- boxes.append(_box)
141
+ # Add score to box as (0, score)
142
+ boxes.append(np.vstack([_box, np.array([0.0, score])]))
142
143
 
143
144
  if not self.assume_straight_pages:
144
- return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
145
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5, 2), dtype=pred.dtype)
145
146
  else:
146
147
  return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
147
148
 
@@ -153,7 +154,7 @@ class _FAST(BaseModel):
153
154
 
154
155
  min_size_box: int = 3
155
156
  assume_straight_pages: bool = True
156
- shrink_ratio = 0.1
157
+ shrink_ratio = 0.4
157
158
 
158
159
  def build_target(
159
160
  self,
@@ -26,19 +26,19 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
26
26
  "input_shape": (3, 1024, 1024),
27
27
  "mean": (0.798, 0.785, 0.772),
28
28
  "std": (0.264, 0.2749, 0.287),
29
- "url": None,
29
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_tiny-1acac421.pt&src=0",
30
30
  },
31
31
  "fast_small": {
32
32
  "input_shape": (3, 1024, 1024),
33
33
  "mean": (0.798, 0.785, 0.772),
34
34
  "std": (0.264, 0.2749, 0.287),
35
- "url": None,
35
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_small-10952cc1.pt&src=0",
36
36
  },
37
37
  "fast_base": {
38
38
  "input_shape": (3, 1024, 1024),
39
39
  "mean": (0.798, 0.785, 0.772),
40
40
  "std": (0.264, 0.2749, 0.287),
41
- "url": None,
41
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_base-688a8b34.pt&src=0",
42
42
  },
43
43
  }
44
44
 
@@ -119,7 +119,7 @@ class FAST(_FAST, nn.Module):
119
119
  def __init__(
120
120
  self,
121
121
  feat_extractor: IntermediateLayerGetter,
122
- bin_thresh: float = 0.3,
122
+ bin_thresh: float = 0.1,
123
123
  box_thresh: float = 0.1,
124
124
  dropout_prob: float = 0.1,
125
125
  pooling_size: int = 4, # different from paper performs better on close text-rich images
@@ -10,11 +10,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
13
- from tensorflow import keras
14
- from tensorflow.keras import Sequential, layers
13
+ from tensorflow.keras import Model, Sequential, layers
15
14
 
16
15
  from doctr.file_utils import CLASS_NAME
17
- from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params
16
+ from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, _build_model, load_pretrained_params
18
17
  from doctr.utils.repr import NestedObject
19
18
 
20
19
  from ...classification import textnet_base, textnet_small, textnet_tiny
@@ -29,19 +28,19 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
29
28
  "input_shape": (1024, 1024, 3),
30
29
  "mean": (0.798, 0.785, 0.772),
31
30
  "std": (0.264, 0.2749, 0.287),
32
- "url": None,
31
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_tiny-d7379d7b.weights.h5&src=0",
33
32
  },
34
33
  "fast_small": {
35
34
  "input_shape": (1024, 1024, 3),
36
35
  "mean": (0.798, 0.785, 0.772),
37
36
  "std": (0.264, 0.2749, 0.287),
38
- "url": None,
37
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_small-44b27eb6.weights.h5&src=0",
39
38
  },
40
39
  "fast_base": {
41
40
  "input_shape": (1024, 1024, 3),
42
41
  "mean": (0.798, 0.785, 0.772),
43
42
  "std": (0.264, 0.2749, 0.287),
44
- "url": None,
43
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_base-f2c6c736.weights.h5&src=0",
45
44
  },
46
45
  }
47
46
 
@@ -100,7 +99,7 @@ class FastHead(Sequential):
100
99
  super().__init__(_layers)
101
100
 
102
101
 
103
- class FAST(_FAST, keras.Model, NestedObject):
102
+ class FAST(_FAST, Model, NestedObject):
104
103
  """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
105
104
  <https://arxiv.org/pdf/2111.02394.pdf>`_.
106
105
 
@@ -122,7 +121,7 @@ class FAST(_FAST, keras.Model, NestedObject):
122
121
  def __init__(
123
122
  self,
124
123
  feature_extractor: IntermediateLayerGetter,
125
- bin_thresh: float = 0.3,
124
+ bin_thresh: float = 0.1,
126
125
  box_thresh: float = 0.1,
127
126
  dropout_prob: float = 0.1,
128
127
  pooling_size: int = 4, # different from paper performs better on close text-rich images
@@ -334,12 +333,16 @@ def _fast(
334
333
 
335
334
  # Build the model
336
335
  model = FAST(feat_extractor, cfg=_cfg, **kwargs)
336
+ _build_model(model)
337
+
337
338
  # Load pretrained parameters
338
339
  if pretrained:
339
- load_pretrained_params(model, _cfg["url"])
340
-
341
- # Build the model for reparameterization to access the layers
342
- _ = model(tf.random.uniform(shape=[1, *_cfg["input_shape"]], maxval=1, dtype=tf.float32), training=False)
340
+ # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
341
+ load_pretrained_params(
342
+ model,
343
+ _cfg["url"],
344
+ skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
345
+ )
343
346
 
344
347
  return model
345
348
 
@@ -111,7 +111,7 @@ class LinkNetPostProcessor(DetectionPostProcessor):
111
111
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
112
112
  for contour in contours:
113
113
  # Check whether smallest enclosing bounding box is not too small
114
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
114
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
115
115
  continue
116
116
  # Compute objectness
117
117
  if self.assume_straight_pages:
@@ -138,10 +138,11 @@ class LinkNetPostProcessor(DetectionPostProcessor):
138
138
  # compute relative box to get rid of img shape
139
139
  _box[:, 0] /= width
140
140
  _box[:, 1] /= height
141
- boxes.append(_box)
141
+ # Add score to box as (0, score)
142
+ boxes.append(np.vstack([_box, np.array([0.0, score])]))
142
143
 
143
144
  if not self.assume_straight_pages:
144
- return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
145
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5, 2), dtype=pred.dtype)
145
146
  else:
146
147
  return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
147
148
 
@@ -10,12 +10,17 @@ from typing import Any, Dict, List, Optional, Tuple
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
13
- from tensorflow import keras
14
- from tensorflow.keras import Model, Sequential, layers
13
+ from tensorflow.keras import Model, Sequential, layers, losses
15
14
 
16
15
  from doctr.file_utils import CLASS_NAME
17
16
  from doctr.models.classification import resnet18, resnet34, resnet50
18
- from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params
17
+ from doctr.models.utils import (
18
+ IntermediateLayerGetter,
19
+ _bf16_to_float32,
20
+ _build_model,
21
+ conv_sequence,
22
+ load_pretrained_params,
23
+ )
19
24
  from doctr.utils.repr import NestedObject
20
25
 
21
26
  from .base import LinkNetPostProcessor, _LinkNet
@@ -27,19 +32,19 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
27
32
  "mean": (0.798, 0.785, 0.772),
28
33
  "std": (0.264, 0.2749, 0.287),
29
34
  "input_shape": (1024, 1024, 3),
30
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet18-b9ee56e6.zip&src=0",
35
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet18-615a82c5.weights.h5&src=0",
31
36
  },
32
37
  "linknet_resnet34": {
33
38
  "mean": (0.798, 0.785, 0.772),
34
39
  "std": (0.264, 0.2749, 0.287),
35
40
  "input_shape": (1024, 1024, 3),
36
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet34-51909c56.zip&src=0",
41
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet34-9d772be5.weights.h5&src=0",
37
42
  },
38
43
  "linknet_resnet50": {
39
44
  "mean": (0.798, 0.785, 0.772),
40
45
  "std": (0.264, 0.2749, 0.287),
41
46
  "input_shape": (1024, 1024, 3),
42
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet50-ac9f3829.zip&src=0",
47
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet50-6bf6c8b5.weights.h5&src=0",
43
48
  },
44
49
  }
45
50
 
@@ -80,17 +85,17 @@ class LinkNetFPN(Model, NestedObject):
80
85
  for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1])
81
86
  ]
82
87
 
83
- def call(self, x: List[tf.Tensor]) -> tf.Tensor:
88
+ def call(self, x: List[tf.Tensor], **kwargs: Any) -> tf.Tensor:
84
89
  out = 0
85
90
  for decoder, fmap in zip(self.decoders, x[::-1]):
86
- out = decoder(out + fmap)
91
+ out = decoder(out + fmap, **kwargs)
87
92
  return out
88
93
 
89
94
  def extra_repr(self) -> str:
90
95
  return f"out_chans={self.out_chans}"
91
96
 
92
97
 
93
- class LinkNet(_LinkNet, keras.Model):
98
+ class LinkNet(_LinkNet, Model):
94
99
  """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
95
100
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
96
101
 
@@ -187,7 +192,7 @@ class LinkNet(_LinkNet, keras.Model):
187
192
  seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
188
193
  seg_mask = tf.cast(seg_mask, tf.float32)
189
194
 
190
- bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
195
+ bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
191
196
  proba_map = tf.sigmoid(out_map)
192
197
 
193
198
  # Focal loss
@@ -275,9 +280,16 @@ def _linknet(
275
280
 
276
281
  # Build the model
277
282
  model = LinkNet(feat_extractor, cfg=_cfg, **kwargs)
283
+ _build_model(model)
284
+
278
285
  # Load pretrained parameters
279
286
  if pretrained:
280
- load_pretrained_params(model, _cfg["url"])
287
+ # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
288
+ load_pretrained_params(
289
+ model,
290
+ _cfg["url"],
291
+ skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
292
+ )
281
293
 
282
294
  return model
283
295
 
@@ -9,6 +9,7 @@ import numpy as np
9
9
  import torch
10
10
  from torch import nn
11
11
 
12
+ from doctr.models.detection._utils import _remove_padding
12
13
  from doctr.models.preprocessor import PreProcessor
13
14
  from doctr.models.utils import set_device_and_dtype
14
15
 
@@ -40,6 +41,11 @@ class DetectionPredictor(nn.Module):
40
41
  return_maps: bool = False,
41
42
  **kwargs: Any,
42
43
  ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
44
+ # Extract parameters from the preprocessor
45
+ preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
46
+ symmetric_pad = self.pre_processor.resize.symmetric_pad
47
+ assume_straight_pages = self.model.assume_straight_pages
48
+
43
49
  # Dimension check
44
50
  if any(page.ndim != 3 for page in pages):
45
51
  raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
@@ -52,7 +58,15 @@ class DetectionPredictor(nn.Module):
52
58
  predicted_batches = [
53
59
  self.model(batch, return_preds=True, return_model_output=True, **kwargs) for batch in processed_batches
54
60
  ]
55
- preds = [pred for batch in predicted_batches for pred in batch["preds"]]
61
+ # Remove padding from loc predictions
62
+ preds = _remove_padding(
63
+ pages, # type: ignore[arg-type]
64
+ [pred for batch in predicted_batches for pred in batch["preds"]],
65
+ preserve_aspect_ratio=preserve_aspect_ratio,
66
+ symmetric_pad=symmetric_pad,
67
+ assume_straight_pages=assume_straight_pages,
68
+ )
69
+
56
70
  if return_maps:
57
71
  seg_maps = [
58
72
  pred.permute(1, 2, 0).detach().cpu().numpy() for batch in predicted_batches for pred in batch["out_map"]
@@ -7,8 +7,9 @@ from typing import Any, Dict, List, Tuple, Union
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
10
- from tensorflow import keras
10
+ from tensorflow.keras import Model
11
11
 
12
+ from doctr.models.detection._utils import _remove_padding
12
13
  from doctr.models.preprocessor import PreProcessor
13
14
  from doctr.utils.repr import NestedObject
14
15
 
@@ -29,7 +30,7 @@ class DetectionPredictor(NestedObject):
29
30
  def __init__(
30
31
  self,
31
32
  pre_processor: PreProcessor,
32
- model: keras.Model,
33
+ model: Model,
33
34
  ) -> None:
34
35
  self.pre_processor = pre_processor
35
36
  self.model = model
@@ -40,6 +41,11 @@ class DetectionPredictor(NestedObject):
40
41
  return_maps: bool = False,
41
42
  **kwargs: Any,
42
43
  ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
44
+ # Extract parameters from the preprocessor
45
+ preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
46
+ symmetric_pad = self.pre_processor.resize.symmetric_pad
47
+ assume_straight_pages = self.model.assume_straight_pages
48
+
43
49
  # Dimension check
44
50
  if any(page.ndim != 3 for page in pages):
45
51
  raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
@@ -50,7 +56,15 @@ class DetectionPredictor(NestedObject):
50
56
  for batch in processed_batches
51
57
  ]
52
58
 
53
- preds = [pred for batch in predicted_batches for pred in batch["preds"]]
59
+ # Remove padding from loc predictions
60
+ preds = _remove_padding(
61
+ pages,
62
+ [pred for batch in predicted_batches for pred in batch["preds"]],
63
+ preserve_aspect_ratio=preserve_aspect_ratio,
64
+ symmetric_pad=symmetric_pad,
65
+ assume_straight_pages=assume_straight_pages,
66
+ )
67
+
54
68
  if return_maps:
55
69
  seg_maps = [pred.numpy() for batch in predicted_batches for pred in batch["out_map"]]
56
70
  return preds, seg_maps
@@ -8,6 +8,7 @@ from typing import Any, List
8
8
  from doctr.file_utils import is_tf_available, is_torch_available
9
9
 
10
10
  from .. import detection
11
+ from ..detection.fast import reparameterize
11
12
  from ..preprocessor import PreProcessor
12
13
  from .predictor import DetectionPredictor
13
14
 
@@ -51,18 +52,22 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
51
52
  pretrained_backbone=kwargs.get("pretrained_backbone", True),
52
53
  assume_straight_pages=assume_straight_pages,
53
54
  )
55
+ # Reparameterize FAST models by default to lower inference latency and memory usage
56
+ if isinstance(_model, detection.FAST):
57
+ _model = reparameterize(_model)
54
58
  else:
55
59
  if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
56
60
  raise ValueError(f"unknown architecture: {type(arch)}")
57
61
 
58
62
  _model = arch
59
63
  _model.assume_straight_pages = assume_straight_pages
64
+ _model.postprocessor.assume_straight_pages = assume_straight_pages
60
65
 
61
66
  kwargs.pop("pretrained_backbone", None)
62
67
 
63
68
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
64
69
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
65
- kwargs["batch_size"] = kwargs.get("batch_size", 1)
70
+ kwargs["batch_size"] = kwargs.get("batch_size", 2)
66
71
  predictor = DetectionPredictor(
67
72
  PreProcessor(_model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:], **kwargs),
68
73
  _model,
@@ -71,7 +76,7 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
71
76
 
72
77
 
73
78
  def detection_predictor(
74
- arch: Any = "db_resnet50",
79
+ arch: Any = "fast_base",
75
80
  pretrained: bool = False,
76
81
  assume_straight_pages: bool = True,
77
82
  **kwargs: Any,
@@ -20,7 +20,6 @@ from huggingface_hub import (
20
20
  get_token_permission,
21
21
  hf_hub_download,
22
22
  login,
23
- snapshot_download,
24
23
  )
25
24
 
26
25
  from doctr import models
@@ -33,10 +32,9 @@ __all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config
33
32
 
34
33
 
35
34
  AVAILABLE_ARCHS = {
36
- "classification": models.classification.zoo.ARCHS,
35
+ "classification": models.classification.zoo.ARCHS + models.classification.zoo.ORIENTATION_ARCHS,
37
36
  "detection": models.detection.zoo.ARCHS,
38
37
  "recognition": models.recognition.zoo.ARCHS,
39
- "obj_detection": ["fasterrcnn_mobilenet_v3_large_fpn"] if is_torch_available() else None,
40
38
  }
41
39
 
42
40
 
@@ -75,7 +73,7 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
75
73
  weights_path = save_directory / "pytorch_model.bin"
76
74
  torch.save(model.state_dict(), weights_path)
77
75
  elif is_tf_available():
78
- weights_path = save_directory / "tf_model" / "weights"
76
+ weights_path = save_directory / "tf_model.weights.h5"
79
77
  model.save_weights(str(weights_path))
80
78
 
81
79
  config_path = save_directory / "config.json"
@@ -110,8 +108,8 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
110
108
 
111
109
  if run_config is None and arch is None:
112
110
  raise ValueError("run_config or arch must be specified")
113
- if task not in ["classification", "detection", "recognition", "obj_detection"]:
114
- raise ValueError("task must be one of classification, detection, recognition, obj_detection")
111
+ if task not in ["classification", "detection", "recognition"]:
112
+ raise ValueError("task must be one of classification, detection, recognition")
115
113
 
116
114
  # default readme
117
115
  readme = textwrap.dedent(
@@ -165,7 +163,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
165
163
  \n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}"""
166
164
  )
167
165
 
168
- if arch not in AVAILABLE_ARCHS[task]: # type: ignore
166
+ if arch not in AVAILABLE_ARCHS[task]:
169
167
  raise ValueError(
170
168
  f"Architecture: {arch} for task: {task} not found.\
171
169
  \nAvailable architectures: {AVAILABLE_ARCHS}"
@@ -175,7 +173,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
175
173
 
176
174
  local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name)
177
175
  repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=False)
178
- repo = Repository(local_dir=local_cache_dir, clone_from=repo_url, use_auth_token=True)
176
+ repo = Repository(local_dir=local_cache_dir, clone_from=repo_url)
179
177
 
180
178
  with repo.commit(commit_message):
181
179
  _save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task)
@@ -217,14 +215,6 @@ def from_hub(repo_id: str, **kwargs: Any):
217
215
  model = models.detection.__dict__[arch](pretrained=False)
218
216
  elif task == "recognition":
219
217
  model = models.recognition.__dict__[arch](pretrained=False, input_shape=cfg["input_shape"], vocab=cfg["vocab"])
220
- elif task == "obj_detection" and is_torch_available():
221
- model = models.obj_detection.__dict__[arch](
222
- pretrained=False,
223
- image_mean=cfg["mean"],
224
- image_std=cfg["std"],
225
- max_size=cfg["input_shape"][-1],
226
- num_classes=len(cfg["classes"]),
227
- )
228
218
 
229
219
  # update model cfg
230
220
  model.cfg = cfg
@@ -234,7 +224,7 @@ def from_hub(repo_id: str, **kwargs: Any):
234
224
  state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu")
235
225
  model.load_state_dict(state_dict)
236
226
  else: # tf
237
- repo_path = snapshot_download(repo_id, **kwargs)
238
- model.load_weights(os.path.join(repo_path, "tf_model", "weights"))
227
+ weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
228
+ model.load_weights(weights)
239
229
 
240
230
  return model
@@ -7,7 +7,7 @@ from typing import Any, Optional
7
7
 
8
8
  from doctr.models.builder import KIEDocumentBuilder
9
9
 
10
- from ..classification.predictor import CropOrientationPredictor
10
+ from ..classification.predictor import OrientationPredictor
11
11
  from ..predictor.base import _OCRPredictor
12
12
 
13
13
  __all__ = ["_KIEPredictor"]
@@ -25,10 +25,13 @@ class _KIEPredictor(_OCRPredictor):
25
25
  accordingly. Doing so will improve performances for documents with page-uniform rotations.
26
26
  preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
27
27
  symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
28
+ detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
29
+ page. Doing so will slightly deteriorate the overall latency.
28
30
  kwargs: keyword args of `DocumentBuilder`
29
31
  """
30
32
 
31
- crop_orientation_predictor: Optional[CropOrientationPredictor]
33
+ crop_orientation_predictor: Optional[OrientationPredictor]
34
+ page_orientation_predictor: Optional[OrientationPredictor]
32
35
 
33
36
  def __init__(
34
37
  self,
@@ -36,8 +39,15 @@ class _KIEPredictor(_OCRPredictor):
36
39
  straighten_pages: bool = False,
37
40
  preserve_aspect_ratio: bool = True,
38
41
  symmetric_pad: bool = True,
42
+ detect_orientation: bool = False,
39
43
  **kwargs: Any,
40
44
  ) -> None:
41
- super().__init__(assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs)
45
+ super().__init__(
46
+ assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, detect_orientation, **kwargs
47
+ )
48
+
49
+ # Remove the following arguments from kwargs after initialization of the parent class
50
+ kwargs.pop("disable_page_orientation", None)
51
+ kwargs.pop("disable_crop_orientation", None)
42
52
 
43
53
  self.doc_builder: KIEDocumentBuilder = KIEDocumentBuilder(**kwargs)