python-doctr 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. doctr/datasets/cord.py +10 -1
  2. doctr/datasets/funsd.py +11 -1
  3. doctr/datasets/ic03.py +11 -1
  4. doctr/datasets/ic13.py +10 -1
  5. doctr/datasets/iiit5k.py +26 -16
  6. doctr/datasets/imgur5k.py +10 -1
  7. doctr/datasets/sroie.py +11 -1
  8. doctr/datasets/svhn.py +11 -1
  9. doctr/datasets/svt.py +11 -1
  10. doctr/datasets/synthtext.py +11 -1
  11. doctr/datasets/utils.py +7 -2
  12. doctr/datasets/vocabs.py +6 -2
  13. doctr/datasets/wildreceipt.py +12 -1
  14. doctr/file_utils.py +19 -0
  15. doctr/io/elements.py +12 -4
  16. doctr/models/builder.py +2 -2
  17. doctr/models/classification/magc_resnet/tensorflow.py +13 -6
  18. doctr/models/classification/mobilenet/pytorch.py +2 -0
  19. doctr/models/classification/mobilenet/tensorflow.py +14 -8
  20. doctr/models/classification/predictor/pytorch.py +11 -7
  21. doctr/models/classification/predictor/tensorflow.py +10 -6
  22. doctr/models/classification/resnet/tensorflow.py +21 -8
  23. doctr/models/classification/textnet/tensorflow.py +11 -5
  24. doctr/models/classification/vgg/tensorflow.py +9 -3
  25. doctr/models/classification/vit/tensorflow.py +10 -4
  26. doctr/models/classification/zoo.py +22 -10
  27. doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
  28. doctr/models/detection/fast/tensorflow.py +14 -11
  29. doctr/models/detection/linknet/tensorflow.py +23 -11
  30. doctr/models/detection/predictor/tensorflow.py +2 -2
  31. doctr/models/factory/hub.py +5 -6
  32. doctr/models/kie_predictor/base.py +4 -0
  33. doctr/models/kie_predictor/pytorch.py +4 -0
  34. doctr/models/kie_predictor/tensorflow.py +8 -1
  35. doctr/models/modules/transformer/tensorflow.py +0 -2
  36. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  37. doctr/models/modules/vision_transformer/tensorflow.py +1 -1
  38. doctr/models/predictor/base.py +24 -12
  39. doctr/models/predictor/pytorch.py +4 -0
  40. doctr/models/predictor/tensorflow.py +8 -1
  41. doctr/models/preprocessor/tensorflow.py +1 -1
  42. doctr/models/recognition/crnn/tensorflow.py +8 -6
  43. doctr/models/recognition/master/tensorflow.py +9 -4
  44. doctr/models/recognition/parseq/tensorflow.py +10 -8
  45. doctr/models/recognition/sar/tensorflow.py +7 -3
  46. doctr/models/recognition/vitstr/tensorflow.py +9 -4
  47. doctr/models/utils/pytorch.py +1 -1
  48. doctr/models/utils/tensorflow.py +15 -15
  49. doctr/transforms/functional/pytorch.py +1 -1
  50. doctr/transforms/modules/pytorch.py +7 -6
  51. doctr/transforms/modules/tensorflow.py +15 -12
  52. doctr/utils/geometry.py +106 -19
  53. doctr/utils/metrics.py +1 -1
  54. doctr/utils/reconstitution.py +151 -65
  55. doctr/version.py +1 -1
  56. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/METADATA +11 -11
  57. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/RECORD +61 -61
  58. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
  59. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
  60. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
  61. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
@@ -10,11 +10,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
13
- from tensorflow import keras
14
- from tensorflow.keras import Sequential, layers
13
+ from tensorflow.keras import Model, Sequential, layers
15
14
 
16
15
  from doctr.file_utils import CLASS_NAME
17
- from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params
16
+ from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, _build_model, load_pretrained_params
18
17
  from doctr.utils.repr import NestedObject
19
18
 
20
19
  from ...classification import textnet_base, textnet_small, textnet_tiny
@@ -29,19 +28,19 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
29
28
  "input_shape": (1024, 1024, 3),
30
29
  "mean": (0.798, 0.785, 0.772),
31
30
  "std": (0.264, 0.2749, 0.287),
32
- "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_tiny-959daecb.zip&src=0",
31
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_tiny-d7379d7b.weights.h5&src=0",
33
32
  },
