python-doctr 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. doctr/__init__.py +1 -1
  2. doctr/contrib/__init__.py +0 -0
  3. doctr/contrib/artefacts.py +131 -0
  4. doctr/contrib/base.py +105 -0
  5. doctr/datasets/datasets/pytorch.py +2 -2
  6. doctr/datasets/generator/base.py +6 -5
  7. doctr/datasets/imgur5k.py +1 -1
  8. doctr/datasets/loader.py +1 -6
  9. doctr/datasets/utils.py +2 -1
  10. doctr/datasets/vocabs.py +9 -2
  11. doctr/file_utils.py +26 -12
  12. doctr/io/elements.py +40 -6
  13. doctr/io/html.py +2 -2
  14. doctr/io/image/pytorch.py +6 -8
  15. doctr/io/image/tensorflow.py +1 -1
  16. doctr/io/pdf.py +5 -2
  17. doctr/io/reader.py +6 -0
  18. doctr/models/__init__.py +0 -1
  19. doctr/models/_utils.py +57 -20
  20. doctr/models/builder.py +71 -13
  21. doctr/models/classification/mobilenet/pytorch.py +45 -9
  22. doctr/models/classification/mobilenet/tensorflow.py +38 -7
  23. doctr/models/classification/predictor/pytorch.py +18 -11
  24. doctr/models/classification/predictor/tensorflow.py +16 -10
  25. doctr/models/classification/textnet/pytorch.py +3 -3
  26. doctr/models/classification/textnet/tensorflow.py +3 -3
  27. doctr/models/classification/zoo.py +39 -15
  28. doctr/models/detection/_utils/__init__.py +1 -0
  29. doctr/models/detection/_utils/base.py +66 -0
  30. doctr/models/detection/differentiable_binarization/base.py +4 -3
  31. doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
  32. doctr/models/detection/fast/base.py +6 -5
  33. doctr/models/detection/fast/pytorch.py +4 -4
  34. doctr/models/detection/fast/tensorflow.py +4 -4
  35. doctr/models/detection/linknet/base.py +4 -3
  36. doctr/models/detection/predictor/pytorch.py +15 -1
  37. doctr/models/detection/predictor/tensorflow.py +15 -1
  38. doctr/models/detection/zoo.py +7 -2
  39. doctr/models/factory/hub.py +3 -12
  40. doctr/models/kie_predictor/base.py +9 -3
  41. doctr/models/kie_predictor/pytorch.py +41 -20
  42. doctr/models/kie_predictor/tensorflow.py +36 -16
  43. doctr/models/modules/layers/pytorch.py +2 -3
  44. doctr/models/modules/layers/tensorflow.py +6 -8
  45. doctr/models/modules/transformer/pytorch.py +2 -2
  46. doctr/models/predictor/base.py +77 -50
  47. doctr/models/predictor/pytorch.py +31 -20
  48. doctr/models/predictor/tensorflow.py +27 -17
  49. doctr/models/preprocessor/pytorch.py +4 -4
  50. doctr/models/preprocessor/tensorflow.py +3 -2
  51. doctr/models/recognition/master/pytorch.py +2 -2
  52. doctr/models/recognition/parseq/pytorch.py +4 -3
  53. doctr/models/recognition/parseq/tensorflow.py +4 -3
  54. doctr/models/recognition/sar/pytorch.py +7 -6
  55. doctr/models/recognition/sar/tensorflow.py +3 -9
  56. doctr/models/recognition/vitstr/pytorch.py +1 -1
  57. doctr/models/recognition/zoo.py +1 -1
  58. doctr/models/zoo.py +2 -2
  59. doctr/py.typed +0 -0
  60. doctr/transforms/functional/base.py +1 -1
  61. doctr/transforms/functional/pytorch.py +4 -4
  62. doctr/transforms/modules/base.py +37 -15
  63. doctr/transforms/modules/pytorch.py +66 -8
  64. doctr/transforms/modules/tensorflow.py +63 -7
  65. doctr/utils/fonts.py +7 -5
  66. doctr/utils/geometry.py +35 -12
  67. doctr/utils/metrics.py +33 -174
  68. doctr/utils/reconstitution.py +126 -0
  69. doctr/utils/visualization.py +5 -118
  70. doctr/version.py +1 -1
  71. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/METADATA +84 -80
  72. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/RECORD +76 -76
  73. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/WHEEL +1 -1
  74. doctr/models/artefacts/__init__.py +0 -2
  75. doctr/models/artefacts/barcode.py +0 -74
  76. doctr/models/artefacts/face.py +0 -63
  77. doctr/models/obj_detection/__init__.py +0 -1
  78. doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
  79. doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
  80. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/LICENSE +0 -0
  81. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/top_level.txt +0 -0
  82. {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/zip-safe +0 -0
doctr/io/pdf.py CHANGED
@@ -38,5 +38,8 @@ def read_pdf(
38
38
  the list of pages decoded as numpy ndarray of shape H x W x C
39
39
  """
40
40
  # Rasterise pages to numpy ndarrays with pypdfium2
41
- pdf = pdfium.PdfDocument(file, password=password, autoclose=True)
42
- return [page.render(scale=scale, rev_byteorder=rgb_mode, **kwargs).to_numpy() for page in pdf]
41
+ pdf = pdfium.PdfDocument(file, password=password)
42
+ try:
43
+ return [page.render(scale=scale, rev_byteorder=rgb_mode, **kwargs).to_numpy() for page in pdf]
44
+ finally:
45
+ pdf.close()
doctr/io/reader.py CHANGED
@@ -8,6 +8,7 @@ from typing import List, Sequence, Union
8
8
 
9
9
  import numpy as np
10
10
 
11
+ from doctr.file_utils import requires_package
11
12
  from doctr.utils.common_types import AbstractFile
12
13
 
13
14
  from .html import read_html
@@ -54,6 +55,11 @@ class DocumentFile:
54
55
  -------
55
56
  the list of pages decoded as numpy ndarray of shape H x W x 3
56
57
  """
58
+ requires_package(
59
+ "weasyprint",
60
+ "`.from_url` requires weasyprint installed.\n"
61
+ + "Installation instructions: https://doc.courtbouillon.org/weasyprint/stable/first_steps.html#installation",
62
+ )
57
63
  pdf_stream = read_html(url)
58
64
  return cls.from_pdf(pdf_stream, **kwargs)
59
65
 
doctr/models/__init__.py CHANGED
@@ -1,4 +1,3 @@
1
- from . import artefacts
2
1
  from .classification import *
3
2
  from .detection import *
4
3
  from .recognition import *
doctr/models/_utils.py CHANGED
@@ -11,6 +11,8 @@ import cv2
11
11
  import numpy as np
12
12
  from langdetect import LangDetectException, detect_langs
13
13
 
14
+ from doctr.utils.geometry import rotate_image
15
+
14
16
  __all__ = ["estimate_orientation", "get_language", "invert_data_structure"]
15
17
 
16
18
 
@@ -29,56 +31,91 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float:
29
31
  return max(w / h, h / w)
30
32
 
31
33
 
32
- def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> int:
34
+ def estimate_orientation(
35
+ img: np.ndarray,
36
+ general_page_orientation: Optional[Tuple[int, float]] = None,
37
+ n_ct: int = 70,
38
+ ratio_threshold_for_lines: float = 3,
39
+ min_confidence: float = 0.2,
40
+ lower_area: int = 100,
41
+ ) -> int:
33
42
  """Estimate the angle of the general document orientation based on the
34
43
  lines of the document and the assumption that they should be horizontal.
35
44
 
36
45
  Args:
37
46
  ----
38
47
  img: the img or bitmap to analyze (H, W, C)
48
+ general_page_orientation: the general orientation of the page (angle [0, 90, 180, 270 (-90)], confidence)
49
+ estimated by a model
39
50
  n_ct: the number of contours used for the orientation estimation
40
51
  ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines
52
+ min_confidence: the minimum confidence to consider the general_page_orientation
53
+ lower_area: the minimum area of a contour to be considered
41
54
 
42
55
  Returns:
43
56
  -------
44
- the angle of the general document orientation
57
+ the estimated angle of the page (clockwise, negative for left side rotation, positive for right side rotation)
45
58
  """
46
59
  assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported"
47
- max_value = np.max(img)
48
- min_value = np.min(img)
49
- if max_value <= 1 and min_value >= 0 or (max_value <= 255 and min_value >= 0 and img.shape[-1] == 1):
50
- thresh = img.astype(np.uint8)
51
- if max_value <= 255 and min_value >= 0 and img.shape[-1] == 3:
60
+ thresh = None
61
+ # Convert image to grayscale if necessary
62
+ if img.shape[-1] == 3:
52
63
  gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
53
64
  gray_img = cv2.medianBlur(gray_img, 5)
54
- thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] # type: ignore[assignment]
55
-
56
- # try to merge words in lines
57
- (h, w) = img.shape[:2]
58
- k_x = max(1, (floor(w / 100)))
59
- k_y = max(1, (floor(h / 100)))
60
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y))
61
- thresh = cv2.dilate(thresh, kernel, iterations=1) # type: ignore[assignment]
65
+ thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
66
+ else:
67
+ thresh = img.astype(np.uint8) # type: ignore[assignment]
68
+
69
+ page_orientation, orientation_confidence = general_page_orientation or (None, 0.0)
70
+ if page_orientation and orientation_confidence >= min_confidence:
71
+ # We rotate the image to the general orientation which improves the detection
72
+ # No expand needed bitmap is already padded
73
+ thresh = rotate_image(thresh, -page_orientation) # type: ignore
74
+ else: # That's only required if we do not work on the detection models bin map
75
+ # try to merge words in lines
76
+ (h, w) = img.shape[:2]
77
+ k_x = max(1, (floor(w / 100)))
78
+ k_y = max(1, (floor(h / 100)))
79
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y))
80
+ thresh = cv2.dilate(thresh, kernel, iterations=1)
62
81
 
