python-doctr 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. doctr/__init__.py +1 -1
  2. doctr/contrib/__init__.py +0 -0
  3. doctr/contrib/artefacts.py +131 -0
  4. doctr/contrib/base.py +105 -0
  5. doctr/datasets/cord.py +10 -1
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +11 -1
  8. doctr/datasets/generator/base.py +6 -5
  9. doctr/datasets/ic03.py +11 -1
  10. doctr/datasets/ic13.py +10 -1
  11. doctr/datasets/iiit5k.py +26 -16
  12. doctr/datasets/imgur5k.py +11 -2
  13. doctr/datasets/loader.py +1 -6
  14. doctr/datasets/sroie.py +11 -1
  15. doctr/datasets/svhn.py +11 -1
  16. doctr/datasets/svt.py +11 -1
  17. doctr/datasets/synthtext.py +11 -1
  18. doctr/datasets/utils.py +9 -3
  19. doctr/datasets/vocabs.py +15 -4
  20. doctr/datasets/wildreceipt.py +12 -1
  21. doctr/file_utils.py +45 -12
  22. doctr/io/elements.py +52 -10
  23. doctr/io/html.py +2 -2
  24. doctr/io/image/pytorch.py +6 -8
  25. doctr/io/image/tensorflow.py +1 -1
  26. doctr/io/pdf.py +5 -2
  27. doctr/io/reader.py +6 -0
  28. doctr/models/__init__.py +0 -1
  29. doctr/models/_utils.py +57 -20
  30. doctr/models/builder.py +73 -15
  31. doctr/models/classification/magc_resnet/tensorflow.py +13 -6
  32. doctr/models/classification/mobilenet/pytorch.py +47 -9
  33. doctr/models/classification/mobilenet/tensorflow.py +51 -14
  34. doctr/models/classification/predictor/pytorch.py +28 -17
  35. doctr/models/classification/predictor/tensorflow.py +26 -16
  36. doctr/models/classification/resnet/tensorflow.py +21 -8
  37. doctr/models/classification/textnet/pytorch.py +3 -3
  38. doctr/models/classification/textnet/tensorflow.py +11 -5
  39. doctr/models/classification/vgg/tensorflow.py +9 -3
  40. doctr/models/classification/vit/tensorflow.py +10 -4
  41. doctr/models/classification/zoo.py +55 -19
  42. doctr/models/detection/_utils/__init__.py +1 -0
  43. doctr/models/detection/_utils/base.py +66 -0
  44. doctr/models/detection/differentiable_binarization/base.py +4 -3
  45. doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
  46. doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
  47. doctr/models/detection/fast/base.py +6 -5
  48. doctr/models/detection/fast/pytorch.py +4 -4
  49. doctr/models/detection/fast/tensorflow.py +15 -12
  50. doctr/models/detection/linknet/base.py +4 -3
  51. doctr/models/detection/linknet/tensorflow.py +23 -11
  52. doctr/models/detection/predictor/pytorch.py +15 -1
  53. doctr/models/detection/predictor/tensorflow.py +17 -3
  54. doctr/models/detection/zoo.py +7 -2
  55. doctr/models/factory/hub.py +8 -18
  56. doctr/models/kie_predictor/base.py +13 -3
  57. doctr/models/kie_predictor/pytorch.py +45 -20
  58. doctr/models/kie_predictor/tensorflow.py +44 -17
  59. doctr/models/modules/layers/pytorch.py +2 -3
  60. doctr/models/modules/layers/tensorflow.py +6 -8
  61. doctr/models/modules/transformer/pytorch.py +2 -2
  62. doctr/models/modules/transformer/tensorflow.py +0 -2
  63. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  64. doctr/models/modules/vision_transformer/tensorflow.py +1 -1
  65. doctr/models/predictor/base.py +97 -58
  66. doctr/models/predictor/pytorch.py +35 -20
  67. doctr/models/predictor/tensorflow.py +35 -18
  68. doctr/models/preprocessor/pytorch.py +4 -4
  69. doctr/models/preprocessor/tensorflow.py +3 -2
  70. doctr/models/recognition/crnn/tensorflow.py +8 -6
  71. doctr/models/recognition/master/pytorch.py +2 -2
  72. doctr/models/recognition/master/tensorflow.py +9 -4
  73. doctr/models/recognition/parseq/pytorch.py +4 -3
  74. doctr/models/recognition/parseq/tensorflow.py +14 -11
  75. doctr/models/recognition/sar/pytorch.py +7 -6
  76. doctr/models/recognition/sar/tensorflow.py +10 -12
  77. doctr/models/recognition/vitstr/pytorch.py +1 -1
  78. doctr/models/recognition/vitstr/tensorflow.py +9 -4
  79. doctr/models/recognition/zoo.py +1 -1
  80. doctr/models/utils/pytorch.py +1 -1
  81. doctr/models/utils/tensorflow.py +15 -15
  82. doctr/models/zoo.py +2 -2
  83. doctr/py.typed +0 -0
  84. doctr/transforms/functional/base.py +1 -1
  85. doctr/transforms/functional/pytorch.py +5 -5
  86. doctr/transforms/modules/base.py +37 -15
  87. doctr/transforms/modules/pytorch.py +73 -14
  88. doctr/transforms/modules/tensorflow.py +78 -19
  89. doctr/utils/fonts.py +7 -5
  90. doctr/utils/geometry.py +141 -31
  91. doctr/utils/metrics.py +34 -175
  92. doctr/utils/reconstitution.py +212 -0
  93. doctr/utils/visualization.py +5 -118
  94. doctr/version.py +1 -1
  95. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/METADATA +85 -81
  96. python_doctr-0.10.0.dist-info/RECORD +173 -0
  97. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
  98. doctr/models/artefacts/__init__.py +0 -2
  99. doctr/models/artefacts/barcode.py +0 -74
  100. doctr/models/artefacts/face.py +0 -63
  101. doctr/models/obj_detection/__init__.py +0 -1
  102. doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
  103. doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
  104. python_doctr-0.8.1.dist-info/RECORD +0 -173
  105. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
  106. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
  107. {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
doctr/models/_utils.py CHANGED
@@ -11,6 +11,8 @@ import cv2
11
11
  import numpy as np
12
12
  from langdetect import LangDetectException, detect_langs
13
13
 
14
+ from doctr.utils.geometry import rotate_image
15
+
14
16
  __all__ = ["estimate_orientation", "get_language", "invert_data_structure"]
15
17
 
16
18
 
@@ -29,56 +31,91 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float:
29
31
  return max(w / h, h / w)
30
32
 
31
33
 
32
- def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> int:
34
+ def estimate_orientation(
35
+ img: np.ndarray,
36
+ general_page_orientation: Optional[Tuple[int, float]] = None,
37
+ n_ct: int = 70,
38
+ ratio_threshold_for_lines: float = 3,
39
+ min_confidence: float = 0.2,
40
+ lower_area: int = 100,
41
+ ) -> int:
33
42
  """Estimate the angle of the general document orientation based on the
34
43
  lines of the document and the assumption that they should be horizontal.
35
44
 
36
45
  Args:
37
46
  ----
38
47
  img: the img or bitmap to analyze (H, W, C)
48
+ general_page_orientation: the general orientation of the page (angle [0, 90, 180, 270 (-90)], confidence)
49
+ estimated by a model
39
50
  n_ct: the number of contours used for the orientation estimation
40
51
  ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines
52
+ min_confidence: the minimum confidence to consider the general_page_orientation
53
+ lower_area: the minimum area of a contour to be considered
41
54
 
42
55
  Returns:
43
56
  -------
44
- the angle of the general document orientation
57
+ the estimated angle of the page (clockwise, negative for left side rotation, positive for right side rotation)
45
58
  """
46
59
  assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported"
47
- max_value = np.max(img)
48
- min_value = np.min(img)
49
- if max_value <= 1 and min_value >= 0 or (max_value <= 255 and min_value >= 0 and img.shape[-1] == 1):
50
- thresh = img.astype(np.uint8)
51
- if max_value <= 255 and min_value >= 0 and img.shape[-1] == 3:
60
+ thresh = None
61
+ # Convert image to grayscale if necessary
62
+ if img.shape[-1] == 3:
52
63
  gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
53
64
  gray_img = cv2.medianBlur(gray_img, 5)
54
- thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] # type: ignore[assignment]
55
-
56
- # try to merge words in lines
57
- (h, w) = img.shape[:2]
58
- k_x = max(1, (floor(w / 100)))
59
- k_y = max(1, (floor(h / 100)))
60
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y))
61
- thresh = cv2.dilate(thresh, kernel, iterations=1) # type: ignore[assignment]
65
+ thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
66
+ else:
67
+ thresh = img.astype(np.uint8) # type: ignore[assignment]
68
+
69
+ page_orientation, orientation_confidence = general_page_orientation or (None, 0.0)
70
+ if page_orientation and orientation_confidence >= min_confidence:
71
+ # We rotate the image to the general orientation which improves the detection
72
+ # No expand needed bitmap is already padded
73
+ thresh = rotate_image(thresh, -page_orientation) # type: ignore
74
+ else: # That's only required if we do not work on the detection models bin map
75
+ # try to merge words in lines
76
+ (h, w) = img.shape[:2]
77
+ k_x = max(1, (floor(w / 100)))
78
+ k_y = max(1, (floor(h / 100)))
79
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y))
80
+ thresh = cv2.dilate(thresh, kernel, iterations=1)
62
81
 
