python-doctr 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. doctr/__init__.py +1 -1
  2. doctr/contrib/__init__.py +0 -0
  3. doctr/contrib/artefacts.py +131 -0
  4. doctr/contrib/base.py +105 -0
  5. doctr/datasets/cord.py +10 -1
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +11 -1
  8. doctr/datasets/generator/base.py +6 -5
  9. doctr/datasets/ic03.py +11 -1
  10. doctr/datasets/ic13.py +10 -1
  11. doctr/datasets/iiit5k.py +26 -16
  12. doctr/datasets/imgur5k.py +11 -2
  13. doctr/datasets/loader.py +1 -6
  14. doctr/datasets/sroie.py +11 -1
  15. doctr/datasets/svhn.py +11 -1
  16. doctr/datasets/svt.py +11 -1
  17. doctr/datasets/synthtext.py +11 -1
  18. doctr/datasets/utils.py +9 -3
  19. doctr/datasets/vocabs.py +15 -4
  20. doctr/datasets/wildreceipt.py +12 -1
  21. doctr/file_utils.py +45 -12
  22. doctr/io/elements.py +52 -10
  23. doctr/io/html.py +2 -2
  24. doctr/io/image/pytorch.py +6 -8
  25. doctr/io/image/tensorflow.py +1 -1
  26. doctr/io/pdf.py +5 -2
  27. doctr/io/reader.py +6 -0
  28. doctr/models/__init__.py +0 -1
  29. doctr/models/_utils.py +57 -20
  30. doctr/models/builder.py +73 -15
  31. doctr/models/classification/magc_resnet/tensorflow.py +13 -6
  32. doctr/models/classification/mobilenet/pytorch.py +47 -9
  33. doctr/models/classification/mobilenet/tensorflow.py +51 -14
  34. doctr/models/classification/predictor/pytorch.py +28 -17
  35. doctr/models/classification/predictor/tensorflow.py +26 -16
  36. doctr/models/classification/resnet/tensorflow.py +21 -8
  37. doctr/models/classification/textnet/pytorch.py +3 -3
  38. doctr/models/classification/textnet/tensorflow.py +11 -5
  39. doctr/models/classification/vgg/tensorflow.py +9 -3
  40. doctr/models/classification/vit/tensorflow.py +10 -4
  41. doctr/models/classification/zoo.py +55 -19
  42. doctr/models/detection/_utils/__init__.py +1 -0
  43. doctr/models/detection/_utils/base.py +66 -0
  44. doctr/models/detection/differentiable_binarization/base.py +4 -3
  45. doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
  46. doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
  47. doctr/models/detection/fast/base.py +6 -5
  48. doctr/models/detection/fast/pytorch.py +4 -4
  49. doctr/models/detection/fast/tensorflow.py +15 -12
  50. doctr/models/detection/linknet/base.py +4 -3
  51. doctr/models/detection/linknet/tensorflow.py +23 -11
  52. doctr/models/detection/predictor/pytorch.py +15 -1
  53. doctr/models/detection/predictor/tensorflow.py +17 -3
  54. doctr/models/detection/zoo.py +7 -2
  55. doctr/models/factory/hub.py +8 -18
  56. doctr/models/kie_predictor/base.py +13 -3
  57. doctr/models/kie_predictor/pytorch.py +45 -20
  58. doctr/models/kie_predictor/tensorflow.py +44 -17
  59. doctr/models/modules/layers/pytorch.py +2 -3
  60. doctr/models/modules/layers/tensorflow.py +6 -8
  61. doctr/models/modules/transformer/pytorch.py +2 -2
  62. doctr/models/modules/transformer/tensorflow.py +0 -2
  63. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  64. doctr/models/modules/vision_transformer/tensorflow.py +1 -1
  65. doctr/models/predictor/base.py +97 -58
  66. doctr/models/predictor/pytorch.py +35 -20
  67. doctr/models/predictor/tensorflow.py +35 -18
  68. doctr/models/preprocessor/pytorch.py +4 -4
  69. doctr/models/preprocessor/tensorflow.py +3 -2
  70. doctr/models/recognition/crnn/tensorflow.py +8 -6
  71. doctr/models/recognition/master/pytorch.py +2 -2
  72. doctr/models/recognition/master/tensorflow.py +9 -4
  73. doctr/models/recognition/parseq/pytorch.py +4 -3
  74. doctr/models/recognition/parseq/tensorflow.py +14 -11
  75. doctr/models/recognition/sar/pytorch.py +7 -6
  76. doctr/models/recognition/sar/tensorflow.py +10 -12
  77. doctr/models/recognition/vitstr/pytorch.py +1 -1
  78. doctr/models/recognition/vitstr/tensorflow.py +9 -4
  79. doctr/models/recognition/zoo.py +1 -1
  80. doctr/models/utils/pytorch.py +1 -1
  81. doctr/models/utils/tensorflow.py +15 -15
  82. doctr/models/zoo.py +2 -2
  83. doctr/py.typed +0 -0
  84. doctr/transforms/functional/base.py +1 -1
  85. doctr/transforms/functional/pytorch.py +5 -5
  86. doctr/transforms/modules/base.py +37 -15
  87. doctr/transforms/modules/pytorch.py +73 -14
  88. doctr/transforms/modules/tensorflow.py +78 -19
  89. doctr/utils/fonts.py +7 -5
  90. doctr/utils/geometry.py +141 -31
  91. doctr/utils/metrics.py +34 -175
  92. doctr/utils/reconstitution.py +212 -0
  93. doctr/utils/visualization.py +5 -118
  94. doctr/version.py +1 -1
  95. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/METADATA +85 -81
  96. python_doctr-0.10.0.dist-info/RECORD +173 -0
  97. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
  98. doctr/models/artefacts/__init__.py +0 -2
  99. doctr/models/artefacts/barcode.py +0 -74
  100. doctr/models/artefacts/face.py +0 -63
  101. doctr/models/obj_detection/__init__.py +0 -1
  102. doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
  103. doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
  104. python_doctr-0.8.1.dist-info/RECORD +0 -173
  105. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
  106. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
  107. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
