python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. doctr/datasets/__init__.py +2 -0
  2. doctr/datasets/cord.py +6 -4
  3. doctr/datasets/datasets/base.py +3 -2
  4. doctr/datasets/datasets/pytorch.py +4 -2
  5. doctr/datasets/datasets/tensorflow.py +4 -2
  6. doctr/datasets/detection.py +6 -3
  7. doctr/datasets/doc_artefacts.py +2 -1
  8. doctr/datasets/funsd.py +7 -8
  9. doctr/datasets/generator/base.py +3 -2
  10. doctr/datasets/generator/pytorch.py +3 -1
  11. doctr/datasets/generator/tensorflow.py +3 -1
  12. doctr/datasets/ic03.py +3 -2
  13. doctr/datasets/ic13.py +2 -1
  14. doctr/datasets/iiit5k.py +6 -4
  15. doctr/datasets/iiithws.py +2 -1
  16. doctr/datasets/imgur5k.py +3 -2
  17. doctr/datasets/loader.py +4 -2
  18. doctr/datasets/mjsynth.py +2 -1
  19. doctr/datasets/ocr.py +2 -1
  20. doctr/datasets/orientation.py +40 -0
  21. doctr/datasets/recognition.py +3 -2
  22. doctr/datasets/sroie.py +2 -1
  23. doctr/datasets/svhn.py +2 -1
  24. doctr/datasets/svt.py +3 -2
  25. doctr/datasets/synthtext.py +2 -1
  26. doctr/datasets/utils.py +27 -11
  27. doctr/datasets/vocabs.py +26 -1
  28. doctr/datasets/wildreceipt.py +111 -0
  29. doctr/file_utils.py +3 -1
  30. doctr/io/elements.py +52 -35
  31. doctr/io/html.py +5 -3
  32. doctr/io/image/base.py +5 -4
  33. doctr/io/image/pytorch.py +12 -7
  34. doctr/io/image/tensorflow.py +11 -6
  35. doctr/io/pdf.py +5 -4
  36. doctr/io/reader.py +13 -5
  37. doctr/models/_utils.py +30 -53
  38. doctr/models/artefacts/barcode.py +4 -3
  39. doctr/models/artefacts/face.py +4 -2
  40. doctr/models/builder.py +58 -43
  41. doctr/models/classification/__init__.py +1 -0
  42. doctr/models/classification/magc_resnet/pytorch.py +5 -2
  43. doctr/models/classification/magc_resnet/tensorflow.py +5 -2
  44. doctr/models/classification/mobilenet/pytorch.py +16 -4
  45. doctr/models/classification/mobilenet/tensorflow.py +29 -20
  46. doctr/models/classification/predictor/pytorch.py +3 -2
  47. doctr/models/classification/predictor/tensorflow.py +2 -1
  48. doctr/models/classification/resnet/pytorch.py +23 -13
  49. doctr/models/classification/resnet/tensorflow.py +33 -26
  50. doctr/models/classification/textnet/__init__.py +6 -0
  51. doctr/models/classification/textnet/pytorch.py +275 -0
  52. doctr/models/classification/textnet/tensorflow.py +267 -0
  53. doctr/models/classification/vgg/pytorch.py +4 -2
  54. doctr/models/classification/vgg/tensorflow.py +5 -2
  55. doctr/models/classification/vit/pytorch.py +9 -3
  56. doctr/models/classification/vit/tensorflow.py +9 -3
  57. doctr/models/classification/zoo.py +7 -2
  58. doctr/models/core.py +1 -1
  59. doctr/models/detection/__init__.py +1 -0
  60. doctr/models/detection/_utils/pytorch.py +7 -1
  61. doctr/models/detection/_utils/tensorflow.py +7 -3
  62. doctr/models/detection/core.py +9 -3
  63. doctr/models/detection/differentiable_binarization/base.py +37 -25
  64. doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
  65. doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
  66. doctr/models/detection/fast/__init__.py +6 -0
  67. doctr/models/detection/fast/base.py +256 -0
  68. doctr/models/detection/fast/pytorch.py +442 -0
  69. doctr/models/detection/fast/tensorflow.py +428 -0
  70. doctr/models/detection/linknet/base.py +12 -5
  71. doctr/models/detection/linknet/pytorch.py +28 -15
  72. doctr/models/detection/linknet/tensorflow.py +68 -88
  73. doctr/models/detection/predictor/pytorch.py +16 -6
  74. doctr/models/detection/predictor/tensorflow.py +13 -5
  75. doctr/models/detection/zoo.py +19 -16
  76. doctr/models/factory/hub.py +20 -10
  77. doctr/models/kie_predictor/base.py +2 -1
  78. doctr/models/kie_predictor/pytorch.py +28 -36
  79. doctr/models/kie_predictor/tensorflow.py +27 -27
  80. doctr/models/modules/__init__.py +1 -0
  81. doctr/models/modules/layers/__init__.py +6 -0
  82. doctr/models/modules/layers/pytorch.py +166 -0
  83. doctr/models/modules/layers/tensorflow.py +175 -0
  84. doctr/models/modules/transformer/pytorch.py +24 -22
  85. doctr/models/modules/transformer/tensorflow.py +6 -4
  86. doctr/models/modules/vision_transformer/pytorch.py +2 -4
  87. doctr/models/modules/vision_transformer/tensorflow.py +2 -4
  88. doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
  89. doctr/models/predictor/base.py +14 -3
  90. doctr/models/predictor/pytorch.py +26 -29
  91. doctr/models/predictor/tensorflow.py +25 -22
  92. doctr/models/preprocessor/pytorch.py +14 -9
  93. doctr/models/preprocessor/tensorflow.py +10 -5
  94. doctr/models/recognition/core.py +4 -1
  95. doctr/models/recognition/crnn/pytorch.py +23 -16
  96. doctr/models/recognition/crnn/tensorflow.py +25 -17
  97. doctr/models/recognition/master/base.py +4 -1
  98. doctr/models/recognition/master/pytorch.py +20 -9
  99. doctr/models/recognition/master/tensorflow.py +20 -8
  100. doctr/models/recognition/parseq/base.py +4 -1
  101. doctr/models/recognition/parseq/pytorch.py +28 -22
  102. doctr/models/recognition/parseq/tensorflow.py +22 -11
  103. doctr/models/recognition/predictor/_utils.py +3 -2
  104. doctr/models/recognition/predictor/pytorch.py +3 -2
  105. doctr/models/recognition/predictor/tensorflow.py +2 -1
  106. doctr/models/recognition/sar/pytorch.py +14 -7
  107. doctr/models/recognition/sar/tensorflow.py +23 -14
  108. doctr/models/recognition/utils.py +5 -1
  109. doctr/models/recognition/vitstr/base.py +4 -1
  110. doctr/models/recognition/vitstr/pytorch.py +22 -13
  111. doctr/models/recognition/vitstr/tensorflow.py +21 -10
  112. doctr/models/recognition/zoo.py +4 -2
  113. doctr/models/utils/pytorch.py +24 -6
  114. doctr/models/utils/tensorflow.py +22 -3
  115. doctr/models/zoo.py +21 -3
  116. doctr/transforms/functional/base.py +8 -3
  117. doctr/transforms/functional/pytorch.py +23 -6
  118. doctr/transforms/functional/tensorflow.py +25 -5
  119. doctr/transforms/modules/base.py +12 -5
  120. doctr/transforms/modules/pytorch.py +10 -12
  121. doctr/transforms/modules/tensorflow.py +17 -9
  122. doctr/utils/common_types.py +1 -1
  123. doctr/utils/data.py +4 -2
  124. doctr/utils/fonts.py +3 -2
  125. doctr/utils/geometry.py +95 -26
  126. doctr/utils/metrics.py +36 -22
  127. doctr/utils/multithreading.py +5 -3
  128. doctr/utils/repr.py +3 -1
  129. doctr/utils/visualization.py +31 -8
  130. doctr/version.py +1 -1
  131. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
  132. python_doctr-0.8.1.dist-info/RECORD +173 -0
  133. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
  134. python_doctr-0.7.0.dist-info/RECORD +0 -161
  135. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
  136. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
  137. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -22,6 +22,7 @@ class DBPostProcessor(DetectionPostProcessor):