63
82
  # extract contours
64
83
  contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
65
84
 
66
- # Sort contours
67
- contours = sorted(contours, key=get_max_width_length_ratio, reverse=True)
85
+ # Filter & Sort contours
86
+ contours = sorted(
87
+ [contour for contour in contours if cv2.contourArea(contour) > lower_area],
88
+ key=get_max_width_length_ratio,
89
+ reverse=True,
90
+ )
68
91
 
69
92
  angles = []
70
93
  for contour in contours[:n_ct]:
71
- _, (w, h), angle = cv2.minAreaRect(contour)
94
+ _, (w, h), angle = cv2.minAreaRect(contour) # type: ignore[assignment]
72
95
  if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines
73
96
  angles.append(angle)
74
97
  elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree
75
98
  angles.append(angle - 90)
76
99
 
77
100
  if len(angles) == 0:
78
- return 0 # in case no angles is found
101
+ estimated_angle = 0 # in case no angles is found
79
102
  else:
80
103
  median = -median_low(angles)
81
- return round(median) if abs(median) != 0 else 0
104
+ estimated_angle = -round(median) if abs(median) != 0 else 0
105
+
106
+ # combine with the general orientation and the estimated angle
107
+ if page_orientation and orientation_confidence >= min_confidence:
108
+ # special case where the estimated angle is mostly wrong:
109
+ # case 1: - and + swapped
110
+ # case 2: estimated angle is completely wrong
111
+ # so in this case we prefer the general page orientation
112
+ if abs(estimated_angle) == abs(page_orientation):
113
+ return page_orientation
114
+ estimated_angle = estimated_angle if page_orientation == 0 else page_orientation + estimated_angle
115
+ if estimated_angle > 180:
116
+ estimated_angle -= 360
117
+
118
+ return estimated_angle # return the clockwise angle (negative - left side rotation, positive - right side rotation)
82
119
 
