python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/__init__.py +2 -0
- doctr/datasets/cord.py +6 -4
- doctr/datasets/datasets/base.py +3 -2
- doctr/datasets/datasets/pytorch.py +4 -2
- doctr/datasets/datasets/tensorflow.py +4 -2
- doctr/datasets/detection.py +6 -3
- doctr/datasets/doc_artefacts.py +2 -1
- doctr/datasets/funsd.py +7 -8
- doctr/datasets/generator/base.py +3 -2
- doctr/datasets/generator/pytorch.py +3 -1
- doctr/datasets/generator/tensorflow.py +3 -1
- doctr/datasets/ic03.py +3 -2
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +6 -4
- doctr/datasets/iiithws.py +2 -1
- doctr/datasets/imgur5k.py +3 -2
- doctr/datasets/loader.py +4 -2
- doctr/datasets/mjsynth.py +2 -1
- doctr/datasets/ocr.py +2 -1
- doctr/datasets/orientation.py +40 -0
- doctr/datasets/recognition.py +3 -2
- doctr/datasets/sroie.py +2 -1
- doctr/datasets/svhn.py +2 -1
- doctr/datasets/svt.py +3 -2
- doctr/datasets/synthtext.py +2 -1
- doctr/datasets/utils.py +27 -11
- doctr/datasets/vocabs.py +26 -1
- doctr/datasets/wildreceipt.py +111 -0
- doctr/file_utils.py +3 -1
- doctr/io/elements.py +52 -35
- doctr/io/html.py +5 -3
- doctr/io/image/base.py +5 -4
- doctr/io/image/pytorch.py +12 -7
- doctr/io/image/tensorflow.py +11 -6
- doctr/io/pdf.py +5 -4
- doctr/io/reader.py +13 -5
- doctr/models/_utils.py +30 -53
- doctr/models/artefacts/barcode.py +4 -3
- doctr/models/artefacts/face.py +4 -2
- doctr/models/builder.py +58 -43
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/pytorch.py +5 -2
- doctr/models/classification/magc_resnet/tensorflow.py +5 -2
- doctr/models/classification/mobilenet/pytorch.py +16 -4
- doctr/models/classification/mobilenet/tensorflow.py +29 -20
- doctr/models/classification/predictor/pytorch.py +3 -2
- doctr/models/classification/predictor/tensorflow.py +2 -1
- doctr/models/classification/resnet/pytorch.py +23 -13
- doctr/models/classification/resnet/tensorflow.py +33 -26
- doctr/models/classification/textnet/__init__.py +6 -0
- doctr/models/classification/textnet/pytorch.py +275 -0
- doctr/models/classification/textnet/tensorflow.py +267 -0
- doctr/models/classification/vgg/pytorch.py +4 -2
- doctr/models/classification/vgg/tensorflow.py +5 -2
- doctr/models/classification/vit/pytorch.py +9 -3
- doctr/models/classification/vit/tensorflow.py +9 -3
- doctr/models/classification/zoo.py +7 -2
- doctr/models/core.py +1 -1
- doctr/models/detection/__init__.py +1 -0
- doctr/models/detection/_utils/pytorch.py +7 -1
- doctr/models/detection/_utils/tensorflow.py +7 -3
- doctr/models/detection/core.py +9 -3
- doctr/models/detection/differentiable_binarization/base.py +37 -25
- doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
- doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
- doctr/models/detection/fast/__init__.py +6 -0
- doctr/models/detection/fast/base.py +256 -0
- doctr/models/detection/fast/pytorch.py +442 -0
- doctr/models/detection/fast/tensorflow.py +428 -0
- doctr/models/detection/linknet/base.py +12 -5
- doctr/models/detection/linknet/pytorch.py +28 -15
- doctr/models/detection/linknet/tensorflow.py +68 -88
- doctr/models/detection/predictor/pytorch.py +16 -6
- doctr/models/detection/predictor/tensorflow.py +13 -5
- doctr/models/detection/zoo.py +19 -16
- doctr/models/factory/hub.py +20 -10
- doctr/models/kie_predictor/base.py +2 -1
- doctr/models/kie_predictor/pytorch.py +28 -36
- doctr/models/kie_predictor/tensorflow.py +27 -27
- doctr/models/modules/__init__.py +1 -0
- doctr/models/modules/layers/__init__.py +6 -0
- doctr/models/modules/layers/pytorch.py +166 -0
- doctr/models/modules/layers/tensorflow.py +175 -0
- doctr/models/modules/transformer/pytorch.py +24 -22
- doctr/models/modules/transformer/tensorflow.py +6 -4
- doctr/models/modules/vision_transformer/pytorch.py +2 -4
- doctr/models/modules/vision_transformer/tensorflow.py +2 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
- doctr/models/predictor/base.py +14 -3
- doctr/models/predictor/pytorch.py +26 -29
- doctr/models/predictor/tensorflow.py +25 -22
- doctr/models/preprocessor/pytorch.py +14 -9
- doctr/models/preprocessor/tensorflow.py +10 -5
- doctr/models/recognition/core.py +4 -1
- doctr/models/recognition/crnn/pytorch.py +23 -16
- doctr/models/recognition/crnn/tensorflow.py +25 -17
- doctr/models/recognition/master/base.py +4 -1
- doctr/models/recognition/master/pytorch.py +20 -9
- doctr/models/recognition/master/tensorflow.py +20 -8
- doctr/models/recognition/parseq/base.py +4 -1
- doctr/models/recognition/parseq/pytorch.py +28 -22
- doctr/models/recognition/parseq/tensorflow.py +22 -11
- doctr/models/recognition/predictor/_utils.py +3 -2
- doctr/models/recognition/predictor/pytorch.py +3 -2
- doctr/models/recognition/predictor/tensorflow.py +2 -1
- doctr/models/recognition/sar/pytorch.py +14 -7
- doctr/models/recognition/sar/tensorflow.py +23 -14
- doctr/models/recognition/utils.py +5 -1
- doctr/models/recognition/vitstr/base.py +4 -1
- doctr/models/recognition/vitstr/pytorch.py +22 -13
- doctr/models/recognition/vitstr/tensorflow.py +21 -10
- doctr/models/recognition/zoo.py +4 -2
- doctr/models/utils/pytorch.py +24 -6
- doctr/models/utils/tensorflow.py +22 -3
- doctr/models/zoo.py +21 -3
- doctr/transforms/functional/base.py +8 -3
- doctr/transforms/functional/pytorch.py +23 -6
- doctr/transforms/functional/tensorflow.py +25 -5
- doctr/transforms/modules/base.py +12 -5
- doctr/transforms/modules/pytorch.py +10 -12
- doctr/transforms/modules/tensorflow.py +17 -9
- doctr/utils/common_types.py +1 -1
- doctr/utils/data.py +4 -2
- doctr/utils/fonts.py +3 -2
- doctr/utils/geometry.py +95 -26
- doctr/utils/metrics.py +36 -22
- doctr/utils/multithreading.py +5 -3
- doctr/utils/repr.py +3 -1
- doctr/utils/visualization.py +31 -8
- doctr/version.py +1 -1
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
- python_doctr-0.8.1.dist-info/RECORD +173 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
- python_doctr-0.7.0.dist-info/RECORD +0 -161
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
doctr/models/builder.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -20,6 +20,7 @@ class DocumentBuilder(NestedObject):
|
|
|
20
20
|
"""Implements a document builder
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
+
----
|
|
23
24
|
resolve_lines: whether words should be automatically grouped into lines
|
|
24
25
|
resolve_blocks: whether lines should be automatically grouped into blocks
|
|
25
26
|
paragraph_break: relative length of the minimum space separating paragraphs
|
|
@@ -44,9 +45,11 @@ class DocumentBuilder(NestedObject):
|
|
|
44
45
|
"""Sort bounding boxes from top to bottom, left to right
|
|
45
46
|
|
|
46
47
|
Args:
|
|
48
|
+
----
|
|
47
49
|
boxes: bounding boxes of shape (N, 4) or (N, 4, 2) (in case of rotated bbox)
|
|
48
50
|
|
|
49
51
|
Returns:
|
|
52
|
+
-------
|
|
50
53
|
tuple: indices of ordered boxes of shape (N,), boxes
|
|
51
54
|
If straight boxes are passed tpo the function, boxes are unchanged
|
|
52
55
|
else: boxes returned are straight boxes fitted to the straightened rotated boxes
|
|
@@ -66,10 +69,12 @@ class DocumentBuilder(NestedObject):
|
|
|
66
69
|
"""Split a line in sub_lines
|
|
67
70
|
|
|
68
71
|
Args:
|
|
72
|
+
----
|
|
69
73
|
boxes: bounding boxes of shape (N, 4)
|
|
70
74
|
word_idcs: list of indexes for the words of the line
|
|
71
75
|
|
|
72
76
|
Returns:
|
|
77
|
+
-------
|
|
73
78
|
A list of (sub-)lines computed from the original line (words)
|
|
74
79
|
"""
|
|
75
80
|
lines = []
|
|
@@ -104,12 +109,13 @@ class DocumentBuilder(NestedObject):
|
|
|
104
109
|
"""Order boxes to group them in lines
|
|
105
110
|
|
|
106
111
|
Args:
|
|
112
|
+
----
|
|
107
113
|
boxes: bounding boxes of shape (N, 4) or (N, 4, 2) in case of rotated bbox
|
|
108
114
|
|
|
109
115
|
Returns:
|
|
116
|
+
-------
|
|
110
117
|
nested list of box indices
|
|
111
118
|
"""
|
|
112
|
-
|
|
113
119
|
# Sort boxes, and straighten the boxes if they are rotated
|
|
114
120
|
idxs, boxes = self._sort_boxes(boxes)
|
|
115
121
|
|
|
@@ -151,25 +157,23 @@ class DocumentBuilder(NestedObject):
|
|
|
151
157
|
"""Order lines to group them in blocks
|
|
152
158
|
|
|
153
159
|
Args:
|
|
160
|
+
----
|
|
154
161
|
boxes: bounding boxes of shape (N, 4) or (N, 4, 2)
|
|
155
162
|
lines: list of lines, each line is a list of idx
|
|
156
163
|
|
|
157
164
|
Returns:
|
|
165
|
+
-------
|
|
158
166
|
nested list of box indices
|
|
159
167
|
"""
|
|
160
168
|
# Resolve enclosing boxes of lines
|
|
161
169
|
if boxes.ndim == 3:
|
|
162
|
-
box_lines: np.ndarray = np.asarray(
|
|
163
|
-
[
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
]
|
|
167
|
-
)
|
|
170
|
+
box_lines: np.ndarray = np.asarray([
|
|
171
|
+
resolve_enclosing_rbbox([tuple(boxes[idx, :, :]) for idx in line]) # type: ignore[misc]
|
|
172
|
+
for line in lines
|
|
173
|
+
])
|
|
168
174
|
else:
|
|
169
175
|
_box_lines = [
|
|
170
|
-
resolve_enclosing_bbox(
|
|
171
|
-
[(tuple(boxes[idx, :2]), tuple(boxes[idx, 2:])) for idx in line] # type: ignore[misc]
|
|
172
|
-
)
|
|
176
|
+
resolve_enclosing_bbox([(tuple(boxes[idx, :2]), tuple(boxes[idx, 2:])) for idx in line])
|
|
173
177
|
for line in lines
|
|
174
178
|
]
|
|
175
179
|
box_lines = np.asarray([(x1, y1, x2, y2) for ((x1, y1), (x2, y2)) in _box_lines])
|
|
@@ -220,13 +224,14 @@ class DocumentBuilder(NestedObject):
|
|
|
220
224
|
"""Gather independent words in structured blocks
|
|
221
225
|
|
|
222
226
|
Args:
|
|
227
|
+
----
|
|
223
228
|
boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2)
|
|
224
229
|
word_preds: list of all detected words of the page, of shape N
|
|
225
230
|
|
|
226
231
|
Returns:
|
|
232
|
+
-------
|
|
227
233
|
list of block elements
|
|
228
234
|
"""
|
|
229
|
-
|
|
230
235
|
if boxes.shape[0] != len(word_preds):
|
|
231
236
|
raise ValueError(f"Incompatible argument lengths: {boxes.shape[0]}, {len(word_preds)}")
|
|
232
237
|
|
|
@@ -248,24 +253,18 @@ class DocumentBuilder(NestedObject):
|
|
|
248
253
|
_blocks = [lines]
|
|
249
254
|
|
|
250
255
|
blocks = [
|
|
251
|
-
Block(
|
|
252
|
-
[
|
|
253
|
-
|
|
254
|
-
[
|
|
255
|
-
|
|
256
|
-
*word_preds[idx],
|
|
257
|
-
tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type]
|
|
258
|
-
)
|
|
259
|
-
if boxes.ndim == 3
|
|
260
|
-
else Word(
|
|
261
|
-
*word_preds[idx], ((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3]))
|
|
262
|
-
)
|
|
263
|
-
for idx in line
|
|
264
|
-
]
|
|
256
|
+
Block([
|
|
257
|
+
Line([
|
|
258
|
+
Word(
|
|
259
|
+
*word_preds[idx],
|
|
260
|
+
tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type]
|
|
265
261
|
)
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
262
|
+
if boxes.ndim == 3
|
|
263
|
+
else Word(*word_preds[idx], ((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])))
|
|
264
|
+
for idx in line
|
|
265
|
+
])
|
|
266
|
+
for line in lines
|
|
267
|
+
])
|
|
269
268
|
for lines in _blocks
|
|
270
269
|
]
|
|
271
270
|
|
|
@@ -280,6 +279,7 @@ class DocumentBuilder(NestedObject):
|
|
|
280
279
|
|
|
281
280
|
def __call__(
|
|
282
281
|
self,
|
|
282
|
+
pages: List[np.ndarray],
|
|
283
283
|
boxes: List[np.ndarray],
|
|
284
284
|
text_preds: List[List[Tuple[str, float]]],
|
|
285
285
|
page_shapes: List[Tuple[int, int]],
|
|
@@ -289,12 +289,19 @@ class DocumentBuilder(NestedObject):
|
|
|
289
289
|
"""Re-arrange detected words into structured blocks
|
|
290
290
|
|
|
291
291
|
Args:
|
|
292
|
+
----
|
|
293
|
+
pages: list of N elements, where each element represents the page image
|
|
292
294
|
boxes: list of N elements, where each element represents the localization predictions, of shape (*, 5)
|
|
293
295
|
or (*, 6) for all words for a given page
|
|
294
296
|
text_preds: list of N elements, where each element is the list of all word prediction (text + confidence)
|
|
295
|
-
|
|
297
|
+
page_shapes: shape of each page, of size N
|
|
298
|
+
orientations: optional, list of N elements,
|
|
299
|
+
where each element is a dictionary containing the orientation (orientation + confidence)
|
|
300
|
+
languages: optional, list of N elements,
|
|
301
|
+
where each element is a dictionary containing the language (language + confidence)
|
|
296
302
|
|
|
297
303
|
Returns:
|
|
304
|
+
-------
|
|
298
305
|
document object
|
|
299
306
|
"""
|
|
300
307
|
if len(boxes) != len(text_preds) or len(boxes) != len(page_shapes):
|
|
@@ -307,15 +314,12 @@ class DocumentBuilder(NestedObject):
|
|
|
307
314
|
if self.export_as_straight_boxes and len(boxes) > 0:
|
|
308
315
|
# If boxes are already straight OK, else fit a bounding rect
|
|
309
316
|
if boxes[0].ndim == 3:
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
for p_boxes in boxes:
|
|
313
|
-
# Iterate over boxes of the pages
|
|
314
|
-
straight_boxes.append(np.concatenate((p_boxes.min(1), p_boxes.max(1)), 1))
|
|
315
|
-
boxes = straight_boxes
|
|
317
|
+
# Iterate over pages and boxes
|
|
318
|
+
boxes = [np.concatenate((p_boxes.min(1), p_boxes.max(1)), 1) for p_boxes in boxes]
|
|
316
319
|
|
|
317
320
|
_pages = [
|
|
318
321
|
Page(
|
|
322
|
+
page,
|
|
319
323
|
self._build_blocks(
|
|
320
324
|
page_boxes,
|
|
321
325
|
word_preds,
|
|
@@ -325,8 +329,8 @@ class DocumentBuilder(NestedObject):
|
|
|
325
329
|
orientation,
|
|
326
330
|
language,
|
|
327
331
|
)
|
|
328
|
-
for _idx, shape, page_boxes, word_preds, orientation, language in zip(
|
|
329
|
-
range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
|
|
332
|
+
for page, _idx, shape, page_boxes, word_preds, orientation, language in zip(
|
|
333
|
+
pages, range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
|
|
330
334
|
)
|
|
331
335
|
]
|
|
332
336
|
|
|
@@ -337,6 +341,7 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
337
341
|
"""Implements a KIE document builder
|
|
338
342
|
|
|
339
343
|
Args:
|
|
344
|
+
----
|
|
340
345
|
resolve_lines: whether words should be automatically grouped into lines
|
|
341
346
|
resolve_blocks: whether lines should be automatically grouped into blocks
|
|
342
347
|
paragraph_break: relative length of the minimum space separating paragraphs
|
|
@@ -346,6 +351,7 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
346
351
|
|
|
347
352
|
def __call__( # type: ignore[override]
|
|
348
353
|
self,
|
|
354
|
+
pages: List[np.ndarray],
|
|
349
355
|
boxes: List[Dict[str, np.ndarray]],
|
|
350
356
|
text_preds: List[Dict[str, List[Tuple[str, float]]]],
|
|
351
357
|
page_shapes: List[Tuple[int, int]],
|
|
@@ -355,12 +361,19 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
355
361
|
"""Re-arrange detected words into structured predictions
|
|
356
362
|
|
|
357
363
|
Args:
|
|
364
|
+
----
|
|
365
|
+
pages: list of N elements, where each element represents the page image
|
|
358
366
|
boxes: list of N dictionaries, where each element represents the localization predictions for a class,
|
|
359
|
-
|
|
367
|
+
of shape (*, 5) or (*, 6) for all predictions
|
|
360
368
|
text_preds: list of N dictionaries, where each element is the list of all word prediction
|
|
361
|
-
|
|
369
|
+
page_shapes: shape of each page, of size N
|
|
370
|
+
orientations: optional, list of N elements,
|
|
371
|
+
where each element is a dictionary containing the orientation (orientation + confidence)
|
|
372
|
+
languages: optional, list of N elements,
|
|
373
|
+
where each element is a dictionary containing the language (language + confidence)
|
|
362
374
|
|
|
363
375
|
Returns:
|
|
376
|
+
-------
|
|
364
377
|
document object
|
|
365
378
|
"""
|
|
366
379
|
if len(boxes) != len(text_preds) or len(boxes) != len(page_shapes):
|
|
@@ -384,6 +397,7 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
384
397
|
|
|
385
398
|
_pages = [
|
|
386
399
|
KIEPage(
|
|
400
|
+
page,
|
|
387
401
|
{
|
|
388
402
|
k: self._build_blocks(
|
|
389
403
|
page_boxes[k],
|
|
@@ -396,8 +410,8 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
396
410
|
orientation,
|
|
397
411
|
language,
|
|
398
412
|
)
|
|
399
|
-
for _idx, shape, page_boxes, word_preds, orientation, language in zip(
|
|
400
|
-
range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
|
|
413
|
+
for page, _idx, shape, page_boxes, word_preds, orientation, language in zip(
|
|
414
|
+
pages, range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
|
|
401
415
|
)
|
|
402
416
|
]
|
|
403
417
|
|
|
@@ -411,13 +425,14 @@ class KIEDocumentBuilder(DocumentBuilder):
|
|
|
411
425
|
"""Gather independent words in structured blocks
|
|
412
426
|
|
|
413
427
|
Args:
|
|
428
|
+
----
|
|
414
429
|
boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2)
|
|
415
430
|
word_preds: list of all detected words of the page, of shape N
|
|
416
431
|
|
|
417
432
|
Returns:
|
|
433
|
+
-------
|
|
418
434
|
list of block elements
|
|
419
435
|
"""
|
|
420
|
-
|
|
421
436
|
if boxes.shape[0] != len(word_preds):
|
|
422
437
|
raise ValueError(f"Incompatible argument lengths: {boxes.shape[0]}, {len(word_preds)}")
|
|
423
438
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -36,6 +36,7 @@ class MAGC(nn.Module):
|
|
|
36
36
|
<https://arxiv.org/pdf/1910.02562.pdf>`_.
|
|
37
37
|
|
|
38
38
|
Args:
|
|
39
|
+
----
|
|
39
40
|
inplanes: input channels
|
|
40
41
|
headers: number of headers to split channels
|
|
41
42
|
attn_scale: if True, re-scale attention to counteract the variance distibutions
|
|
@@ -153,12 +154,14 @@ def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
153
154
|
>>> out = model(input_tensor)
|
|
154
155
|
|
|
155
156
|
Args:
|
|
157
|
+
----
|
|
156
158
|
pretrained: boolean, True if model is pretrained
|
|
159
|
+
**kwargs: keyword arguments of the ResNet architecture
|
|
157
160
|
|
|
158
161
|
Returns:
|
|
162
|
+
-------
|
|
159
163
|
A feature extractor model
|
|
160
164
|
"""
|
|
161
|
-
|
|
162
165
|
return _magc_resnet(
|
|
163
166
|
"magc_resnet31",
|
|
164
167
|
pretrained,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -36,6 +36,7 @@ class MAGC(layers.Layer):
|
|
|
36
36
|
<https://arxiv.org/pdf/1910.02562.pdf>`_.
|
|
37
37
|
|
|
38
38
|
Args:
|
|
39
|
+
----
|
|
39
40
|
inplanes: input channels
|
|
40
41
|
headers: number of headers to split channels
|
|
41
42
|
attn_scale: if True, re-scale attention to counteract the variance distibutions
|
|
@@ -169,12 +170,14 @@ def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
169
170
|
>>> out = model(input_tensor)
|
|
170
171
|
|
|
171
172
|
Args:
|
|
173
|
+
----
|
|
172
174
|
pretrained: boolean, True if model is pretrained
|
|
175
|
+
**kwargs: keyword arguments of the ResNet architecture
|
|
173
176
|
|
|
174
177
|
Returns:
|
|
178
|
+
-------
|
|
175
179
|
A feature extractor model
|
|
176
180
|
"""
|
|
177
|
-
|
|
178
181
|
return _magc_resnet(
|
|
179
182
|
"magc_resnet31",
|
|
180
183
|
pretrained,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -113,12 +113,14 @@ def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.M
|
|
|
113
113
|
>>> out = model(input_tensor)
|
|
114
114
|
|
|
115
115
|
Args:
|
|
116
|
+
----
|
|
116
117
|
pretrained: boolean, True if model is pretrained
|
|
118
|
+
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
117
119
|
|
|
118
120
|
Returns:
|
|
121
|
+
-------
|
|
119
122
|
a torch.nn.Module
|
|
120
123
|
"""
|
|
121
|
-
|
|
122
124
|
return _mobilenet_v3(
|
|
123
125
|
"mobilenet_v3_small", pretrained, ignore_keys=["classifier.3.weight", "classifier.3.bias"], **kwargs
|
|
124
126
|
)
|
|
@@ -136,12 +138,14 @@ def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3
|
|
|
136
138
|
>>> out = model(input_tensor)
|
|
137
139
|
|
|
138
140
|
Args:
|
|
141
|
+
----
|
|
139
142
|
pretrained: boolean, True if model is pretrained
|
|
143
|
+
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
140
144
|
|
|
141
145
|
Returns:
|
|
146
|
+
-------
|
|
142
147
|
a torch.nn.Module
|
|
143
148
|
"""
|
|
144
|
-
|
|
145
149
|
return _mobilenet_v3(
|
|
146
150
|
"mobilenet_v3_small_r",
|
|
147
151
|
pretrained,
|
|
@@ -163,9 +167,12 @@ def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.M
|
|
|
163
167
|
>>> out = model(input_tensor)
|
|
164
168
|
|
|
165
169
|
Args:
|
|
170
|
+
----
|
|
166
171
|
pretrained: boolean, True if model is pretrained
|
|
172
|
+
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
167
173
|
|
|
168
174
|
Returns:
|
|
175
|
+
-------
|
|
169
176
|
a torch.nn.Module
|
|
170
177
|
"""
|
|
171
178
|
return _mobilenet_v3(
|
|
@@ -188,9 +195,12 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3
|
|
|
188
195
|
>>> out = model(input_tensor)
|
|
189
196
|
|
|
190
197
|
Args:
|
|
198
|
+
----
|
|
191
199
|
pretrained: boolean, True if model is pretrained
|
|
200
|
+
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
192
201
|
|
|
193
202
|
Returns:
|
|
203
|
+
-------
|
|
194
204
|
a torch.nn.Module
|
|
195
205
|
"""
|
|
196
206
|
return _mobilenet_v3(
|
|
@@ -214,12 +224,14 @@ def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> m
|
|
|
214
224
|
>>> out = model(input_tensor)
|
|
215
225
|
|
|
216
226
|
Args:
|
|
227
|
+
----
|
|
217
228
|
pretrained: boolean, True if model is pretrained
|
|
229
|
+
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
218
230
|
|
|
219
231
|
Returns:
|
|
232
|
+
-------
|
|
220
233
|
a torch.nn.Module
|
|
221
234
|
"""
|
|
222
|
-
|
|
223
235
|
return _mobilenet_v3(
|
|
224
236
|
"mobilenet_v3_small_orientation",
|
|
225
237
|
pretrained,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -82,14 +82,12 @@ class SqueezeExcitation(Sequential):
|
|
|
82
82
|
"""Squeeze and Excitation."""
|
|
83
83
|
|
|
84
84
|
def __init__(self, chan: int, squeeze_factor: int = 4) -> None:
|
|
85
|
-
super().__init__(
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
]
|
|
92
|
-
)
|
|
85
|
+
super().__init__([
|
|
86
|
+
layers.GlobalAveragePooling2D(),
|
|
87
|
+
layers.Dense(chan // squeeze_factor, activation="relu"),
|
|
88
|
+
layers.Dense(chan, activation="hard_sigmoid"),
|
|
89
|
+
layers.Reshape((1, 1, chan)),
|
|
90
|
+
])
|
|
93
91
|
|
|
94
92
|
def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
|
95
93
|
x = super().call(inputs, **kwargs)
|
|
@@ -126,6 +124,7 @@ class InvertedResidual(layers.Layer):
|
|
|
126
124
|
"""InvertedResidual for mobilenet
|
|
127
125
|
|
|
128
126
|
Args:
|
|
127
|
+
----
|
|
129
128
|
conf: configuration object for inverted residual
|
|
130
129
|
"""
|
|
131
130
|
|
|
@@ -220,14 +219,12 @@ class MobileNetV3(Sequential):
|
|
|
220
219
|
)
|
|
221
220
|
|
|
222
221
|
if include_top:
|
|
223
|
-
_layers.extend(
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
]
|
|
230
|
-
)
|
|
222
|
+
_layers.extend([
|
|
223
|
+
layers.GlobalAveragePooling2D(),
|
|
224
|
+
layers.Dense(head_chans, activation=hard_swish),
|
|
225
|
+
layers.Dropout(0.2),
|
|
226
|
+
layers.Dense(num_classes),
|
|
227
|
+
])
|
|
231
228
|
|
|
232
229
|
super().__init__(_layers)
|
|
233
230
|
self.cfg = cfg
|
|
@@ -309,12 +306,14 @@ def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
|
|
|
309
306
|
>>> out = model(input_tensor)
|
|
310
307
|
|
|
311
308
|
Args:
|
|
309
|
+
----
|
|
312
310
|
pretrained: boolean, True if model is pretrained
|
|
311
|
+
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
313
312
|
|
|
314
313
|
Returns:
|
|
314
|
+
-------
|
|
315
315
|
a keras.Model
|
|
316
316
|
"""
|
|
317
|
-
|
|
318
317
|
return _mobilenet_v3("mobilenet_v3_small", pretrained, False, **kwargs)
|
|
319
318
|
|
|
320
319
|
|
|
@@ -330,12 +329,14 @@ def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3
|
|
|
330
329
|
>>> out = model(input_tensor)
|
|
331
330
|
|
|
332
331
|
Args:
|
|
332
|
+
----
|
|
333
333
|
pretrained: boolean, True if model is pretrained
|
|
334
|
+
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
334
335
|
|
|
335
336
|
Returns:
|
|
337
|
+
-------
|
|
336
338
|
a keras.Model
|
|
337
339
|
"""
|
|
338
|
-
|
|
339
340
|
return _mobilenet_v3("mobilenet_v3_small_r", pretrained, True, **kwargs)
|
|
340
341
|
|
|
341
342
|
|
|
@@ -351,9 +352,12 @@ def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
|
|
|
351
352
|
>>> out = model(input_tensor)
|
|
352
353
|
|
|
353
354
|
Args:
|
|
355
|
+
----
|
|
354
356
|
pretrained: boolean, True if model is pretrained
|
|
357
|
+
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
355
358
|
|
|
356
359
|
Returns:
|
|
360
|
+
-------
|
|
357
361
|
a keras.Model
|
|
358
362
|
"""
|
|
359
363
|
return _mobilenet_v3("mobilenet_v3_large", pretrained, False, **kwargs)
|
|
@@ -371,9 +375,12 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3
|
|
|
371
375
|
>>> out = model(input_tensor)
|
|
372
376
|
|
|
373
377
|
Args:
|
|
378
|
+
----
|
|
374
379
|
pretrained: boolean, True if model is pretrained
|
|
380
|
+
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
375
381
|
|
|
376
382
|
Returns:
|
|
383
|
+
-------
|
|
377
384
|
a keras.Model
|
|
378
385
|
"""
|
|
379
386
|
return _mobilenet_v3("mobilenet_v3_large_r", pretrained, True, **kwargs)
|
|
@@ -391,10 +398,12 @@ def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> M
|
|
|
391
398
|
>>> out = model(input_tensor)
|
|
392
399
|
|
|
393
400
|
Args:
|
|
401
|
+
----
|
|
394
402
|
pretrained: boolean, True if model is pretrained
|
|
403
|
+
**kwargs: keyword arguments of the MobileNetV3 architecture
|
|
395
404
|
|
|
396
405
|
Returns:
|
|
406
|
+
-------
|
|
397
407
|
a keras.Model
|
|
398
408
|
"""
|
|
399
|
-
|
|
400
409
|
return _mobilenet_v3("mobilenet_v3_small_orientation", pretrained, include_top=True, **kwargs)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -20,6 +20,7 @@ class CropOrientationPredictor(nn.Module):
|
|
|
20
20
|
4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
+
----
|
|
23
24
|
pre_processor: transform inputs for easier batched model inference
|
|
24
25
|
model: core classification architecture (backbone + classification head)
|
|
25
26
|
"""
|
|
@@ -33,7 +34,7 @@ class CropOrientationPredictor(nn.Module):
|
|
|
33
34
|
self.pre_processor = pre_processor
|
|
34
35
|
self.model = model.eval()
|
|
35
36
|
|
|
36
|
-
@torch.
|
|
37
|
+
@torch.inference_mode()
|
|
37
38
|
def forward(
|
|
38
39
|
self,
|
|
39
40
|
crops: List[Union[np.ndarray, torch.Tensor]],
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -20,6 +20,7 @@ class CropOrientationPredictor(NestedObject):
|
|
|
20
20
|
4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
+
----
|
|
23
24
|
pre_processor: transform inputs for easier batched model inference
|
|
24
25
|
model: core classification architecture (backbone + classification head)
|
|
25
26
|
"""
|