python-doctr 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. doctr/datasets/cord.py +10 -1
  2. doctr/datasets/funsd.py +11 -1
  3. doctr/datasets/ic03.py +11 -1
  4. doctr/datasets/ic13.py +10 -1
  5. doctr/datasets/iiit5k.py +26 -16
  6. doctr/datasets/imgur5k.py +10 -1
  7. doctr/datasets/sroie.py +11 -1
  8. doctr/datasets/svhn.py +11 -1
  9. doctr/datasets/svt.py +11 -1
  10. doctr/datasets/synthtext.py +11 -1
  11. doctr/datasets/utils.py +7 -2
  12. doctr/datasets/vocabs.py +6 -2
  13. doctr/datasets/wildreceipt.py +12 -1
  14. doctr/file_utils.py +19 -0
  15. doctr/io/elements.py +12 -4
  16. doctr/models/builder.py +2 -2
  17. doctr/models/classification/magc_resnet/tensorflow.py +13 -6
  18. doctr/models/classification/mobilenet/pytorch.py +2 -0
  19. doctr/models/classification/mobilenet/tensorflow.py +14 -8
  20. doctr/models/classification/predictor/pytorch.py +11 -7
  21. doctr/models/classification/predictor/tensorflow.py +10 -6
  22. doctr/models/classification/resnet/tensorflow.py +21 -8
  23. doctr/models/classification/textnet/tensorflow.py +11 -5
  24. doctr/models/classification/vgg/tensorflow.py +9 -3
  25. doctr/models/classification/vit/tensorflow.py +10 -4
  26. doctr/models/classification/zoo.py +22 -10
  27. doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
  28. doctr/models/detection/fast/tensorflow.py +14 -11
  29. doctr/models/detection/linknet/tensorflow.py +23 -11
  30. doctr/models/detection/predictor/tensorflow.py +2 -2
  31. doctr/models/factory/hub.py +5 -6
  32. doctr/models/kie_predictor/base.py +4 -0
  33. doctr/models/kie_predictor/pytorch.py +4 -0
  34. doctr/models/kie_predictor/tensorflow.py +8 -1
  35. doctr/models/modules/transformer/tensorflow.py +0 -2
  36. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  37. doctr/models/modules/vision_transformer/tensorflow.py +1 -1
  38. doctr/models/predictor/base.py +24 -12
  39. doctr/models/predictor/pytorch.py +4 -0
  40. doctr/models/predictor/tensorflow.py +8 -1
  41. doctr/models/preprocessor/tensorflow.py +1 -1
  42. doctr/models/recognition/crnn/tensorflow.py +8 -6
  43. doctr/models/recognition/master/tensorflow.py +9 -4
  44. doctr/models/recognition/parseq/tensorflow.py +10 -8
  45. doctr/models/recognition/sar/tensorflow.py +7 -3
  46. doctr/models/recognition/vitstr/tensorflow.py +9 -4
  47. doctr/models/utils/pytorch.py +1 -1
  48. doctr/models/utils/tensorflow.py +15 -15
  49. doctr/transforms/functional/pytorch.py +1 -1
  50. doctr/transforms/modules/pytorch.py +7 -6
  51. doctr/transforms/modules/tensorflow.py +15 -12
  52. doctr/utils/geometry.py +106 -19
  53. doctr/utils/metrics.py +1 -1
  54. doctr/utils/reconstitution.py +151 -65
  55. doctr/version.py +1 -1
  56. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/METADATA +11 -11
  57. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/RECORD +61 -61
  58. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
  59. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
  60. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
  61. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
@@ -9,12 +9,12 @@ from functools import partial
9
9
  from typing import Any, Dict, List, Optional, Tuple
10
10
 
11
11
  import tensorflow as tf
12
- from tensorflow.keras import layers
12
+ from tensorflow.keras import activations, layers
13
13
  from tensorflow.keras.models import Sequential
14
14
 
15
15
  from doctr.datasets import VOCABS
16
16
 
17
- from ...utils import load_pretrained_params
17
+ from ...utils import _build_model, load_pretrained_params
18
18
  from ..resnet.tensorflow import ResNet
19
19
 
20
20
  __all__ = ["magc_resnet31"]
