python-doctr 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. doctr/__init__.py +1 -1
  2. doctr/contrib/__init__.py +0 -0
  3. doctr/contrib/artefacts.py +131 -0
  4. doctr/contrib/base.py +105 -0
  5. doctr/datasets/datasets/pytorch.py +2 -2
  6. doctr/datasets/generator/base.py +6 -5
  7. doctr/datasets/imgur5k.py +1 -1
  8. doctr/datasets/loader.py +1 -6
  9. doctr/datasets/utils.py +2 -1
  10. doctr/datasets/vocabs.py +9 -2
  11. doctr/file_utils.py +26 -12
  12. doctr/io/elements.py +40 -6
  13. doctr/io/html.py +2 -2
  14. doctr/io/image/pytorch.py +6 -8
  15. doctr/io/image/tensorflow.py +1 -1
  16. doctr/io/pdf.py +5 -2
  17. doctr/io/reader.py +6 -0
  18. doctr/models/__init__.py +0 -1
  19. doctr/models/_utils.py +57 -20
  20. doctr/models/builder.py +71 -13
  21. doctr/models/classification/mobilenet/pytorch.py +45 -9
  22. doctr/models/classification/mobilenet/tensorflow.py +38 -7
  23. doctr/models/classification/predictor/pytorch.py +18 -11
  24. doctr/models/classification/predictor/tensorflow.py +16 -10
  25. doctr/models/classification/textnet/pytorch.py +3 -3
  26. doctr/models/classification/textnet/tensorflow.py +3 -3
  27. doctr/models/classification/zoo.py +39 -15
  28. doctr/models/detection/__init__.py +1 -0
  29. doctr/models/detection/_utils/__init__.py +1 -0
  30. doctr/models/detection/_utils/base.py +66 -0
  31. doctr/models/detection/differentiable_binarization/base.py +4 -3
  32. doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
  33. doctr/models/detection/differentiable_binarization/tensorflow.py +14 -18
  34. doctr/models/detection/fast/__init__.py +6 -0
  35. doctr/models/detection/fast/base.py +257 -0
  36. doctr/models/detection/fast/pytorch.py +442 -0
  37. doctr/models/detection/fast/tensorflow.py +428 -0
  38. doctr/models/detection/linknet/base.py +4 -3
  39. doctr/models/detection/predictor/pytorch.py +15 -1
  40. doctr/models/detection/predictor/tensorflow.py +15 -1
  41. doctr/models/detection/zoo.py +21 -4
  42. doctr/models/factory/hub.py +3 -12
  43. doctr/models/kie_predictor/base.py +9 -3
  44. doctr/models/kie_predictor/pytorch.py +41 -20
  45. doctr/models/kie_predictor/tensorflow.py +36 -16
  46. doctr/models/modules/layers/pytorch.py +89 -10
  47. doctr/models/modules/layers/tensorflow.py +88 -10
  48. doctr/models/modules/transformer/pytorch.py +2 -2
  49. doctr/models/predictor/base.py +77 -50
  50. doctr/models/predictor/pytorch.py +31 -20
  51. doctr/models/predictor/tensorflow.py +27 -17
  52. doctr/models/preprocessor/pytorch.py +4 -4
  53. doctr/models/preprocessor/tensorflow.py +3 -2
  54. doctr/models/recognition/master/pytorch.py +2 -2
  55. doctr/models/recognition/parseq/pytorch.py +4 -3
  56. doctr/models/recognition/parseq/tensorflow.py +4 -3
  57. doctr/models/recognition/sar/pytorch.py +7 -6
  58. doctr/models/recognition/sar/tensorflow.py +3 -9
  59. doctr/models/recognition/vitstr/pytorch.py +1 -1
  60. doctr/models/recognition/zoo.py +1 -1
  61. doctr/models/zoo.py +2 -2
  62. doctr/py.typed +0 -0
  63. doctr/transforms/functional/base.py +1 -1
  64. doctr/transforms/functional/pytorch.py +4 -4
  65. doctr/transforms/modules/base.py +37 -15
  66. doctr/transforms/modules/pytorch.py +66 -8
  67. doctr/transforms/modules/tensorflow.py +63 -7
  68. doctr/utils/fonts.py +7 -5
  69. doctr/utils/geometry.py +35 -12
  70. doctr/utils/metrics.py +33 -174
  71. doctr/utils/reconstitution.py +126 -0
  72. doctr/utils/visualization.py +5 -118
  73. doctr/version.py +1 -1
  74. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/METADATA +96 -91
  75. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/RECORD +79 -75
  76. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/WHEEL +1 -1
  77. doctr/models/artefacts/__init__.py +0 -2
  78. doctr/models/artefacts/barcode.py +0 -74
  79. doctr/models/artefacts/face.py +0 -63
  80. doctr/models/obj_detection/__init__.py +0 -1
  81. doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
  82. doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
  83. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/LICENSE +0 -0
  84. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/top_level.txt +0 -0
  85. {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/zip-safe +0 -0
@@ -10,10 +10,10 @@ import torch
10
10
  from torch import nn
11
11
 
12
12
  from doctr.io.elements import Document
13
- from doctr.models._utils import estimate_orientation, get_language, invert_data_structure
13
+ from doctr.models._utils import get_language, invert_data_structure
14
14
  from doctr.models.detection.predictor import DetectionPredictor
15
15
  from doctr.models.recognition.predictor import RecognitionPredictor
16
- from doctr.utils.geometry import rotate_image
16
+ from doctr.utils.geometry import detach_scores
17
17
 
18
18
  from .base import _KIEPredictor
19
19
 
@@ -55,7 +55,13 @@ class KIEPredictor(nn.Module, _KIEPredictor):
55
55
  self.det_predictor = det_predictor.eval() # type: ignore[attr-defined]
56
56
  self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined]