34
33
  "fast_small": {
35
34
  "input_shape": (1024, 1024, 3),
36
35
  "mean": (0.798, 0.785, 0.772),
37
36
  "std": (0.264, 0.2749, 0.287),
38
- "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_small-f1617503.zip&src=0",
37
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_small-44b27eb6.weights.h5&src=0",
39
38
  },
40
39
  "fast_base": {
41
40
  "input_shape": (1024, 1024, 3),
42
41
  "mean": (0.798, 0.785, 0.772),
43
42
  "std": (0.264, 0.2749, 0.287),
44
- "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_base-255e2ac3.zip&src=0",
43
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_base-f2c6c736.weights.h5&src=0",
45
44
  },
46
45
  }
47
46
 
@@ -100,7 +99,7 @@ class FastHead(Sequential):
100
99
  super().__init__(_layers)
101
100
 
102
101
 
103
- class FAST(_FAST, keras.Model, NestedObject):
102
+ class FAST(_FAST, Model, NestedObject):
104
103
  """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
105
104
  <https://arxiv.org/pdf/2111.02394.pdf>`_.
106
105
 
@@ -334,12 +333,16 @@ def _fast(
334
333
 
335
334
  # Build the model
336
335
  model = FAST(feat_extractor, cfg=_cfg, **kwargs)
336
+ _build_model(model)
337
+
337
338
  # Load pretrained parameters
338
339
  if pretrained:
339
- load_pretrained_params(model, _cfg["url"])
340
-
341
- # Build the model for reparameterization to access the layers
342
- _ = model(tf.random.uniform(shape=[1, *_cfg["input_shape"]], maxval=1, dtype=tf.float32), training=False)
340
+ # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
341
+ load_pretrained_params(
342
+ model,
343
+ _cfg["url"],
344
+ skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
345
+ )
343
346
 
344
347
  return model
345
348
 
@@ -10,12 +10,17 @@ from typing import Any, Dict, List, Optional, Tuple
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
13
- from tensorflow import keras
14
- from tensorflow.keras import Model, Sequential, layers
13
+ from tensorflow.keras import Model, Sequential, layers, losses
15
14
 
16
15
  from doctr.file_utils import CLASS_NAME
17
16
  from doctr.models.classification import resnet18, resnet34, resnet50
18
- from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params
17
+ from doctr.models.utils import (
18
+ IntermediateLayerGetter,
19
+ _bf16_to_float32,
20
+ _build_model,
21
+ conv_sequence,
22
+ load_pretrained_params,
23
+ )
19
24
  from doctr.utils.repr import NestedObject
20
25
 
21
26
  from .base import LinkNetPostProcessor, _LinkNet
@@ -27,19 +32,19 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
27
32
  "mean": (0.798, 0.785, 0.772),
28
33
  "std": (0.264, 0.2749, 0.287),
29
34
  "input_shape": (1024, 1024, 3),
30
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet18-b9ee56e6.zip&src=0",
35
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet18-615a82c5.weights.h5&src=0",
31
36
  },
32
37
  "linknet_resnet34": {
33
38
  "mean": (0.798, 0.785, 0.772),
34
39
  "std": (0.264, 0.2749, 0.287),
35
40
  "input_shape": (1024, 1024, 3),
36
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet34-51909c56.zip&src=0",
41
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet34-9d772be5.weights.h5&src=0",
37
42
  },
38
43
  "linknet_resnet50": {
39
44
  "mean": (0.798, 0.785, 0.772),
40
45
  "std": (0.264, 0.2749, 0.287),
41
46
  "input_shape": (1024, 1024, 3),
42
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet50-ac9f3829.zip&src=0",
47
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet50-6bf6c8b5.weights.h5&src=0",
43
48
  },
44
49
  }
45
50
 
@@ -80,17 +85,17 @@ class LinkNetFPN(Model, NestedObject):
80
85
  for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1])
81
86
  ]
82
87
 
83
- def call(self, x: List[tf.Tensor]) -> tf.Tensor:
88
+ def call(self, x: List[tf.Tensor], **kwargs: Any) -> tf.Tensor:
84
89
  out = 0
85
90
  for decoder, fmap in zip(self.decoders, x[::-1]):
86
- out = decoder(out + fmap)
91
+ out = decoder(out + fmap, **kwargs)
87
92
  return out
88
93
 
89
94
  def extra_repr(self) -> str:
90
95
  return f"out_chans={self.out_chans}"
91
96
 
92
97
 
93
- class LinkNet(_LinkNet, keras.Model):
98
+ class LinkNet(_LinkNet, Model):
94
99
  """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