@@ -26,7 +26,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
26
26
  "std": (0.299, 0.296, 0.301),
27
27
  "input_shape": (32, 32, 3),
28
28
  "classes": list(VOCABS["french"]),
29
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/magc_resnet31-addbb705.zip&src=0",
29
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/magc_resnet31-16aa7d71.weights.h5&src=0",
30
30
  },
31
31
  }
32
32
 
@@ -57,6 +57,7 @@ class MAGC(layers.Layer):
57
57
  self.headers = headers # h
58
58
  self.inplanes = inplanes # C
59
59
  self.attn_scale = attn_scale
60
+ self.ratio = ratio
60
61
  self.planes = int(inplanes * ratio)
61
62
 
62
63
  self.single_header_inplanes = int(inplanes / headers) # C / h
@@ -97,7 +98,7 @@ class MAGC(layers.Layer):
97
98
  if self.attn_scale and self.headers > 1:
98
99
  context_mask = context_mask / math.sqrt(self.single_header_inplanes)
99
100
  # B*h, 1, H*W, 1
100
- context_mask = tf.keras.activations.softmax(context_mask, axis=2)
101
+ context_mask = activations.softmax(context_mask, axis=2)
101
102
 
102
103
  # Compute context
103
104
  # B*h, 1, C/h, 1
@@ -114,7 +115,7 @@ class MAGC(layers.Layer):
114
115
  # Context modeling: B, H, W, C -> B, 1, 1, C
115
116
  context = self.context_modeling(inputs)
116
117
  # Transform: B, 1, 1, C -> B, 1, 1, C
117
- transformed = self.transform(context)
118
+ transformed = self.transform(context, **kwargs)
118
119
  return inputs + transformed
119
120
 
120
121
 
@@ -151,9 +152,15 @@ def _magc_resnet(
151
152
  cfg=_cfg,
152
153
  **kwargs,
153
154
  )
155
+ _build_model(model)
156
+
154
157
  # Load pretrained parameters
155
158
  if pretrained:
156
- load_pretrained_params(model, default_cfgs[arch]["url"])
159
+ # The number of classes is not the same as the number of classes in the pretrained model =>
160
+ # skip the mismatching layers for fine tuning
161
+ load_pretrained_params(
162
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
163
+ )
157
164
 
158
165
  return model
159
166
 
@@ -9,12 +9,14 @@ from copy import deepcopy
9
9
  from typing import Any, Dict, List, Optional
10
10
 
11
11
  from torchvision.models import mobilenetv3
12
+ from torchvision.models.mobilenetv3 import MobileNetV3
12
13
 
13
14
  from doctr.datasets import VOCABS
14
15
 
15
16
  from ...utils import load_pretrained_params
16
17
 
17
18
  __all__ = [
19
+ "MobileNetV3",
18
20
  "mobilenet_v3_small",
19
21
  "mobilenet_v3_small_r",
20
22
  "mobilenet_v3_large",
@@ -13,7 +13,7 @@ from tensorflow.keras import layers
13
13
  from tensorflow.keras.models import Sequential
14
14
 
15
15
  from ....datasets import VOCABS
16
- from ...utils import conv_sequence, load_pretrained_params
16
+ from ...utils import _build_model, conv_sequence, load_pretrained_params
17
17
 
18
18
  __all__ = [
19
19
  "MobileNetV3",
@@ -32,42 +32,42 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
32
32
  "std": (0.299, 0.296, 0.301),
33
33
  "input_shape": (32, 32, 3),
34
34
  "classes": list(VOCABS["french"]),
35
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large-47d25d7e.zip&src=0",
35
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large-d857506e.weights.h5&src=0",
36
36
  },
37
37
  "mobilenet_v3_large_r": {
38
38
  "mean": (0.694, 0.695, 0.693),
39
39
  "std": (0.299, 0.296, 0.301),
40
40
  "input_shape": (32, 32, 3),
41
41
  "classes": list(VOCABS["french"]),
42
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large_r-a108e192.zip&src=0",
42
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large_r-eef2e3c6.weights.h5&src=0",
43
43
  },
44
44
  "mobilenet_v3_small": {
45
45
  "mean": (0.694, 0.695, 0.693),
46
46
  "std": (0.299, 0.296, 0.301),
47
47
  "input_shape": (32, 32, 3),
48
48
  "classes": list(VOCABS["french"]),
49
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small-8a32c32c.zip&src=0",
49
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0",
50
50
  },
51
51
  "mobilenet_v3_small_r": {
52
52
  "mean": (0.694, 0.695, 0.693),
53
53
  "std": (0.299, 0.296, 0.301),
54
54
  "input_shape": (32, 32, 3),
55
55
  "classes": list(VOCABS["french"]),
56
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-3d61452e.zip&src=0",
56
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_r-dd50218d.weights.h5&src=0",
57
57
  },
58
58
  "mobilenet_v3_small_crop_orientation": {
59
59
  "mean": (0.694, 0.695, 0.693),
60
60
  "std": (0.299, 0.296, 0.301),
61
61
  "input_shape": (128, 128, 3),
62
62
  "classes": [0, -90, 180, 90],
63
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-1ea8db03.zip&src=0",
63
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5&src=0",
64
64
  },
65
65
  "mobilenet_v3_small_page_orientation": {
66
66
  "mean": (0.694, 0.695, 0.693),
67
67
  "std": (0.299, 0.296, 0.301),
68
68
  "input_shape": (512, 512, 3),
69
69
  "classes": [0, -90, 180, 90],
70
- "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_page_orientation-aec9553e.zip&src=0",
70
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5&src=0",
71
71
  },
72
72
  }