22
22
  <https://github.com/xuannianz/DifferentiableBinarization>`_.
23
23
 
24
24
  Args:
25
+ ----
25
26
  unclip ratio: ratio used to unshrink polygons
26
27
  min_size_box: minimal length (pix) to keep a box
27
28
  max_candidates: maximum boxes to consider in a single page
@@ -37,7 +38,7 @@ class DBPostProcessor(DetectionPostProcessor):
37
38
  assume_straight_pages: bool = True,
38
39
  ) -> None:
39
40
  super().__init__(box_thresh, bin_thresh, assume_straight_pages)
40
- self.unclip_ratio = 1.5 if assume_straight_pages else 2.2
41
+ self.unclip_ratio = 1.5
41
42
 
42
43
  def polygon_to_box(
43
44
  self,
@@ -46,9 +47,11 @@ class DBPostProcessor(DetectionPostProcessor):
46
47
  """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
47
48
 
48
49
  Args:
50
+ ----
49
51
  points: The first parameter.
50
52
 
51
53
  Returns:
54
+ -------
52
55
  a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
53
56
  """
54
57
  if not self.assume_straight_pages:
@@ -80,7 +83,7 @@ class DBPostProcessor(DetectionPostProcessor):
80
83
  if len(expanded_points) < 1:
81
84
  return None # type: ignore[return-value]
82
85
  return (
83
- cv2.boundingRect(expanded_points)
86
+ cv2.boundingRect(expanded_points) # type: ignore[return-value]
84
87
  if self.assume_straight_pages
85
88
  else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0)
86
89
  )
