python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. doctr/datasets/__init__.py +2 -0
  2. doctr/datasets/cord.py +6 -4
  3. doctr/datasets/datasets/base.py +3 -2
  4. doctr/datasets/datasets/pytorch.py +4 -2
  5. doctr/datasets/datasets/tensorflow.py +4 -2
  6. doctr/datasets/detection.py +6 -3
  7. doctr/datasets/doc_artefacts.py +2 -1
  8. doctr/datasets/funsd.py +7 -8
  9. doctr/datasets/generator/base.py +3 -2
  10. doctr/datasets/generator/pytorch.py +3 -1
  11. doctr/datasets/generator/tensorflow.py +3 -1
  12. doctr/datasets/ic03.py +3 -2
  13. doctr/datasets/ic13.py +2 -1
  14. doctr/datasets/iiit5k.py +6 -4
  15. doctr/datasets/iiithws.py +2 -1
  16. doctr/datasets/imgur5k.py +3 -2
  17. doctr/datasets/loader.py +4 -2
  18. doctr/datasets/mjsynth.py +2 -1
  19. doctr/datasets/ocr.py +2 -1
  20. doctr/datasets/orientation.py +40 -0
  21. doctr/datasets/recognition.py +3 -2
  22. doctr/datasets/sroie.py +2 -1
  23. doctr/datasets/svhn.py +2 -1
  24. doctr/datasets/svt.py +3 -2
  25. doctr/datasets/synthtext.py +2 -1
  26. doctr/datasets/utils.py +27 -11
  27. doctr/datasets/vocabs.py +26 -1
  28. doctr/datasets/wildreceipt.py +111 -0
  29. doctr/file_utils.py +3 -1
  30. doctr/io/elements.py +52 -35
  31. doctr/io/html.py +5 -3
  32. doctr/io/image/base.py +5 -4
  33. doctr/io/image/pytorch.py +12 -7
  34. doctr/io/image/tensorflow.py +11 -6
  35. doctr/io/pdf.py +5 -4
  36. doctr/io/reader.py +13 -5
  37. doctr/models/_utils.py +30 -53
  38. doctr/models/artefacts/barcode.py +4 -3
  39. doctr/models/artefacts/face.py +4 -2
  40. doctr/models/builder.py +58 -43
  41. doctr/models/classification/__init__.py +1 -0
  42. doctr/models/classification/magc_resnet/pytorch.py +5 -2
  43. doctr/models/classification/magc_resnet/tensorflow.py +5 -2
  44. doctr/models/classification/mobilenet/pytorch.py +16 -4
  45. doctr/models/classification/mobilenet/tensorflow.py +29 -20
  46. doctr/models/classification/predictor/pytorch.py +3 -2
  47. doctr/models/classification/predictor/tensorflow.py +2 -1
  48. doctr/models/classification/resnet/pytorch.py +23 -13
  49. doctr/models/classification/resnet/tensorflow.py +33 -26
  50. doctr/models/classification/textnet/__init__.py +6 -0
  51. doctr/models/classification/textnet/pytorch.py +275 -0
  52. doctr/models/classification/textnet/tensorflow.py +267 -0
  53. doctr/models/classification/vgg/pytorch.py +4 -2
  54. doctr/models/classification/vgg/tensorflow.py +5 -2
  55. doctr/models/classification/vit/pytorch.py +9 -3
  56. doctr/models/classification/vit/tensorflow.py +9 -3
  57. doctr/models/classification/zoo.py +7 -2
  58. doctr/models/core.py +1 -1
  59. doctr/models/detection/__init__.py +1 -0
  60. doctr/models/detection/_utils/pytorch.py +7 -1
  61. doctr/models/detection/_utils/tensorflow.py +7 -3
  62. doctr/models/detection/core.py +9 -3
  63. doctr/models/detection/differentiable_binarization/base.py +37 -25
  64. doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
  65. doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
  66. doctr/models/detection/fast/__init__.py +6 -0
  67. doctr/models/detection/fast/base.py +256 -0
  68. doctr/models/detection/fast/pytorch.py +442 -0
  69. doctr/models/detection/fast/tensorflow.py +428 -0
  70. doctr/models/detection/linknet/base.py +12 -5
  71. doctr/models/detection/linknet/pytorch.py +28 -15
  72. doctr/models/detection/linknet/tensorflow.py +68 -88
  73. doctr/models/detection/predictor/pytorch.py +16 -6
  74. doctr/models/detection/predictor/tensorflow.py +13 -5
  75. doctr/models/detection/zoo.py +19 -16
  76. doctr/models/factory/hub.py +20 -10
  77. doctr/models/kie_predictor/base.py +2 -1
  78. doctr/models/kie_predictor/pytorch.py +28 -36
  79. doctr/models/kie_predictor/tensorflow.py +27 -27
  80. doctr/models/modules/__init__.py +1 -0
  81. doctr/models/modules/layers/__init__.py +6 -0
  82. doctr/models/modules/layers/pytorch.py +166 -0
  83. doctr/models/modules/layers/tensorflow.py +175 -0
  84. doctr/models/modules/transformer/pytorch.py +24 -22
  85. doctr/models/modules/transformer/tensorflow.py +6 -4
  86. doctr/models/modules/vision_transformer/pytorch.py +2 -4
  87. doctr/models/modules/vision_transformer/tensorflow.py +2 -4
  88. doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
  89. doctr/models/predictor/base.py +14 -3
  90. doctr/models/predictor/pytorch.py +26 -29
  91. doctr/models/predictor/tensorflow.py +25 -22
  92. doctr/models/preprocessor/pytorch.py +14 -9
  93. doctr/models/preprocessor/tensorflow.py +10 -5
  94. doctr/models/recognition/core.py +4 -1
  95. doctr/models/recognition/crnn/pytorch.py +23 -16
  96. doctr/models/recognition/crnn/tensorflow.py +25 -17
  97. doctr/models/recognition/master/base.py +4 -1
  98. doctr/models/recognition/master/pytorch.py +20 -9
  99. doctr/models/recognition/master/tensorflow.py +20 -8
  100. doctr/models/recognition/parseq/base.py +4 -1
  101. doctr/models/recognition/parseq/pytorch.py +28 -22
  102. doctr/models/recognition/parseq/tensorflow.py +22 -11
  103. doctr/models/recognition/predictor/_utils.py +3 -2
  104. doctr/models/recognition/predictor/pytorch.py +3 -2
  105. doctr/models/recognition/predictor/tensorflow.py +2 -1
  106. doctr/models/recognition/sar/pytorch.py +14 -7
  107. doctr/models/recognition/sar/tensorflow.py +23 -14
  108. doctr/models/recognition/utils.py +5 -1
  109. doctr/models/recognition/vitstr/base.py +4 -1
  110. doctr/models/recognition/vitstr/pytorch.py +22 -13
  111. doctr/models/recognition/vitstr/tensorflow.py +21 -10
  112. doctr/models/recognition/zoo.py +4 -2
  113. doctr/models/utils/pytorch.py +24 -6
  114. doctr/models/utils/tensorflow.py +22 -3
  115. doctr/models/zoo.py +21 -3
  116. doctr/transforms/functional/base.py +8 -3
  117. doctr/transforms/functional/pytorch.py +23 -6
  118. doctr/transforms/functional/tensorflow.py +25 -5
  119. doctr/transforms/modules/base.py +12 -5
  120. doctr/transforms/modules/pytorch.py +10 -12
  121. doctr/transforms/modules/tensorflow.py +17 -9
  122. doctr/utils/common_types.py +1 -1
  123. doctr/utils/data.py +4 -2
  124. doctr/utils/fonts.py +3 -2
  125. doctr/utils/geometry.py +95 -26
  126. doctr/utils/metrics.py +36 -22
  127. doctr/utils/multithreading.py +5 -3
  128. doctr/utils/repr.py +3 -1
  129. doctr/utils/visualization.py +31 -8
  130. doctr/version.py +1 -1
  131. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
  132. python_doctr-0.8.1.dist-info/RECORD +173 -0
  133. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
  134. python_doctr-0.7.0.dist-info/RECORD +0 -161
  135. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
  136. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
  137. {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
doctr/models/builder.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -20,6 +20,7 @@ class DocumentBuilder(NestedObject):
20
20
  """Implements a document builder
21
21
 
22
22
  Args:
23
+ ----
23
24
  resolve_lines: whether words should be automatically grouped into lines
24
25
  resolve_blocks: whether lines should be automatically grouped into blocks
25
26
  paragraph_break: relative length of the minimum space separating paragraphs
@@ -44,9 +45,11 @@ class DocumentBuilder(NestedObject):
44
45
  """Sort bounding boxes from top to bottom, left to right
45
46
 
46
47
  Args:
48
+ ----
47
49
  boxes: bounding boxes of shape (N, 4) or (N, 4, 2) (in case of rotated bbox)
48
50
 
49
51
  Returns:
52
+ -------
50
53
  tuple: indices of ordered boxes of shape (N,), boxes
51
54
  If straight boxes are passed tpo the function, boxes are unchanged
52
55
  else: boxes returned are straight boxes fitted to the straightened rotated boxes
@@ -66,10 +69,12 @@ class DocumentBuilder(NestedObject):
66
69
  """Split a line in sub_lines
67
70
 
68
71
  Args:
72
+ ----
69
73
  boxes: bounding boxes of shape (N, 4)
70
74
  word_idcs: list of indexes for the words of the line
71
75
 
72
76
  Returns:
77
+ -------
73
78
  A list of (sub-)lines computed from the original line (words)
74
79
  """
75
80
  lines = []
@@ -104,12 +109,13 @@ class DocumentBuilder(NestedObject):
104
109
  """Order boxes to group them in lines
105
110
 
106
111
  Args:
112
+ ----
107
113
  boxes: bounding boxes of shape (N, 4) or (N, 4, 2) in case of rotated bbox
108
114
 
109
115
  Returns:
116
+ -------
110
117
  nested list of box indices
111
118
  """
112
-
113
119
  # Sort boxes, and straighten the boxes if they are rotated
114
120
  idxs, boxes = self._sort_boxes(boxes)
115
121
 
@@ -151,25 +157,23 @@ class DocumentBuilder(NestedObject):
151
157
  """Order lines to group them in blocks
152
158
 
153
159
  Args:
160
+ ----
154
161
  boxes: bounding boxes of shape (N, 4) or (N, 4, 2)
155
162
  lines: list of lines, each line is a list of idx
156
163
 
157
164
  Returns:
165
+ -------
158
166
  nested list of box indices
159
167
  """
160
168
  # Resolve enclosing boxes of lines
161
169
  if boxes.ndim == 3:
162
- box_lines: np.ndarray = np.asarray(
163
- [
164
- resolve_enclosing_rbbox([tuple(boxes[idx, :, :]) for idx in line]) # type: ignore[misc]
165
- for line in lines
166
- ]
167
- )
170
+ box_lines: np.ndarray = np.asarray([
171
+ resolve_enclosing_rbbox([tuple(boxes[idx, :, :]) for idx in line]) # type: ignore[misc]
172
+ for line in lines
173
+ ])
168
174
  else:
169
175
  _box_lines = [
170
- resolve_enclosing_bbox(
171
- [(tuple(boxes[idx, :2]), tuple(boxes[idx, 2:])) for idx in line] # type: ignore[misc]
172
- )
176
+ resolve_enclosing_bbox([(tuple(boxes[idx, :2]), tuple(boxes[idx, 2:])) for idx in line])
173
177
  for line in lines
174
178
  ]
175
179
  box_lines = np.asarray([(x1, y1, x2, y2) for ((x1, y1), (x2, y2)) in _box_lines])
@@ -220,13 +224,14 @@ class DocumentBuilder(NestedObject):
220
224
  """Gather independent words in structured blocks
221
225
 
222
226
  Args:
227
+ ----
223
228
  boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2)
224
229
  word_preds: list of all detected words of the page, of shape N
225
230
 
226
231
  Returns:
232
+ -------
227
233
  list of block elements
228
234
  """
229
-
230
235
  if boxes.shape[0] != len(word_preds):
231
236
  raise ValueError(f"Incompatible argument lengths: {boxes.shape[0]}, {len(word_preds)}")
232
237
 
@@ -248,24 +253,18 @@ class DocumentBuilder(NestedObject):
248
253
  _blocks = [lines]
249
254
 
250
255
  blocks = [
251
- Block(
252
- [
253
- Line(
254
- [
255
- Word(
256
- *word_preds[idx],
257
- tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type]
258
- )
259
- if boxes.ndim == 3
260
- else Word(
261
- *word_preds[idx], ((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3]))
262
- )
263
- for idx in line
264
- ]
256
+ Block([
257
+ Line([
258
+ Word(
259
+ *word_preds[idx],
260
+ tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type]
265
261
  )
266
- for line in lines
267
- ]
268
- )
262
+ if boxes.ndim == 3
263
+ else Word(*word_preds[idx], ((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])))
264
+ for idx in line
265
+ ])
266
+ for line in lines
267
+ ])
269
268
  for lines in _blocks