73
73
 
@@ -295,9 +295,15 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa
295
295
  cfg=_cfg,
296
296
  **kwargs,
297
297
  )
298
+ _build_model(model)
299
+
298
300
  # Load pretrained parameters
299
301
  if pretrained:
300
- load_pretrained_params(model, default_cfgs[arch]["url"])
302
+ # The number of classes is not the same as the number of classes in the pretrained model =>
303
+ # skip the mismatching layers for fine tuning
304
+ load_pretrained_params(
305
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
306
+ )
301
307
 
302
308
  return model
303
309
 
@@ -3,7 +3,7 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Union
6
+ from typing import List, Optional, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -27,12 +27,12 @@ class OrientationPredictor(nn.Module):
27
27
 
28
28
  def __init__(
29
29
  self,
30
- pre_processor: PreProcessor,
31
- model: nn.Module,
30
+ pre_processor: Optional[PreProcessor],
31
+ model: Optional[nn.Module],
32
32
  ) -> None:
33
33
  super().__init__()
34
- self.pre_processor = pre_processor
35
- self.model = model.eval()
34
+ self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
35
+ self.model = model.eval() if isinstance(model, nn.Module) else None
36
36
 
37
37
  @torch.inference_mode()
38
38
  def forward(
@@ -43,12 +43,16 @@ class OrientationPredictor(nn.Module):
43
43
  if any(input.ndim != 3 for input in inputs):
44
44
  raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
45
45
 
46
+ if self.model is None or self.pre_processor is None:
47
+ # predictor is disabled
48
+ return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)]
49
+
46
50
  processed_batches = self.pre_processor(inputs)
47
51
  _params = next(self.model.parameters())
48
52
  self.model, processed_batches = set_device_and_dtype(
49
53
  self.model, processed_batches, _params.device, _params.dtype
50
54
  )
51
- predicted_batches = [self.model(batch) for batch in processed_batches]
55
+ predicted_batches = [self.model(batch) for batch in processed_batches] # type: ignore[misc]
52
56
  # confidence
53
57
  probs = [
54
58
  torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches
@@ -57,7 +61,7 @@ class OrientationPredictor(nn.Module):
57
61
  predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches]
58
62
 
59
63
  class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
60
- classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs]
64
+ classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] # type: ignore[union-attr]
61
65
  confs = [round(float(p), 2) for prob in probs for p in prob]
62
66
 
63
67
  return [class_idxs, classes, confs]
@@ -3,11 +3,11 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Union
6
+ from typing import List, Optional, Union
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
10
- from tensorflow import keras
10
+ from tensorflow.keras import Model
11
11
 
12
12
  from doctr.models.preprocessor import PreProcessor
13
13
  from doctr.utils.repr import NestedObject