@@ -10,10 +10,10 @@ import torch
10
10
  from torch import nn
11
11
 
12
12
  from doctr.io.elements import Document
13
- from doctr.models._utils import estimate_orientation, get_language
13
+ from doctr.models._utils import get_language
14
14
  from doctr.models.detection.predictor import DetectionPredictor
15
15
  from doctr.models.recognition.predictor import RecognitionPredictor
16
- from doctr.utils.geometry import rotate_image
16
+ from doctr.utils.geometry import detach_scores
17
17
 
18
18
  from .base import _OCRPredictor
19
19
 
@@ -55,7 +55,13 @@ class OCRPredictor(nn.Module, _OCRPredictor):
55
55
  self.det_predictor = det_predictor.eval() # type: ignore[attr-defined]
56
56
  self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined]
57
57
  _OCRPredictor.__init__(
58
- self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
58
+ self,
59
+ assume_straight_pages,
60
+ straighten_pages,
61
+ preserve_aspect_ratio,
62
+ symmetric_pad,
63
+ detect_orientation,
64
+ **kwargs,
59
65
  )
60
66
  self.detect_orientation = detect_orientation
61
67
  self.detect_language = detect_language
@@ -81,19 +87,19 @@ class OCRPredictor(nn.Module, _OCRPredictor):
81
87
  for out_map in out_maps
82
88
  ]
83
89
  if self.detect_orientation:
84
- origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
90
+ general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) # type: ignore[arg-type]
85
91
  orientations = [
86
- {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
92
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
87
93
  ]
88
94
  else:
89
95
  orientations = None
96
+ general_pages_orientations = None
97
+ origin_pages_orientations = None
90
98
  if self.straighten_pages:
91
- origin_page_orientations = (
92
- origin_page_orientations
93
- if self.detect_orientation
94
- else [estimate_orientation(seq_map) for seq_map in seg_maps]
95
- )
96
- pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
99
+ pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
100
+ # update page shapes after straightening
101
+ origin_page_shapes = [page.shape[:2] for page in pages]
102
+
97
103
  # Forward again to get predictions on straight pages
98
104
  loc_preds = self.det_predictor(pages, **kwargs)
99
105
 
@@ -102,30 +108,37 @@ class OCRPredictor(nn.Module, _OCRPredictor):
102
108
  ), "Detection Model in ocr_predictor should output only one class"
