python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. doctr/datasets/__init__.py +2 -0
  2. doctr/datasets/cord.py +6 -4
  3. doctr/datasets/datasets/base.py +3 -2
  4. doctr/datasets/datasets/pytorch.py +4 -2
  5. doctr/datasets/datasets/tensorflow.py +4 -2
  6. doctr/datasets/detection.py +6 -3
  7. doctr/datasets/doc_artefacts.py +2 -1
  8. doctr/datasets/funsd.py +7 -8
  9. doctr/datasets/generator/base.py +3 -2
  10. doctr/datasets/generator/pytorch.py +3 -1
  11. doctr/datasets/generator/tensorflow.py +3 -1
  12. doctr/datasets/ic03.py +3 -2
  13. doctr/datasets/ic13.py +2 -1
  14. doctr/datasets/iiit5k.py +6 -4
  15. doctr/datasets/iiithws.py +2 -1
  16. doctr/datasets/imgur5k.py +3 -2
  17. doctr/datasets/loader.py +4 -2
  18. doctr/datasets/mjsynth.py +2 -1
  19. doctr/datasets/ocr.py +2 -1
  20. doctr/datasets/orientation.py +40 -0
  21. doctr/datasets/recognition.py +3 -2
  22. doctr/datasets/sroie.py +2 -1
  23. doctr/datasets/svhn.py +2 -1
  24. doctr/datasets/svt.py +3 -2
  25. doctr/datasets/synthtext.py +2 -1
  26. doctr/datasets/utils.py +27 -11
  27. doctr/datasets/vocabs.py +26 -1
  28. doctr/datasets/wildreceipt.py +111 -0
  29. doctr/file_utils.py +3 -1
  30. doctr/io/elements.py +52 -35
  31. doctr/io/html.py +5 -3
  32. doctr/io/image/base.py +5 -4
  33. doctr/io/image/pytorch.py +12 -7
  34. doctr/io/image/tensorflow.py +11 -6
  35. doctr/io/pdf.py +5 -4
  36. doctr/io/reader.py +13 -5
  37. doctr/models/_utils.py +30 -53
  38. doctr/models/artefacts/barcode.py +4 -3
  39. doctr/models/artefacts/face.py +4 -2
  40. doctr/models/builder.py +58 -43
  41. doctr/models/classification/__init__.py +1 -0
  42. doctr/models/classification/magc_resnet/pytorch.py +5 -2
  43. doctr/models/classification/magc_resnet/tensorflow.py +5 -2
  44. doctr/models/classification/mobilenet/pytorch.py +16 -4
  45. doctr/models/classification/mobilenet/tensorflow.py +29 -20
  46. doctr/models/classification/predictor/pytorch.py +3 -2
  47. doctr/models/classification/predictor/tensorflow.py +2 -1
  48. doctr/models/classification/resnet/pytorch.py +23 -13
  49. doctr/models/classification/resnet/tensorflow.py +33 -26
  50. doctr/models/classification/textnet/__init__.py +6 -0
  51. doctr/models/classification/textnet/pytorch.py +275 -0
  52. doctr/models/classification/textnet/tensorflow.py +267 -0
  53. doctr/models/classification/vgg/pytorch.py +4 -2
  54. doctr/models/classification/vgg/tensorflow.py +5 -2
  55. doctr/models/classification/vit/pytorch.py +9 -3
  56. doctr/models/classification/vit/tensorflow.py +9 -3
  57. doctr/models/classification/zoo.py +7 -2
  58. doctr/models/core.py +1 -1
  59. doctr/models/detection/__init__.py +1 -0
  60. doctr/models/detection/_utils/pytorch.py +7 -1
  61. doctr/models/detection/_utils/tensorflow.py +7 -3
  62. doctr/models/detection/core.py +9 -3
  63. doctr/models/detection/differentiable_binarization/base.py +37 -25
  64. doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
  65. doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
  66. doctr/models/detection/fast/__init__.py +6 -0
  67. doctr/models/detection/fast/base.py +256 -0
  68. doctr/models/detection/fast/pytorch.py +442 -0
  69. doctr/models/detection/fast/tensorflow.py +428 -0
  70. doctr/models/detection/linknet/base.py +12 -5
  71. doctr/models/detection/linknet/pytorch.py +28 -15
  72. doctr/models/detection/linknet/tensorflow.py +68 -88
  73. doctr/models/detection/predictor/pytorch.py +16 -6
  74. doctr/models/detection/predictor/tensorflow.py +13 -5
  75. doctr/models/detection/zoo.py +19 -16
  76. doctr/models/factory/hub.py +20 -10
  77. doctr/models/kie_predictor/base.py +2 -1
  78. doctr/models/kie_predictor/pytorch.py +28 -36
  79. doctr/models/kie_predictor/tensorflow.py +27 -27
  80. doctr/models/modules/__init__.py +1 -0
  81. doctr/models/modules/layers/__init__.py +6 -0
  82. doctr/models/modules/layers/pytorch.py +166 -0
  83. doctr/models/modules/layers/tensorflow.py +175 -0
  84. doctr/models/modules/transformer/pytorch.py +24 -22
  85. doctr/models/modules/transformer/tensorflow.py +6 -4
  86. doctr/models/modules/vision_transformer/pytorch.py +2 -4
  87. doctr/models/modules/vision_transformer/tensorflow.py +2 -4
  88. doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
  89. doctr/models/predictor/base.py +14 -3
  90. doctr/models/predictor/pytorch.py +26 -29
  91. doctr/models/predictor/tensorflow.py +25 -22
  92. doctr/models/preprocessor/pytorch.py +14 -9
  93. doctr/models/preprocessor/tensorflow.py +10 -5
  94. doctr/models/recognition/core.py +4 -1
  95. doctr/models/recognition/crnn/pytorch.py +23 -16
  96. doctr/models/recognition/crnn/tensorflow.py +25 -17
  97. doctr/models/recognition/master/base.py +4 -1
  98. doctr/models/recognition/master/pytorch.py +20 -9
  99. doctr/models/recognition/master/tensorflow.py +20 -8
  100. doctr/models/recognition/parseq/base.py +4 -1
  101. doctr/models/recognition/parseq/pytorch.py +28 -22
  102. doctr/models/recognition/parseq/tensorflow.py +22 -11
  103. doctr/models/recognition/predictor/_utils.py +3 -2
  104. doctr/models/recognition/predictor/pytorch.py +3 -2
  105. doctr/models/recognition/predictor/tensorflow.py +2 -1
  106. doctr/models/recognition/sar/pytorch.py +14 -7
  107. doctr/models/recognition/sar/tensorflow.py +23 -14
  108. doctr/models/recognition/utils.py +5 -1
  109. doctr/models/recognition/vitstr/base.py +4 -1
  110. doctr/models/recognition/vitstr/pytorch.py +22 -13
  111. doctr/models/recognition/vitstr/tensorflow.py +21 -10
  112. doctr/models/recognition/zoo.py +4 -2
  113. doctr/models/utils/pytorch.py +24 -6
  114. doctr/models/utils/tensorflow.py +22 -3
  115. doctr/models/zoo.py +21 -3
  116. doctr/transforms/functional/base.py +8 -3
  117. doctr/transforms/functional/pytorch.py +23 -6
  118. doctr/transforms/functional/tensorflow.py +25 -5
  119. doctr/transforms/modules/base.py +12 -5
  120. doctr/transforms/modules/pytorch.py +10 -12
  121. doctr/transforms/modules/tensorflow.py +17 -9
  122. doctr/utils/common_types.py +1 -1
  123. doctr/utils/data.py +4 -2
  124. doctr/utils/fonts.py +3 -2
  125. doctr/utils/geometry.py +95 -26
  126. doctr/utils/metrics.py +36 -22
  127. doctr/utils/multithreading.py +5 -3
  128. doctr/utils/repr.py +3 -1
  129. doctr/utils/visualization.py +31 -8
  130. doctr/version.py +1 -1
  131. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
  132. python_doctr-0.8.1.dist-info/RECORD +173 -0
  133. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
  134. python_doctr-0.7.0.dist-info/RECORD +0 -161
  135. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
  136. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
  137. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -13,7 +13,7 @@ from doctr.io.elements import Document