@@ -29,11 +29,11 @@ class OrientationPredictor(NestedObject):
29
29
 
30
30
  def __init__(
31
31
  self,
32
- pre_processor: PreProcessor,
33
- model: keras.Model,
32
+ pre_processor: Optional[PreProcessor],
33
+ model: Optional[Model],
34
34
  ) -> None:
35
- self.pre_processor = pre_processor
36
- self.model = model
35
+ self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
36
+ self.model = model if isinstance(model, Model) else None
37
37
 
38
38
  def __call__(
39
39
  self,
@@ -43,6 +43,10 @@ class OrientationPredictor(NestedObject):
43
43
  if any(input.ndim != 3 for input in inputs):
44
44
  raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
45
45
 
46
+ if self.model is None or self.pre_processor is None:
47
+ # predictor is disabled
48
+ return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)]
49
+
46
50
  processed_batches = self.pre_processor(inputs)
47
51
  predicted_batches = [self.model(batch, training=False) for batch in processed_batches]
48
52
 
@@ -13,7 +13,7 @@ from tensorflow.keras.models import Sequential
13
13
 
14
14
  from doctr.datasets import VOCABS
15
15
 
16
- from ...utils import conv_sequence, load_pretrained_params
16
+ from ...utils import _build_model, conv_sequence, load_pretrained_params
17
17
 
18
18
  __all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"]
19
19
 
@@ -24,35 +24,35 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
24
24
  "std": (0.299, 0.296, 0.301),
25
25
  "input_shape": (32, 32, 3),
26
26
  "classes": list(VOCABS["french"]),
27
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/resnet18-d4634669.zip&src=0",
27
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0",
28
28
  },
29
29
  "resnet31": {
30
30
  "mean": (0.694, 0.695, 0.693),
31
31
  "std": (0.299, 0.296, 0.301),
32
32
  "input_shape": (32, 32, 3),
33
33
  "classes": list(VOCABS["french"]),
34
- "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet31-5a47a60b.zip&src=0",
34
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0",
35
35
  },
36
36
  "resnet34": {
37
37
  "mean": (0.694, 0.695, 0.693),
38
38
  "std": (0.299, 0.296, 0.301),
39
39
  "input_shape": (32, 32, 3),
40
40
  "classes": list(VOCABS["french"]),
41
- "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34-5dcc97ca.zip&src=0",
41
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0",
42
42
  },
43
43
  "resnet50": {
44
44
  "mean": (0.694, 0.695, 0.693),
45
45
  "std": (0.299, 0.296, 0.301),
46
46
  "input_shape": (32, 32, 3),
47
47
  "classes": list(VOCABS["french"]),
48
- "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet50-e75e4cdf.zip&src=0",
48
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0",
49
49
  },
50
50
  "resnet34_wide": {
51
51
  "mean": (0.694, 0.695, 0.693),
52
52
  "std": (0.299, 0.296, 0.301),
53
53
  "input_shape": (32, 32, 3),
54
54
  "classes": list(VOCABS["french"]),
55
- "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34_wide-c1271816.zip&src=0",
55
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0",
56
56
  },
57
57
  }
58
58
 
@@ -210,9 +210,15 @@ def _resnet(
210
210
  model = ResNet(
211
211
  num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, cfg=_cfg, **kwargs
212
212
  )
213
+ _build_model(model)
214
+
213
215
  # Load pretrained parameters
214
216
  if pretrained:
215
- load_pretrained_params(model, default_cfgs[arch]["url"])
217
+ # The number of classes is not the same as the number of classes in the pretrained model =>
218
+ # skip the mismatching layers for fine tuning
219
+ load_pretrained_params(
220
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
221
+ )
216
222
 
217
223
  return model
218
224
 
@@ -354,10 +360,17 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
354
360
  )
355
361
 
356
362
  model.cfg = _cfg
363
+ _build_model(model)
357
364
 
358
365
  # Load pretrained parameters
359
366
  if pretrained:
360
- load_pretrained_params(model, default_cfgs["resnet50"]["url"])
367
+ # The number of classes is not the same as the number of classes in the pretrained model =>
368
+ # skip the mismatching layers for fine tuning
369
+ load_pretrained_params(
370
+ model,
371
+ default_cfgs["resnet50"]["url"],
372
+ skip_mismatch=kwargs["num_classes"] != len(default_cfgs["resnet50"]["classes"]),
373
+ )
361
374
 