103
109
 
104
110
  loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds]
111
+ # Detach objectness scores from loc_preds
112
+ loc_preds, objectness_scores = detach_scores(loc_preds)
105
113
  # Check whether crop mode should be switched to channels first
106
114
  channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
107
115
 
108
- # Rectify crops if aspect ratio
109
- loc_preds = self._remove_padding(pages, loc_preds)
110
-
111
116
  # Apply hooks to loc_preds if any
112
117
  for hook in self.hooks:
113
118
  loc_preds = hook(loc_preds)
114
119
 
115
120
  # Crop images
116
121
  crops, loc_preds = self._prepare_crops(
117
- pages,
122
+ pages, # type: ignore[arg-type]
118
123
  loc_preds,
119
124
  channels_last=channels_last,
120
125
  assume_straight_pages=self.assume_straight_pages,
126
+ assume_horizontal=self._page_orientation_disabled,
121
127
  )
122
- # Rectify crop orientation
128
+ # Rectify crop orientation and get crop orientation predictions
129
+ crop_orientations: Any = []
123
130
  if not self.assume_straight_pages:
124
- crops, loc_preds = self._rectify_crops(crops, loc_preds)
131
+ crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds)
132
+ crop_orientations = [
133
+ {"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations
134
+ ]
135
+
125
136
  # Identify character sequences
126
137
  word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs)
138
+ if not crop_orientations:
139
+ crop_orientations = [{"value": 0, "confidence": None} for _ in word_preds]
127
140
 
128
- boxes, text_preds = self._process_predictions(loc_preds, word_preds)
141
+ boxes, text_preds, crop_orientations = self._process_predictions(loc_preds, word_preds, crop_orientations)
129
142
 
130
143
  if self.detect_language:
131
144
  languages = [get_language(" ".join([item[0] for item in text_pred])) for text_pred in text_preds]
@@ -134,10 +147,12 @@ class OCRPredictor(nn.Module, _OCRPredictor):
134
147
  languages_dict = None
135
148
 
136
149
  out = self.doc_builder(
137
- pages,
150
+ pages, # type: ignore[arg-type]
138
151
  boxes,
152
+ objectness_scores,
139
153
  text_preds,
140
- origin_page_shapes,
154
+ origin_page_shapes, # type: ignore[arg-type]
155
+ crop_orientations,
141
156
  orientations,
142
157
  languages_dict,
143
158
  )
@@ -9,10 +9,10 @@ import numpy as np
9
9
  import tensorflow as tf
10
10
 
11
11
  from doctr.io.elements import Document
12
- from doctr.models._utils import estimate_orientation, get_language
12
+ from doctr.models._utils import get_language
13
13
  from doctr.models.detection.predictor import DetectionPredictor
14
14
  from doctr.models.recognition.predictor import RecognitionPredictor
15
- from doctr.utils.geometry import rotate_image
15
+ from doctr.utils.geometry import detach_scores
16
16
  from doctr.utils.repr import NestedObject
17
17
 
18
18
  from .base import _OCRPredictor
@@ -56,7 +56,13 @@ class OCRPredictor(NestedObject, _OCRPredictor):
56
56
  self.det_predictor = det_predictor
57
57
  self.reco_predictor = reco_predictor
58
58
  _OCRPredictor.__init__(
59
- self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
59
+ self,
60
+ assume_straight_pages,
61
+ straighten_pages,
62
+ preserve_aspect_ratio,
63
+ symmetric_pad,
64
+ detect_orientation,
65
+ **kwargs,
60
66
  )
61
67
  self.detect_orientation = detect_orientation
62
68
  self.detect_language = detect_language
@@ -81,19 +87,19 @@ class OCRPredictor(NestedObject, _OCRPredictor):
81
87
  for out_map in out_maps
82
88
  ]
83
89
  if self.detect_orientation:
84
- origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
90
+ general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
85
91
  orientations = [
86
- {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
92
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
87
93
  ]
88
94
  else:
89
95
  orientations = None
96
+ general_pages_orientations = None
97
+ origin_pages_orientations = None
90
98
  if self.straighten_pages:
91
- origin_page_orientations = (
92
- origin_page_orientations
93
- if self.detect_orientation
94
- else [estimate_orientation(seq_map) for seq_map in seg_maps]
95
- )
96
- pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
99
+ pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
100
+ # update page shapes after straightening
101
+ origin_page_shapes = [page.shape[:2] for page in pages]
102
+
97
103
  # forward again to get predictions on straight pages
98
104
  loc_preds_dict = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
99
105
 
@@ -101,9 +107,8 @@ class OCRPredictor(NestedObject, _OCRPredictor):
101
107
  len(loc_pred) == 1 for loc_pred in loc_preds_dict
102
108
  ), "Detection Model in ocr_predictor should output only one class"
103
109
  loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict] # type: ignore[union-attr]
104
-
105
- # Rectify crops if aspect ratio
106
- loc_preds = self._remove_padding(pages, loc_preds)
110
+ # Detach objectness scores from loc_preds
111
+ loc_preds, objectness_scores = detach_scores(loc_preds)
107
112
 
108
113
  # Apply hooks to loc_preds if any
109
114
  for hook in self.hooks:
@@ -111,16 +116,26 @@ class OCRPredictor(NestedObject, _OCRPredictor):
111
116
 
112
117
  # Crop images
113
118
  crops, loc_preds = self._prepare_crops(
114
- pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages
119
+ pages,
120
+ loc_preds,
121
+ channels_last=True,
122
+ assume_straight_pages=self.assume_straight_pages,
123
+ assume_horizontal=self._page_orientation_disabled,
115
124
  )
116
- # Rectify crop orientation
125
+ # Rectify crop orientation and get crop orientation predictions
126
+ crop_orientations: Any = []
117
127
  if not self.assume_straight_pages:
118
- crops, loc_preds = self._rectify_crops(crops, loc_preds)
128
+ crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds)
129
+ crop_orientations = [
130
+ {"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations
131
+ ]
119
132
 
120
133
  # Identify character sequences
121
134
  word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs)
135
+ if not crop_orientations:
136
+ crop_orientations = [{"value": 0, "confidence": None} for _ in word_preds]
122
137
 
123
- boxes, text_preds = self._process_predictions(loc_preds, word_preds)
138
+ boxes, text_preds, crop_orientations = self._process_predictions(loc_preds, word_preds, crop_orientations)
124
139
 
125
140
  if self.detect_language:
126
141
  languages = [get_language(" ".join([item[0] for item in text_pred])) for text_pred in text_preds]
@@ -131,8 +146,10 @@ class OCRPredictor(NestedObject, _OCRPredictor):
131
146
  out = self.doc_builder(
132
147
  pages,
133
148
  boxes,
149
+ objectness_scores,
134
150
  text_preds,
135
151
  origin_page_shapes, # type: ignore[arg-type]
152
+ crop_orientations,
136
153
  orientations,
137
154
  languages_dict,
138
155
  )
@@ -79,7 +79,7 @@ class PreProcessor(nn.Module):
79
79
  else:
80
80
  x = x.to(dtype=torch.float32) # type: ignore[union-attr]
81
81
 
82
- return x # type: ignore[return-value]
82
+ return x
83
83
 
84
84
  def __call__(self, x: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]) -> List[torch.Tensor]:
85
85
  """Prepare document data for model forwarding
@@ -103,7 +103,7 @@ class PreProcessor(nn.Module):
103
103
  elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
104
104
  raise TypeError("unsupported data type for torch.Tensor")
105
105
  # Resizing
106
- if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]: # type: ignore[union-attr]
106
+ if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]:
107
107
  x = F.resize(
108
108
  x, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
109
109
  )
@@ -118,11 +118,11 @@ class PreProcessor(nn.Module):
118
118
  # Sample transform (to tensor, resize)
119
119
  samples = list(multithread_exec(self.sample_transforms, x))
120
120
  # Batching
121
- batches = self.batch_inputs(samples) # type: ignore[assignment]
121
+ batches = self.batch_inputs(samples)
122
122
  else:
123
123
  raise TypeError(f"invalid input type: {type(x)}")
124
124
 
125
125
  # Batch transforms (normalize)
126
126
  batches = list(multithread_exec(self.normalize, batches))
127
127
 
128
- return batches # type: ignore[return-value]
128
+ return batches
@@ -41,6 +41,7 @@ class PreProcessor(NestedObject):
41
41
  self.resize = Resize(output_size, **kwargs)
42
42
  # Perform the division by 255 at the same time
43
43
  self.normalize = Normalize(mean, std)
44
+ self._runs_on_cuda = tf.config.list_physical_devices("GPU") != []
44
45
 
45
46
  def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]:
46
47
  """Gather samples into batches for inference purposes
@@ -113,13 +114,13 @@ class PreProcessor(NestedObject):
113
114
 
114
115
  elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, tf.Tensor)) for sample in x):
115
116
  # Sample transform (to tensor, resize)
116
- samples = list(multithread_exec(self.sample_transforms, x))
117
+ samples = list(multithread_exec(self.sample_transforms, x, threads=1 if self._runs_on_cuda else None))
117
118
  # Batching
118
119
  batches = self.batch_inputs(samples)
119
120
  else:
120
121
  raise TypeError(f"invalid input type: {type(x)}")
121
122
 
122
123
  # Batch transforms (normalize)
123
- batches = list(multithread_exec(self.normalize, batches))
124
+ batches = list(multithread_exec(self.normalize, batches, threads=1 if self._runs_on_cuda else None))
124
125
 
125
126
  return batches
@@ -13,7 +13,7 @@ from tensorflow.keras.models import Model, Sequential
13
13
  from doctr.datasets import VOCABS
14
14
 
15
15
  from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
16
- from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
16
+ from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
17
17
  from ..core import RecognitionModel, RecognitionPostProcessor
18
18
 
19
19
  __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
@@ -24,21 +24,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
24
24
  "std": (0.299, 0.296, 0.301),
25
25
  "input_shape": (32, 128, 3),
26
26
  "vocab": VOCABS["legacy_french"],
27
- "url": "https://doctr-static.mindee.com/models?id=v0.3.0/crnn_vgg16_bn-76b7f2c6.zip&src=0",
27
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_vgg16_bn-9c188f45.weights.h5&src=0",
28
28
  },
29
29
  "crnn_mobilenet_v3_small": {
30
30
  "mean": (0.694, 0.695, 0.693),
31
31
  "std": (0.299, 0.296, 0.301),
32
32
  "input_shape": (32, 128, 3),
33
33
  "vocab": VOCABS["french"],
34
- "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_mobilenet_v3_small-7f36edec.zip&src=0",
34
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_small-54850265.weights.h5&src=0",
35
35
  },
36
36
  "crnn_mobilenet_v3_large": {
37
37
  "mean": (0.694, 0.695, 0.693),
38
38
  "std": (0.299, 0.296, 0.301),
39
39
  "input_shape": (32, 128, 3),
40
40
  "vocab": VOCABS["french"],
41
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/crnn_mobilenet_v3_large-cccc50b1.zip&src=0",
41
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_large-c64045e5.weights.h5&src=0",
42
42
  },
43
43
  }
44
44
 
@@ -128,7 +128,7 @@ class CRNN(RecognitionModel, Model):
128
128
 