83
120
 
84
121
  def rectify_crops(
doctr/models/builder.py CHANGED
@@ -31,7 +31,7 @@ class DocumentBuilder(NestedObject):
31
31
  def __init__(
32
32
  self,
33
33
  resolve_lines: bool = True,
34
- resolve_blocks: bool = True,
34
+ resolve_blocks: bool = False,
35
35
  paragraph_break: float = 0.035,
36
36
  export_as_straight_boxes: bool = False,
37
37
  ) -> None:
@@ -220,13 +220,22 @@ class DocumentBuilder(NestedObject):
220
220
 
221
221
  return blocks
222
222
 
223
- def _build_blocks(self, boxes: np.ndarray, word_preds: List[Tuple[str, float]]) -> List[Block]:
223
+ def _build_blocks(
224
+ self,
225
+ boxes: np.ndarray,
226
+ objectness_scores: np.ndarray,
227
+ word_preds: List[Tuple[str, float]],
228
+ crop_orientations: List[Dict[str, Any]],
229
+ ) -> List[Block]:
224
230
  """Gather independent words in structured blocks
225
231
 
226
232
  Args:
227
233
  ----
228
- boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2)
234
+ boxes: bounding boxes of all detected words of the page, of shape (N, 4) or (N, 4, 2)
235
+ objectness_scores: objectness scores of all detected words of the page, of shape N
229
236
  word_preds: list of all detected words of the page, of shape N
237
+ crop_orientations: list of dictoinaries containing
238
+ the general orientation (orientations + confidences) of the crops
230
239
 
231
240
  Returns:
232
241
  -------
@@ -257,10 +266,17 @@ class DocumentBuilder(NestedObject):
257
266
  Line([
258
267
  Word(
259
268
  *word_preds[idx],
260
- tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type]
269
+ tuple(tuple(pt) for pt in boxes[idx].tolist()), # type: ignore[arg-type]
270
+ float(objectness_scores[idx]),
271
+ crop_orientations[idx],
261
272
  )
262
273
  if boxes.ndim == 3
263
- else Word(*word_preds[idx], ((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])))
274
+ else Word(
275
+ *word_preds[idx],
276
+ ((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])),
277
+ float(objectness_scores[idx]),
278
+ crop_orientations[idx],
279
+ )
264
280
  for idx in line
265
281
  ])
266
282
  for line in lines
@@ -281,8 +297,10 @@ class DocumentBuilder(NestedObject):
281
297
  self,
282
298
  pages: List[np.ndarray],
283
299
  boxes: List[np.ndarray],
300
+ objectness_scores: List[np.ndarray],
284
301
  text_preds: List[List[Tuple[str, float]]],
285
302
  page_shapes: List[Tuple[int, int]],
303
+ crop_orientations: List[Dict[str, Any]],
286
304
  orientations: Optional[List[Dict[str, Any]]] = None,
287
305
  languages: Optional[List[Dict[str, Any]]] = None,
288
306
  ) -> Document:
@@ -291,10 +309,13 @@ class DocumentBuilder(NestedObject):
291
309
  Args:
292
310
  ----
293
311
  pages: list of N elements, where each element represents the page image
294
- boxes: list of N elements, where each element represents the localization predictions, of shape (*, 5)
295
- or (*, 6) for all words for a given page
312
+ boxes: list of N elements, where each element represents the localization predictions, of shape (*, 4)
313
+ or (*, 4, 2) for all words for a given page
314
+ objectness_scores: list of N elements, where each element represents the objectness scores
296
315
  text_preds: list of N elements, where each element is the list of all word prediction (text + confidence)
297
316
  page_shapes: shape of each page, of size N
317
+ crop_orientations: list of N elements, where each element is
318
+ a dictionary containing the general orientation (orientations + confidences) of the crops
298
319
  orientations: optional, list of N elements,
299
320
  where each element is a dictionary containing the orientation (orientation + confidence)
300
321
  languages: optional, list of N elements,
@@ -304,7 +325,9 @@ class DocumentBuilder(NestedObject):
304
325
  -------
305
326
  document object
306
327
  """
307
- if len(boxes) != len(text_preds) or len(boxes) != len(page_shapes):
328
+ if len(boxes) != len(text_preds) != len(crop_orientations) != len(objectness_scores) or len(boxes) != len(
329
+ page_shapes
330
+ ) != len(crop_orientations) != len(objectness_scores):
308
331
  raise ValueError("All arguments are expected to be lists of the same size")
309
332
 
310
333
  _orientations = (
@@ -322,15 +345,25 @@ class DocumentBuilder(NestedObject):
322
345
  page,
323
346
  self._build_blocks(
324
347
  page_boxes,
348
+ loc_scores,
325
349
  word_preds,
350
+ word_crop_orientations,
326
351
  ),
327
352
  _idx,
328
353
  shape,
329
354
  orientation,
330
355
  language,
331
356
  )
332
- for page, _idx, shape, page_boxes, word_preds, orientation, language in zip(
333
- pages, range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
357
+ for page, _idx, shape, page_boxes, loc_scores, word_preds, word_crop_orientations, orientation, language in zip( # noqa: E501
358
+ pages,
359
+ range(len(boxes)),
360
+ page_shapes,
361
+ boxes,
362
+ objectness_scores,
363
+ text_preds,
364
+ crop_orientations,
365
+ _orientations,
366
+ _languages,
334
367
  )
335
368
  ]
336
369
 
@@ -353,8 +386,10 @@ class KIEDocumentBuilder(DocumentBuilder):
353
386
  self,
354
387
  pages: List[np.ndarray],
355
388
  boxes: List[Dict[str, np.ndarray]],
389
+ objectness_scores: List[Dict[str, np.ndarray]],
356
390
  text_preds: List[Dict[str, List[Tuple[str, float]]]],
357
391
  page_shapes: List[Tuple[int, int]],
392
+ crop_orientations: List[Dict[str, List[Dict[str, Any]]]],
358
393
  orientations: Optional[List[Dict[str, Any]]] = None,
359
394
  languages: Optional[List[Dict[str, Any]]] = None,
360
395
  ) -> KIEDocument:
@@ -365,8 +400,11 @@ class KIEDocumentBuilder(DocumentBuilder):
365
400
  pages: list of N elements, where each element represents the page image
366
401
  boxes: list of N dictionaries, where each element represents the localization predictions for a class,
367
402
  of shape (*, 5) or (*, 6) for all predictions
403
+ objectness_scores: list of N dictionaries, where each element represents the objectness scores for a class
368
404
  text_preds: list of N dictionaries, where each element is the list of all word prediction
369
405
  page_shapes: shape of each page, of size N
406
+ crop_orientations: list of N dictonaries, where each element is
407
+ a list containing the general crop orientations (orientations + confidences) of the crops
370
408
  orientations: optional, list of N elements,
371
409
  where each element is a dictionary containing the orientation (orientation + confidence)
372
410
  languages: optional, list of N elements,
@@ -376,7 +414,9 @@ class KIEDocumentBuilder(DocumentBuilder):
376
414
  -------
377
415
  document object
378
416
  """
379
- if len(boxes) != len(text_preds) or len(boxes) != len(page_shapes):
417
+ if len(boxes) != len(text_preds) != len(crop_orientations) != len(objectness_scores) or len(boxes) != len(
418
+ page_shapes
419
+ ) != len(crop_orientations) != len(objectness_scores):
380
420
  raise ValueError("All arguments are expected to be lists of the same size")
381
421
  _orientations = (
382
422
  orientations if isinstance(orientations, list) else [None] * len(boxes) # type: ignore[list-item]
@@ -401,7 +441,9 @@ class KIEDocumentBuilder(DocumentBuilder):
401
441
  {
402
442
  k: self._build_blocks(
403
443
  page_boxes[k],
444
+ loc_scores[k],
404
445
  word_preds[k],
446
+ word_crop_orientations[k],
405
447
  )
406
448
  for k in page_boxes.keys()
407
449
  },
@@ -410,8 +452,16 @@ class KIEDocumentBuilder(DocumentBuilder):
410
452
  orientation,
411
453
  language,
412
454
  )
413
- for page, _idx, shape, page_boxes, word_preds, orientation, language in zip(
414
- pages, range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
455
+ for page, _idx, shape, page_boxes, loc_scores, word_preds, word_crop_orientations, orientation, language in zip( # noqa: E501
456
+ pages,
457
+ range(len(boxes)),
458
+ page_shapes,
459
+ boxes,
460
+ objectness_scores,
461
+ text_preds,
462
+ crop_orientations,
463
+ _orientations,
464
+ _languages,
415
465
  )
416
466
  ]
417
467
 
@@ -420,14 +470,18 @@ class KIEDocumentBuilder(DocumentBuilder):
420
470
  def _build_blocks( # type: ignore[override]
421
471
  self,
422
472
  boxes: np.ndarray,
473
+ objectness_scores: np.ndarray,
423
474
  word_preds: List[Tuple[str, float]],
475
+ crop_orientations: List[Dict[str, Any]],
424
476
  ) -> List[Prediction]:
425
477
  """Gather independent words in structured blocks