362
375
  return model
363
376
 
@@ -12,7 +12,7 @@ from tensorflow.keras import Sequential, layers
12
12
  from doctr.datasets import VOCABS
13
13
 
14
14
  from ...modules.layers.tensorflow import FASTConvLayer
15
- from ...utils import conv_sequence, load_pretrained_params
15
+ from ...utils import _build_model, conv_sequence, load_pretrained_params
16
16
 
17
17
  __all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
18
18
 
@@ -22,21 +22,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
22
22
  "std": (0.299, 0.296, 0.301),
23
23
  "input_shape": (32, 32, 3),
24
24
  "classes": list(VOCABS["french"]),
25
- "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-fe9cc245.zip&src=0",
25
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0",
26
26
  },
27
27
  "textnet_small": {
28
28
  "mean": (0.694, 0.695, 0.693),
29
29
  "std": (0.299, 0.296, 0.301),
30
30
  "input_shape": (32, 32, 3),
31
31
  "classes": list(VOCABS["french"]),
32
- "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-29c39c82.zip&src=0",
32
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0",
33
33
  },
34
34
  "textnet_base": {
35
35
  "mean": (0.694, 0.695, 0.693),
36
36
  "std": (0.299, 0.296, 0.301),
37
37
  "input_shape": (32, 32, 3),
38
38
  "classes": list(VOCABS["french"]),
39
- "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-168aa82c.zip&src=0",
39
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0",
40
40
  },
41
41
  }
42
42
 
@@ -111,9 +111,15 @@ def _textnet(
111
111
 
112
112
  # Build the model
113
113
  model = TextNet(cfg=_cfg, **kwargs)
114
+ _build_model(model)
115
+
114
116
  # Load pretrained parameters
115
117
  if pretrained:
116
- load_pretrained_params(model, default_cfgs[arch]["url"])
118
+ # The number of classes is not the same as the number of classes in the pretrained model =>
119
+ # skip the mismatching layers for fine tuning
120
+ load_pretrained_params(
121
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
122
+ )
117
123
 
118
124
  return model
119
125
 
@@ -11,7 +11,7 @@ from tensorflow.keras.models import Sequential
11
11
 
12
12
  from doctr.datasets import VOCABS
13
13
 
14
- from ...utils import conv_sequence, load_pretrained_params
14
+ from ...utils import _build_model, conv_sequence, load_pretrained_params
15
15
 
16
16
  __all__ = ["VGG", "vgg16_bn_r"]
17
17
 
@@ -22,7 +22,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
22
22
  "std": (1.0, 1.0, 1.0),
23
23
  "input_shape": (32, 32, 3),
24
24
  "classes": list(VOCABS["french"]),
25
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/vgg16_bn_r-c5836cea.zip&src=0",
25
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0",
26
26
  },
27
27
  }
28
28
 
@@ -81,9 +81,15 @@ def _vgg(
81
81
 
82
82
  # Build the model
83
83
  model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs)
84
+ _build_model(model)
85
+
84
86
  # Load pretrained parameters
85
87
  if pretrained:
86
- load_pretrained_params(model, default_cfgs[arch]["url"])
88
+ # The number of classes is not the same as the number of classes in the pretrained model =>
89
+ # skip the mismatching layers for fine tuning
90
+ load_pretrained_params(
91
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
92
+ )
87
93
 
88
94
  return model
89
95
 
@@ -14,7 +14,7 @@ from doctr.models.modules.transformer import EncoderBlock
14
14
  from doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding
15
15
  from doctr.utils.repr import NestedObject
16
16
 
17
- from ...utils import load_pretrained_params
17
+ from ...utils import _build_model, load_pretrained_params
18
18
 
19
19
  __all__ = ["vit_s", "vit_b"]
20
20
 
@@ -25,14 +25,14 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
25
25
  "std": (0.299, 0.296, 0.301),
26
26
  "input_shape": (3, 32, 32),
27
27
  "classes": list(VOCABS["french"]),