129
129
  def __init__(
130
130
  self,
131
- feature_extractor: tf.keras.Model,
131
+ feature_extractor: Model,
132
132
  vocab: str,
133
133
  rnn_units: int = 128,
134
134
  exportable: bool = False,
@@ -245,9 +245,11 @@ def _crnn(
245
245
 
246
246
  # Build the model
247
247
  model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
248
+ _build_model(model)
248
249
  # Load pretrained parameters
249
250
  if pretrained:
250
- load_pretrained_params(model, _cfg["url"])
251
+ # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
252
+ load_pretrained_params(model, _cfg["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
251
253
 
252
254
  return model
253
255
 
@@ -107,7 +107,7 @@ class MASTER(_MASTER, nn.Module):
107
107
  # NOTE: nn.TransformerDecoder takes the inverse from this implementation
108
108
  # [True, True, True, ..., False, False, False] -> False is masked
109
109
  # (N, 1, 1, max_length)
110
- target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
110
+ target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
111
111
  target_length = target.size(1)
112
112
  # sub mask filled diagonal with True = see and False = masked (max_length, max_length)
113
113
  # NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14)
@@ -142,7 +142,7 @@ class MASTER(_MASTER, nn.Module):
142
142
  # Input length : number of timesteps
143
143
  input_len = model_output.shape[1]
144
144
  # Add one for additional <eos> token (sos disappear in shift!)
145
- seq_len = seq_len + 1 # type: ignore[assignment]
145
+ seq_len = seq_len + 1
146
146
  # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
147
147
  # The "masked" first gt char is <sos>. Delete last logit of the model output.
148
148
  cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
@@ -13,7 +13,7 @@ from doctr.datasets import VOCABS
13
13
  from doctr.models.classification import magc_resnet31
14
14
  from doctr.models.modules.transformer import Decoder, PositionalEncoding
15
15
 
16
- from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
16
+ from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
17
17
  from .base import _MASTER, _MASTERPostProcessor
18
18
 
19
19
  __all__ = ["MASTER", "master"]
@@ -25,7 +25,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
25
25
  "std": (0.299, 0.296, 0.301),
26
26
  "input_shape": (32, 128, 3),
27
27
  "vocab": VOCABS["french"],
28
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/master-a8232e9f.zip&src=0",
28
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/master-d7fdaeff.weights.h5&src=0",
29
29
  },
30
30
  }
31
31
 
@@ -51,7 +51,7 @@ class MASTER(_MASTER, Model):
51
51
 
52
52
  def __init__(
53
53
  self,
54
- feature_extractor: tf.keras.Model,
54
+ feature_extractor: Model,
55
55
  vocab: str,
56
56
  d_model: int = 512,
57
57
  dff: int = 2048,
@@ -290,9 +290,14 @@ def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool
290
290
  cfg=_cfg,
291
291
  **kwargs,
292
292
  )
293
+ _build_model(model)
294
+
293
295
  # Load pretrained parameters
294
296
  if pretrained:
295
- load_pretrained_params(model, default_cfgs[arch]["url"])
297
+ # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
298
+ load_pretrained_params(
299
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
300
+ )
296
301
 
297
302
  return model
298
303
 
@@ -212,7 +212,7 @@ class PARSeq(_PARSeq, nn.Module):
212
212
 
213
213
  sos_idx = torch.zeros(len(final_perms), 1, device=seqlen.device)
214
214
  eos_idx = torch.full((len(final_perms), 1), max_num_chars + 1, device=seqlen.device)
215
- combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int() # type: ignore
215
+ combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
216
216
  if len(combined) > 1:
217
217
  combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
218
218
  return combined
@@ -282,7 +282,8 @@ class PARSeq(_PARSeq, nn.Module):
282
282
  ys[:, i + 1] = pos_prob.squeeze().argmax(-1)
283
283
 
284
284
  # Stop decoding if all sequences have reached the EOS token
285
- if max_len is None and (ys == self.vocab_size).any(dim=-1).all(): # type: ignore[attr-defined]
285
+ # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
286
+ if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
286
287
  break
287
288
 
288
289
  logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
@@ -297,7 +298,7 @@ class PARSeq(_PARSeq, nn.Module):
297
298
 
298
299
  # Create padding mask for refined target input maskes all behind EOS token as False
299
300
  # (N, 1, 1, max_length)
300
- target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
301
+ target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
301
302
  mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
302
303
  logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
303
304
 
@@ -16,7 +16,7 @@ from doctr.datasets import VOCABS
16
16
  from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
17
17
 
18
18
  from ...classification import vit_s
19
- from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
19
+ from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
20
20
  from .base import _PARSeq, _PARSeqPostProcessor
21
21
 
22
22
  __all__ = ["PARSeq", "parseq"]
@@ -27,7 +27,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
27
27
  "std": (0.299, 0.296, 0.301),
28
28
  "input_shape": (32, 128, 3),
29
29
  "vocab": VOCABS["french"],
30
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/parseq-24cf693e.zip&src=0",
30
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/parseq-4152a87e.weights.h5&src=0",
31
31
  },
32
32
  }