63
82
  # extract contours
64
83
  contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
65
84
 
66
- # Sort contours
67
- contours = sorted(contours, key=get_max_width_length_ratio, reverse=True)
85
+ # Filter & Sort contours
86
+ contours = sorted(
87
+ [contour for contour in contours if cv2.contourArea(contour) > lower_area],
88
+ key=get_max_width_length_ratio,
89
+ reverse=True,
90
+ )
68
91
 
69
92
  angles = []
70
93
  for contour in contours[:n_ct]:
71
- _, (w, h), angle = cv2.minAreaRect(contour)
94
+ _, (w, h), angle = cv2.minAreaRect(contour) # type: ignore[assignment]
72
95
  if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines
73
96
  angles.append(angle)
74
97
  elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree
75
98
  angles.append(angle - 90)
76
99
 
77
100
  if len(angles) == 0:
78
- return 0 # in case no angles is found
101
+ estimated_angle = 0 # in case no angles is found
79
102
  else:
80
103
  median = -median_low(angles)
81
- return round(median) if abs(median) != 0 else 0
104
+ estimated_angle = -round(median) if abs(median) != 0 else 0
105
+
106
+ # combine with the general orientation and the estimated angle
107
+ if page_orientation and orientation_confidence >= min_confidence:
108
+ # special case where the estimated angle is mostly wrong:
109
+ # case 1: - and + swapped
110
+ # case 2: estimated angle is completely wrong
111
+ # so in this case we prefer the general page orientation
112
+ if abs(estimated_angle) == abs(page_orientation):
113
+ return page_orientation
114
+ estimated_angle = estimated_angle if page_orientation == 0 else page_orientation + estimated_angle
115
+ if estimated_angle > 180:
116
+ estimated_angle -= 360
117
+
118
+ return estimated_angle # return the clockwise angle (negative - left side rotation, positive - right side rotation)
82
119
 