28
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_s-6300fcc9.zip&src=0",
28
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_s-69bc459e.weights.h5&src=0",
29
29
  },
30
30
  "vit_b": {
31
31
  "mean": (0.694, 0.695, 0.693),
32
32
  "std": (0.299, 0.296, 0.301),
33
33
  "input_shape": (32, 32, 3),
34
34
  "classes": list(VOCABS["french"]),
35
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_b-57158446.zip&src=0",
35
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_b-c64705bd.weights.h5&src=0",
36
36
  },
37
37
  }
38
38
 
@@ -121,9 +121,15 @@ def _vit(
121
121
 
122
122
  # Build the model
123
123
  model = VisionTransformer(cfg=_cfg, **kwargs)
124
+ _build_model(model)
125
+
124
126
  # Load pretrained parameters
125
127
  if pretrained:
126
- load_pretrained_params(model, default_cfgs[arch]["url"])
128
+ # The number of classes is not the same as the number of classes in the pretrained model =>
129
+ # skip the mismatching layers for fine tuning
130
+ load_pretrained_params(
131
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
132
+ )
127
133
 
128
134
  return model
129
135
 
@@ -34,15 +34,27 @@ ARCHS: List[str] = [
34
34
  ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
35
35
 
36
36
 
37
- def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> OrientationPredictor:
38
- if arch not in ORIENTATION_ARCHS:
39
- raise ValueError(f"unknown architecture '{arch}'")
37
+ def _orientation_predictor(
38
+ arch: Any, pretrained: bool, model_type: str, disabled: bool = False, **kwargs: Any
39
+ ) -> OrientationPredictor:
40
+ if disabled:
41
+ # Case where the orientation predictor is disabled
42
+ return OrientationPredictor(None, None)
43
+
44
+ if isinstance(arch, str):
45
+ if arch not in ORIENTATION_ARCHS:
46
+ raise ValueError(f"unknown architecture '{arch}'")
47
+
48
+ # Load directly classifier from backbone
49
+ _model = classification.__dict__[arch](pretrained=pretrained)
50
+ else:
51
+ if not isinstance(arch, classification.MobileNetV3):
52
+ raise ValueError(f"unknown architecture: {type(arch)}")
53
+ _model = arch
40
54
 
41
- # Load directly classifier from backbone
42
- _model = classification.__dict__[arch](pretrained=pretrained)
43
55
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
44
56
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
45
- kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
57
+ kwargs["batch_size"] = kwargs.get("batch_size", 128 if model_type == "crop" else 4)
46
58
  input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]
47
59
  predictor = OrientationPredictor(
48
60
  PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
@@ -51,7 +63,7 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient
51
63
 
52
64
 
53
65
  def crop_orientation_predictor(
54
- arch: str = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
66
+ arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
55
67
  ) -> OrientationPredictor:
56
68
  """Crop orientation classification architecture.
57
69
 
@@ -71,11 +83,11 @@ def crop_orientation_predictor(
71
83
  -------
72
84
  OrientationPredictor
73
85
  """
74
- return _orientation_predictor(arch, pretrained, **kwargs)
86
+ return _orientation_predictor(arch, pretrained, model_type="crop", **kwargs)
75
87
 
76
88
 
77
89
  def page_orientation_predictor(
78
- arch: str = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
90
+ arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
79
91
  ) -> OrientationPredictor:
80
92
  """Page orientation classification architecture.
81
93
 
@@ -95,4 +107,4 @@ def page_orientation_predictor(
95
107
  -------
96
108
  OrientationPredictor
97
109
  """
98
- return _orientation_predictor(arch, pretrained, **kwargs)
110
+ return _orientation_predictor(arch, pretrained, model_type="page", **kwargs)
@@ -10,12 +10,17 @@ from typing import Any, Dict, List, Optional, Tuple
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
13
- from tensorflow import keras
14
- from tensorflow.keras import layers
13
+ from tensorflow.keras import Model, Sequential, layers, losses
15
14
  from tensorflow.keras.applications import ResNet50
16
15
 
17
16
  from doctr.file_utils import CLASS_NAME
18
- from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params
17
+ from doctr.models.utils import (
18
+ IntermediateLayerGetter,
19
+ _bf16_to_float32,
20
+ _build_model,
21
+ conv_sequence,
22
+ load_pretrained_params,
23
+ )
19
24
  from doctr.utils.repr import NestedObject
20
25
 
21
26
  from ...classification import mobilenet_v3_large
@@ -29,13 +34,13 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
29
34
  "mean": (0.798, 0.785, 0.772),
30
35
  "std": (0.264, 0.2749, 0.287),
31
36
  "input_shape": (1024, 1024, 3),
32
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-84171458.zip&src=0",
37
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_resnet50-649fa22b.weights.h5&src=0",
33
38
  },
34
39
  "db_mobilenet_v3_large": {
35
40
  "mean": (0.798, 0.785, 0.772),
36
41
  "std": (0.264, 0.2749, 0.287),
37
42
  "input_shape": (1024, 1024, 3),
38
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-da524564.zip&src=0",
43
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_mobilenet_v3_large-ee2e1dbe.weights.h5&src=0",
39
44
  },
40
45
  }