@@ -90,20 +93,22 @@ class DBPostProcessor(DetectionPostProcessor):
90
93
  pred: np.ndarray,
91
94
  bitmap: np.ndarray,
92
95
  ) -> np.ndarray:
93
- """Compute boxes from a bitmap/pred_map
96
+ """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
94
97
 
95
98
  Args:
99
+ ----
96
100
  pred: Pred map from differentiable binarization output
97
101
  bitmap: Bitmap map computed from pred (binarized)
98
102
  angle_tol: Comparison tolerance of the angle with the median angle across the page
99
103
  ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
100
104
 
101
105
  Returns:
106
+ -------
102
107
  np tensor boxes for the bitmap, each box is a 5-element list
103
108
  containing x, y, w, h, score for the box
104
109
  """
105
110
  height, width = bitmap.shape[:2]
106
- min_size_box = 1 + int(height / 512)
111
+ min_size_box = 2
107
112
  boxes: List[Union[np.ndarray, List[float]]] = []
108
113
  # get contours from connected components on the bitmap
109
114
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
@@ -158,6 +163,7 @@ class _DBNet:
158
163
  <https://arxiv.org/pdf/1911.08947.pdf>`_.
159
164
 
160
165
  Args:
166
+ ----
161
167
  feature extractor: the backbone serving as feature extractor
162
168
  fpn_channels: number of channels each extracted feature maps is mapped to
163
169
  """
@@ -174,17 +180,20 @@ class _DBNet:
174
180
  ys: np.ndarray,
175
181
  a: np.ndarray,
176
182
  b: np.ndarray,
177
- eps: float = 1e-7,
183
+ eps: float = 1e-6,
178
184
  ) -> float:
179
185
  """Compute the distance for each point of the map (xs, ys) to the (a, b) segment
180
186
 
181
187
  Args:
188
+ ----
182
189
  xs : map of x coordinates (height, width)
183
190
  ys : map of y coordinates (height, width)
184
191
  a: first point defining the [ab] segment
185
192
  b: second point defining the [ab] segment
193
+ eps: epsilon to avoid division by zero
186
194
 
187
195
  Returns:
196
+ -------
188
197
  The computed distance
189
198
 
190
199
  """
@@ -192,9 +201,10 @@ class _DBNet:
192
201
  square_dist_2 = np.square(xs - b[0]) + np.square(ys - b[1])
193
202
  square_dist = np.square(a[0] - b[0]) + np.square(a[1] - b[1])
194
203
  cosin = (square_dist - square_dist_1 - square_dist_2) / (2 * np.sqrt(square_dist_1 * square_dist_2) + eps)
204
+ cosin = np.clip(cosin, -1.0, 1.0)
195
205
  square_sin = 1 - np.square(cosin)
196
206
  square_sin = np.nan_to_num(square_sin)
197
- result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist)
207
+ result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist + eps)
198
208
  result[cosin < 0] = np.sqrt(np.fmin(square_dist_1, square_dist_2))[cosin < 0]
199
209
  return result
200
210
 