426
478
 
427
479
  Args:
428
480
  ----
429
- boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2)
481
+ boxes: bounding boxes of all detected words of the page, of shape (N, 4) or (N, 4, 2)
482
+ objectness_scores: objectness scores of all detected words of the page
430
483
  word_preds: list of all detected words of the page, of shape N
484
+ crop_orientations: list of orientations for each word crop
431
485
 
432
486
  Returns:
433
487
  -------
@@ -446,13 +500,17 @@ class KIEDocumentBuilder(DocumentBuilder):
446
500
  Prediction(
447
501
  value=word_preds[idx][0],
448
502
  confidence=word_preds[idx][1],
449
- geometry=tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type]
503
+ geometry=tuple(tuple(pt) for pt in boxes[idx].tolist()), # type: ignore[arg-type]
504
+ objectness_score=float(objectness_scores[idx]),
505
+ crop_orientation=crop_orientations[idx],
450
506
  )
451
507
  if boxes.ndim == 3
452
508
  else Prediction(
453
509
  value=word_preds[idx][0],
454
510
  confidence=word_preds[idx][1],
455
511
  geometry=((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])),
512
+ objectness_score=float(objectness_scores[idx]),
513
+ crop_orientation=crop_orientations[idx],
456
514
  )
457
515
  for idx in idxs