33
33
 
@@ -43,7 +43,7 @@ class CharEmbedding(layers.Layer):
43
43
 
44
44
  def __init__(self, vocab_size: int, d_model: int):
45
45
  super(CharEmbedding, self).__init__()
46
- self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)
46
+ self.embedding = layers.Embedding(vocab_size, d_model)
47
47
  self.d_model = d_model
48
48
 
49
49
  def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
@@ -167,7 +167,6 @@ class PARSeq(_PARSeq, Model):
167
167
 
168
168
  self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
169
169
 
170
- @tf.function
171
170
  def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
172
171
  # Generates permutations of the target sequence.
173
172
  # Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
@@ -214,7 +213,6 @@ class PARSeq(_PARSeq, Model):
214
213
  )
215
214
  return combined
216
215
 
217
- @tf.function
218
216
  def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
219
217
  # Generate source and target mask for the decoder attention.
220
218
  sz = permutation.shape[0]
@@ -234,11 +232,10 @@ class PARSeq(_PARSeq, Model):
234
232
  target_mask = mask[1:, :-1]
235
233
  return tf.cast(source_mask, dtype=tf.bool), tf.cast(target_mask, dtype=tf.bool)
236
234
 
237
- @tf.function
238
235
  def decode(
239
236
  self,
240
237
  target: tf.Tensor,
241
- memory: tf,
238
+ memory: tf.Tensor,
242
239
  target_mask: Optional[tf.Tensor] = None,
243
240
  target_query: Optional[tf.Tensor] = None,
244
241
  **kwargs: Any,
@@ -288,10 +285,11 @@ class PARSeq(_PARSeq, Model):
288
285
  )
289
286
 
290
287
  # Stop decoding if all sequences have reached the EOS token
291
- # We need to check it on True to be compatible with ONNX
288
+ # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
292
289
  if (
293
- max_len is None
294
- and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1)) is True
290
+ not self.exportable
291
+ and max_len is None
292
+ and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1))
295
293
  ):
296
294
  break
297
295
 
@@ -475,9 +473,14 @@ def _parseq(
475
473
 
476
474
  # Build the model
477
475
  model = PARSeq(feat_extractor, cfg=_cfg, **kwargs)
476
+ _build_model(model)
477
+
478
478
  # Load pretrained parameters
479
479
  if pretrained:
480
- load_pretrained_params(model, default_cfgs[arch]["url"])
480
+ # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
481
+ load_pretrained_params(
482
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
483
+ )
481
484
 
482
485
  return model
483
486
 
@@ -125,25 +125,26 @@ class SARDecoder(nn.Module):
125
125
  if t == 0:
126
126
  # step to init the first states of the LSTMCell
127
127
  hidden_state_init = cell_state_init = torch.zeros(
128
- features.size(0), features.size(1), device=features.device
128
+ features.size(0), features.size(1), device=features.device, dtype=features.dtype
129
129
  )
130
130
  hidden_state, cell_state = hidden_state_init, cell_state_init
131
131
  prev_symbol = holistic
132
132
  elif t == 1:
133
133
  # step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
134
134
  # (N, vocab_size + 1) --> (N, embedding_units)
135
- prev_symbol = torch.zeros(features.size(0), self.vocab_size + 1, device=features.device)
135
+ prev_symbol = torch.zeros(
136
+ features.size(0), self.vocab_size + 1, device=features.device, dtype=features.dtype
137
+ )
136
138
  prev_symbol = self.embed(prev_symbol)
137
139
  else:
138
- if gt is not None:
140
+ if gt is not None and self.training:
139
141
  # (N, embedding_units) -2 because of <bos> and <eos> (same)
140
142
  prev_symbol = self.embed(gt_embedding[:, t - 2])
141
143
  else:
142
144
  # -1 to start at timestep where prev_symbol was initialized
143
145
  index = logits_list[t - 1].argmax(-1)
144
146
  # update prev_symbol with ones at the index of the previous logit vector
145
- # (N, embedding_units)
146
- prev_symbol = prev_symbol.scatter_(1, index.unsqueeze(1), 1)
147
+ prev_symbol = self.embed(self.embed_tgt(index))
147
148
 
148
149
  # (N, C), (N, C) take the last hidden state and cell state from current timestep
149
150
  hidden_state_init, cell_state_init = self.lstm_cell(prev_symbol, (hidden_state_init, cell_state_init))
@@ -292,7 +293,7 @@ class SAR(nn.Module, RecognitionModel):
292
293
  # Input length : number of timesteps
293
294
  input_len = model_output.shape[1]
294
295
  # Add one for additional <eos> token
295
- seq_len = seq_len + 1 # type: ignore[assignment]
296
+ seq_len = seq_len + 1
296
297
  # Compute loss
297
298
  # (N, L, vocab_size + 1)
298
299
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
@@ -13,7 +13,7 @@ from doctr.datasets import VOCABS
13
13
  from doctr.utils.repr import NestedObject
14
14
 
15
15
  from ...classification import resnet31
16
- from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
16
+ from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
17
17
  from ..core import RecognitionModel, RecognitionPostProcessor
18
18
 
19
19
  __all__ = ["SAR", "sar_resnet31"]
@@ -24,7 +24,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
24
24
  "std": (0.299, 0.296, 0.301),
25
25
  "input_shape": (32, 128, 3),
26
26
  "vocab": VOCABS["french"],
27
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/sar_resnet31-c41e32a5.zip&src=0",
27
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/sar_resnet31-5a58806c.weights.h5&src=0",
28
28
  },