@@ -207,6 +217,7 @@ class _DBNet:
207
217
  """Draw a polygon treshold map on a canvas, as described in the DB paper
208
218
 
209
219
  Args:
220
+ ----
210
221
  polygon : array of coord., to draw the boundary of the polygon
211
222
  canvas : threshold map to fill with polygons
212
223
  mask : mask for training on threshold polygons
@@ -223,7 +234,7 @@ class _DBNet:
223
234
  padded_polygon: np.ndarray = np.array(padding.Execute(distance)[0])
224
235
 
225
236
  # Fill the mask with 1 on the new padded polygon
226
- cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
237
+ cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) # type: ignore[call-overload]
227
238
 
228
239
  # Get min/max to recover polygon after distance computation
229
240
  xmin = padded_polygon[:, 0].min()
@@ -255,7 +266,10 @@ class _DBNet:
255
266
 
256
267
  # Fill the canvas with the distances computed inside the valid padded polygon
257
268
  canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax(
258
- 1 - distance_map[ymin_valid - ymin : ymax_valid - ymin + 1, xmin_valid - xmin : xmax_valid - xmin + 1],
269
+ 1
270
+ - distance_map[
271
+ ymin_valid - ymin : ymax_valid - ymax + height, xmin_valid - xmin : xmax_valid - xmax + width
272
+ ],
259
273
  canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1],
260
274
  )
261
275
 
@@ -264,7 +278,7 @@ class _DBNet:
264
278
  def build_target(
265
279
  self,
266
280
  target: List[Dict[str, np.ndarray]],
267
- output_shape: Tuple[int, int, int, int],
281
+ output_shape: Tuple[int, int, int],
268
282
  channels_last: bool = True,
269
283
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
270
284
  if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
@@ -274,23 +288,24 @@ class _DBNet:
274
288
 
275
289
  input_dtype = next(iter(target[0].values())).dtype if len(target) > 0 else np.float32
276
290
 
291
+ h: int
292
+ w: int
277
293
  if channels_last:
278
- h, w = output_shape[1:-1]
279
- target_shape = (output_shape[0], output_shape[-1], h, w) # (Batch_size, num_classes, h, w)
294
+ h, w, num_classes = output_shape
280
295
  else:
281
- h, w = output_shape[-2:]
282
- target_shape = output_shape # (Batch_size, num_classes, h, w)
296
+ num_classes, h, w = output_shape
297
+ target_shape = (len(target), num_classes, h, w)
298
+
283
299
  seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
284
300
  seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
285
301
  thresh_target: np.ndarray = np.zeros(target_shape, dtype=np.float32)
286
- thresh_mask: np.ndarray = np.ones(target_shape, dtype=np.uint8)
302
+ thresh_mask: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
287
303
 
288
304
  for idx, tgt in enumerate(target):
289
305
  for class_idx, _tgt in enumerate(tgt.values()):
290
306
  # Draw each polygon on gt
291
307
  if _tgt.shape[0] == 0:
292
308
  # Empty image, full masked
293
- # seg_mask[idx, :, :, class_idx] = False
294
309
  seg_mask[idx, class_idx] = False
295
310
 
296
311
  # Absolute bounding boxes
@@ -316,10 +331,9 @@ class _DBNet:
316
331
  )
317
332
  boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])
318
333
 
319
- for box, box_size, poly in zip(abs_boxes, boxes_size, polys):
334
+ for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
320
335
  # Mask boxes that are too small
321
336
  if box_size < self.min_size_box:
322
- # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
323
337
  seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
324
338
  continue
325
339
 
@@ -329,19 +343,17 @@ class _DBNet:
329
343
  subject = [tuple(coor) for coor in poly]
330
344
  padding = pyclipper.PyclipperOffset()
331
345
  padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
332
- shrinked = padding.Execute(-distance)
346
+ shrunken = padding.Execute(-distance)
333
347
 
334
348
  # Draw polygon on gt if it is valid
335
- if len(shrinked) == 0:
336
- # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
349
+ if len(shrunken) == 0:
337
350
  seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
338
351
  continue
339
- shrinked = np.array(shrinked[0]).reshape(-1, 2)
340
- if shrinked.shape[0] <= 2 or not Polygon(shrinked).is_valid:
341
- # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
352
+ shrunken = np.array(shrunken[0]).reshape(-1, 2)
353
+ if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
342
354
  seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
343
355
  continue
344
- cv2.fillPoly(seg_target[idx, class_idx], [shrinked.astype(np.int32)], 1)
356
+ cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
345
357
 
346
358
  # Draw on both thresh map and thresh mask
347
359
  poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map(
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -16,10 +16,10 @@ from torchvision.ops.deform_conv import DeformConv2d
16
16
  from doctr.file_utils import CLASS_NAME
17
17
 
18
18
  from ...classification import mobilenet_v3_large
19
- from ...utils import load_pretrained_params
19
+ from ...utils import _bf16_to_float32, load_pretrained_params
20
20
  from .base import DBPostProcessor, _DBNet
21
21
 
22
- __all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large", "db_resnet50_rotation"]
22
+ __all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
23
23
 
24
24
 
25
25
  default_cfgs: Dict[str, Dict[str, Any]] = {
@@ -27,25 +27,19 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
27
27
  "input_shape": (3, 1024, 1024),
28
28
  "mean": (0.798, 0.785, 0.772),
29
29
  "std": (0.264, 0.2749, 0.287),
30
- "url": "https://doctr-static.mindee.com/models?id=v0.3.1/db_resnet50-ac60cadc.pt&src=0",
30
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-79bd7d70.pt&src=0",
31
31
  },
32
32
  "db_resnet34": {
33
33
  "input_shape": (3, 1024, 1024),
34
34
  "mean": (0.798, 0.785, 0.772),
35
35
  "std": (0.264, 0.2749, 0.287),
36
- "url": None,
36
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet34-cb6aed9e.pt&src=0",
37
37
  },
38
38
  "db_mobilenet_v3_large": {
39
39
  "input_shape": (3, 1024, 1024),
40
40
  "mean": (0.798, 0.785, 0.772),
41
41
  "std": (0.264, 0.2749, 0.287),
42
- "url": "https://doctr-static.mindee.com/models?id=v0.3.1/db_mobilenet_v3_large-fd62154b.pt&src=0",
43
- },
44
- "db_resnet50_rotation": {
45
- "input_shape": (3, 1024, 1024),
46
- "mean": (0.798, 0.785, 0.772),
47
- "std": (0.264, 0.2749, 0.287),
48
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/db_resnet50-1138863a.pt&src=0",
42
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-81e9b152.pt&src=0",
49
43
  },
50
44
  }
51
45
 
@@ -63,28 +57,24 @@ class FeaturePyramidNetwork(nn.Module):
63
57
 
64
58
  conv_layer = DeformConv2d if deform_conv else nn.Conv2d
65
59
 
66
- self.in_branches = nn.ModuleList(
67
- [
68
- nn.Sequential(
69
- conv_layer(chans, out_channels, 1, bias=False),
70
- nn.BatchNorm2d(out_channels),
71
- nn.ReLU(inplace=True),
72
- )
73
- for idx, chans in enumerate(in_channels)
74
- ]
75
- )
60
+ self.in_branches = nn.ModuleList([
61
+ nn.Sequential(
62
+ conv_layer(chans, out_channels, 1, bias=False),
63
+ nn.BatchNorm2d(out_channels),
64
+ nn.ReLU(inplace=True),
65
+ )
66
+ for idx, chans in enumerate(in_channels)
67
+ ])
76
68
  self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
77
- self.out_branches = nn.ModuleList(
78
- [
79
- nn.Sequential(
80
- conv_layer(out_channels, out_chans, 3, padding=1, bias=False),
81
- nn.BatchNorm2d(out_chans),
82
- nn.ReLU(inplace=True),
83
- nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True),
84
- )
85
- for idx, chans in enumerate(in_channels)
86
- ]
87
- )
69
+ self.out_branches = nn.ModuleList([
70
+ nn.Sequential(
71
+ conv_layer(out_channels, out_chans, 3, padding=1, bias=False),
72
+ nn.BatchNorm2d(out_chans),
73
+ nn.ReLU(inplace=True),
74
+ nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True),
75
+ )
76
+ for idx, chans in enumerate(in_channels)
77
+ ])
88
78
 
89
79
  def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
90
80
  if len(x) != len(self.out_branches):
@@ -106,9 +96,12 @@ class DBNet(_DBNet, nn.Module):
106
96
  <https://arxiv.org/pdf/1911.08947.pdf>`_.
107
97
 
108
98
  Args:
99
+ ----
109
100
  feature extractor: the backbone serving as feature extractor
110
101
  head_chans: the number of channels in the head
111
102
  deform_conv: whether to use deformable convolution
103
+ bin_thresh: threshold for binarization
104
+ box_thresh: minimal objectness score to consider a box
112
105
  assume_straight_pages: if True, fit straight bounding boxes only
113
106
  exportable: onnx exportable returns only logits
114
107
  cfg: the configuration dict of the model
@@ -121,6 +114,7 @@ class DBNet(_DBNet, nn.Module):
121
114
  head_chans: int = 256,
122
115
  deform_conv: bool = False,
123
116
  bin_thresh: float = 0.3,
117
+ box_thresh: float = 0.1,
124
118
  assume_straight_pages: bool = True,
125
119
  exportable: bool = False,
126
120
  cfg: Optional[Dict[str, Any]] = None,
@@ -169,7 +163,9 @@ class DBNet(_DBNet, nn.Module):
169
163
  nn.ConvTranspose2d(head_chans // 4, num_classes, 2, stride=2),
170
164
  )
171
165
 
172
- self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh)
166
+ self.postprocessor = DBPostProcessor(
167
+ assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
168
+ )
173
169
 
174
170
  for n, m in self.named_modules():
175
171
  # Don't override the initialization of the backbone
@@ -203,7 +199,7 @@ class DBNet(_DBNet, nn.Module):
203
199
  return out
204
200
 
205
201
  if return_model_output or target is None or return_preds:
206
- prob_map = torch.sigmoid(logits)
202
+ prob_map = _bf16_to_float32(torch.sigmoid(logits))
207
203
 
208
204
  if return_model_output:
209
205
  out["out_map"] = prob_map
@@ -222,64 +218,72 @@ class DBNet(_DBNet, nn.Module):
222
218
 
223
219
  return out
224
220
 
225
- def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: List[np.ndarray]) -> torch.Tensor:
221
+ def compute_loss(
222
+ self,
223
+ out_map: torch.Tensor,
224
+ thresh_map: torch.Tensor,
225
+ target: List[np.ndarray],
226
+ gamma: float = 2.0,
227
+ alpha: float = 0.5,
228
+ eps: float = 1e-8,
229
+ ) -> torch.Tensor:
226
230
  """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