458
516
  ]
@@ -9,12 +9,12 @@ from functools import partial
9
9
  from typing import Any, Dict, List, Optional, Tuple
10
10
 
11
11
  import tensorflow as tf
12
- from tensorflow.keras import layers
12
+ from tensorflow.keras import activations, layers
13
13
  from tensorflow.keras.models import Sequential
14
14
 
15
15
  from doctr.datasets import VOCABS
16
16
 
17
- from ...utils import load_pretrained_params
17
+ from ...utils import _build_model, load_pretrained_params
18
18
  from ..resnet.tensorflow import ResNet
19
19
 
20
20
  __all__ = ["magc_resnet31"]
@@ -26,7 +26,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
26
26
  "std": (0.299, 0.296, 0.301),
27
27
  "input_shape": (32, 32, 3),
28
28
  "classes": list(VOCABS["french"]),
29
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/magc_resnet31-addbb705.zip&src=0",
29
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/magc_resnet31-16aa7d71.weights.h5&src=0",
30
30
  },
31
31
  }
32
32
 
@@ -57,6 +57,7 @@ class MAGC(layers.Layer):
57
57
  self.headers = headers # h
58
58
  self.inplanes = inplanes # C
59
59
  self.attn_scale = attn_scale
60
+ self.ratio = ratio
60
61
  self.planes = int(inplanes * ratio)
61
62
 
62
63
  self.single_header_inplanes = int(inplanes / headers) # C / h
@@ -97,7 +98,7 @@ class MAGC(layers.Layer):
97
98
  if self.attn_scale and self.headers > 1:
98
99
  context_mask = context_mask / math.sqrt(self.single_header_inplanes)
99
100
  # B*h, 1, H*W, 1
100
- context_mask = tf.keras.activations.softmax(context_mask, axis=2)
101
+ context_mask = activations.softmax(context_mask, axis=2)
101
102
 
102
103
  # Compute context
103
104
  # B*h, 1, C/h, 1
@@ -114,7 +115,7 @@ class MAGC(layers.Layer):
114
115
  # Context modeling: B, H, W, C -> B, 1, 1, C
115
116
  context = self.context_modeling(inputs)
116
117
  # Transform: B, 1, 1, C -> B, 1, 1, C
117
- transformed = self.transform(context)
118
+ transformed = self.transform(context, **kwargs)
118
119
  return inputs + transformed
119
120
 
120
121
 
@@ -151,9 +152,15 @@ def _magc_resnet(
151
152
  cfg=_cfg,
152
153
  **kwargs,
153
154
  )
155
+ _build_model(model)
156
+
154
157
  # Load pretrained parameters
155
158
  if pretrained:
156
- load_pretrained_params(model, default_cfgs[arch]["url"])
159
+ # The number of classes is not the same as the number of classes in the pretrained model =>
160
+ # skip the mismatching layers for fine tuning
161
+ load_pretrained_params(
162
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
163
+ )
157
164
 
158
165
  return model
159
166
 
@@ -9,17 +9,20 @@ from copy import deepcopy
9
9
  from typing import Any, Dict, List, Optional
10
10
 
11
11
  from torchvision.models import mobilenetv3
12
+ from torchvision.models.mobilenetv3 import MobileNetV3
12
13
 
13
14
  from doctr.datasets import VOCABS
14
15
 
15
16
  from ...utils import load_pretrained_params
16
17
 
17
18
  __all__ = [
19
+ "MobileNetV3",
18
20
  "mobilenet_v3_small",
19
21
  "mobilenet_v3_small_r",
20
22
  "mobilenet_v3_large",
21
23
  "mobilenet_v3_large_r",
22
- "mobilenet_v3_small_orientation",
24
+ "mobilenet_v3_small_crop_orientation",
25
+ "mobilenet_v3_small_page_orientation",
23
26
  ]
24
27
 
25
28
  default_cfgs: Dict[str, Dict[str, Any]] = {
@@ -51,12 +54,19 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
51
54
  "classes": list(VOCABS["french"]),
52
55
  "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-1a8a3530.pt&src=0",
53
56
  },