57
57
  _KIEPredictor.__init__(
58
- self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
58
+ self,
59
+ assume_straight_pages,
60
+ straighten_pages,
61
+ preserve_aspect_ratio,
62
+ symmetric_pad,
63
+ detect_orientation,
64
+ **kwargs,
59
65
  )
60
66
  self.detect_orientation = detect_orientation
61
67
  self.detect_language = detect_language
@@ -83,29 +89,31 @@ class KIEPredictor(nn.Module, _KIEPredictor):
83
89
  for out_map in out_maps
84
90
  ]
85
91
  if self.detect_orientation:
86
- origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
92
+ general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) # type: ignore[arg-type]
87
93
  orientations = [
88
- {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
94
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
89
95
  ]
90
96
  else:
91
97
  orientations = None
98
+ general_pages_orientations = None
99
+ origin_pages_orientations = None
92
100
  if self.straighten_pages:
93
- origin_page_orientations = (
94
- origin_page_orientations
95
- if self.detect_orientation
96
- else [estimate_orientation(seq_map) for seq_map in seg_maps]
97
- )
98
- pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
101
+ pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
99
102
  # Forward again to get predictions on straight pages
100
103
  loc_preds = self.det_predictor(pages, **kwargs)
101
104
 
102
105
  dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
106
+
107
+ # Detach objectness scores from loc_preds
108
+ objectness_scores = {}
109
+ for class_name, det_preds in dict_loc_preds.items():
110
+ _loc_preds, _scores = detach_scores(det_preds)
111
+ dict_loc_preds[class_name] = _loc_preds
112
+ objectness_scores[class_name] = _scores
113
+
103
114
  # Check whether crop mode should be switched to channels first
104
115
  channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
105
116
 
106
- # Rectify crops if aspect ratio
107
- dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}
108
-
109
117
  # Apply hooks to loc_preds if any
110
118
  for hook in self.hooks:
111
119
  dict_loc_preds = hook(dict_loc_preds)