95
100
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
96
101
 
@@ -187,7 +192,7 @@ class LinkNet(_LinkNet, keras.Model):
187
192
  seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
188
193
  seg_mask = tf.cast(seg_mask, tf.float32)
189
194
 
190
- bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
195
+ bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
191
196
  proba_map = tf.sigmoid(out_map)
192
197
 
193
198
  # Focal loss
@@ -275,9 +280,16 @@ def _linknet(
275
280
 
276
281
  # Build the model
277
282
  model = LinkNet(feat_extractor, cfg=_cfg, **kwargs)
283
+ _build_model(model)
284
+
278
285
  # Load pretrained parameters
279
286
  if pretrained:
280
- load_pretrained_params(model, _cfg["url"])
287
+ # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
288
+ load_pretrained_params(
289
+ model,
290
+ _cfg["url"],
291
+ skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
292
+ )
281
293
 
282
294
  return model
283
295
 
@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Tuple, Union
7
7
 
8
8
  import numpy as np
9
9
  import tensorflow as tf
10
- from tensorflow import keras
10
+ from tensorflow.keras import Model
11
11
 
12
12
  from doctr.models.detection._utils import _remove_padding
13
13
  from doctr.models.preprocessor import PreProcessor
@@ -30,7 +30,7 @@ class DetectionPredictor(NestedObject):
30
30
  def __init__(
31
31
  self,
32
32
  pre_processor: PreProcessor,
33
- model: keras.Model,
33
+ model: Model,
34
34
  ) -> None:
35
35
  self.pre_processor = pre_processor
36
36
  self.model = model
@@ -20,7 +20,6 @@ from huggingface_hub import (
20
20
  get_token_permission,
21
21
  hf_hub_download,
22
22
  login,
23
- snapshot_download,
24
23
  )
25
24
 
26
25
  from doctr import models
@@ -33,7 +32,7 @@ __all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config
33
32
 
34
33
 
35
34
  AVAILABLE_ARCHS = {
36
- "classification": models.classification.zoo.ARCHS,
35
+ "classification": models.classification.zoo.ARCHS + models.classification.zoo.ORIENTATION_ARCHS,
37
36
  "detection": models.detection.zoo.ARCHS,
38
37
  "recognition": models.recognition.zoo.ARCHS,
39
38
  }
@@ -74,7 +73,7 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
74
73
  weights_path = save_directory / "pytorch_model.bin"
75
74
  torch.save(model.state_dict(), weights_path)
76
75
  elif is_tf_available():
77
- weights_path = save_directory / "tf_model" / "weights"
76
+ weights_path = save_directory / "tf_model.weights.h5"
78
77
  model.save_weights(str(weights_path))
79
78
 
80
79
  config_path = save_directory / "config.json"
@@ -174,7 +173,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
174
173
 
175
174
  local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name)
176
175
  repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=False)
177
- repo = Repository(local_dir=local_cache_dir, clone_from=repo_url, use_auth_token=True)
176
+ repo = Repository(local_dir=local_cache_dir, clone_from=repo_url)
178
177
 
179
178
  with repo.commit(commit_message):
180
179
  _save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task)
@@ -225,7 +224,7 @@ def from_hub(repo_id: str, **kwargs: Any):
225
224
  state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu")
226
225
  model.load_state_dict(state_dict)
227
226
  else: # tf
228
- repo_path = snapshot_download(repo_id, **kwargs)
229
- model.load_weights(os.path.join(repo_path, "tf_model", "weights"))
227
+ weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
228
+ model.load_weights(weights)
230
229
 
231
230
  return model
@@ -46,4 +46,8 @@ class _KIEPredictor(_OCRPredictor):
46
46
  assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, detect_orientation, **kwargs
47
47
  )
48
48
 
49
+ # Remove the following arguments from kwargs after initialization of the parent class
50
+ kwargs.pop("disable_page_orientation", None)
51
+ kwargs.pop("disable_crop_orientation", None)
52
+
49
53
  self.doc_builder: KIEDocumentBuilder = KIEDocumentBuilder(**kwargs)