13
13
  from doctr.models._utils import estimate_orientation, get_language, invert_data_structure
14
14
  from doctr.models.detection.predictor import DetectionPredictor
15
15
  from doctr.models.recognition.predictor import RecognitionPredictor
16
- from doctr.utils.geometry import rotate_boxes, rotate_image
16
+ from doctr.utils.geometry import rotate_image
17
17
 
18
18
  from .base import _KIEPredictor
19
19
 
@@ -24,6 +24,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
24
24
  """Implements an object able to localize and identify text elements in a set of documents
25
25
 
26
26
  Args:
27
+ ----
27
28
  det_predictor: detection module
28
29
  reco_predictor: recognition module
29
30
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
@@ -35,7 +36,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
35
36
  page. Doing so will slightly deteriorate the overall latency.
36
37
  detect_language: if True, the language prediction will be added to the predictions for each
37
38
  page. Doing so will slightly deteriorate the overall latency.
38
- kwargs: keyword args of `DocumentBuilder`
39
+ **kwargs: keyword args of `DocumentBuilder`
39
40
  """
40
41
 
41
42
  def __init__(
@@ -59,7 +60,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
59
60
  self.detect_orientation = detect_orientation
60
61
  self.detect_language = detect_language
61
62
 
62
- @torch.no_grad()
63
+ @torch.inference_mode()
63
64
  def forward(
64
65
  self,
65
66
  pages: List[Union[np.ndarray, torch.Tensor]],
@@ -71,11 +72,20 @@ class KIEPredictor(nn.Module, _KIEPredictor):
71
72
 
72
73
  origin_page_shapes = [page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:] for page in pages]
73
74
 
75
+ # Localize text elements
76
+ loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
77
+
74
78
  # Detect document rotation and rotate pages
79
+ seg_maps = [
80
+ np.where(np.expand_dims(np.amax(out_map, axis=-1), axis=-1) > kwargs.get("bin_thresh", 0.3), 255, 0).astype(
81
+ np.uint8
82
+ )
83
+ for out_map in out_maps
84
+ ]
75
85
  if self.detect_orientation:
76
- origin_page_orientations = [estimate_orientation(page) for page in pages] # type: ignore[arg-type]
86
+ origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
77
87
  orientations = [
78
- {"value": orientation_page, "confidence": 1.0} for orientation_page in origin_page_orientations
88
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
79
89
  ]
80
90
  else:
81
91
  orientations = None
@@ -83,29 +93,28 @@ class KIEPredictor(nn.Module, _KIEPredictor):
83
93
  origin_page_orientations = (
84
94
  origin_page_orientations
85
95
  if self.detect_orientation
86
- else [estimate_orientation(page) for page in pages] # type: ignore[arg-type]
96
+ else [estimate_orientation(seq_map) for seq_map in seg_maps]
87
97
  )
88
- pages = [
89
- rotate_image(page, -angle, expand=True) # type: ignore[arg-type]
90
- for page, angle in zip(pages, origin_page_orientations)
91
- ]
98
+ pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
99
+ # Forward again to get predictions on straight pages
100
+ loc_preds = self.det_predictor(pages, **kwargs)
92
101
 
93
- # Localize text elements
94
- loc_preds = self.det_predictor(pages, **kwargs)
95
102
  dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
96
103
  # Check whether crop mode should be switched to channels first
97
104
  channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
98
105
 
99
106
  # Rectify crops if aspect ratio
100
- dict_loc_preds = {
101
- k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items() # type: ignore[arg-type]
102
- }
107
+ dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}
108
+
109
+ # Apply hooks to loc_preds if any
110
+ for hook in self.hooks:
111
+ dict_loc_preds = hook(dict_loc_preds)
103
112
 