@@ -114,32 +122,43 @@ class KIEPredictor(nn.Module, _KIEPredictor):
114
122
  crops = {}
115
123
  for class_name in dict_loc_preds.keys():
116
124
  crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
117
- pages,
125
+ pages, # type: ignore[arg-type]
118
126
  dict_loc_preds[class_name],
119
127
  channels_last=channels_last,
120
128
  assume_straight_pages=self.assume_straight_pages,
121
129
  )
122
130
  # Rectify crop orientation
131
+ crop_orientations: Any = {}
123
132
  if not self.assume_straight_pages:
124
133
  for class_name in dict_loc_preds.keys():
125
- crops[class_name], dict_loc_preds[class_name] = self._rectify_crops(
134
+ crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
126
135
  crops[class_name], dict_loc_preds[class_name]
127
136
  )
137
+ crop_orientations[class_name] = [
138
+ {"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations
139
+ ]
140
+
128
141
  # Identify character sequences
129
142
  word_preds = {
130
143
  k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs)
131
144
  for k, crop_value in crops.items()
132
145
  }
146
+ if not crop_orientations:
147
+ crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
133
148
 
134
149
  boxes: Dict = {}
135
150
  text_preds: Dict = {}
151
+ word_crop_orientations: Dict = {}
136
152
  for class_name in dict_loc_preds.keys():
137
- boxes[class_name], text_preds[class_name] = self._process_predictions(
138
- dict_loc_preds[class_name], word_preds[class_name]
153
+ boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
154
+ dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
139
155
  )
140
156
 
141
157
  boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
158
+ objectness_scores_per_page: List[Dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
142
159
  text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
160
+ crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
161
+
143
162
  if self.detect_language:
144
163
  languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
145
164
  languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
@@ -147,10 +166,12 @@ class KIEPredictor(nn.Module, _KIEPredictor):
147
166
  languages_dict = None
148
167
 
149
168
  out = self.doc_builder(
150
- pages,
169
+ pages, # type: ignore[arg-type]
151
170
  boxes_per_page,
171
+ objectness_scores_per_page,
152
172
  text_preds_per_page,
153
- origin_page_shapes,
173
+ origin_page_shapes, # type: ignore[arg-type]
174
+ crop_orientations_per_page,
154
175
  orientations,
155
176
  languages_dict,
156
177
  )
@@ -9,10 +9,10 @@ import numpy as np
9
9
  import tensorflow as tf
10
10
 
11
11
  from doctr.io.elements import Document
12
- from doctr.models._utils import estimate_orientation, get_language, invert_data_structure
12
+ from doctr.models._utils import get_language, invert_data_structure
13
13
  from doctr.models.detection.predictor import DetectionPredictor
14
14
  from doctr.models.recognition.predictor import RecognitionPredictor
15
- from doctr.utils.geometry import rotate_image
15
+ from doctr.utils.geometry import detach_scores
16
16
  from doctr.utils.repr import NestedObject
17
17
 
18
18
  from .base import _KIEPredictor
@@ -56,7 +56,13 @@ class KIEPredictor(NestedObject, _KIEPredictor):
56
56
  self.det_predictor = det_predictor
57
57
  self.reco_predictor = reco_predictor
58
58
  _KIEPredictor.__init__(
59
- self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
59
+ self,
60
+ assume_straight_pages,
61
+ straighten_pages,
62
+ preserve_aspect_ratio,
63
+ symmetric_pad,
64
+ detect_orientation,
65
+ **kwargs,
60
66
  )
61
67
  self.detect_orientation = detect_orientation
62
68
  self.detect_language = detect_language
@@ -83,25 +89,27 @@ class KIEPredictor(NestedObject, _KIEPredictor):
83
89
  for out_map in out_maps
84
90
  ]
85
91
  if self.detect_orientation:
86
- origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
92
+ general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
87
93
  orientations = [
88
- {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
94
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
89
95
  ]
90
96
  else:
91
97
  orientations = None
98
+ general_pages_orientations = None
99
+ origin_pages_orientations = None
92
100
  if self.straighten_pages:
93
- origin_page_orientations = (
94
- origin_page_orientations
95
- if self.detect_orientation
96
- else [estimate_orientation(seq_map) for seq_map in seg_maps]
97
- )
98
- pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
101
+ pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
99
102
  # Forward again to get predictions on straight pages
100
103
  loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
101
104
 
102
105
  dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
103
- # Rectify crops if aspect ratio
104
- dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}
106
+
107
+ # Detach objectness scores from loc_preds
108
+ objectness_scores = {}
109
+ for class_name, det_preds in dict_loc_preds.items():
110
+ _loc_preds, _scores = detach_scores(det_preds)
111
+ dict_loc_preds[class_name] = _loc_preds
112
+ objectness_scores[class_name] = _scores
105
113
 
106
114
  # Apply hooks to loc_preds if any
107
115
  for hook in self.hooks:
@@ -113,28 +121,38 @@ class KIEPredictor(NestedObject, _KIEPredictor):
113
121
  crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
114
122
  pages, dict_loc_preds[class_name], channels_last=True, assume_straight_pages=self.assume_straight_pages
115
123
  )
124
+
116
125
  # Rectify crop orientation
126
+ crop_orientations: Any = {}
117
127
  if not self.assume_straight_pages:
118
128
  for class_name in dict_loc_preds.keys():
119
- crops[class_name], dict_loc_preds[class_name] = self._rectify_crops(
129
+ crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
120
130
  crops[class_name], dict_loc_preds[class_name]
121
131
  )
132
+ crop_orientations[class_name] = [
133
+ {"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations
134
+ ]
122
135
 
123
136
  # Identify character sequences
124
137
  word_preds = {
125
138
  k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs)
126
139
  for k, crop_value in crops.items()
127
140
  }
141
+ if not crop_orientations:
142
+ crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
128
143
 
129
144
  boxes: Dict = {}
130
145
  text_preds: Dict = {}
146
+ word_crop_orientations: Dict = {}
131
147
  for class_name in dict_loc_preds.keys():
132
- boxes[class_name], text_preds[class_name] = self._process_predictions(
133
- dict_loc_preds[class_name], word_preds[class_name]
148
+ boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
149
+ dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
134
150
  )
135
151
 
136
152
  boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
153
+ objectness_scores_per_page: List[Dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
137
154
  text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
155
+ crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
138
156
 
139
157
  if self.detect_language:
140
158
  languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
@@ -145,8 +163,10 @@ class KIEPredictor(NestedObject, _KIEPredictor):
145
163
  out = self.doc_builder(
146
164
  pages,
147
165
  boxes_per_page,
166
+ objectness_scores_per_page,
148
167
  text_preds_per_page,
149
168
  origin_page_shapes, # type: ignore[arg-type]
169
+ crop_orientations_per_page,
150
170
  orientations,
151
171
  languages_dict,
152
172
  )
@@ -5,6 +5,7 @@
5
5
 
6
6
  from typing import Tuple, Union
7
7
 
8
+ import numpy as np
8
9
  import torch
9
10
  import torch.nn as nn
10
11
 
@@ -26,18 +27,20 @@ class FASTConvLayer(nn.Module):
26
27
  ) -> None:
27
28
  super().__init__()
28
29
 
29
- converted_ks = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
30
+ self.groups = groups
31
+ self.in_channels = in_channels
32
+ self.converted_ks = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
30
33
 
31
34
  self.hor_conv, self.hor_bn = None, None
32
35
  self.ver_conv, self.ver_bn = None, None
33
36
 
34
- padding = (int(((converted_ks[0] - 1) * dilation) / 2), int(((converted_ks[1] - 1) * dilation) / 2))
37
+ padding = (int(((self.converted_ks[0] - 1) * dilation) / 2), int(((self.converted_ks[1] - 1) * dilation) / 2))
35
38
 