83
120
 
84
121
  def rectify_crops(
doctr/models/builder.py CHANGED
@@ -31,7 +31,7 @@ class DocumentBuilder(NestedObject):
31
31
  def __init__(
32
32
  self,
33
33
  resolve_lines: bool = True,
34
- resolve_blocks: bool = True,
34
+ resolve_blocks: bool = False,
35
35
  paragraph_break: float = 0.035,
36
36
  export_as_straight_boxes: bool = False,
37
37
  ) -> None:
@@ -220,13 +220,22 @@ class DocumentBuilder(NestedObject):
220
220
 
221
221
  return blocks
222
222
 
223
- def _build_blocks(self, boxes: np.ndarray, word_preds: List[Tuple[str, float]]) -> List[Block]:
223
+ def _build_blocks(
224
+ self,
225
+ boxes: np.ndarray,
226
+ objectness_scores: np.ndarray,
227
+ word_preds: List[Tuple[str, float]],
228
+ crop_orientations: List[Dict[str, Any]],
229
+ ) -> List[Block]:
224
230
  """Gather independent words in structured blocks
225
231
 
226
232
  Args:
227
233
  ----
228
- boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2)
234
+ boxes: bounding boxes of all detected words of the page, of shape (N, 4) or (N, 4, 2)
235
+ objectness_scores: objectness scores of all detected words of the page, of shape N
229
236
  word_preds: list of all detected words of the page, of shape N
237
+ crop_orientations: list of dictoinaries containing
238
+ the general orientation (orientations + confidences) of the crops
230
239
 
231
240
  Returns:
232
241
  -------
@@ -258,9 +267,16 @@ class DocumentBuilder(NestedObject):
258
267
  Word(
259
268
  *word_preds[idx],
260
269
  tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type]
270
+ float(objectness_scores[idx]),
271
+ crop_orientations[idx],
261
272
  )
262
273
  if boxes.ndim == 3
263
- else Word(*word_preds[idx], ((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])))
274
+ else Word(
275
+ *word_preds[idx],
276
+ ((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])),
277
+ float(objectness_scores[idx]),
278
+ crop_orientations[idx],
279
+ )
264
280
  for idx in line
265
281
  ])
266
282
  for line in lines
@@ -281,8 +297,10 @@ class DocumentBuilder(NestedObject):
281
297
  self,
282
298
  pages: List[np.ndarray],
283
299
  boxes: List[np.ndarray],
300
+ objectness_scores: List[np.ndarray],
284
301
  text_preds: List[List[Tuple[str, float]]],
285
302
  page_shapes: List[Tuple[int, int]],
303
+ crop_orientations: List[Dict[str, Any]],
286
304
  orientations: Optional[List[Dict[str, Any]]] = None,
287
305
  languages: Optional[List[Dict[str, Any]]] = None,
288
306
  ) -> Document:
@@ -291,10 +309,13 @@ class DocumentBuilder(NestedObject):
291
309
  Args:
292
310
  ----
293
311
  pages: list of N elements, where each element represents the page image
294
- boxes: list of N elements, where each element represents the localization predictions, of shape (*, 5)
295
- or (*, 6) for all words for a given page
312
+ boxes: list of N elements, where each element represents the localization predictions, of shape (*, 4)
313
+ or (*, 4, 2) for all words for a given page
314
+ objectness_scores: list of N elements, where each element represents the objectness scores
296
315
  text_preds: list of N elements, where each element is the list of all word prediction (text + confidence)
297
316
  page_shapes: shape of each page, of size N
317
+ crop_orientations: list of N elements, where each element is
318
+ a dictionary containing the general orientation (orientations + confidences) of the crops
298
319
  orientations: optional, list of N elements,
299
320
  where each element is a dictionary containing the orientation (orientation + confidence)
300
321
  languages: optional, list of N elements,
@@ -304,7 +325,9 @@ class DocumentBuilder(NestedObject):
304
325
  -------
305
326
  document object
306
327
  """
307
- if len(boxes) != len(text_preds) or len(boxes) != len(page_shapes):
328
+ if len(boxes) != len(text_preds) != len(crop_orientations) != len(objectness_scores) or len(boxes) != len(
329
+ page_shapes
330
+ ) != len(crop_orientations) != len(objectness_scores):
308
331
  raise ValueError("All arguments are expected to be lists of the same size")
309
332
 
310
333
  _orientations = (
@@ -322,15 +345,25 @@ class DocumentBuilder(NestedObject):
322
345
  page,
323
346
  self._build_blocks(
324
347
  page_boxes,
348
+ loc_scores,
325
349
  word_preds,
350
+ word_crop_orientations,
326
351
  ),
327
352
  _idx,
328
353
  shape,
329
354
  orientation,
330
355
  language,
331
356
  )
332
- for page, _idx, shape, page_boxes, word_preds, orientation, language in zip(
333
- pages, range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
357
+ for page, _idx, shape, page_boxes, loc_scores, word_preds, word_crop_orientations, orientation, language in zip( # noqa: E501
358
+ pages,
359
+ range(len(boxes)),
360
+ page_shapes,
361
+ boxes,
362
+ objectness_scores,
363
+ text_preds,
364
+ crop_orientations,
365
+ _orientations,
366
+ _languages,
334
367
  )
335
368
  ]
336
369
 
@@ -353,8 +386,10 @@ class KIEDocumentBuilder(DocumentBuilder):
353
386
  self,
354
387
  pages: List[np.ndarray],
355
388
  boxes: List[Dict[str, np.ndarray]],
389
+ objectness_scores: List[Dict[str, np.ndarray]],
356
390
  text_preds: List[Dict[str, List[Tuple[str, float]]]],
357
391
  page_shapes: List[Tuple[int, int]],
392
+ crop_orientations: List[Dict[str, List[Dict[str, Any]]]],
358
393
  orientations: Optional[List[Dict[str, Any]]] = None,
359
394
  languages: Optional[List[Dict[str, Any]]] = None,
360
395
  ) -> KIEDocument:
@@ -365,8 +400,11 @@ class KIEDocumentBuilder(DocumentBuilder):
365
400
  pages: list of N elements, where each element represents the page image
366
401
  boxes: list of N dictionaries, where each element represents the localization predictions for a class,
367
402
  of shape (*, 5) or (*, 6) for all predictions
403
+ objectness_scores: list of N dictionaries, where each element represents the objectness scores for a class
368
404
  text_preds: list of N dictionaries, where each element is the list of all word prediction
369
405
  page_shapes: shape of each page, of size N
406
+ crop_orientations: list of N dictonaries, where each element is
407
+ a list containing the general crop orientations (orientations + confidences) of the crops
370
408
  orientations: optional, list of N elements,
371
409
  where each element is a dictionary containing the orientation (orientation + confidence)
372
410
  languages: optional, list of N elements,
@@ -376,7 +414,9 @@ class KIEDocumentBuilder(DocumentBuilder):
376
414
  -------
377
415
  document object
378
416
  """
379
- if len(boxes) != len(text_preds) or len(boxes) != len(page_shapes):
417
+ if len(boxes) != len(text_preds) != len(crop_orientations) != len(objectness_scores) or len(boxes) != len(
418
+ page_shapes
419
+ ) != len(crop_orientations) != len(objectness_scores):
380
420
  raise ValueError("All arguments are expected to be lists of the same size")
381
421
  _orientations = (
382
422
  orientations if isinstance(orientations, list) else [None] * len(boxes) # type: ignore[list-item]
@@ -401,7 +441,9 @@ class KIEDocumentBuilder(DocumentBuilder):
401
441
  {
402
442
  k: self._build_blocks(
403
443
  page_boxes[k],
444
+ loc_scores[k],
404
445
  word_preds[k],
446
+ word_crop_orientations[k],
405
447
  )
406
448
  for k in page_boxes.keys()
407
449
  },
@@ -410,8 +452,16 @@ class KIEDocumentBuilder(DocumentBuilder):
410
452
  orientation,
411
453
  language,
412
454
  )
413
- for page, _idx, shape, page_boxes, word_preds, orientation, language in zip(
414
- pages, range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
455
+ for page, _idx, shape, page_boxes, loc_scores, word_preds, word_crop_orientations, orientation, language in zip( # noqa: E501
456
+ pages,
457
+ range(len(boxes)),
458
+ page_shapes,
459
+ boxes,
460
+ objectness_scores,
461
+ text_preds,
462
+ crop_orientations,
463
+ _orientations,
464
+ _languages,
415
465
  )
416
466
  ]
417
467
 
@@ -420,14 +470,18 @@ class KIEDocumentBuilder(DocumentBuilder):
420
470
  def _build_blocks( # type: ignore[override]
421
471
  self,
422
472
  boxes: np.ndarray,
473
+ objectness_scores: np.ndarray,
423
474
  word_preds: List[Tuple[str, float]],
475
+ crop_orientations: List[Dict[str, Any]],
424
476
  ) -> List[Prediction]:
425
477
  """Gather independent words in structured blocks