104
113
  # Crop images
105
114
  crops = {}
106
115
  for class_name in dict_loc_preds.keys():
107
116
  crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
108
- pages, # type: ignore[arg-type]
117
+ pages,
109
118
  dict_loc_preds[class_name],
110
119
  channels_last=channels_last,
111
120
  assume_straight_pages=self.assume_straight_pages,
@@ -136,29 +145,12 @@ class KIEPredictor(nn.Module, _KIEPredictor):
136
145
  languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
137
146
  else:
138
147
  languages_dict = None
139
- # Rotate back pages and boxes while keeping original image size
140
- if self.straighten_pages:
141
- boxes_per_page = [
142
- {
143
- k: rotate_boxes(
144
- page_boxes,
145
- angle,
146
- orig_shape=page.shape[:2]
147
- if isinstance(page, np.ndarray)
148
- else page.shape[1:], # type: ignore[arg-type]
149
- target_shape=mask, # type: ignore[arg-type]
150
- )
151
- for k, page_boxes in page_boxes_dict.items()
152
- }
153
- for page_boxes_dict, page, angle, mask in zip(
154
- boxes_per_page, pages, origin_page_orientations, origin_page_shapes
155
- )
156
- ]
157
148
 
158
149
  out = self.doc_builder(
150
+ pages,
159
151
  boxes_per_page,
160
152
  text_preds_per_page,
161
- [page.shape[:2] if channels_last else page.shape[-2:] for page in pages], # type: ignore[misc]
153
+ origin_page_shapes,
162
154
  orientations,
163
155
  languages_dict,
164
156
  )
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -12,7 +12,7 @@ from doctr.io.elements import Document
12
12
  from doctr.models._utils import estimate_orientation, get_language, invert_data_structure