227
231
  and a list of masks for each image. From there it computes the loss with the model output
228
232
 
229
233
  Args:
234
+ ----
230
235
  out_map: output feature map of the model of shape (N, C, H, W)
231
236
  thresh_map: threshold map of shape (N, C, H, W)
232
237
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
238
+ gamma: modulating factor in the focal loss formula
239
+ alpha: balancing factor in the focal loss formula
240
+ eps: epsilon factor in dice loss
233
241
 
234
242
  Returns:
243
+ -------
235
244
  A loss tensor
236
245
  """
246
+ if gamma < 0:
247
+ raise ValueError("Value of gamma should be greater than or equal to zero.")
237
248
 
238
249
  prob_map = torch.sigmoid(out_map)
239
250
  thresh_map = torch.sigmoid(thresh_map)
240
251
 
241
- targets = self.build_target(target, prob_map.shape, False) # type: ignore[arg-type]
252
+ targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
242
253
 
243
254
  seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
244
255
  seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
245
256
  thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3])
246
257
  thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device)
247
258
 
248
- # Compute balanced BCE loss for proba_map
249
- bce_scale = 5.0
250
- balanced_bce_loss = torch.zeros(1, device=out_map.device)
251
- dice_loss = torch.zeros(1, device=out_map.device)
252
- l1_loss = torch.zeros(1, device=out_map.device)
253
259
  if torch.any(seg_mask):
254
- bce_loss = F.binary_cross_entropy_with_logits(
255
- out_map,
256
- seg_target,
257
- reduction="none",
258
- )[seg_mask]
259
-
260
- neg_target = 1 - seg_target[seg_mask]
261
- positive_count = seg_target[seg_mask].sum()
262
- negative_count = torch.minimum(neg_target.sum(), 3.0 * positive_count)
263
- negative_loss = bce_loss * neg_target
264
- negative_loss = negative_loss.sort().values[-int(negative_count.item()) :]
265
- sum_losses = torch.sum(bce_loss * seg_target[seg_mask]) + torch.sum(negative_loss)
266
- balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6)
267
-
268
- # Compute dice loss for approxbin_map
269
- bin_map = 1 / (1 + torch.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask])))
270
-
271
- bce_min = bce_loss.min()
272
- weights = (bce_loss - bce_min) / (bce_loss.max() - bce_min) + 1.0
273
- inter = torch.sum(bin_map * seg_target[seg_mask] * weights)
274
- union = torch.sum(bin_map) + torch.sum(seg_target[seg_mask]) + 1e-8
275
- dice_loss = 1 - 2.0 * inter / union
260
+ # Focal loss
261
+ focal_scale = 10.0
262
+ bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none")
263
+
264
+ p_t = prob_map * seg_target + (1 - prob_map) * (1 - seg_target)
265
+ alpha_t = alpha * seg_target + (1 - alpha) * (1 - seg_target)
266
+ # Unreduced version
267
+ focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
268
+ # Class reduced
269
+ focal_loss = (seg_mask * focal_loss).sum((0, 1, 2, 3)) / seg_mask.sum((0, 1, 2, 3))
270
+
271
+ # Compute dice loss for each class or for approx binary_map
272
+ if len(self.class_names) > 1:
273
+ dice_map = torch.softmax(out_map, dim=1)
274
+ else:
275
+ # compute binary map instead
276
+ dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment]
277
+ # Class reduced
278
+ inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
279
+ cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
280
+ dice_loss = (1 - 2 * inter / (cardinality + eps)).mean()
276
281
 
277
282
  # Compute l1 loss for thresh_map
278
- l1_scale = 10.0
279
283
  if torch.any(thresh_mask):
280
- l1_loss = torch.mean(torch.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask]))
284
+ l1_loss = (torch.abs(thresh_map - thresh_target) * thresh_mask).sum() / (thresh_mask.sum() + eps)
281
285
 
282
- return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss
286
+ return l1_loss + focal_scale * focal_loss + dice_loss
283
287
 
284
288
 
285
289
  def _dbnet(
@@ -337,12 +341,14 @@ def db_resnet34(pretrained: bool = False, **kwargs: Any) -> DBNet:
337
341
  >>> out = model(input_tensor)
338
342
 
339
343
  Args:
344
+ ----
340
345
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
346
+ **kwargs: keyword arguments of the DBNet architecture
341
347
 
342
348
  Returns:
349
+ -------
343
350
  text detection architecture
344
351
  """