@@ -99,6 +99,9 @@ class KIEPredictor(nn.Module, _KIEPredictor):
99
99
  origin_pages_orientations = None
100
100
  if self.straighten_pages:
101
101
  pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
102
+ # update page shapes after straightening
103
+ origin_page_shapes = [page.shape[:2] for page in pages]
104
+
102
105
  # Forward again to get predictions on straight pages
103
106
  loc_preds = self.det_predictor(pages, **kwargs)
104
107
 
@@ -126,6 +129,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
126
129
  dict_loc_preds[class_name],
127
130
  channels_last=channels_last,
128
131
  assume_straight_pages=self.assume_straight_pages,
132
+ assume_horizontal=self._page_orientation_disabled,
129
133
  )
130
134
  # Rectify crop orientation
131
135
  crop_orientations: Any = {}
@@ -99,6 +99,9 @@ class KIEPredictor(NestedObject, _KIEPredictor):
99
99
  origin_pages_orientations = None
100
100
  if self.straighten_pages:
101
101
  pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
102
+ # update page shapes after straightening
103
+ origin_page_shapes = [page.shape[:2] for page in pages]
104
+
102
105
  # Forward again to get predictions on straight pages
103
106
  loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
104
107
 
@@ -119,7 +122,11 @@ class KIEPredictor(NestedObject, _KIEPredictor):
119
122
  crops = {}
120
123
  for class_name in dict_loc_preds.keys():
121
124
  crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
122
- pages, dict_loc_preds[class_name], channels_last=True, assume_straight_pages=self.assume_straight_pages
125
+ pages,
126
+ dict_loc_preds[class_name],
127
+ channels_last=True,
128
+ assume_straight_pages=self.assume_straight_pages,
129
+ assume_horizontal=self._page_orientation_disabled,
123
130
  )
124
131
 
125
132
  # Rectify crop orientation
@@ -13,8 +13,6 @@ from doctr.utils.repr import NestedObject
13
13
 
14
14
  __all__ = ["Decoder", "PositionalEncoding", "EncoderBlock", "PositionwiseFeedForward", "MultiHeadAttention"]
15
15
 
16
- tf.config.run_functions_eagerly(True)
17
-
18
16
 
19
17
  class PositionalEncoding(layers.Layer, NestedObject):
20
18
  """Compute positional encoding"""
@@ -20,7 +20,7 @@ class PatchEmbedding(nn.Module):
20
20
  channels, height, width = input_shape
21
21
  self.patch_size = patch_size
22
22
  self.interpolate = True if patch_size[0] == patch_size[1] else False