426
478
 
427
479
  Args:
428
480
  ----
429
- boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2)
481
+ boxes: bounding boxes of all detected words of the page, of shape (N, 4) or (N, 4, 2)
482
+ objectness_scores: objectness scores of all detected words of the page
430
483
  word_preds: list of all detected words of the page, of shape N
484
+ crop_orientations: list of orientations for each word crop
431
485
 
432
486
  Returns:
433
487
  -------
@@ -447,12 +501,16 @@ class KIEDocumentBuilder(DocumentBuilder):
447
501
  value=word_preds[idx][0],
448
502
  confidence=word_preds[idx][1],
449
503
  geometry=tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type]
504
+ objectness_score=float(objectness_scores[idx]),
505
+ crop_orientation=crop_orientations[idx],
450
506
  )
451
507
  if boxes.ndim == 3
452
508
  else Prediction(
453
509
  value=word_preds[idx][0],
454
510
  confidence=word_preds[idx][1],
455
511
  geometry=((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])),
512
+ objectness_score=float(objectness_scores[idx]),
513
+ crop_orientation=crop_orientations[idx],
456
514
  )
457
515
  for idx in idxs
458
516
  ]
@@ -19,7 +19,8 @@ __all__ = [
19
19
  "mobilenet_v3_small_r",
20
20
  "mobilenet_v3_large",
21
21
  "mobilenet_v3_large_r",
22
- "mobilenet_v3_small_orientation",
22
+ "mobilenet_v3_small_crop_orientation",
23
+ "mobilenet_v3_small_page_orientation",
23
24
  ]