54
- "mobilenet_v3_small_orientation": {
57
+ "mobilenet_v3_small_crop_orientation": {
55
58
  "mean": (0.694, 0.695, 0.693),
56
59
  "std": (0.299, 0.296, 0.301),
57
- "input_shape": (3, 128, 128),
58
- "classes": [0, 90, 180, 270],
59
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-24f8ff57.pt&src=0",
60
+ "input_shape": (3, 256, 256),
61
+ "classes": [0, -90, 180, 90],
62
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_crop_orientation-f0847a18.pt&src=0",
63
+ },
64
+ "mobilenet_v3_small_page_orientation": {
65
+ "mean": (0.694, 0.695, 0.693),
66
+ "std": (0.299, 0.296, 0.301),
67
+ "input_shape": (3, 512, 512),
68
+ "classes": [0, -90, 180, 90],
69
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_page_orientation-8e60325c.pt&src=0",
60
70
  },
61
71
  }
62
72
 
@@ -212,14 +222,42 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3
212
222
  )
213
223
 
214
224
 
215
- def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
225
+ def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
226
+ """MobileNetV3-Small architecture as described in
227
+ `"Searching for MobileNetV3",
228
+ <https://arxiv.org/pdf/1905.02244.pdf>`_.
229
+
230
+ >>> import torch
231
+ >>> from doctr.models import mobilenet_v3_small_crop_orientation
232
+ >>> model = mobilenet_v3_small_crop_orientation(pretrained=False)
233
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
234
+ >>> out = model(input_tensor)
235
+
236
+ Args:
237
+ ----
238
+ pretrained: boolean, True if model is pretrained
239
+ **kwargs: keyword arguments of the MobileNetV3 architecture
240
+
241
+ Returns:
242
+ -------
243
+ a torch.nn.Module
244
+ """
245
+ return _mobilenet_v3(
246
+ "mobilenet_v3_small_crop_orientation",
247
+ pretrained,
248
+ ignore_keys=["classifier.3.weight", "classifier.3.bias"],
249
+ **kwargs,
250
+ )
251
+
252
+
253
+ def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
216
254
  """MobileNetV3-Small architecture as described in
217
255
  `"Searching for MobileNetV3",
218
256
  <https://arxiv.org/pdf/1905.02244.pdf>`_.
219
257
 
220
258
  >>> import torch
221
- >>> from doctr.models import mobilenet_v3_small_orientation
222
- >>> model = mobilenet_v3_small_orientation(pretrained=False)
259
+ >>> from doctr.models import mobilenet_v3_small_page_orientation
260
+ >>> model = mobilenet_v3_small_page_orientation(pretrained=False)
223
261
  >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
224
262
  >>> out = model(input_tensor)
225
263
 
@@ -233,7 +271,7 @@ def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> m
233
271
  a torch.nn.Module
234
272
  """