23
- self.grid_size = tuple([s // p for s, p in zip((height, width), self.patch_size)])
23
+ self.grid_size = tuple(s // p for s, p in zip((height, width), self.patch_size))
24
24
  self.num_patches = self.grid_size[0] * self.grid_size[1]
25
25
 
26
26
  self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
@@ -22,7 +22,7 @@ class PatchEmbedding(layers.Layer, NestedObject):
22
22
  height, width, _ = input_shape
23
23
  self.patch_size = patch_size
24
24
  self.interpolate = True if patch_size[0] == patch_size[1] else False
25
- self.grid_size = tuple([s // p for s, p in zip((height, width), self.patch_size)])
25
+ self.grid_size = tuple(s // p for s, p in zip((height, width), self.patch_size))
26
26
  self.num_patches = self.grid_size[0] * self.grid_size[1]
27
27
 
28
28
  self.cls_token = self.add_weight(shape=(1, 1, embed_dim), initializer="zeros", trainable=True, name="cls_token")
@@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
8
8
  import numpy as np
9
9
 
10
10
  from doctr.models.builder import DocumentBuilder
11
- from doctr.utils.geometry import extract_crops, extract_rcrops, rotate_image
11
+ from doctr.utils.geometry import extract_crops, extract_rcrops, remove_image_padding, rotate_image
12
12
 
13
13
  from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds
14
14
  from ..classification import crop_orientation_predictor, page_orientation_predictor
@@ -48,9 +48,15 @@ class _OCRPredictor:
48
48
  ) -> None:
49
49
  self.assume_straight_pages = assume_straight_pages
50
50
  self.straighten_pages = straighten_pages
51
- self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor(pretrained=True)
51
+ self._page_orientation_disabled = kwargs.pop("disable_page_orientation", False)
52
+ self._crop_orientation_disabled = kwargs.pop("disable_crop_orientation", False)
53
+ self.crop_orientation_predictor = (
54
+ None
55
+ if assume_straight_pages
56
+ else crop_orientation_predictor(pretrained=True, disabled=self._crop_orientation_disabled)
57
+ )
52
58
  self.page_orientation_predictor = (
53
- page_orientation_predictor(pretrained=True)
59
+ page_orientation_predictor(pretrained=True, disabled=self._page_orientation_disabled)
54
60
  if detect_orientation or straighten_pages or not assume_straight_pages
55
61
  else None
56
62
  )
@@ -101,8 +107,8 @@ class _OCRPredictor:
101
107
  ]
102
108
  )
103
109
  return [
104
- # We exapnd if the page is wider than tall and the angle is 90 or -90
105
- rotate_image(page, angle, expand=page.shape[1] > page.shape[0] and abs(angle) == 90)
110
+ # expand if height and width are not equal, then remove the padding
111
+ remove_image_padding(rotate_image(page, angle, expand=page.shape[0] != page.shape[1]))
106
112
  for page, angle in zip(pages, origin_pages_orientations)
107
113
  ]
108
114
 
@@ -112,13 +118,18 @@ class _OCRPredictor:
112
118
  loc_preds: List[np.ndarray],
113
119
  channels_last: bool,
114
120
  assume_straight_pages: bool = False,
121
+ assume_horizontal: bool = False,
115
122
  ) -> List[List[np.ndarray]]:
116
- extraction_fn = extract_crops if assume_straight_pages else extract_rcrops
117
-
118
- crops = [
119
- extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator]
120
- for page, _boxes in zip(pages, loc_preds)
121
- ]
123
+ if assume_straight_pages:
124
+ crops = [
125
+ extract_crops(page, _boxes[:, :4], channels_last=channels_last)
126
+ for page, _boxes in zip(pages, loc_preds)
127
+ ]
128
+ else:
129
+ crops = [
130
+ extract_rcrops(page, _boxes[:, :4], channels_last=channels_last, assume_horizontal=assume_horizontal)
131
+ for page, _boxes in zip(pages, loc_preds)
132
+ ]
122
133
  return crops
123
134
 
124
135
  @staticmethod
@@ -127,8 +138,9 @@ class _OCRPredictor:
127
138
  loc_preds: List[np.ndarray],
128
139
  channels_last: bool,
129
140
  assume_straight_pages: bool = False,
141
+ assume_horizontal: bool = False,
130
142
  ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
131
- crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages)
143
+ crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages, assume_horizontal)
132
144
 
133
145
  # Avoid sending zero-sized crops
134
146
  is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
@@ -97,6 +97,9 @@ class OCRPredictor(nn.Module, _OCRPredictor):
97
97
  origin_pages_orientations = None
98
98
  if self.straighten_pages:
99
99
  pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
100
+ # update page shapes after straightening
101
+ origin_page_shapes = [page.shape[:2] for page in pages]
102
+
100
103
  # Forward again to get predictions on straight pages
101
104
  loc_preds = self.det_predictor(pages, **kwargs)
102
105
 
@@ -120,6 +123,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
120
123
  loc_preds,
121
124
  channels_last=channels_last,
122
125
  assume_straight_pages=self.assume_straight_pages,
126
+ assume_horizontal=self._page_orientation_disabled,
123
127
  )
124
128
  # Rectify crop orientation and get crop orientation predictions
125
129
  crop_orientations: Any = []
@@ -97,6 +97,9 @@ class OCRPredictor(NestedObject, _OCRPredictor):
97
97
  origin_pages_orientations = None
98
98
  if self.straighten_pages:
99
99
  pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
100
+ # update page shapes after straightening
101
+ origin_page_shapes = [page.shape[:2] for page in pages]
102
+
100
103
  # forward again to get predictions on straight pages
101
104
  loc_preds_dict = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
102
105
 
@@ -113,7 +116,11 @@ class OCRPredictor(NestedObject, _OCRPredictor):
113
116
 
114
117
  # Crop images
115
118
  crops, loc_preds = self._prepare_crops(
116
- pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages
119
+ pages,
120
+ loc_preds,
121
+ channels_last=True,
122
+ assume_straight_pages=self.assume_straight_pages,
123
+ assume_horizontal=self._page_orientation_disabled,
117
124
  )