24
25
 
25
26
  default_cfgs: Dict[str, Dict[str, Any]] = {
@@ -51,12 +52,19 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
51
52
  "classes": list(VOCABS["french"]),
52
53
  "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-1a8a3530.pt&src=0",
53
54
  },
54
- "mobilenet_v3_small_orientation": {
55
+ "mobilenet_v3_small_crop_orientation": {
55
56
  "mean": (0.694, 0.695, 0.693),
56
57
  "std": (0.299, 0.296, 0.301),
57
- "input_shape": (3, 128, 128),
58
- "classes": [0, 90, 180, 270],
59
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-24f8ff57.pt&src=0",
58
+ "input_shape": (3, 256, 256),
59
+ "classes": [0, -90, 180, 90],
60
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_crop_orientation-f0847a18.pt&src=0",
61
+ },
62
+ "mobilenet_v3_small_page_orientation": {
63
+ "mean": (0.694, 0.695, 0.693),
64
+ "std": (0.299, 0.296, 0.301),
65
+ "input_shape": (3, 512, 512),
66
+ "classes": [0, -90, 180, 90],
67
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_page_orientation-8e60325c.pt&src=0",
60
68
  },
61
69
  }
62
70
 
@@ -212,14 +220,42 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3
212
220
  )
213
221
 
214
222
 
215
- def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
223
+ def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
224
+ """MobileNetV3-Small architecture as described in
225
+ `"Searching for MobileNetV3",
226
+ <https://arxiv.org/pdf/1905.02244.pdf>`_.
227
+
228
+ >>> import torch
229
+ >>> from doctr.models import mobilenet_v3_small_crop_orientation
230
+ >>> model = mobilenet_v3_small_crop_orientation(pretrained=False)
231
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
232
+ >>> out = model(input_tensor)
233
+
234
+ Args:
235
+ ----
236
+ pretrained: boolean, True if model is pretrained
237
+ **kwargs: keyword arguments of the MobileNetV3 architecture
238
+
239
+ Returns:
240
+ -------
241
+ a torch.nn.Module
242
+ """
243
+ return _mobilenet_v3(
244
+ "mobilenet_v3_small_crop_orientation",
245
+ pretrained,
246
+ ignore_keys=["classifier.3.weight", "classifier.3.bias"],
247
+ **kwargs,
248
+ )
249
+
250
+
251
+ def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
216
252
  """MobileNetV3-Small architecture as described in
217
253
  `"Searching for MobileNetV3",
218
254
  <https://arxiv.org/pdf/1905.02244.pdf>`_.
219
255
 
220
256
  >>> import torch
221
- >>> from doctr.models import mobilenet_v3_small_orientation
222
- >>> model = mobilenet_v3_small_orientation(pretrained=False)
257
+ >>> from doctr.models import mobilenet_v3_small_page_orientation
258
+ >>> model = mobilenet_v3_small_page_orientation(pretrained=False)
223
259
  >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
224
260
  >>> out = model(input_tensor)
225
261
 
@@ -233,7 +269,7 @@ def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> m
233
269
  a torch.nn.Module
234
270
  """
235
271
  return _mobilenet_v3(
236
- "mobilenet_v3_small_orientation",
272
+ "mobilenet_v3_small_page_orientation",
237
273
  pretrained,
238
274
  ignore_keys=["classifier.3.weight", "classifier.3.bias"],
239
275
  **kwargs,
@@ -21,7 +21,8 @@ __all__ = [
21
21
  "mobilenet_v3_small_r",
22
22
  "mobilenet_v3_large",
23
23
  "mobilenet_v3_large_r",
24
- "mobilenet_v3_small_orientation",
24
+ "mobilenet_v3_small_crop_orientation",
25
+ "mobilenet_v3_small_page_orientation",
25
26
  ]
26
27
 
27
28
 
@@ -54,13 +55,20 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
54
55
  "classes": list(VOCABS["french"]),
55
56
  "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-3d61452e.zip&src=0",
56
57
  },