13
13
  from doctr.models.detection.predictor import DetectionPredictor
14
14
  from doctr.models.recognition.predictor import RecognitionPredictor
15
- from doctr.utils.geometry import rotate_boxes, rotate_image
15
+ from doctr.utils.geometry import rotate_image
16
16
  from doctr.utils.repr import NestedObject
17
17
 
18
18
  from .base import _KIEPredictor
@@ -24,6 +24,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
24
24
  """Implements an object able to localize and identify text elements in a set of documents
25
25
 
26
26
  Args:
27
+ ----
27
28
  det_predictor: detection module
28
29
  reco_predictor: recognition module
29
30
  assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
@@ -35,7 +36,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
35
36
  page. Doing so will slightly deteriorate the overall latency.
36
37
  detect_language: if True, the language prediction will be added to the predictions for each
37
38
  page. Doing so will slightly deteriorate the overall latency.
38
- kwargs: keyword args of `DocumentBuilder`
39
+ **kwargs: keyword args of `DocumentBuilder`
39
40
  """
40
41
 
41
42
  _children_names = ["det_predictor", "reco_predictor", "doc_builder"]
@@ -71,27 +72,41 @@ class KIEPredictor(NestedObject, _KIEPredictor):
71
72
 
72
73
  origin_page_shapes = [page.shape[:2] for page in pages]
73
74
 
75
+ # Localize text elements
76
+ loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
77
+
74
78
  # Detect document rotation and rotate pages
79
+ seg_maps = [
80
+ np.where(np.expand_dims(np.amax(out_map, axis=-1), axis=-1) > kwargs.get("bin_thresh", 0.3), 255, 0).astype(
81
+ np.uint8
82
+ )
83
+ for out_map in out_maps
84
+ ]
75
85
  if self.detect_orientation:
76
- origin_page_orientations = [estimate_orientation(page) for page in pages]
86
+ origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
77
87
  orientations = [
78
- {"value": orientation_page, "confidence": 1.0} for orientation_page in origin_page_orientations
88
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
79
89
  ]
80
90
  else:
81
91
  orientations = None
82
92
  if self.straighten_pages:
83
93
  origin_page_orientations = (
84
- origin_page_orientations if self.detect_orientation else [estimate_orientation(page) for page in pages]
94
+ origin_page_orientations
95
+ if self.detect_orientation
96
+ else [estimate_orientation(seq_map) for seq_map in seg_maps]
85
97
  )
86
- pages = [rotate_image(page, -angle, expand=True) for page, angle in zip(pages, origin_page_orientations)]
87
-
88
- # Localize text elements
89
- loc_preds = self.det_predictor(pages, **kwargs)
98
+ pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
99
+ # Forward again to get predictions on straight pages
100
+ loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
90
101
 
91
- dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
102
+ dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
92
103
  # Rectify crops if aspect ratio
93
104
  dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}
94
105
 
106
+ # Apply hooks to loc_preds if any
107
+ for hook in self.hooks:
108
+ dict_loc_preds = hook(dict_loc_preds)
109
+
95
110
  # Crop images
96
111
  crops = {}
97
112
  for class_name in dict_loc_preds.keys():
@@ -126,24 +141,9 @@ class KIEPredictor(NestedObject, _KIEPredictor):
126
141
  languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
127
142
  else:
128
143
  languages_dict = None
129
- # Rotate back pages and boxes while keeping original image size
130
- if self.straighten_pages:
131
- boxes_per_page = [
132
- {
133
- k: rotate_boxes(
134
- page_boxes,
135
- angle,
136
- orig_shape=page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:],
137
- target_shape=mask, # type: ignore[arg-type]
138
- )
139
- for k, page_boxes in page_boxes_dict.items()
140
- }
141
- for page_boxes_dict, page, angle, mask in zip(
142
- boxes_per_page, pages, origin_page_orientations, origin_page_shapes
143
- )
144
- ]
145
144
 