118
125
  # Rectify crop orientation and get crop orientation predictions
119
126
  crop_orientations: Any = []
@@ -41,7 +41,7 @@ class PreProcessor(NestedObject):
41
41
  self.resize = Resize(output_size, **kwargs)
42
42
  # Perform the division by 255 at the same time
43
43
  self.normalize = Normalize(mean, std)
44
- self._runs_on_cuda = tf.test.is_gpu_available()
44
+ self._runs_on_cuda = tf.config.list_physical_devices("GPU") != []
45
45
 
46
46
  def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]:
47
47
  """Gather samples into batches for inference purposes
@@ -13,7 +13,7 @@ from tensorflow.keras.models import Model, Sequential
13
13
  from doctr.datasets import VOCABS
14
14
 
15
15
  from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
16
- from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
16
+ from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
17
17
  from ..core import RecognitionModel, RecognitionPostProcessor
18
18
 
19
19
  __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
@@ -24,21 +24,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
24
24
  "std": (0.299, 0.296, 0.301),
25
25
  "input_shape": (32, 128, 3),
26
26
  "vocab": VOCABS["legacy_french"],
27
- "url": "https://doctr-static.mindee.com/models?id=v0.3.0/crnn_vgg16_bn-76b7f2c6.zip&src=0",
27
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_vgg16_bn-9c188f45.weights.h5&src=0",
28
28
  },
29
29
  "crnn_mobilenet_v3_small": {
30
30
  "mean": (0.694, 0.695, 0.693),
31
31
  "std": (0.299, 0.296, 0.301),
32
32
  "input_shape": (32, 128, 3),
33
33
  "vocab": VOCABS["french"],
34
- "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_mobilenet_v3_small-7f36edec.zip&src=0",
34
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_small-54850265.weights.h5&src=0",
35
35
  },
36
36
  "crnn_mobilenet_v3_large": {
37
37
  "mean": (0.694, 0.695, 0.693),
38
38
  "std": (0.299, 0.296, 0.301),
39
39
  "input_shape": (32, 128, 3),
40
40
  "vocab": VOCABS["french"],
41
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/crnn_mobilenet_v3_large-cccc50b1.zip&src=0",
41
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_large-c64045e5.weights.h5&src=0",
42
42
  },
43
43
  }
44
44
 
@@ -128,7 +128,7 @@ class CRNN(RecognitionModel, Model):
128
128
 
129
129
  def __init__(
130
130
  self,
131
- feature_extractor: tf.keras.Model,
131
+ feature_extractor: Model,
132
132
  vocab: str,
133
133
  rnn_units: int = 128,
134
134
  exportable: bool = False,
@@ -245,9 +245,11 @@ def _crnn(
245
245
 
246
246
  # Build the model
247
247
  model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
248
+ _build_model(model)
248
249
  # Load pretrained parameters
249
250
  if pretrained:
250
- load_pretrained_params(model, _cfg["url"])
251
+ # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
252
+ load_pretrained_params(model, _cfg["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
251
253
 
252
254
  return model
253
255
 
@@ -13,7 +13,7 @@ from doctr.datasets import VOCABS
13
13
  from doctr.models.classification import magc_resnet31
14
14
  from doctr.models.modules.transformer import Decoder, PositionalEncoding
15
15
 
16
- from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
16
+ from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
17
17
  from .base import _MASTER, _MASTERPostProcessor
18
18
 
19
19
  __all__ = ["MASTER", "master"]
@@ -25,7 +25,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
25
25
  "std": (0.299, 0.296, 0.301),
26
26
  "input_shape": (32, 128, 3),
27
27
  "vocab": VOCABS["french"],
28
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/master-a8232e9f.zip&src=0",
28
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/master-d7fdaeff.weights.h5&src=0",
29
29
  },
30
30
  }
31
31
 
@@ -51,7 +51,7 @@ class MASTER(_MASTER, Model):
51
51
 
52
52
  def __init__(
53
53
  self,
54
- feature_extractor: tf.keras.Model,
54
+ feature_extractor: Model,
55
55
  vocab: str,
56
56
  d_model: int = 512,
57
57
  dff: int = 2048,
@@ -290,9 +290,14 @@ def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool
290
290
  cfg=_cfg,
291
291
  **kwargs,
292
292
  )
293
+ _build_model(model)
294
+
293
295
  # Load pretrained parameters
294
296
  if pretrained:
295
- load_pretrained_params(model, default_cfgs[arch]["url"])
297
+ # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
298
+ load_pretrained_params(
299
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
300
+ )
296
301
 
297
302
  return model
298
303
 
@@ -16,7 +16,7 @@ from doctr.datasets import VOCABS
16
16
  from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
17
17
 
18
18
  from ...classification import vit_s
19
- from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
19
+ from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
20
20
  from .base import _PARSeq, _PARSeqPostProcessor
21
21
 
22
22
  __all__ = ["PARSeq", "parseq"]
@@ -27,7 +27,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
27
27
  "std": (0.299, 0.296, 0.301),
28
28
  "input_shape": (32, 128, 3),
29
29
  "vocab": VOCABS["french"],
30
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/parseq-24cf693e.zip&src=0",
30
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/parseq-4152a87e.weights.h5&src=0",
31
31
  },
32
32
  }
33
33
 
@@ -43,7 +43,7 @@ class CharEmbedding(layers.Layer):
43
43
 
44
44
  def __init__(self, vocab_size: int, d_model: int):
45
45
  super(CharEmbedding, self).__init__()
46
- self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)
46
+ self.embedding = layers.Embedding(vocab_size, d_model)
47
47
  self.d_model = d_model
48
48
 
49
49
  def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
@@ -167,7 +167,6 @@ class PARSeq(_PARSeq, Model):
167
167
 
168
168
  self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
169
169
 
170
- @tf.function
171
170
  def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
172
171
  # Generates permutations of the target sequence.
173
172
  # Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
@@ -214,7 +213,6 @@ class PARSeq(_PARSeq, Model):
214
213
  )
215
214
  return combined
216
215
 
217
- @tf.function
218
216
  def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
219
217
  # Generate source and target mask for the decoder attention.
220
218
  sz = permutation.shape[0]
@@ -234,11 +232,10 @@ class PARSeq(_PARSeq, Model):
234
232
  target_mask = mask[1:, :-1]
235
233
  return tf.cast(source_mask, dtype=tf.bool), tf.cast(target_mask, dtype=tf.bool)
236
234
 
237
- @tf.function
238
235
  def decode(
239
236
  self,
240
237
  target: tf.Tensor,
241
- memory: tf,
238
+ memory: tf.Tensor,
242
239
  target_mask: Optional[tf.Tensor] = None,
243
240
  target_query: Optional[tf.Tensor] = None,
244
241
  **kwargs: Any,
@@ -476,9 +473,14 @@ def _parseq(
476
473
 
477
474
  # Build the model
478
475
  model = PARSeq(feat_extractor, cfg=_cfg, **kwargs)
476
+ _build_model(model)
477
+
479
478
  # Load pretrained parameters
480
479
  if pretrained:
481
- load_pretrained_params(model, default_cfgs[arch]["url"])
480
+ # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
481
+ load_pretrained_params(
482
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
483
+ )
482
484
 
483
485
  return model
484
486
 
@@ -13,7 +13,7 @@ from doctr.datasets import VOCABS
13
13
  from doctr.utils.repr import NestedObject
14
14
 
15
15
  from ...classification import resnet31
16
- from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
16
+ from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
17
17
  from ..core import RecognitionModel, RecognitionPostProcessor
18
18
 
19
19
  __all__ = ["SAR", "sar_resnet31"]
@@ -24,7 +24,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
24
24
  "std": (0.299, 0.296, 0.301),
25
25
  "input_shape": (32, 128, 3),
26
26
  "vocab": VOCABS["french"],
27
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/sar_resnet31-c41e32a5.zip&src=0",
27
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/sar_resnet31-5a58806c.weights.h5&src=0",
28
28
  },
29
29
  }
30
30
 
@@ -392,9 +392,13 @@ def _sar(
392
392
 
393
393
  # Build the model
394
394
  model = SAR(feat_extractor, cfg=_cfg, **kwargs)
395
+ _build_model(model)
395
396
  # Load pretrained parameters
396
397
  if pretrained:
397
- load_pretrained_params(model, default_cfgs[arch]["url"])
398
+ # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
399
+ load_pretrained_params(
400
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
401
+ )
398
402
 
399
403
  return model
400
404