python-doctr 0.12.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/datasets/__init__.py +0 -5
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/vocabs.py +0 -2
- doctr/file_utils.py +2 -101
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/pytorch.py +1 -1
- doctr/models/_utils.py +3 -3
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +2 -2
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +1 -1
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +1 -1
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vip/__init__.py +1 -4
- doctr/models/classification/vip/layers/__init__.py +1 -4
- doctr/models/classification/vip/layers/pytorch.py +1 -1
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +2 -2
- doctr/models/classification/zoo.py +6 -11
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/core.py +1 -1
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +4 -12
- doctr/models/detection/differentiable_binarization/pytorch.py +3 -3
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +4 -14
- doctr/models/detection/fast/pytorch.py +4 -4
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +3 -12
- doctr/models/detection/linknet/pytorch.py +2 -2
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +1 -1
- doctr/models/detection/zoo.py +15 -32
- doctr/models/factory/hub.py +8 -21
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/pytorch.py +2 -6
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +3 -3
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +3 -8
- doctr/models/predictor/pytorch.py +2 -5
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +27 -32
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +6 -6
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/pytorch.py +5 -5
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/pytorch.py +5 -5
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +7 -16
- doctr/models/recognition/predictor/pytorch.py +1 -2
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +3 -3
- doctr/models/recognition/viptr/__init__.py +1 -4
- doctr/models/recognition/viptr/pytorch.py +3 -3
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/pytorch.py +3 -3
- doctr/models/recognition/zoo.py +13 -13
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +1 -1
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +26 -92
- doctr/transforms/modules/pytorch.py +28 -26
- doctr/utils/geometry.py +6 -10
- doctr/utils/visualization.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +18 -75
- python_doctr-1.0.0.dist-info/RECORD +149 -0
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -442
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -418
- doctr/models/classification/textnet/tensorflow.py +0 -275
- doctr/models/classification/vgg/tensorflow.py +0 -125
- doctr/models/classification/vit/tensorflow.py +0 -201
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
- doctr/models/detection/fast/tensorflow.py +0 -427
- doctr/models/detection/linknet/tensorflow.py +0 -377
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -317
- doctr/models/recognition/master/tensorflow.py +0 -320
- doctr/models/recognition/parseq/tensorflow.py +0 -516
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -423
- doctr/models/recognition/vitstr/tensorflow.py +0 -285
- doctr/models/utils/tensorflow.py +0 -189
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.12.0.dist-info/RECORD +0 -180
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
doctr/__init__.py
CHANGED
doctr/datasets/__init__.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from doctr.file_utils import is_tf_available
|
|
2
|
-
|
|
3
1
|
from .generator import *
|
|
4
2
|
from .coco_text import *
|
|
5
3
|
from .cord import *
|
|
@@ -22,6 +20,3 @@ from .synthtext import *
|
|
|
22
20
|
from .utils import *
|
|
23
21
|
from .vocabs import *
|
|
24
22
|
from .wildreceipt import *
|
|
25
|
-
|
|
26
|
-
if is_tf_available():
|
|
27
|
-
from .loader import *
|
|
@@ -50,9 +50,9 @@ class AbstractDataset(_AbstractDataset):
|
|
|
50
50
|
@staticmethod
|
|
51
51
|
def collate_fn(samples: list[tuple[torch.Tensor, Any]]) -> tuple[torch.Tensor, list[Any]]:
|
|
52
52
|
images, targets = zip(*samples)
|
|
53
|
-
images = torch.stack(images, dim=0)
|
|
53
|
+
images = torch.stack(images, dim=0) # type: ignore[assignment]
|
|
54
54
|
|
|
55
|
-
return images, list(targets)
|
|
55
|
+
return images, list(targets) # type: ignore[return-value]
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
|
doctr/datasets/vocabs.py
CHANGED
|
@@ -264,8 +264,6 @@ VOCABS["estonian"] = VOCABS["english"] + "šžõäöüŠŽÕÄÖÜ"
|
|
|
264
264
|
VOCABS["esperanto"] = re.sub(r"[QqWwXxYy]", "", VOCABS["english"]) + "ĉĝĥĵŝŭĈĜĤĴŜŬ" + "₷"
|
|
265
265
|
|
|
266
266
|
VOCABS["french"] = VOCABS["english"] + "àâéèêëîïôùûüçÀÂÉÈÊËÎÏÔÙÛÜÇ"
|
|
267
|
-
# NOTE: legacy french is outdated, but kept for compatibility
|
|
268
|
-
VOCABS["legacy_french"] = VOCABS["latin"] + "°" + "àâéèêëîïôùûçÀÂÉÈËÎÏÔÙÛÇ" + _BASE_VOCABS["currency"]
|
|
269
267
|
|
|
270
268
|
VOCABS["finnish"] = VOCABS["english"] + "äöÄÖ"
|
|
271
269
|
|
doctr/file_utils.py
CHANGED
|
@@ -3,102 +3,13 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
|
|
7
|
-
|
|
8
6
|
import importlib.metadata
|
|
9
|
-
import importlib.util
|
|
10
7
|
import logging
|
|
11
|
-
import os
|
|
12
|
-
|
|
13
|
-
CLASS_NAME: str = "words"
|
|
14
|
-
|
|
15
8
|
|
|
16
|
-
__all__ = ["
|
|
9
|
+
__all__ = ["requires_package", "CLASS_NAME"]
|
|
17
10
|
|
|
11
|
+
CLASS_NAME: str = "words"
|
|
18
12
|
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
|
19
|
-
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
|
20
|
-
|
|
21
|
-
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
|
22
|
-
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
|
26
|
-
_torch_available = importlib.util.find_spec("torch") is not None
|
|
27
|
-
if _torch_available:
|
|
28
|
-
try:
|
|
29
|
-
_torch_version = importlib.metadata.version("torch")
|
|
30
|
-
logging.info(f"PyTorch version {_torch_version} available.")
|
|
31
|
-
except importlib.metadata.PackageNotFoundError: # pragma: no cover
|
|
32
|
-
_torch_available = False
|
|
33
|
-
else: # pragma: no cover
|
|
34
|
-
logging.info("Disabling PyTorch because USE_TF is set")
|
|
35
|
-
_torch_available = False
|
|
36
|
-
|
|
37
|
-
# Compatibility fix to make sure tensorflow.keras stays at Keras 2
|
|
38
|
-
if "TF_USE_LEGACY_KERAS" not in os.environ:
|
|
39
|
-
os.environ["TF_USE_LEGACY_KERAS"] = "1"
|
|
40
|
-
|
|
41
|
-
elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
|
|
42
|
-
raise ValueError(
|
|
43
|
-
"docTR is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def ensure_keras_v2() -> None: # pragma: no cover
|
|
48
|
-
if not os.environ.get("TF_USE_LEGACY_KERAS") == "1":
|
|
49
|
-
os.environ["TF_USE_LEGACY_KERAS"] = "1"
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
|
53
|
-
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
|
54
|
-
if _tf_available:
|
|
55
|
-
candidates = (
|
|
56
|
-
"tensorflow",
|
|
57
|
-
"tensorflow-cpu",
|
|
58
|
-
"tensorflow-gpu",
|
|
59
|
-
"tf-nightly",
|
|
60
|
-
"tf-nightly-cpu",
|
|
61
|
-
"tf-nightly-gpu",
|
|
62
|
-
"intel-tensorflow",
|
|
63
|
-
"tensorflow-rocm",
|
|
64
|
-
"tensorflow-macos",
|
|
65
|
-
)
|
|
66
|
-
_tf_version = None
|
|
67
|
-
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
|
68
|
-
for pkg in candidates:
|
|
69
|
-
try:
|
|
70
|
-
_tf_version = importlib.metadata.version(pkg)
|
|
71
|
-
break
|
|
72
|
-
except importlib.metadata.PackageNotFoundError:
|
|
73
|
-
pass
|
|
74
|
-
_tf_available = _tf_version is not None
|
|
75
|
-
if _tf_available:
|
|
76
|
-
if int(_tf_version.split(".")[0]) < 2: # type: ignore[union-attr] # pragma: no cover
|
|
77
|
-
logging.info(f"TensorFlow found but with version {_tf_version}. DocTR requires version 2 minimum.")
|
|
78
|
-
_tf_available = False
|
|
79
|
-
else:
|
|
80
|
-
logging.info(f"TensorFlow version {_tf_version} available.")
|
|
81
|
-
ensure_keras_v2()
|
|
82
|
-
|
|
83
|
-
import warnings
|
|
84
|
-
|
|
85
|
-
warnings.simplefilter("always", DeprecationWarning)
|
|
86
|
-
warnings.warn(
|
|
87
|
-
"Support for TensorFlow in DocTR is deprecated and will be removed in the next major release (v1.0.0). "
|
|
88
|
-
"Please switch to the PyTorch backend.",
|
|
89
|
-
DeprecationWarning,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
else: # pragma: no cover
|
|
93
|
-
logging.info("Disabling Tensorflow because USE_TORCH is set")
|
|
94
|
-
_tf_available = False
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
if not _torch_available and not _tf_available: # pragma: no cover
|
|
98
|
-
raise ModuleNotFoundError(
|
|
99
|
-
"DocTR requires either TensorFlow or PyTorch to be installed. Please ensure one of them"
|
|
100
|
-
" is installed and that either USE_TF or USE_TORCH is enabled."
|
|
101
|
-
)
|
|
102
13
|
|
|
103
14
|
|
|
104
15
|
def requires_package(name: str, extra_message: str | None = None) -> None: # pragma: no cover
|
|
@@ -117,13 +28,3 @@ def requires_package(name: str, extra_message: str | None = None) -> None: # pr
|
|
|
117
28
|
f"\n\n{extra_message if extra_message is not None else ''} "
|
|
118
29
|
f"\nPlease install it with the following command: pip install {name}\n"
|
|
119
30
|
)
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
def is_torch_available():
|
|
123
|
-
"""Whether PyTorch is installed."""
|
|
124
|
-
return _torch_available
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
def is_tf_available():
|
|
128
|
-
"""Whether TensorFlow is installed."""
|
|
129
|
-
return _tf_available
|
doctr/io/image/__init__.py
CHANGED
doctr/io/image/pytorch.py
CHANGED
doctr/models/_utils.py
CHANGED
|
@@ -63,7 +63,7 @@ def estimate_orientation(
|
|
|
63
63
|
thresh = img.astype(np.uint8)
|
|
64
64
|
|
|
65
65
|
page_orientation, orientation_confidence = general_page_orientation or (None, 0.0)
|
|
66
|
-
if page_orientation and orientation_confidence >= min_confidence:
|
|
66
|
+
if page_orientation is not None and orientation_confidence >= min_confidence:
|
|
67
67
|
# We rotate the image to the general orientation which improves the detection
|
|
68
68
|
# No expand needed bitmap is already padded
|
|
69
69
|
thresh = rotate_image(thresh, -page_orientation)
|
|
@@ -100,7 +100,7 @@ def estimate_orientation(
|
|
|
100
100
|
estimated_angle = -round(median) if abs(median) != 0 else 0
|
|
101
101
|
|
|
102
102
|
# combine with the general orientation and the estimated angle
|
|
103
|
-
if page_orientation and orientation_confidence >= min_confidence:
|
|
103
|
+
if page_orientation is not None and orientation_confidence >= min_confidence:
|
|
104
104
|
# special case where the estimated angle is mostly wrong:
|
|
105
105
|
# case 1: - and + swapped
|
|
106
106
|
# case 2: estimated angle is completely wrong
|
|
@@ -184,7 +184,7 @@ def invert_data_structure(
|
|
|
184
184
|
dictionary of list when x is a list of dictionaries or a list of dictionaries when x is dictionary of lists
|
|
185
185
|
"""
|
|
186
186
|
if isinstance(x, dict):
|
|
187
|
-
assert len({len(v) for v in x.values()}) == 1, "All the lists in the
|
|
187
|
+
assert len({len(v) for v in x.values()}) == 1, "All the lists in the dictionary should have the same length."
|
|
188
188
|
return [dict(zip(x, t)) for t in zip(*x.values())]
|
|
189
189
|
elif isinstance(x, list):
|
|
190
190
|
return {k: [dic[k] for dic in x] for k in x[0]}
|
|
@@ -14,7 +14,7 @@ from torch import nn
|
|
|
14
14
|
|
|
15
15
|
from doctr.datasets import VOCABS
|
|
16
16
|
|
|
17
|
-
from ..resnet
|
|
17
|
+
from ..resnet import ResNet
|
|
18
18
|
|
|
19
19
|
__all__ = ["magc_resnet31"]
|
|
20
20
|
|
|
@@ -72,7 +72,7 @@ class MAGC(nn.Module):
|
|
|
72
72
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
73
73
|
batch, _, height, width = inputs.size()
|
|
74
74
|
# (N * headers, C / headers, H , W)
|
|
75
|
-
x = inputs.view(batch * self.headers, self.single_header_inplanes, height, width)
|
|
75
|
+
x = inputs.contiguous().view(batch * self.headers, self.single_header_inplanes, height, width)
|
|
76
76
|
shortcut = x
|
|
77
77
|
# (N * headers, C / headers, H * W)
|
|
78
78
|
shortcut = shortcut.view(batch * self.headers, self.single_header_inplanes, height * width)
|
|
@@ -35,7 +35,7 @@ class OrientationPredictor(nn.Module):
|
|
|
35
35
|
@torch.inference_mode()
|
|
36
36
|
def forward(
|
|
37
37
|
self,
|
|
38
|
-
inputs: list[np.ndarray
|
|
38
|
+
inputs: list[np.ndarray],
|
|
39
39
|
) -> list[list[int] | list[float]]:
|
|
40
40
|
# Dimension check
|
|
41
41
|
if any(input.ndim != 3 for input in inputs):
|
|
@@ -11,7 +11,7 @@ from torch import nn
|
|
|
11
11
|
|
|
12
12
|
from doctr.datasets import VOCABS
|
|
13
13
|
|
|
14
|
-
from ...modules.layers
|
|
14
|
+
from ...modules.layers import FASTConvLayer
|
|
15
15
|
from ...utils import conv_sequence_pt, load_pretrained_params
|
|
16
16
|
|
|
17
17
|
__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
|
|
@@ -433,7 +433,7 @@ class LePEAttention(nn.Module):
|
|
|
433
433
|
Returns:
|
|
434
434
|
A float tensor of shape (b, h, w, c).
|
|
435
435
|
"""
|
|
436
|
-
b_merged =
|
|
436
|
+
b_merged = img_splits_hw.shape[0] // ((h * w) // (h_sp * w_sp))
|
|
437
437
|
img = img_splits_hw.view(b_merged, h // h_sp, w // w_sp, h_sp, w_sp, -1)
|
|
438
438
|
# contiguous() required to ensure the tensor has a contiguous memory layout
|
|
439
439
|
# after permute, allowing the subsequent view operation to work correctly.
|
|
@@ -11,9 +11,9 @@ from torch import nn
|
|
|
11
11
|
|
|
12
12
|
from doctr.datasets import VOCABS
|
|
13
13
|
from doctr.models.modules.transformer import EncoderBlock
|
|
14
|
-
from doctr.models.modules.vision_transformer
|
|
14
|
+
from doctr.models.modules.vision_transformer import PatchEmbedding
|
|
15
15
|
|
|
16
|
-
from ...utils
|
|
16
|
+
from ...utils import load_pretrained_params
|
|
17
17
|
|
|
18
18
|
__all__ = ["vit_s", "vit_b"]
|
|
19
19
|
|
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
-
from doctr.
|
|
8
|
+
from doctr.models.utils import _CompiledModule
|
|
9
9
|
|
|
10
10
|
from .. import classification
|
|
11
11
|
from ..preprocessor import PreProcessor
|
|
@@ -30,11 +30,10 @@ ARCHS: list[str] = [
|
|
|
30
30
|
"vgg16_bn_r",
|
|
31
31
|
"vit_s",
|
|
32
32
|
"vit_b",
|
|
33
|
+
"vip_tiny",
|
|
34
|
+
"vip_base",
|
|
33
35
|
]
|
|
34
36
|
|
|
35
|
-
if is_torch_available():
|
|
36
|
-
ARCHS.extend(["vip_tiny", "vip_base"])
|
|
37
|
-
|
|
38
37
|
ORIENTATION_ARCHS: list[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
|
|
39
38
|
|
|
40
39
|
|
|
@@ -52,12 +51,8 @@ def _orientation_predictor(
|
|
|
52
51
|
# Load directly classifier from backbone
|
|
53
52
|
_model = classification.__dict__[arch](pretrained=pretrained)
|
|
54
53
|
else:
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
# Adding the type for torch compiled models to the allowed architectures
|
|
58
|
-
from doctr.models.utils import _CompiledModule
|
|
59
|
-
|
|
60
|
-
allowed_archs.append(_CompiledModule)
|
|
54
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
55
|
+
allowed_archs = [classification.MobileNetV3, _CompiledModule]
|
|
61
56
|
|
|
62
57
|
if not isinstance(arch, tuple(allowed_archs)):
|
|
63
58
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
@@ -66,7 +61,7 @@ def _orientation_predictor(
|
|
|
66
61
|
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
|
|
67
62
|
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
68
63
|
kwargs["batch_size"] = kwargs.get("batch_size", 128 if model_type == "crop" else 4)
|
|
69
|
-
input_shape = _model.cfg["input_shape"][
|
|
64
|
+
input_shape = _model.cfg["input_shape"][1:]
|
|
70
65
|
predictor = OrientationPredictor(
|
|
71
66
|
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
|
|
72
67
|
)
|
doctr/models/detection/core.py
CHANGED
|
@@ -53,7 +53,7 @@ class DetectionPostProcessor(NestedObject):
|
|
|
53
53
|
|
|
54
54
|
else:
|
|
55
55
|
mask: np.ndarray = np.zeros((h, w), np.int32)
|
|
56
|
-
cv2.fillPoly(mask, [points.astype(np.int32)], 1.0)
|
|
56
|
+
cv2.fillPoly(mask, [points.astype(np.int32)], 1.0)
|
|
57
57
|
product = pred * mask
|
|
58
58
|
return np.sum(product) / np.count_nonzero(product)
|
|
59
59
|
|
|
@@ -224,7 +224,7 @@ class _DBNet:
|
|
|
224
224
|
padded_polygon: np.ndarray = np.array(padding.Execute(distance)[0])
|
|
225
225
|
|
|
226
226
|
# Fill the mask with 1 on the new padded polygon
|
|
227
|
-
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
|
|
227
|
+
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
|
|
228
228
|
|
|
229
229
|
# Get min/max to recover polygon after distance computation
|
|
230
230
|
xmin = padded_polygon[:, 0].min()
|
|
@@ -269,7 +269,6 @@ class _DBNet:
|
|
|
269
269
|
self,
|
|
270
270
|
target: list[dict[str, np.ndarray]],
|
|
271
271
|
output_shape: tuple[int, int, int],
|
|
272
|
-
channels_last: bool = True,
|
|
273
272
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
274
273
|
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
|
|
275
274
|
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
|
|
@@ -280,10 +279,8 @@ class _DBNet:
|
|
|
280
279
|
|
|
281
280
|
h: int
|
|
282
281
|
w: int
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
else:
|
|
286
|
-
num_classes, h, w = output_shape
|
|
282
|
+
|
|
283
|
+
num_classes, h, w = output_shape
|
|
287
284
|
target_shape = (len(target), num_classes, h, w)
|
|
288
285
|
|
|
289
286
|
seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
|
|
@@ -343,17 +340,12 @@ class _DBNet:
|
|
|
343
340
|
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
|
|
344
341
|
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
|
|
345
342
|
continue
|
|
346
|
-
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
343
|
+
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
347
344
|
|
|
348
345
|
# Draw on both thresh map and thresh mask
|
|
349
346
|
poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map(
|
|
350
347
|
poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx]
|
|
351
348
|
)
|
|
352
|
-
if channels_last:
|
|
353
|
-
seg_target = seg_target.transpose((0, 2, 3, 1))
|
|
354
|
-
seg_mask = seg_mask.transpose((0, 2, 3, 1))
|
|
355
|
-
thresh_target = thresh_target.transpose((0, 2, 3, 1))
|
|
356
|
-
thresh_mask = thresh_mask.transpose((0, 2, 3, 1))
|
|
357
349
|
|
|
358
350
|
thresh_target = thresh_target.astype(input_dtype) * (self.thresh_max - self.thresh_min) + self.thresh_min
|
|
359
351
|
|
|
@@ -215,7 +215,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
215
215
|
|
|
216
216
|
if target is None or return_preds:
|
|
217
217
|
# Disable for torch.compile compatibility
|
|
218
|
-
@torch.compiler.disable
|
|
218
|
+
@torch.compiler.disable
|
|
219
219
|
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
220
220
|
return [
|
|
221
221
|
dict(zip(self.class_names, preds))
|
|
@@ -261,7 +261,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
261
261
|
prob_map = torch.sigmoid(out_map)
|
|
262
262
|
thresh_map = torch.sigmoid(thresh_map)
|
|
263
263
|
|
|
264
|
-
targets = self.build_target(target, out_map.shape[1:]
|
|
264
|
+
targets = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
|
|
265
265
|
|
|
266
266
|
seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
|
|
267
267
|
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
|
|
@@ -285,7 +285,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
285
285
|
dice_map = torch.softmax(out_map, dim=1)
|
|
286
286
|
else:
|
|
287
287
|
# compute binary map instead
|
|
288
|
-
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
|
|
288
|
+
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
|
|
289
289
|
# Class reduced
|
|
290
290
|
inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
|
|
291
291
|
cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
|
|
@@ -153,14 +153,12 @@ class _FAST(BaseModel):
|
|
|
153
153
|
self,
|
|
154
154
|
target: list[dict[str, np.ndarray]],
|
|
155
155
|
output_shape: tuple[int, int, int],
|
|
156
|
-
channels_last: bool = True,
|
|
157
156
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
158
157
|
"""Build the target, and it's mask to be used from loss computation.
|
|
159
158
|
|
|
160
159
|
Args:
|
|
161
160
|
target: target coming from dataset
|
|
162
161
|
output_shape: shape of the output of the model without batch_size
|
|
163
|
-
channels_last: whether channels are last or not
|
|
164
162
|
|
|
165
163
|
Returns:
|
|
166
164
|
the new formatted target, mask and shrunken text kernel
|
|
@@ -172,10 +170,8 @@ class _FAST(BaseModel):
|
|
|
172
170
|
|
|
173
171
|
h: int
|
|
174
172
|
w: int
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
else:
|
|
178
|
-
num_classes, h, w = output_shape
|
|
173
|
+
|
|
174
|
+
num_classes, h, w = output_shape
|
|
179
175
|
target_shape = (len(target), num_classes, h, w)
|
|
180
176
|
|
|
181
177
|
seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
|
|
@@ -235,14 +231,8 @@ class _FAST(BaseModel):
|
|
|
235
231
|
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
|
|
236
232
|
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
|
|
237
233
|
continue
|
|
238
|
-
cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
234
|
+
cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
239
235
|
# draw the original polygon on the segmentation target
|
|
240
|
-
cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0)
|
|
241
|
-
|
|
242
|
-
# Don't forget to switch back to channel last if Tensorflow is used
|
|
243
|
-
if channels_last:
|
|
244
|
-
seg_target = seg_target.transpose((0, 2, 3, 1))
|
|
245
|
-
seg_mask = seg_mask.transpose((0, 2, 3, 1))
|
|
246
|
-
shrunken_kernel = shrunken_kernel.transpose((0, 2, 3, 1))
|
|
236
|
+
cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0)
|
|
247
237
|
|
|
248
238
|
return seg_target, seg_mask, shrunken_kernel
|
|
@@ -206,7 +206,7 @@ class FAST(_FAST, nn.Module):
|
|
|
206
206
|
|
|
207
207
|
if target is None or return_preds:
|
|
208
208
|
# Disable for torch.compile compatibility
|
|
209
|
-
@torch.compiler.disable
|
|
209
|
+
@torch.compiler.disable
|
|
210
210
|
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
211
211
|
return [
|
|
212
212
|
dict(zip(self.class_names, preds))
|
|
@@ -238,7 +238,7 @@ class FAST(_FAST, nn.Module):
|
|
|
238
238
|
Returns:
|
|
239
239
|
A loss tensor
|
|
240
240
|
"""
|
|
241
|
-
targets = self.build_target(target, out_map.shape[1:]
|
|
241
|
+
targets = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
|
|
242
242
|
|
|
243
243
|
seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
|
|
244
244
|
shrunken_kernel = torch.from_numpy(targets[2]).to(out_map.device)
|
|
@@ -303,7 +303,7 @@ def reparameterize(model: FAST | nn.Module) -> FAST:
|
|
|
303
303
|
|
|
304
304
|
for module in model.modules():
|
|
305
305
|
if hasattr(module, "reparameterize_layer"):
|
|
306
|
-
module.reparameterize_layer()
|
|
306
|
+
module.reparameterize_layer() # type: ignore[operator]
|
|
307
307
|
|
|
308
308
|
for name, child in model.named_children():
|
|
309
309
|
if isinstance(child, nn.BatchNorm2d):
|
|
@@ -315,7 +315,7 @@ def reparameterize(model: FAST | nn.Module) -> FAST:
|
|
|
315
315
|
|
|
316
316
|
factor = child.weight / torch.sqrt(child.running_var + child.eps) # type: ignore
|
|
317
317
|
last_conv.weight = nn.Parameter(conv_w * factor.reshape([last_conv.out_channels, 1, 1, 1]))
|
|
318
|
-
last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias)
|
|
318
|
+
last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias) # type: ignore[operator]
|
|
319
319
|
model._modules[last_conv_name] = last_conv # type: ignore[index]
|
|
320
320
|
model._modules[name] = nn.Identity()
|
|
321
321
|
last_conv = None
|