146
145
  out = self.doc_builder(
146
+ pages,
147
147
  boxes_per_page,
148
148
  text_preds_per_page,
149
149
  origin_page_shapes, # type: ignore[arg-type]
@@ -1,2 +1,3 @@
1
+ from .layers import *
1
2
  from .transformer import *
2
3
  from .vision_transformer import *
@@ -0,0 +1,6 @@
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
+
3
+ if is_tf_available():
4
+ from .tensorflow import *
5
+ elif is_torch_available():
6
+ from .pytorch import * # type: ignore[assignment]
@@ -0,0 +1,166 @@
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ __all__ = ["FASTConvLayer"]
13
+
14
+
15
+ class FASTConvLayer(nn.Module):
16
+ """Convolutional layer used in the TextNet and FAST architectures"""
17
+
18
+ def __init__(
19
+ self,
20
+ in_channels: int,
21
+ out_channels: int,
22
+ kernel_size: Union[int, Tuple[int, int]],
23
+ stride: int = 1,
24
+ dilation: int = 1,
25
+ groups: int = 1,
26
+ bias: bool = False,
27
+ ) -> None:
28
+ super().__init__()
29
+
30
+ self.groups = groups
31
+ self.in_channels = in_channels
32
+ self.converted_ks = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
33
+
34
+ self.hor_conv, self.hor_bn = None, None
35
+ self.ver_conv, self.ver_bn = None, None
36
+
37
+ padding = (int(((self.converted_ks[0] - 1) * dilation) / 2), int(((self.converted_ks[1] - 1) * dilation) / 2))
38
+
39
+ self.activation = nn.ReLU(inplace=True)
40
+ self.conv = nn.Conv2d(
41
+ in_channels,
42
+ out_channels,
43
+ kernel_size=self.converted_ks,
44
+ stride=stride,
45
+ padding=padding,
46
+ dilation=dilation,
47
+ groups=groups,
48
+ bias=bias,
49
+ )
50
+
51
+ self.bn = nn.BatchNorm2d(out_channels)
52
+
53
+ if self.converted_ks[1] != 1:
54
+ self.ver_conv = nn.Conv2d(
55
+ in_channels,
56
+ out_channels,
57
+ kernel_size=(self.converted_ks[0], 1),
58
+ padding=(int(((self.converted_ks[0] - 1) * dilation) / 2), 0),
59
+ stride=stride,
60
+ dilation=dilation,
61
+ groups=groups,
62
+ bias=bias,
63
+ )
64
+ self.ver_bn = nn.BatchNorm2d(out_channels)
65
+
66
+ if self.converted_ks[0] != 1:
67
+ self.hor_conv = nn.Conv2d(
68
+ in_channels,
69
+ out_channels,
70
+ kernel_size=(1, self.converted_ks[1]),
71
+ padding=(0, int(((self.converted_ks[1] - 1) * dilation) / 2)),
72
+ stride=stride,
73
+ dilation=dilation,
74
+ groups=groups,
75
+ bias=bias,
76
+ )
77
+ self.hor_bn = nn.BatchNorm2d(out_channels)
78
+
79
+ self.rbr_identity = nn.BatchNorm2d(in_channels) if out_channels == in_channels and stride == 1 else None
80
+
81
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
82
+ if hasattr(self, "fused_conv"):
83
+ return self.activation(self.fused_conv(x))
84
+
85
+ main_outputs = self.bn(self.conv(x))
86
+ vertical_outputs = self.ver_bn(self.ver_conv(x)) if self.ver_conv is not None and self.ver_bn is not None else 0
87
+ horizontal_outputs = (
88
+ self.hor_bn(self.hor_conv(x)) if self.hor_bn is not None and self.hor_conv is not None else 0
89
+ )
90
+ id_out = self.rbr_identity(x) if self.rbr_identity is not None and self.ver_bn is not None else 0
91
+
92
+ return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
93
+
94
+ # The following logic is used to reparametrize the layer
95
+ # Borrowed from: https://github.com/czczup/FAST/blob/main/models/utils/nas_utils.py
96
+ def _identity_to_conv(
97
+ self, identity: Union[nn.BatchNorm2d, None]
98
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
99
+ if identity is None or identity.running_var is None:
100
+ return 0, 0
101
+ if not hasattr(self, "id_tensor"):
102
+ input_dim = self.in_channels // self.groups
103
+ kernel_value = np.zeros((self.in_channels, input_dim, 1, 1), dtype=np.float32)
104
+ for i in range(self.in_channels):
105
+ kernel_value[i, i % input_dim, 0, 0] = 1
106
+ id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
107
+ self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
108
+ kernel = self.id_tensor
109
+ std = (identity.running_var + identity.eps).sqrt() # type: ignore[attr-defined]
110
+ t = (identity.weight / std).reshape(-1, 1, 1, 1)
111
+ return kernel * t, identity.bias - identity.running_mean * identity.weight / std
112
+
113
+ def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> Tuple[torch.Tensor, torch.Tensor]:
114
+ kernel = conv.weight
115
+ kernel = self._pad_to_mxn_tensor(kernel)
116
+ std = (bn.running_var + bn.eps).sqrt() # type: ignore
117
+ t = (bn.weight / std).reshape(-1, 1, 1, 1)
118
+ return kernel * t, bn.bias - bn.running_mean * bn.weight / std
119
+
120
+ def _get_equivalent_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
121
+ kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
122
+ if self.ver_conv is not None:
123
+ kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) # type: ignore[arg-type]
124
+ else:
125
+ kernel_mx1, bias_mx1 = 0, 0 # type: ignore[assignment]
126
+ if self.hor_conv is not None:
127
+ kernel_1xn, bias_1xn = self._fuse_bn_tensor(self.hor_conv, self.hor_bn) # type: ignore[arg-type]
128
+ else:
129
+ kernel_1xn, bias_1xn = 0, 0 # type: ignore[assignment]
130
+ kernel_id, bias_id = self._identity_to_conv(self.rbr_identity)
131
+ kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id
132
+ bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id
133
+ return kernel_mxn, bias_mxn
134
+
135
+ def _pad_to_mxn_tensor(self, kernel: torch.Tensor) -> torch.Tensor:
136
+ kernel_height, kernel_width = self.converted_ks
137
+ height, width = kernel.shape[2:]
138
+ pad_left_right = (kernel_width - width) // 2
139
+ pad_top_down = (kernel_height - height) // 2
140
+ return torch.nn.functional.pad(kernel, [pad_left_right, pad_left_right, pad_top_down, pad_top_down], value=0)
141
+
142
+ def reparameterize_layer(self):
143
+ if hasattr(self, "fused_conv"):
144
+ return
145
+ kernel, bias = self._get_equivalent_kernel_bias()
146
+ self.fused_conv = nn.Conv2d(
147
+ in_channels=self.conv.in_channels,
148
+ out_channels=self.conv.out_channels,
149
+ kernel_size=self.conv.kernel_size, # type: ignore[arg-type]
150
+ stride=self.conv.stride, # type: ignore[arg-type]
151
+ padding=self.conv.padding, # type: ignore[arg-type]
152
+ dilation=self.conv.dilation, # type: ignore[arg-type]
153
+ groups=self.conv.groups,
154
+ bias=True,
155
+ )
156
+ self.fused_conv.weight.data = kernel
157
+ self.fused_conv.bias.data = bias # type: ignore[union-attr]
158
+ self.deploy = True
159
+ for para in self.parameters():
160
+ para.detach_()
161
+ for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]:
162
+ if hasattr(self, attr):
163
+ self.__delattr__(attr)
164
+
165
+ if hasattr(self, "rbr_identity"):
166
+ self.__delattr__("rbr_identity")
@@ -0,0 +1,175 @@
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, Tuple, Union
7
+
8
+ import numpy as np
9
+ import tensorflow as tf
10
+ from tensorflow.keras import layers
11
+
12
+ from doctr.utils.repr import NestedObject
13
+
14
+ __all__ = ["FASTConvLayer"]
15
+
16
+
17
+ class FASTConvLayer(layers.Layer, NestedObject):
18
+ """Convolutional layer used in the TextNet and FAST architectures"""
19
+
20
+ def __init__(
21
+ self,
22
+ in_channels: int,
23
+ out_channels: int,
24
+ kernel_size: Union[int, Tuple[int, int]],
25
+ stride: int = 1,
26
+ dilation: int = 1,
27
+ groups: int = 1,
28
+ bias: bool = False,
29
+ ) -> None:
30
+ super().__init__()
31
+
32
+ self.groups = groups
33
+ self.in_channels = in_channels
34
+ self.converted_ks = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
35
+
36
+ self.hor_conv, self.hor_bn = None, None
37
+ self.ver_conv, self.ver_bn = None, None
38
+
39
+ padding = ((self.converted_ks[0] - 1) * dilation // 2, (self.converted_ks[1] - 1) * dilation // 2)
40
+
41
+ self.activation = layers.ReLU()
42
+ self.conv_pad = layers.ZeroPadding2D(padding=padding)
43
+
44
+ self.conv = layers.Conv2D(
45
+ filters=out_channels,
46
+ kernel_size=self.converted_ks,
47
+ strides=stride,
48
+ dilation_rate=dilation,
49
+ groups=groups,
50
+ use_bias=bias,
51
+ )
52
+
53
+ self.bn = layers.BatchNormalization()
54
+
55
+ if self.converted_ks[1] != 1:
56
+ self.ver_pad = layers.ZeroPadding2D(
57
+ padding=(int(((self.converted_ks[0] - 1) * dilation) / 2), 0),
58
+ )
59
+ self.ver_conv = layers.Conv2D(
60
+ filters=out_channels,
61
+ kernel_size=(self.converted_ks[0], 1),
62
+ strides=stride,
63
+ dilation_rate=dilation,
64
+ groups=groups,
65
+ use_bias=bias,
66
+ )
67
+ self.ver_bn = layers.BatchNormalization()
68
+
69
+ if self.converted_ks[0] != 1:
70
+ self.hor_pad = layers.ZeroPadding2D(
71
+ padding=(0, int(((self.converted_ks[1] - 1) * dilation) / 2)),
72
+ )
73
+ self.hor_conv = layers.Conv2D(
74
+ filters=out_channels,
75
+ kernel_size=(1, self.converted_ks[1]),
76
+ strides=stride,
77
+ dilation_rate=dilation,
78
+ groups=groups,
79
+ use_bias=bias,
80
+ )
81
+ self.hor_bn = layers.BatchNormalization()
82
+
83
+ self.rbr_identity = layers.BatchNormalization() if out_channels == in_channels and stride == 1 else None
84
+
85
+ def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
86
+ if hasattr(self, "fused_conv"):
87
+ return self.activation(self.fused_conv(self.conv_pad(x, **kwargs), **kwargs))
88
+
89
+ main_outputs = self.bn(self.conv(self.conv_pad(x, **kwargs), **kwargs), **kwargs)
90
+ vertical_outputs = (
91
+ self.ver_bn(self.ver_conv(self.ver_pad(x, **kwargs), **kwargs), **kwargs)
92
+ if self.ver_conv is not None and self.ver_bn is not None
93
+ else 0
94
+ )
95
+ horizontal_outputs = (
96
+ self.hor_bn(self.hor_conv(self.hor_pad(x, **kwargs), **kwargs), **kwargs)
97
+ if self.hor_bn is not None and self.hor_conv is not None
98
+ else 0
99
+ )
100
+ id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None and self.ver_bn is not None else 0
101
+
102
+ return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
103
+
104
+ # The following logic is used to reparametrize the layer
105
+ # Adapted from: https://github.com/mindee/doctr/blob/main/doctr/models/modules/layers/pytorch.py
106
+ def _identity_to_conv(
107
+ self, identity: layers.BatchNormalization
108
+ ) -> Union[Tuple[tf.Tensor, tf.Tensor], Tuple[int, int]]:
109
+ if identity is None or not hasattr(identity, "moving_mean") or not hasattr(identity, "moving_variance"):
110
+ return 0, 0
111
+ if not hasattr(self, "id_tensor"):
112
+ input_dim = self.in_channels // self.groups
113
+ kernel_value = np.zeros((self.in_channels, input_dim, 1, 1), dtype=np.float32)
114
+ for i in range(self.in_channels):
115
+ kernel_value[i, i % input_dim, 0, 0] = 1
116
+ id_tensor = tf.constant(kernel_value, dtype=tf.float32)
117
+ self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
118
+ kernel = self.id_tensor
119
+ std = tf.sqrt(identity.moving_variance + identity.epsilon)
120
+ t = tf.reshape(identity.gamma / std, (-1, 1, 1, 1))
121
+ return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std
122
+
123
+ def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> Tuple[tf.Tensor, tf.Tensor]:
124
+ kernel = conv.kernel
125
+ kernel = self._pad_to_mxn_tensor(kernel)
126
+ std = tf.sqrt(bn.moving_variance + bn.epsilon)
127
+ t = tf.reshape(bn.gamma / std, (1, 1, 1, -1))
128
+ return kernel * t, bn.beta - bn.moving_mean * bn.gamma / std
129
+
130
+ def _get_equivalent_kernel_bias(self):
131
+ kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
132
+ if self.ver_conv is not None:
133
+ kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn)
134
+ else:
135
+ kernel_mx1, bias_mx1 = 0, 0
136
+ if self.hor_conv is not None:
137
+ kernel_1xn, bias_1xn = self._fuse_bn_tensor(self.hor_conv, self.hor_bn)
138
+ else:
139
+ kernel_1xn, bias_1xn = 0, 0
140
+ kernel_id, bias_id = self._identity_to_conv(self.rbr_identity)
141
+ if not isinstance(kernel_id, int):
142
+ kernel_id = tf.transpose(kernel_id, (2, 3, 0, 1))
143
+ kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id
144
+ bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id
145
+ return kernel_mxn, bias_mxn
146
+
147
+ def _pad_to_mxn_tensor(self, kernel: tf.Tensor) -> tf.Tensor:
148
+ kernel_height, kernel_width = self.converted_ks
149
+ height, width = kernel.shape[2:]
150
+ pad_left_right = tf.maximum(0, (kernel_width - width) // 2)
151
+ pad_top_down = tf.maximum(0, (kernel_height - height) // 2)
152
+ return tf.pad(kernel, [[0, 0], [0, 0], [pad_top_down, pad_top_down], [pad_left_right, pad_left_right]])
153
+
154
+ def reparameterize_layer(self):
155
+ kernel, bias = self._get_equivalent_kernel_bias()
156
+ self.fused_conv = layers.Conv2D(
157
+ filters=self.conv.filters,
158
+ kernel_size=self.conv.kernel_size,
159
+ strides=self.conv.strides,
160
+ padding=self.conv.padding,
161
+ dilation_rate=self.conv.dilation_rate,
162
+ groups=self.conv.groups,
163
+ use_bias=True,
164
+ )
165
+ # build layer to initialize weights and biases
166
+ self.fused_conv.build(input_shape=(None, None, None, kernel.shape[-2]))
167
+ self.fused_conv.set_weights([kernel.numpy(), bias.numpy()])
168
+ for para in self.trainable_variables:
169
+ para._trainable = False
170
+ for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]:
171
+ if hasattr(self, attr):
172
+ delattr(self, attr)
173
+
174
+ if hasattr(self, "rbr_identity"):
175
+ delattr(self, "rbr_identity")
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -30,14 +30,17 @@ class PositionalEncoding(nn.Module):
30
30
  self.register_buffer("pe", pe.unsqueeze(0))
31
31
 
32
32
  def forward(self, x: torch.Tensor) -> torch.Tensor:
33
- """
33
+ """Forward pass
34
+
34
35
  Args:
36
+ ----
35
37
  x: embeddings (batch, max_len, d_model)
36
38
 
37
- Returns:
39
+ Returns
40
+ -------
38
41
  positional embeddings (batch, max_len, d_model)
39
42
  """
40
- x = x + self.pe[:, : x.size(1)] # type: ignore
43
+ x = x + self.pe[:, : x.size(1)]
41
44
  return self.dropout(x)
42
45
 
43
46
 
@@ -45,12 +48,11 @@ def scaled_dot_product_attention(
45
48
  query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None
46
49
  ) -> Tuple[torch.Tensor, torch.Tensor]:
47
50
  """Scaled Dot-Product Attention"""
48
-
49
51
  scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
50
52
  if mask is not None:
51
53
  # NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
52
- scores = scores.masked_fill(mask == 0, float("-inf"))
53
- p_attn = torch.softmax(scores, dim=-1)
54
+ scores = scores.masked_fill(mask == 0, float("-inf")) # type: ignore[attr-defined]
55
+ p_attn = torch.softmax(scores, dim=-1) # type: ignore[call-overload]
54
56
  return torch.matmul(p_attn, value), p_attn
55
57
 
56
58
 
@@ -121,12 +123,12 @@ class EncoderBlock(nn.Module):
121
123
  self.layer_norm_output = nn.LayerNorm(d_model, eps=1e-5)
122
124
  self.dropout = nn.Dropout(dropout)
123
125
 
124
- self.attention = nn.ModuleList(
125
- [MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)]
126
- )
127
- self.position_feed_forward = nn.ModuleList(
128
- [PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)]
129
- )
126
+ self.attention = nn.ModuleList([
127
+ MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)
128
+ ])
129
+ self.position_feed_forward = nn.ModuleList([
130
+ PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
131
+ ])
130
132
 
131
133
  def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
132
134
  output = x
@@ -167,15 +169,15 @@ class Decoder(nn.Module):
167
169
  self.embed = nn.Embedding(vocab_size, d_model)
168
170
  self.positional_encoding = PositionalEncoding(d_model, dropout, maximum_position_encoding)
169
171
 
170
- self.attention = nn.ModuleList(
171
- [MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)]
172
- )
173
- self.source_attention = nn.ModuleList(
174
- [MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)]
175
- )
176
- self.position_feed_forward = nn.ModuleList(
177
- [PositionwiseFeedForward(d_model, dff, dropout) for _ in range(self.num_layers)]
178
- )
172
+ self.attention = nn.ModuleList([
173
+ MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)
174
+ ])
175
+ self.source_attention = nn.ModuleList([
176
+ MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)
177
+ ])
178
+ self.position_feed_forward = nn.ModuleList([
179
+ PositionwiseFeedForward(d_model, dff, dropout) for _ in range(self.num_layers)
180
+ ])
179
181
 
180
182
  def forward(
181
183
  self,