235
273
  return _mobilenet_v3(
236
- "mobilenet_v3_small_orientation",
274
+ "mobilenet_v3_small_page_orientation",
237
275
  pretrained,
238
276
  ignore_keys=["classifier.3.weight", "classifier.3.bias"],
239
277
  **kwargs,
@@ -13,7 +13,7 @@ from tensorflow.keras import layers
13
13
  from tensorflow.keras.models import Sequential
14
14
 
15
15
  from ....datasets import VOCABS
16
- from ...utils import conv_sequence, load_pretrained_params
16
+ from ...utils import _build_model, conv_sequence, load_pretrained_params
17
17
 
18
18
  __all__ = [
19
19
  "MobileNetV3",
@@ -21,7 +21,8 @@ __all__ = [
21
21
  "mobilenet_v3_small_r",
22
22
  "mobilenet_v3_large",
23
23
  "mobilenet_v3_large_r",
24
- "mobilenet_v3_small_orientation",
24
+ "mobilenet_v3_small_crop_orientation",
25
+ "mobilenet_v3_small_page_orientation",
25
26
  ]
26
27
 
27
28
 
@@ -31,35 +32,42 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
31
32
  "std": (0.299, 0.296, 0.301),
32
33
  "input_shape": (32, 32, 3),
33
34
  "classes": list(VOCABS["french"]),
34
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large-47d25d7e.zip&src=0",
35
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large-d857506e.weights.h5&src=0",
35
36
  },
36
37
  "mobilenet_v3_large_r": {
37
38
  "mean": (0.694, 0.695, 0.693),
38
39
  "std": (0.299, 0.296, 0.301),
39
40
  "input_shape": (32, 32, 3),
40
41
  "classes": list(VOCABS["french"]),
41
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large_r-a108e192.zip&src=0",
42
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large_r-eef2e3c6.weights.h5&src=0",
42
43
  },
43
44
  "mobilenet_v3_small": {
44
45
  "mean": (0.694, 0.695, 0.693),
45
46
  "std": (0.299, 0.296, 0.301),
46
47
  "input_shape": (32, 32, 3),
47
48
  "classes": list(VOCABS["french"]),
48
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small-8a32c32c.zip&src=0",
49
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0",
49
50
  },
50
51
  "mobilenet_v3_small_r": {
51
52
  "mean": (0.694, 0.695, 0.693),
52
53
  "std": (0.299, 0.296, 0.301),
53
54
  "input_shape": (32, 32, 3),
54
55
  "classes": list(VOCABS["french"]),
55
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-3d61452e.zip&src=0",
56
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_r-dd50218d.weights.h5&src=0",
56
57
  },
57
- "mobilenet_v3_small_orientation": {
58
+ "mobilenet_v3_small_crop_orientation": {
58
59
  "mean": (0.694, 0.695, 0.693),
59
60
  "std": (0.299, 0.296, 0.301),
60
61
  "input_shape": (128, 128, 3),
61
- "classes": [0, 90, 180, 270],
62
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-1ea8db03.zip&src=0",
62
+ "classes": [0, -90, 180, 90],
63
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5&src=0",
64
+ },
65
+ "mobilenet_v3_small_page_orientation": {
66
+ "mean": (0.694, 0.695, 0.693),
67
+ "std": (0.299, 0.296, 0.301),
68
+ "input_shape": (512, 512, 3),
69
+ "classes": [0, -90, 180, 90],
70
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5&src=0",
63
71
  },
64
72
  }
65
73
 
@@ -287,9 +295,15 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa
287
295
  cfg=_cfg,
288
296
  **kwargs,
289
297
  )
298
+ _build_model(model)
299
+
290
300
  # Load pretrained parameters
291
301
  if pretrained:
292
- load_pretrained_params(model, default_cfgs[arch]["url"])
302
+ # The number of classes is not the same as the number of classes in the pretrained model =>
303
+ # skip the mismatching layers for fine tuning
304
+ load_pretrained_params(
305
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
306
+ )
293
307
 
294
308
  return model
295
309
 
@@ -386,14 +400,37 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3
386
400
  return _mobilenet_v3("mobilenet_v3_large_r", pretrained, True, **kwargs)
387
401
 
388
402
 
389
- def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
403
+ def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
404
+ """MobileNetV3-Small architecture as described in
405
+ `"Searching for MobileNetV3",
406
+ <https://arxiv.org/pdf/1905.02244.pdf>`_.
407
+
408
+ >>> import tensorflow as tf
409
+ >>> from doctr.models import mobilenet_v3_small_crop_orientation
410
+ >>> model = mobilenet_v3_small_crop_orientation(pretrained=False)
411
+ >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
412
+ >>> out = model(input_tensor)
413
+
414
+ Args:
415
+ ----
416
+ pretrained: boolean, True if model is pretrained
417
+ **kwargs: keyword arguments of the MobileNetV3 architecture
418
+
419
+ Returns:
420
+ -------
421
+ a keras.Model
422
+ """
423
+ return _mobilenet_v3("mobilenet_v3_small_crop_orientation", pretrained, include_top=True, **kwargs)
424
+
425
+
426
+ def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
390
427
  """MobileNetV3-Small architecture as described in
391
428
  `"Searching for MobileNetV3",
392
429
  <https://arxiv.org/pdf/1905.02244.pdf>`_.
393
430
 
394
431
  >>> import tensorflow as tf
395
- >>> from doctr.models import mobilenet_v3_small_orientation
396
- >>> model = mobilenet_v3_small_orientation(pretrained=False)
432
+ >>> from doctr.models import mobilenet_v3_small_page_orientation
433
+ >>> model = mobilenet_v3_small_page_orientation(pretrained=False)
397
434
  >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
398
435
  >>> out = model(input_tensor)
399
436
 
@@ -406,4 +443,4 @@ def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> M
406
443
  -------
407
444
  a keras.Model
408
445
  """
409
- return _mobilenet_v3("mobilenet_v3_small_orientation", pretrained, include_top=True, **kwargs)
446
+ return _mobilenet_v3("mobilenet_v3_small_page_orientation", pretrained, include_top=True, **kwargs)