57
- "mobilenet_v3_small_orientation": {
58
+ "mobilenet_v3_small_crop_orientation": {
58
59
  "mean": (0.694, 0.695, 0.693),
59
60
  "std": (0.299, 0.296, 0.301),
60
61
  "input_shape": (128, 128, 3),
61
- "classes": [0, 90, 180, 270],
62
+ "classes": [0, -90, 180, 90],
62
63
  "url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-1ea8db03.zip&src=0",
63
64
  },
65
+ "mobilenet_v3_small_page_orientation": {
66
+ "mean": (0.694, 0.695, 0.693),
67
+ "std": (0.299, 0.296, 0.301),
68
+ "input_shape": (512, 512, 3),
69
+ "classes": [0, -90, 180, 90],
70
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_page_orientation-aec9553e.zip&src=0",
71
+ },
64
72
  }
65
73
 
66
74
 
@@ -386,14 +394,37 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3
386
394
  return _mobilenet_v3("mobilenet_v3_large_r", pretrained, True, **kwargs)
387
395
 
388
396
 
389
- def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
397
+ def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
398
+ """MobileNetV3-Small architecture as described in
399
+ `"Searching for MobileNetV3",
400
+ <https://arxiv.org/pdf/1905.02244.pdf>`_.
401
+
402
+ >>> import tensorflow as tf
403
+ >>> from doctr.models import mobilenet_v3_small_crop_orientation
404
+ >>> model = mobilenet_v3_small_crop_orientation(pretrained=False)
405
+ >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
406
+ >>> out = model(input_tensor)
407
+
408
+ Args:
409
+ ----
410
+ pretrained: boolean, True if model is pretrained
411
+ **kwargs: keyword arguments of the MobileNetV3 architecture
412
+
413
+ Returns:
414
+ -------
415
+ a keras.Model
416
+ """
417
+ return _mobilenet_v3("mobilenet_v3_small_crop_orientation", pretrained, include_top=True, **kwargs)
418
+
419
+
420
+ def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
390
421
  """MobileNetV3-Small architecture as described in
391
422
  `"Searching for MobileNetV3",
392
423
  <https://arxiv.org/pdf/1905.02244.pdf>`_.
393
424
 
394
425
  >>> import tensorflow as tf
395
- >>> from doctr.models import mobilenet_v3_small_orientation
396
- >>> model = mobilenet_v3_small_orientation(pretrained=False)
426
+ >>> from doctr.models import mobilenet_v3_small_page_orientation
427
+ >>> model = mobilenet_v3_small_page_orientation(pretrained=False)
397
428
  >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
398
429
  >>> out = model(input_tensor)
399
430
 
@@ -406,4 +437,4 @@ def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> M
406
437
  -------
407
438
  a keras.Model
408
439
  """
409
- return _mobilenet_v3("mobilenet_v3_small_orientation", pretrained, include_top=True, **kwargs)
440
+ return _mobilenet_v3("mobilenet_v3_small_page_orientation", pretrained, include_top=True, **kwargs)
@@ -12,12 +12,12 @@ from torch import nn
12
12
  from doctr.models.preprocessor import PreProcessor
13
13
  from doctr.models.utils import set_device_and_dtype
14
14
 
15
- __all__ = ["CropOrientationPredictor"]
15
+ __all__ = ["OrientationPredictor"]
16
16
 
17
17
 
18
- class CropOrientationPredictor(nn.Module):
19
- """Implements an object able to detect the reading direction of a text box.
20
- 4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.
18
+ class OrientationPredictor(nn.Module):
19
+ """Implements an object able to detect the reading direction of a text box or a page.
20
+ 4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
21
21
 
22
22
  Args:
23
23
  ----
@@ -37,20 +37,27 @@ class CropOrientationPredictor(nn.Module):
37
37
  @torch.inference_mode()