270
269
  ]
271
270
 
@@ -280,6 +279,7 @@ class DocumentBuilder(NestedObject):
280
279
 
281
280
  def __call__(
282
281
  self,
282
+ pages: List[np.ndarray],
283
283
  boxes: List[np.ndarray],
284
284
  text_preds: List[List[Tuple[str, float]]],
285
285
  page_shapes: List[Tuple[int, int]],
@@ -289,12 +289,19 @@ class DocumentBuilder(NestedObject):
289
289
  """Re-arrange detected words into structured blocks
290
290
 
291
291
  Args:
292
+ ----
293
+ pages: list of N elements, where each element represents the page image
292
294
  boxes: list of N elements, where each element represents the localization predictions, of shape (*, 5)
293
295
  or (*, 6) for all words for a given page
294
296
  text_preds: list of N elements, where each element is the list of all word prediction (text + confidence)
295
- page_shape: shape of each page, of size N
297
+ page_shapes: shape of each page, of size N
298
+ orientations: optional, list of N elements,
299
+ where each element is a dictionary containing the orientation (orientation + confidence)
300
+ languages: optional, list of N elements,
301
+ where each element is a dictionary containing the language (language + confidence)
296
302
 
297
303
  Returns:
304
+ -------
298
305
  document object
299
306
  """
300
307
  if len(boxes) != len(text_preds) or len(boxes) != len(page_shapes):
@@ -307,15 +314,12 @@ class DocumentBuilder(NestedObject):
307
314
  if self.export_as_straight_boxes and len(boxes) > 0:
308
315
  # If boxes are already straight OK, else fit a bounding rect
309
316
  if boxes[0].ndim == 3:
310
- straight_boxes: List[np.ndarray] = []
311
- # Iterate over pages
312
- for p_boxes in boxes:
313
- # Iterate over boxes of the pages
314
- straight_boxes.append(np.concatenate((p_boxes.min(1), p_boxes.max(1)), 1))
315
- boxes = straight_boxes
317
+ # Iterate over pages and boxes
318
+ boxes = [np.concatenate((p_boxes.min(1), p_boxes.max(1)), 1) for p_boxes in boxes]
316
319
 
317
320
  _pages = [
318
321
  Page(
322
+ page,
319
323
  self._build_blocks(
320
324
  page_boxes,
321
325
  word_preds,
@@ -325,8 +329,8 @@ class DocumentBuilder(NestedObject):
325
329
  orientation,
326
330
  language,
327
331
  )
328
- for _idx, shape, page_boxes, word_preds, orientation, language in zip(
329
- range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
332
+ for page, _idx, shape, page_boxes, word_preds, orientation, language in zip(
333
+ pages, range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
330
334
  )
331
335
  ]
332
336
 
@@ -337,6 +341,7 @@ class KIEDocumentBuilder(DocumentBuilder):
337
341
  """Implements a KIE document builder
338
342
 
339
343
  Args:
344
+ ----
340
345
  resolve_lines: whether words should be automatically grouped into lines
341
346
  resolve_blocks: whether lines should be automatically grouped into blocks
342
347
  paragraph_break: relative length of the minimum space separating paragraphs
@@ -346,6 +351,7 @@ class KIEDocumentBuilder(DocumentBuilder):
346
351
 
347
352
  def __call__( # type: ignore[override]
348
353
  self,
354
+ pages: List[np.ndarray],
349
355
  boxes: List[Dict[str, np.ndarray]],
350
356
  text_preds: List[Dict[str, List[Tuple[str, float]]]],
351
357
  page_shapes: List[Tuple[int, int]],
@@ -355,12 +361,19 @@ class KIEDocumentBuilder(DocumentBuilder):
355
361
  """Re-arrange detected words into structured predictions
356
362
 
357
363
  Args:
364
+ ----
365
+ pages: list of N elements, where each element represents the page image
358
366
  boxes: list of N dictionaries, where each element represents the localization predictions for a class,
359
- of shape (*, 5) or (*, 6) for all predictions
367
+ of shape (*, 5) or (*, 6) for all predictions
360
368
  text_preds: list of N dictionaries, where each element is the list of all word prediction
361
- page_shape: shape of each page, of size N
369
+ page_shapes: shape of each page, of size N
370
+ orientations: optional, list of N elements,
371
+ where each element is a dictionary containing the orientation (orientation + confidence)
372
+ languages: optional, list of N elements,
373
+ where each element is a dictionary containing the language (language + confidence)
362
374
 
363
375
  Returns:
376
+ -------
364
377
  document object
365
378
  """
366
379
  if len(boxes) != len(text_preds) or len(boxes) != len(page_shapes):
@@ -384,6 +397,7 @@ class KIEDocumentBuilder(DocumentBuilder):
384
397
 
385
398
  _pages = [
386
399
  KIEPage(
400
+ page,
387
401
  {
388
402
  k: self._build_blocks(
389
403
  page_boxes[k],
@@ -396,8 +410,8 @@ class KIEDocumentBuilder(DocumentBuilder):
396
410
  orientation,
397
411
  language,
398
412
  )
399
- for _idx, shape, page_boxes, word_preds, orientation, language in zip(
400
- range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
413
+ for page, _idx, shape, page_boxes, word_preds, orientation, language in zip(
414
+ pages, range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
401
415
  )
402
416
  ]
403
417
 
@@ -411,13 +425,14 @@ class KIEDocumentBuilder(DocumentBuilder):
411
425
  """Gather independent words in structured blocks
412
426
 
413
427
  Args:
428
+ ----
414
429
  boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2)
415
430
  word_preds: list of all detected words of the page, of shape N
416
431
 
417
432
  Returns:
433
+ -------
418
434
  list of block elements
419
435
  """
420
-
421
436
  if boxes.shape[0] != len(word_preds):
422
437
  raise ValueError(f"Incompatible argument lengths: {boxes.shape[0]}, {len(word_preds)}")
423
438
 
@@ -3,4 +3,5 @@ from .resnet import *
3
3
  from .vgg import *
4
4
  from .magc_resnet import *
5
5
  from .vit import *
6
+ from .textnet import *
6
7
  from .zoo import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -36,6 +36,7 @@ class MAGC(nn.Module):
36
36
  <https://arxiv.org/pdf/1910.02562.pdf>`_.
37
37
 
38
38
  Args:
39
+ ----
39
40
  inplanes: input channels
40
41
  headers: number of headers to split channels
41
42
  attn_scale: if True, re-scale attention to counteract the variance distibutions
@@ -153,12 +154,14 @@ def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
153
154
  >>> out = model(input_tensor)
154
155
 
155
156
  Args:
157
+ ----
156
158
  pretrained: boolean, True if model is pretrained
159
+ **kwargs: keyword arguments of the ResNet architecture
157
160
 
158
161
  Returns:
162
+ -------
159
163
  A feature extractor model
160
164
  """
161
-
162
165
  return _magc_resnet(
163
166
  "magc_resnet31",
164
167
  pretrained,
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -36,6 +36,7 @@ class MAGC(layers.Layer):
36
36
  <https://arxiv.org/pdf/1910.02562.pdf>`_.
37
37
 
38
38
  Args:
39
+ ----
39
40
  inplanes: input channels
40
41
  headers: number of headers to split channels
41
42
  attn_scale: if True, re-scale attention to counteract the variance distibutions
@@ -169,12 +170,14 @@ def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
169
170
  >>> out = model(input_tensor)
170
171
 
171
172
  Args:
173
+ ----
172
174
  pretrained: boolean, True if model is pretrained
175
+ **kwargs: keyword arguments of the ResNet architecture
173
176
 
174
177
  Returns:
178
+ -------
175
179
  A feature extractor model
176
180
  """
177
-
178
181
  return _magc_resnet(
179
182
  "magc_resnet31",
180
183
  pretrained,
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -113,12 +113,14 @@ def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.M
113
113
  >>> out = model(input_tensor)
114
114
 
115
115
  Args:
116
+ ----
116
117
  pretrained: boolean, True if model is pretrained
118
+ **kwargs: keyword arguments of the MobileNetV3 architecture
117
119
 
118
120
  Returns:
121
+ -------
119
122
  a torch.nn.Module
120
123
  """
121
-
122
124
  return _mobilenet_v3(
123
125
  "mobilenet_v3_small", pretrained, ignore_keys=["classifier.3.weight", "classifier.3.bias"], **kwargs
124
126
  )
@@ -136,12 +138,14 @@ def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3
136
138
  >>> out = model(input_tensor)
137
139
 
138
140
  Args:
141
+ ----
139
142
  pretrained: boolean, True if model is pretrained
143
+ **kwargs: keyword arguments of the MobileNetV3 architecture
140
144
 
141
145
  Returns:
146
+ -------
142
147
  a torch.nn.Module
143
148
  """
144
-
145
149
  return _mobilenet_v3(
146
150
  "mobilenet_v3_small_r",
147
151
  pretrained,
@@ -163,9 +167,12 @@ def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.M
163
167
  >>> out = model(input_tensor)
164
168
 
165
169
  Args:
170
+ ----
166
171
  pretrained: boolean, True if model is pretrained
172
+ **kwargs: keyword arguments of the MobileNetV3 architecture
167
173
 
168
174
  Returns:
175
+ -------
169
176
  a torch.nn.Module
170
177
  """
171
178
  return _mobilenet_v3(
@@ -188,9 +195,12 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3
188
195
  >>> out = model(input_tensor)
189
196
 
190
197
  Args:
198
+ ----
191
199
  pretrained: boolean, True if model is pretrained
200
+ **kwargs: keyword arguments of the MobileNetV3 architecture
192
201
 
193
202
  Returns:
203
+ -------
194
204
  a torch.nn.Module
195
205
  """
196
206
  return _mobilenet_v3(
@@ -214,12 +224,14 @@ def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> m
214
224
  >>> out = model(input_tensor)
215
225
 
216
226
  Args:
227
+ ----
217
228
  pretrained: boolean, True if model is pretrained
229
+ **kwargs: keyword arguments of the MobileNetV3 architecture
218
230
 
219
231
  Returns:
232
+ -------
220
233
  a torch.nn.Module
221
234
  """
222
-
223
235
  return _mobilenet_v3(
224
236
  "mobilenet_v3_small_orientation",
225
237
  pretrained,
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -82,14 +82,12 @@ class SqueezeExcitation(Sequential):
82
82
  """Squeeze and Excitation."""
83
83
 
84
84
  def __init__(self, chan: int, squeeze_factor: int = 4) -> None:
85
- super().__init__(
86
- [
87
- layers.GlobalAveragePooling2D(),
88
- layers.Dense(chan // squeeze_factor, activation="relu"),
89
- layers.Dense(chan, activation="hard_sigmoid"),
90
- layers.Reshape((1, 1, chan)),
91
- ]
92
- )
85
+ super().__init__([
86
+ layers.GlobalAveragePooling2D(),
87
+ layers.Dense(chan // squeeze_factor, activation="relu"),
88
+ layers.Dense(chan, activation="hard_sigmoid"),
89
+ layers.Reshape((1, 1, chan)),
90
+ ])
93
91
 
94
92
  def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor:
95
93
  x = super().call(inputs, **kwargs)
@@ -126,6 +124,7 @@ class InvertedResidual(layers.Layer):
126
124
  """InvertedResidual for mobilenet
127
125
 
128
126
  Args:
127
+ ----
129
128
  conf: configuration object for inverted residual
130
129
  """
131
130
 
@@ -220,14 +219,12 @@ class MobileNetV3(Sequential):
220
219
  )
221
220
 
222
221
  if include_top:
223
- _layers.extend(
224
- [
225
- layers.GlobalAveragePooling2D(),
226
- layers.Dense(head_chans, activation=hard_swish),
227
- layers.Dropout(0.2),
228
- layers.Dense(num_classes),
229
- ]
230
- )
222
+ _layers.extend([
223
+ layers.GlobalAveragePooling2D(),
224
+ layers.Dense(head_chans, activation=hard_swish),
225
+ layers.Dropout(0.2),
226
+ layers.Dense(num_classes),
227
+ ])
231
228
 
232
229
  super().__init__(_layers)
233
230
  self.cfg = cfg
@@ -309,12 +306,14 @@ def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
309
306
  >>> out = model(input_tensor)
310
307
 
311
308
  Args:
309
+ ----
312
310
  pretrained: boolean, True if model is pretrained
311
+ **kwargs: keyword arguments of the MobileNetV3 architecture
313
312
 
314
313
  Returns:
314
+ -------
315
315
  a keras.Model
316
316
  """
317
-
318
317
  return _mobilenet_v3("mobilenet_v3_small", pretrained, False, **kwargs)
319
318
 
320
319
 
@@ -330,12 +329,14 @@ def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3
330
329
  >>> out = model(input_tensor)
331
330
 
332
331
  Args:
332
+ ----
333
333
  pretrained: boolean, True if model is pretrained
334
+ **kwargs: keyword arguments of the MobileNetV3 architecture
334
335
 
335
336
  Returns:
337
+ -------
336
338
  a keras.Model
337
339
  """
338
-
339
340
  return _mobilenet_v3("mobilenet_v3_small_r", pretrained, True, **kwargs)
340
341
 
341
342
 
@@ -351,9 +352,12 @@ def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
351
352
  >>> out = model(input_tensor)
352
353
 
353
354
  Args:
355
+ ----
354
356
  pretrained: boolean, True if model is pretrained
357
+ **kwargs: keyword arguments of the MobileNetV3 architecture
355
358
 
356
359
  Returns:
360
+ -------
357
361
  a keras.Model
358
362
  """
359
363
  return _mobilenet_v3("mobilenet_v3_large", pretrained, False, **kwargs)
@@ -371,9 +375,12 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3
371
375
  >>> out = model(input_tensor)
372
376
 
373
377
  Args:
378
+ ----
374
379
  pretrained: boolean, True if model is pretrained
380
+ **kwargs: keyword arguments of the MobileNetV3 architecture
375
381
 
376
382
  Returns:
383
+ -------
377
384
  a keras.Model
378
385
  """
379
386
  return _mobilenet_v3("mobilenet_v3_large_r", pretrained, True, **kwargs)
@@ -391,10 +398,12 @@ def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> M
391
398
  >>> out = model(input_tensor)
392
399
 
393
400
  Args:
401
+ ----
394
402
  pretrained: boolean, True if model is pretrained
403
+ **kwargs: keyword arguments of the MobileNetV3 architecture
395
404
 
396
405
  Returns:
406
+ -------
397
407
  a keras.Model
398
408
  """
399
-
400
409
  return _mobilenet_v3("mobilenet_v3_small_orientation", pretrained, include_top=True, **kwargs)
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -20,6 +20,7 @@ class CropOrientationPredictor(nn.Module):
20
20
  4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.
21
21
 
22
22
  Args:
23
+ ----
23
24
  pre_processor: transform inputs for easier batched model inference
24
25
  model: core classification architecture (backbone + classification head)
25
26
  """
@@ -33,7 +34,7 @@ class CropOrientationPredictor(nn.Module):
33
34
  self.pre_processor = pre_processor
34
35
  self.model = model.eval()
35
36
 
36
- @torch.no_grad()
37
+ @torch.inference_mode()
37
38
  def forward(
38
39
  self,
39
40
  crops: List[Union[np.ndarray, torch.Tensor]],
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2023, Mindee.
1
+ # Copyright (C) 2021-2024, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -20,6 +20,7 @@ class CropOrientationPredictor(NestedObject):
20
20
  4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.
21
21
 
22
22
  Args:
23
+ ----
23
24
  pre_processor: transform inputs for easier batched model inference
24
25
  model: core classification architecture (backbone + classification head)
25
26
  """