36
39
  self.activation = nn.ReLU(inplace=True)
37
40
  self.conv = nn.Conv2d(
38
41
  in_channels,
39
42
  out_channels,
40
- kernel_size=converted_ks,
43
+ kernel_size=self.converted_ks,
41
44
  stride=stride,
42
45
  padding=padding,
43
46
  dilation=dilation,
@@ -47,12 +50,12 @@ class FASTConvLayer(nn.Module):
47
50
 
48
51
  self.bn = nn.BatchNorm2d(out_channels)
49
52
 
50
- if converted_ks[1] != 1:
53
+ if self.converted_ks[1] != 1:
51
54
  self.ver_conv = nn.Conv2d(
52
55
  in_channels,
53
56
  out_channels,
54
- kernel_size=(converted_ks[0], 1),
55
- padding=(int(((converted_ks[0] - 1) * dilation) / 2), 0),
57
+ kernel_size=(self.converted_ks[0], 1),
58
+ padding=(int(((self.converted_ks[0] - 1) * dilation) / 2), 0),
56
59
  stride=stride,
57
60
  dilation=dilation,
58
61
  groups=groups,
@@ -60,12 +63,12 @@ class FASTConvLayer(nn.Module):
60
63
  )
61
64
  self.ver_bn = nn.BatchNorm2d(out_channels)
62
65
 
63
- if converted_ks[0] != 1:
66
+ if self.converted_ks[0] != 1:
64
67
  self.hor_conv = nn.Conv2d(
65
68
  in_channels,
66
69
  out_channels,
67
- kernel_size=(1, converted_ks[1]),
68
- padding=(0, int(((converted_ks[1] - 1) * dilation) / 2)),
70
+ kernel_size=(1, self.converted_ks[1]),
71
+ padding=(0, int(((self.converted_ks[1] - 1) * dilation) / 2)),
69
72
  stride=stride,
70
73
  dilation=dilation,
71
74
  groups=groups,
@@ -76,11 +79,87 @@ class FASTConvLayer(nn.Module):
76
79
  self.rbr_identity = nn.BatchNorm2d(in_channels) if out_channels == in_channels and stride == 1 else None
77
80
 
78
81
  def forward(self, x: torch.Tensor) -> torch.Tensor:
82
+ if hasattr(self, "fused_conv"):
83
+ return self.activation(self.fused_conv(x))
84
+
79
85
  main_outputs = self.bn(self.conv(x))
80
86
  vertical_outputs = self.ver_bn(self.ver_conv(x)) if self.ver_conv is not None and self.ver_bn is not None else 0
81
87
  horizontal_outputs = (
82
88
  self.hor_bn(self.hor_conv(x)) if self.hor_bn is not None and self.hor_conv is not None else 0
83
89
  )
84
- id_out = self.rbr_identity(x) if self.rbr_identity is not None and self.ver_bn is not None else 0
90
+ id_out = self.rbr_identity(x) if self.rbr_identity is not None else 0
85
91
 
86
92
  return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
93
+
94
+ # The following logic is used to reparametrize the layer
95
+ # Borrowed from: https://github.com/czczup/FAST/blob/main/models/utils/nas_utils.py
96
+ def _identity_to_conv(
97
+ self, identity: Union[nn.BatchNorm2d, None]
98
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
99
+ if identity is None or identity.running_var is None:
100
+ return 0, 0
101
+ if not hasattr(self, "id_tensor"):
102
+ input_dim = self.in_channels // self.groups
103
+ kernel_value = np.zeros((self.in_channels, input_dim, 1, 1), dtype=np.float32)
104
+ for i in range(self.in_channels):
105
+ kernel_value[i, i % input_dim, 0, 0] = 1
106
+ id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
107
+ self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
108
+ kernel = self.id_tensor
109
+ std = (identity.running_var + identity.eps).sqrt()
110
+ t = (identity.weight / std).reshape(-1, 1, 1, 1)
111
+ return kernel * t, identity.bias - identity.running_mean * identity.weight / std
112
+
113
+ def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> Tuple[torch.Tensor, torch.Tensor]:
114
+ kernel = conv.weight
115
+ kernel = self._pad_to_mxn_tensor(kernel)
116
+ std = (bn.running_var + bn.eps).sqrt() # type: ignore
117
+ t = (bn.weight / std).reshape(-1, 1, 1, 1)
118
+ return kernel * t, bn.bias - bn.running_mean * bn.weight / std
119
+
120
+ def _get_equivalent_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
121
+ kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
122
+ if self.ver_conv is not None:
123
+ kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) # type: ignore[arg-type]
124
+ else:
125
+ kernel_mx1, bias_mx1 = 0, 0 # type: ignore[assignment]
126
+ if self.hor_conv is not None:
127
+ kernel_1xn, bias_1xn = self._fuse_bn_tensor(self.hor_conv, self.hor_bn) # type: ignore[arg-type]
128
+ else:
129
+ kernel_1xn, bias_1xn = 0, 0 # type: ignore[assignment]
130
+ kernel_id, bias_id = self._identity_to_conv(self.rbr_identity)
131
+ kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id
132
+ bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id
133
+ return kernel_mxn, bias_mxn
134
+
135
+ def _pad_to_mxn_tensor(self, kernel: torch.Tensor) -> torch.Tensor:
136
+ kernel_height, kernel_width = self.converted_ks
137
+ height, width = kernel.shape[2:]
138
+ pad_left_right = (kernel_width - width) // 2
139
+ pad_top_down = (kernel_height - height) // 2
140
+ return torch.nn.functional.pad(kernel, [pad_left_right, pad_left_right, pad_top_down, pad_top_down], value=0)
141
+
142
+ def reparameterize_layer(self):
143
+ if hasattr(self, "fused_conv"):
144
+ return
145
+ kernel, bias = self._get_equivalent_kernel_bias()
146
+ self.fused_conv = nn.Conv2d(
147
+ in_channels=self.conv.in_channels,
148
+ out_channels=self.conv.out_channels,
149
+ kernel_size=self.conv.kernel_size, # type: ignore[arg-type]
150
+ stride=self.conv.stride, # type: ignore[arg-type]
151
+ padding=self.conv.padding, # type: ignore[arg-type]
152
+ dilation=self.conv.dilation, # type: ignore[arg-type]
153
+ groups=self.conv.groups,
154
+ bias=True,
155
+ )
156
+ self.fused_conv.weight.data = kernel
157
+ self.fused_conv.bias.data = bias # type: ignore[union-attr]
158
+ for para in self.parameters():
159
+ para.detach_()
160
+ for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]:
161
+ if hasattr(self, attr):
162
+ self.__delattr__(attr)
163
+
164
+ if hasattr(self, "rbr_identity"):
165
+ self.__delattr__("rbr_identity")
@@ -5,6 +5,7 @@
5
5
 
6
6
  from typing import Any, Tuple, Union
7
7
 
8
+ import numpy as np
8
9
  import tensorflow as tf
9
10
  from tensorflow.keras import layers
10
11
 
@@ -28,18 +29,21 @@ class FASTConvLayer(layers.Layer, NestedObject):
28
29
  ) -> None:
29
30
  super().__init__()
30
31
 
31
- converted_ks = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
32
+ self.groups = groups
33
+ self.in_channels = in_channels
34
+ self.converted_ks = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
32
35
 
33
36
  self.hor_conv, self.hor_bn = None, None
34
37
  self.ver_conv, self.ver_bn = None, None
35
38
 
36
- padding = ((converted_ks[0] - 1) * dilation // 2, (converted_ks[1] - 1) * dilation // 2)
39
+ padding = ((self.converted_ks[0] - 1) * dilation // 2, (self.converted_ks[1] - 1) * dilation // 2)
37
40
 
38
41
  self.activation = layers.ReLU()
39
42
  self.conv_pad = layers.ZeroPadding2D(padding=padding)
43
+
40
44
  self.conv = layers.Conv2D(
41
45
  filters=out_channels,
42
- kernel_size=converted_ks,
46
+ kernel_size=self.converted_ks,
43
47
  strides=stride,
44
48
  dilation_rate=dilation,
45
49
  groups=groups,
@@ -48,13 +52,13 @@ class FASTConvLayer(layers.Layer, NestedObject):
48
52
 
49
53
  self.bn = layers.BatchNormalization()
50
54
 
51
- if converted_ks[1] != 1:
55
+ if self.converted_ks[1] != 1:
52
56
  self.ver_pad = layers.ZeroPadding2D(
53
- padding=(int(((converted_ks[0] - 1) * dilation) / 2), 0),
57
+ padding=(int(((self.converted_ks[0] - 1) * dilation) / 2), 0),
54
58
  )
55
59
  self.ver_conv = layers.Conv2D(
56
60
  filters=out_channels,
57
- kernel_size=(converted_ks[0], 1),
61
+ kernel_size=(self.converted_ks[0], 1),
58
62
  strides=stride,
59
63
  dilation_rate=dilation,
60
64
  groups=groups,
@@ -62,13 +66,13 @@ class FASTConvLayer(layers.Layer, NestedObject):
62
66
  )
63
67
  self.ver_bn = layers.BatchNormalization()
64
68
 
65
- if converted_ks[0] != 1:
69
+ if self.converted_ks[0] != 1:
66
70
  self.hor_pad = layers.ZeroPadding2D(
67
- padding=(0, int(((converted_ks[1] - 1) * dilation) / 2)),
71
+ padding=(0, int(((self.converted_ks[1] - 1) * dilation) / 2)),
68
72
  )
69
73
  self.hor_conv = layers.Conv2D(
70
74
  filters=out_channels,
71
- kernel_size=(1, converted_ks[1]),
75
+ kernel_size=(1, self.converted_ks[1]),
72
76
  strides=stride,
73
77
  dilation_rate=dilation,
74
78
  groups=groups,
@@ -79,6 +83,9 @@ class FASTConvLayer(layers.Layer, NestedObject):
79
83
  self.rbr_identity = layers.BatchNormalization() if out_channels == in_channels and stride == 1 else None
80
84
 
81
85
  def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
86
+ if hasattr(self, "fused_conv"):
87
+ return self.activation(self.fused_conv(self.conv_pad(x, **kwargs), **kwargs))
88
+
82
89
  main_outputs = self.bn(self.conv(self.conv_pad(x, **kwargs), **kwargs), **kwargs)
83
90
  vertical_outputs = (
84
91
  self.ver_bn(self.ver_conv(self.ver_pad(x, **kwargs), **kwargs), **kwargs)
@@ -90,6 +97,77 @@ class FASTConvLayer(layers.Layer, NestedObject):
90
97
  if self.hor_bn is not None and self.hor_conv is not None
91
98
  else 0
92
99
  )
93
- id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None and self.ver_bn is not None else 0
100
+ id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None else 0
94
101
 
95
102
  return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
103
+
104
+ # The following logic is used to reparametrize the layer
105
+ # Adapted from: https://github.com/mindee/doctr/blob/main/doctr/models/modules/layers/pytorch.py
106
+ def _identity_to_conv(
107
+ self, identity: layers.BatchNormalization
108
+ ) -> Union[Tuple[tf.Tensor, tf.Tensor], Tuple[int, int]]:
109
+ if identity is None or not hasattr(identity, "moving_mean") or not hasattr(identity, "moving_variance"):
110
+ return 0, 0
111
+ if not hasattr(self, "id_tensor"):
112
+ input_dim = self.in_channels // self.groups
113
+ kernel_value = np.zeros((1, 1, input_dim, self.in_channels), dtype=np.float32)
114
+ for i in range(self.in_channels):
115
+ kernel_value[0, 0, i % input_dim, i] = 1
116
+ id_tensor = tf.constant(kernel_value, dtype=tf.float32)
117
+ self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
118
+ kernel = self.id_tensor
119
+ std = tf.sqrt(identity.moving_variance + identity.epsilon)
120
+ t = tf.reshape(identity.gamma / std, (1, 1, 1, -1))
121
+ return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std
122
+
123
+ def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> Tuple[tf.Tensor, tf.Tensor]:
124
+ kernel = conv.kernel
125
+ kernel = self._pad_to_mxn_tensor(kernel)
126
+ std = tf.sqrt(bn.moving_variance + bn.epsilon)
127
+ t = tf.reshape(bn.gamma / std, (1, 1, 1, -1))
128
+ return kernel * t, bn.beta - bn.moving_mean * bn.gamma / std
129
+
130
+ def _get_equivalent_kernel_bias(self):
131
+ kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
132
+ if self.ver_conv is not None:
133
+ kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn)
134
+ else:
135
+ kernel_mx1, bias_mx1 = 0, 0
136
+ if self.hor_conv is not None:
137
+ kernel_1xn, bias_1xn = self._fuse_bn_tensor(self.hor_conv, self.hor_bn)
138
+ else:
139
+ kernel_1xn, bias_1xn = 0, 0
140
+ kernel_id, bias_id = self._identity_to_conv(self.rbr_identity)
141
+ kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id
142
+ bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id
143
+ return kernel_mxn, bias_mxn
144
+
145
+ def _pad_to_mxn_tensor(self, kernel: tf.Tensor) -> tf.Tensor:
146
+ kernel_height, kernel_width = self.converted_ks
147
+ height, width = kernel.shape[:2]
148
+ pad_left_right = tf.maximum(0, (kernel_width - width) // 2)
149
+ pad_top_down = tf.maximum(0, (kernel_height - height) // 2)
150
+ return tf.pad(kernel, [[pad_top_down, pad_top_down], [pad_left_right, pad_left_right], [0, 0], [0, 0]])
151
+
152
+ def reparameterize_layer(self):
153
+ kernel, bias = self._get_equivalent_kernel_bias()
154
+ self.fused_conv = layers.Conv2D(
155
+ filters=self.conv.filters,
156
+ kernel_size=self.conv.kernel_size,
157
+ strides=self.conv.strides,
158
+ padding=self.conv.padding,
159
+ dilation_rate=self.conv.dilation_rate,
160
+ groups=self.conv.groups,
161
+ use_bias=True,
162
+ )
163
+ # build layer to initialize weights and biases
164
+ self.fused_conv.build(input_shape=(None, None, None, kernel.shape[-2]))
165
+ self.fused_conv.set_weights([kernel.numpy(), bias.numpy()])
166
+ for para in self.trainable_variables:
167
+ para._trainable = False
168
+ for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]:
169
+ if hasattr(self, attr):
170
+ delattr(self, attr)
171
+
172
+ if hasattr(self, "rbr_identity"):
173
+ delattr(self, "rbr_identity")
@@ -51,8 +51,8 @@ def scaled_dot_product_attention(
51
51
  scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
52
52
  if mask is not None:
53
53
  # NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
54
- scores = scores.masked_fill(mask == 0, float("-inf")) # type: ignore[attr-defined]
55
- p_attn = torch.softmax(scores, dim=-1) # type: ignore[call-overload]
54
+ scores = scores.masked_fill(mask == 0, float("-inf"))
55
+ p_attn = torch.softmax(scores, dim=-1)
56
56
  return torch.matmul(p_attn, value), p_attn
57
57
 
58
58