41
46
 
@@ -81,7 +86,7 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
81
86
  if dilation_factor > 1:
82
87
  _layers.append(layers.UpSampling2D(size=(dilation_factor, dilation_factor), interpolation="nearest"))
83
88
 
84
- module = keras.Sequential(_layers)
89
+ module = Sequential(_layers)
85
90
 
86
91
  return module
87
92
 
@@ -104,7 +109,7 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
104
109
  return layers.concatenate(results)
105
110
 
106
111
 
107
- class DBNet(_DBNet, keras.Model, NestedObject):
112
+ class DBNet(_DBNet, Model, NestedObject):
108
113
  """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
109
114
  <https://arxiv.org/pdf/1911.08947.pdf>`_.
110
115
 
@@ -147,14 +152,14 @@ class DBNet(_DBNet, keras.Model, NestedObject):
147
152
  _inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape]
148
153
  output_shape = tuple(self.fpn(_inputs).shape)
149
154
 
150
- self.probability_head = keras.Sequential([
155
+ self.probability_head = Sequential([
151
156
  *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
152
157
  layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
153
158
  layers.BatchNormalization(),
154
159
  layers.Activation("relu"),
155
160
  layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
156
161
  ])
157
- self.threshold_head = keras.Sequential([
162
+ self.threshold_head = Sequential([
158
163
  *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
159
164
  layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
160
165
  layers.BatchNormalization(),
@@ -206,7 +211,7 @@ class DBNet(_DBNet, keras.Model, NestedObject):
206
211
 
207
212
  # Focal loss
208
213
  focal_scale = 10.0
209
- bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
214
+ bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
210
215
 
211
216
  # Convert logits to prob, compute gamma factor
212
217
  p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map))
@@ -305,9 +310,16 @@ def _db_resnet(
305
310
 
306
311
  # Build the model
307
312
  model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
313
+ _build_model(model)
314
+
308
315
  # Load pretrained parameters
309
316
  if pretrained:
310
- load_pretrained_params(model, _cfg["url"])
317
+ # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
318
+ load_pretrained_params(
319
+ model,
320
+ _cfg["url"],
321
+ skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
322
+ )
311
323
 
312
324
  return model
313
325
 
@@ -326,6 +338,10 @@ def _db_mobilenet(
326
338
  # Patch the config
327
339
  _cfg = deepcopy(default_cfgs[arch])
328
340
  _cfg["input_shape"] = input_shape or _cfg["input_shape"]
341
+ if not kwargs.get("class_names", None):
342
+ kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME])
343
+ else:
344
+ kwargs["class_names"] = sorted(kwargs["class_names"])
329
345
 
330
346
  # Feature extractor
331
347
  feat_extractor = IntermediateLayerGetter(
@@ -339,9 +355,15 @@ def _db_mobilenet(
339
355
 
340
356
  # Build the model
341
357
  model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
358
+ _build_model(model)
342
359
  # Load pretrained parameters
343
360
  if pretrained:
344
- load_pretrained_params(model, _cfg["url"])
361
+ # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
362
+ load_pretrained_params(
363
+ model,
364
+ _cfg["url"],
365
+ skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
366
+ )
345
367
 
346
368
  return model
347
369