38
38
  def forward(
39
39
  self,
40
- crops: List[Union[np.ndarray, torch.Tensor]],
41
- ) -> List[int]:
40
+ inputs: List[Union[np.ndarray, torch.Tensor]],
41
+ ) -> List[Union[List[int], List[float]]]:
42
42
  # Dimension check
43
- if any(crop.ndim != 3 for crop in crops):
44
- raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
43
+ if any(input.ndim != 3 for input in inputs):
44
+ raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
45
45
 
46
- processed_batches = self.pre_processor(crops)
46
+ processed_batches = self.pre_processor(inputs)
47
47
  _params = next(self.model.parameters())
48
48
  self.model, processed_batches = set_device_and_dtype(
49
49
  self.model, processed_batches, _params.device, _params.dtype
50
50
  )
51
51
  predicted_batches = [self.model(batch) for batch in processed_batches]
52
-
52
+ # confidence
53
+ probs = [
54
+ torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches
55
+ ]
53
56
  # Postprocess predictions
54
57
  predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches]
55
58
 
56
- return [int(pred) for batch in predicted_batches for pred in batch]
59
+ class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
60
+ classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs]
61
+ confs = [round(float(p), 2) for prob in probs for p in prob]
62
+
63
+ return [class_idxs, classes, confs]
@@ -12,12 +12,12 @@ from tensorflow import keras
12
12
  from doctr.models.preprocessor import PreProcessor
13
13
  from doctr.utils.repr import NestedObject
14
14
 
15
- __all__ = ["CropOrientationPredictor"]
15
+ __all__ = ["OrientationPredictor"]
16
16
 
17
17
 
18
- class CropOrientationPredictor(NestedObject):
19
- """Implements an object able to detect the reading direction of a text box.
20
- 4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.
18
+ class OrientationPredictor(NestedObject):
19
+ """Implements an object able to detect the reading direction of a text box or a page.
20
+ 4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
21
21
 
22
22
  Args:
23
23
  ----
@@ -37,16 +37,22 @@ class CropOrientationPredictor(NestedObject):
37
37
 
38
38
  def __call__(
39
39
  self,
40
- crops: List[Union[np.ndarray, tf.Tensor]],
41
- ) -> List[int]:
40
+ inputs: List[Union[np.ndarray, tf.Tensor]],
41
+ ) -> List[Union[List[int], List[float]]]:
42
42
  # Dimension check
43
- if any(crop.ndim != 3 for crop in crops):
44
- raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.")
43
+ if any(input.ndim != 3 for input in inputs):
44
+ raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")
45
45
 
46
- processed_batches = self.pre_processor(crops)
46
+ processed_batches = self.pre_processor(inputs)
47
47
  predicted_batches = [self.model(batch, training=False) for batch in processed_batches]
48
48
 
49
+ # confidence
50
+ probs = [tf.math.reduce_max(tf.nn.softmax(batch, axis=1), axis=1).numpy() for batch in predicted_batches]
49
51
  # Postprocess predictions
50
52
  predicted_batches = [out_batch.numpy().argmax(1) for out_batch in predicted_batches]
51
53
 
52
- return [int(pred) for batch in predicted_batches for pred in batch]
54
+ class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
55
+ classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs]
56
+ confs = [round(float(p), 2) for prob in probs for p in prob]
57
+
58
+ return [class_idxs, classes, confs]
@@ -22,21 +22,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
22
22
  "std": (0.299, 0.296, 0.301),
23
23
  "input_shape": (3, 32, 32),
24
24
  "classes": list(VOCABS["french"]),
25
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_tiny-c5970fe0.pt&src=0",
25
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-27288d12.pt&src=0",
26
26
  },
27
27
  "textnet_small": {
28
28
  "mean": (0.694, 0.695, 0.693),
29
29
  "std": (0.299, 0.296, 0.301),
30
30
  "input_shape": (3, 32, 32),
31
31
  "classes": list(VOCABS["french"]),
32
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_small-6e8ab0ce.pt&src=0",
32
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-43166ee6.pt&src=0",
33
33
  },
34
34
  "textnet_base": {
35
35
  "mean": (0.694, 0.695, 0.693),
36
36
  "std": (0.299, 0.296, 0.301),
37
37
  "input_shape": (3, 32, 32),
38
38
  "classes": list(VOCABS["french"]),
39
- "url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_base-8295dc85.pt&src=0",
39
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-7f68d7e0.pt&src=0",
40
40
  },
41
41
  }
42
42