345
-
346
352
  return _dbnet(
347
353
  "db_resnet34",
348
354
  pretrained,
@@ -370,12 +376,14 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
370
376
  >>> out = model(input_tensor)
371
377
 
372
378
  Args:
379
+ ----
373
380
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
381
+ **kwargs: keyword arguments of the DBNet architecture
374
382
 
375
383
  Returns:
384
+ -------
376
385
  text detection architecture
377
386
  """
378
-
379
387
  return _dbnet(
380
388
  "db_resnet50",
381
389
  pretrained,
@@ -403,12 +411,14 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
403
411
  >>> out = model(input_tensor)
404
412
 
405
413
  Args:
414
+ ----
406
415
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
416
+ **kwargs: keyword arguments of the DBNet architecture
407
417
 
408
418
  Returns:
419
+ -------
409
420
  text detection architecture
410
421
  """
411
-
412
422
  return _dbnet(
413
423
  "db_mobilenet_v3_large",
414
424
  pretrained,
@@ -423,37 +433,3 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
423
433
  ],
424
434
  **kwargs,
425
435
  )
426
-
427
-
428
- def db_resnet50_rotation(pretrained: bool = False, **kwargs: Any) -> DBNet:
429
- """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
430
- <https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-50 backbone.
431
- This model is trained with rotated documents
432
-
433
- >>> import torch
434
- >>> from doctr.models import db_resnet50_rotation
435
- >>> model = db_resnet50_rotation(pretrained=True)
436
- >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
437
- >>> out = model(input_tensor)
438
-
439
- Args:
440
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
441
-
442
- Returns:
443
- text detection architecture
444
- """
445
-
446
- return _dbnet(
447
- "db_resnet50_rotation",
448
- pretrained,
449
- resnet50,
450
- ["layer1", "layer2", "layer3", "layer4"],
451
- None,
452
- ignore_keys=[
453
- "prob_head.6.weight",
454
- "prob_head.6.bias",
455
- "thresh_head.6.weight",
456
- "thresh_head.6.bias",
457
- ],
458
- **kwargs,
459
- )