29
29
  }
30
30
 
@@ -177,23 +177,17 @@ class SARDecoder(layers.Layer, NestedObject):
177
177
  elif t == 1:
178
178
  # step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
179
179
  # (N, vocab_size + 1) --> (N, embedding_units)
180
- prev_symbol = tf.zeros([features.shape[0], self.vocab_size + 1])
180
+ prev_symbol = tf.zeros([features.shape[0], self.vocab_size + 1], dtype=features.dtype)
181
181
  prev_symbol = self.embed(prev_symbol, **kwargs)
182
182
  else:
183
- if gt is not None:
183
+ if gt is not None and kwargs.get("training", False):
184
184
  # (N, embedding_units) -2 because of <bos> and <eos> (same)
185
185
  prev_symbol = self.embed(gt_embedding[:, t - 2], **kwargs)
186
186
  else:
187
187
  # -1 to start at timestep where prev_symbol was initialized
188
188
  index = tf.argmax(logits_list[t - 1], axis=-1)
189
189
  # update prev_symbol with ones at the index of the previous logit vector
190
- # (N, embedding_units)
191
- index = tf.ones_like(index)
192
- prev_symbol = tf.scatter_nd(
193
- tf.expand_dims(index, axis=1),
194
- prev_symbol,
195
- tf.constant([features.shape[0], features.shape[-1]], dtype=tf.int64),
196
- )
190
+ prev_symbol = self.embed(self.embed_tgt(index, **kwargs), **kwargs)
197
191
 
198
192
  # (N, C), (N, C) take the last hidden state and cell state from current timestep
199
193
  _, states = self.lstm_cells(prev_symbol, states, **kwargs)
@@ -398,9 +392,13 @@ def _sar(
398
392
 
399
393
  # Build the model
400
394
  model = SAR(feat_extractor, cfg=_cfg, **kwargs)
395
+ _build_model(model)
401
396
  # Load pretrained parameters
402
397
  if pretrained:
403
- load_pretrained_params(model, default_cfgs[arch]["url"])
398
+ # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
399
+ load_pretrained_params(
400
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
401
+ )
404
402
 
405
403
  return model
406
404
 
@@ -137,7 +137,7 @@ class ViTSTR(_ViTSTR, nn.Module):
137
137
  # Input length : number of steps
138
138
  input_len = model_output.shape[1]
139
139
  # Add one for additional <eos> token (sos disappear in shift!)
140
- seq_len = seq_len + 1 # type: ignore[assignment]
140
+ seq_len = seq_len + 1
141
141
  # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
142
142
  # The "masked" first gt char is <sos>.
143
143
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")