python-doctr 0.12.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +0 -5
  3. doctr/datasets/datasets/__init__.py +1 -6
  4. doctr/datasets/datasets/pytorch.py +2 -2
  5. doctr/datasets/generator/__init__.py +1 -6
  6. doctr/datasets/vocabs.py +0 -2
  7. doctr/file_utils.py +2 -101
  8. doctr/io/image/__init__.py +1 -7
  9. doctr/io/image/pytorch.py +1 -1
  10. doctr/models/_utils.py +3 -3
  11. doctr/models/classification/magc_resnet/__init__.py +1 -6
  12. doctr/models/classification/magc_resnet/pytorch.py +2 -2
  13. doctr/models/classification/mobilenet/__init__.py +1 -6
  14. doctr/models/classification/predictor/__init__.py +1 -6
  15. doctr/models/classification/predictor/pytorch.py +1 -1
  16. doctr/models/classification/resnet/__init__.py +1 -6
  17. doctr/models/classification/textnet/__init__.py +1 -6
  18. doctr/models/classification/textnet/pytorch.py +1 -1
  19. doctr/models/classification/vgg/__init__.py +1 -6
  20. doctr/models/classification/vip/__init__.py +1 -4
  21. doctr/models/classification/vip/layers/__init__.py +1 -4
  22. doctr/models/classification/vip/layers/pytorch.py +1 -1
  23. doctr/models/classification/vit/__init__.py +1 -6
  24. doctr/models/classification/vit/pytorch.py +2 -2
  25. doctr/models/classification/zoo.py +6 -11
  26. doctr/models/detection/_utils/__init__.py +1 -6
  27. doctr/models/detection/core.py +1 -1
  28. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  29. doctr/models/detection/differentiable_binarization/base.py +4 -12
  30. doctr/models/detection/differentiable_binarization/pytorch.py +3 -3
  31. doctr/models/detection/fast/__init__.py +1 -6
  32. doctr/models/detection/fast/base.py +4 -14
  33. doctr/models/detection/fast/pytorch.py +4 -4
  34. doctr/models/detection/linknet/__init__.py +1 -6
  35. doctr/models/detection/linknet/base.py +3 -12
  36. doctr/models/detection/linknet/pytorch.py +2 -2
  37. doctr/models/detection/predictor/__init__.py +1 -6
  38. doctr/models/detection/predictor/pytorch.py +1 -1
  39. doctr/models/detection/zoo.py +15 -32
  40. doctr/models/factory/hub.py +8 -21
  41. doctr/models/kie_predictor/__init__.py +1 -6
  42. doctr/models/kie_predictor/pytorch.py +2 -6
  43. doctr/models/modules/layers/__init__.py +1 -6
  44. doctr/models/modules/layers/pytorch.py +3 -3
  45. doctr/models/modules/transformer/__init__.py +1 -6
  46. doctr/models/modules/transformer/pytorch.py +2 -2
  47. doctr/models/modules/vision_transformer/__init__.py +1 -6
  48. doctr/models/predictor/__init__.py +1 -6
  49. doctr/models/predictor/base.py +3 -8
  50. doctr/models/predictor/pytorch.py +2 -5
  51. doctr/models/preprocessor/__init__.py +1 -6
  52. doctr/models/preprocessor/pytorch.py +27 -32
  53. doctr/models/recognition/crnn/__init__.py +1 -6
  54. doctr/models/recognition/crnn/pytorch.py +6 -6
  55. doctr/models/recognition/master/__init__.py +1 -6
  56. doctr/models/recognition/master/pytorch.py +5 -5
  57. doctr/models/recognition/parseq/__init__.py +1 -6
  58. doctr/models/recognition/parseq/pytorch.py +5 -5
  59. doctr/models/recognition/predictor/__init__.py +1 -6
  60. doctr/models/recognition/predictor/_utils.py +7 -16
  61. doctr/models/recognition/predictor/pytorch.py +1 -2
  62. doctr/models/recognition/sar/__init__.py +1 -6
  63. doctr/models/recognition/sar/pytorch.py +3 -3
  64. doctr/models/recognition/viptr/__init__.py +1 -4
  65. doctr/models/recognition/viptr/pytorch.py +3 -3
  66. doctr/models/recognition/vitstr/__init__.py +1 -6
  67. doctr/models/recognition/vitstr/pytorch.py +3 -3
  68. doctr/models/recognition/zoo.py +13 -13
  69. doctr/models/utils/__init__.py +1 -6
  70. doctr/models/utils/pytorch.py +1 -1
  71. doctr/transforms/functional/__init__.py +1 -6
  72. doctr/transforms/functional/pytorch.py +4 -4
  73. doctr/transforms/modules/__init__.py +1 -7
  74. doctr/transforms/modules/base.py +26 -92
  75. doctr/transforms/modules/pytorch.py +28 -26
  76. doctr/utils/geometry.py +6 -10
  77. doctr/utils/visualization.py +1 -1
  78. doctr/version.py +1 -1
  79. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +18 -75
  80. python_doctr-1.0.0.dist-info/RECORD +149 -0
  81. doctr/datasets/datasets/tensorflow.py +0 -59
  82. doctr/datasets/generator/tensorflow.py +0 -58
  83. doctr/datasets/loader.py +0 -94
  84. doctr/io/image/tensorflow.py +0 -101
  85. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  86. doctr/models/classification/mobilenet/tensorflow.py +0 -442
  87. doctr/models/classification/predictor/tensorflow.py +0 -60
  88. doctr/models/classification/resnet/tensorflow.py +0 -418
  89. doctr/models/classification/textnet/tensorflow.py +0 -275
  90. doctr/models/classification/vgg/tensorflow.py +0 -125
  91. doctr/models/classification/vit/tensorflow.py +0 -201
  92. doctr/models/detection/_utils/tensorflow.py +0 -34
  93. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
  94. doctr/models/detection/fast/tensorflow.py +0 -427
  95. doctr/models/detection/linknet/tensorflow.py +0 -377
  96. doctr/models/detection/predictor/tensorflow.py +0 -70
  97. doctr/models/kie_predictor/tensorflow.py +0 -187
  98. doctr/models/modules/layers/tensorflow.py +0 -171
  99. doctr/models/modules/transformer/tensorflow.py +0 -235
  100. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  101. doctr/models/predictor/tensorflow.py +0 -155
  102. doctr/models/preprocessor/tensorflow.py +0 -122
  103. doctr/models/recognition/crnn/tensorflow.py +0 -317
  104. doctr/models/recognition/master/tensorflow.py +0 -320
  105. doctr/models/recognition/parseq/tensorflow.py +0 -516
  106. doctr/models/recognition/predictor/tensorflow.py +0 -79
  107. doctr/models/recognition/sar/tensorflow.py +0 -423
  108. doctr/models/recognition/vitstr/tensorflow.py +0 -285
  109. doctr/models/utils/tensorflow.py +0 -189
  110. doctr/transforms/functional/tensorflow.py +0 -254
  111. doctr/transforms/modules/tensorflow.py +0 -562
  112. python_doctr-0.12.0.dist-info/RECORD +0 -180
  113. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +0 -0
  114. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/licenses/LICENSE +0 -0
  115. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  116. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
doctr/__init__.py CHANGED
@@ -1,3 +1,2 @@
1
1
  from . import io, models, datasets, contrib, transforms, utils
2
- from .file_utils import is_tf_available, is_torch_available
3
2
  from .version import __version__ # noqa: F401
@@ -1,5 +1,3 @@
1
- from doctr.file_utils import is_tf_available
2
-
3
1
  from .generator import *
4
2
  from .coco_text import *
5
3
  from .cord import *
@@ -22,6 +20,3 @@ from .synthtext import *
22
20
  from .utils import *
23
21
  from .vocabs import *
24
22
  from .wildreceipt import *
25
-
26
- if is_tf_available():
27
- from .loader import *
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -50,9 +50,9 @@ class AbstractDataset(_AbstractDataset):
50
50
  @staticmethod
51
51
  def collate_fn(samples: list[tuple[torch.Tensor, Any]]) -> tuple[torch.Tensor, list[Any]]:
52
52
  images, targets = zip(*samples)
53
- images = torch.stack(images, dim=0)
53
+ images = torch.stack(images, dim=0) # type: ignore[assignment]
54
54
 
55
- return images, list(targets)
55
+ return images, list(targets) # type: ignore[return-value]
56
56
 
57
57
 
58
58
  class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
doctr/datasets/vocabs.py CHANGED
@@ -264,8 +264,6 @@ VOCABS["estonian"] = VOCABS["english"] + "šžõäöüŠŽÕÄÖÜ"
264
264
  VOCABS["esperanto"] = re.sub(r"[QqWwXxYy]", "", VOCABS["english"]) + "ĉĝĥĵŝŭĈĜĤĴŜŬ" + "₷"
265
265
 
266
266
  VOCABS["french"] = VOCABS["english"] + "àâéèêëîïôùûüçÀÂÉÈÊËÎÏÔÙÛÜÇ"
267
- # NOTE: legacy french is outdated, but kept for compatibility
268
- VOCABS["legacy_french"] = VOCABS["latin"] + "°" + "àâéèêëîïôùûçÀÂÉÈËÎÏÔÙÛÇ" + _BASE_VOCABS["currency"]
269
267
 
270
268
  VOCABS["finnish"] = VOCABS["english"] + "äöÄÖ"
271
269
 
doctr/file_utils.py CHANGED
@@ -3,102 +3,13 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- # Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
7
-
8
6
  import importlib.metadata
9
- import importlib.util
10
7
  import logging
11
- import os
12
-
13
- CLASS_NAME: str = "words"
14
-
15
8
 
16
- __all__ = ["is_tf_available", "is_torch_available", "requires_package", "CLASS_NAME"]
9
+ __all__ = ["requires_package", "CLASS_NAME"]
17
10
 
11
+ CLASS_NAME: str = "words"
18
12
  ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
19
- ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
20
-
21
- USE_TF = os.environ.get("USE_TF", "AUTO").upper()
22
- USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
23
-
24
-
25
- if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
26
- _torch_available = importlib.util.find_spec("torch") is not None
27
- if _torch_available:
28
- try:
29
- _torch_version = importlib.metadata.version("torch")
30
- logging.info(f"PyTorch version {_torch_version} available.")
31
- except importlib.metadata.PackageNotFoundError: # pragma: no cover
32
- _torch_available = False
33
- else: # pragma: no cover
34
- logging.info("Disabling PyTorch because USE_TF is set")
35
- _torch_available = False
36
-
37
- # Compatibility fix to make sure tensorflow.keras stays at Keras 2
38
- if "TF_USE_LEGACY_KERAS" not in os.environ:
39
- os.environ["TF_USE_LEGACY_KERAS"] = "1"
40
-
41
- elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
42
- raise ValueError(
43
- "docTR is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
44
- )
45
-
46
-
47
- def ensure_keras_v2() -> None: # pragma: no cover
48
- if not os.environ.get("TF_USE_LEGACY_KERAS") == "1":
49
- os.environ["TF_USE_LEGACY_KERAS"] = "1"
50
-
51
-
52
- if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
53
- _tf_available = importlib.util.find_spec("tensorflow") is not None
54
- if _tf_available:
55
- candidates = (
56
- "tensorflow",
57
- "tensorflow-cpu",
58
- "tensorflow-gpu",
59
- "tf-nightly",
60
- "tf-nightly-cpu",
61
- "tf-nightly-gpu",
62
- "intel-tensorflow",
63
- "tensorflow-rocm",
64
- "tensorflow-macos",
65
- )
66
- _tf_version = None
67
- # For the metadata, we have to look for both tensorflow and tensorflow-cpu
68
- for pkg in candidates:
69
- try:
70
- _tf_version = importlib.metadata.version(pkg)
71
- break
72
- except importlib.metadata.PackageNotFoundError:
73
- pass
74
- _tf_available = _tf_version is not None
75
- if _tf_available:
76
- if int(_tf_version.split(".")[0]) < 2: # type: ignore[union-attr] # pragma: no cover
77
- logging.info(f"TensorFlow found but with version {_tf_version}. DocTR requires version 2 minimum.")
78
- _tf_available = False
79
- else:
80
- logging.info(f"TensorFlow version {_tf_version} available.")
81
- ensure_keras_v2()
82
-
83
- import warnings
84
-
85
- warnings.simplefilter("always", DeprecationWarning)
86
- warnings.warn(
87
- "Support for TensorFlow in DocTR is deprecated and will be removed in the next major release (v1.0.0). "
88
- "Please switch to the PyTorch backend.",
89
- DeprecationWarning,
90
- )
91
-
92
- else: # pragma: no cover
93
- logging.info("Disabling Tensorflow because USE_TORCH is set")
94
- _tf_available = False
95
-
96
-
97
- if not _torch_available and not _tf_available: # pragma: no cover
98
- raise ModuleNotFoundError(
99
- "DocTR requires either TensorFlow or PyTorch to be installed. Please ensure one of them"
100
- " is installed and that either USE_TF or USE_TORCH is enabled."
101
- )
102
13
 
103
14
 
104
15
  def requires_package(name: str, extra_message: str | None = None) -> None: # pragma: no cover
@@ -117,13 +28,3 @@ def requires_package(name: str, extra_message: str | None = None) -> None: # pr
117
28
  f"\n\n{extra_message if extra_message is not None else ''} "
118
29
  f"\nPlease install it with the following command: pip install {name}\n"
119
30
  )
120
-
121
-
122
- def is_torch_available():
123
- """Whether PyTorch is installed."""
124
- return _torch_available
125
-
126
-
127
- def is_tf_available():
128
- """Whether TensorFlow is installed."""
129
- return _tf_available
@@ -1,8 +1,2 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
1
  from .base import *
4
-
5
- if is_torch_available():
6
- from .pytorch import *
7
- elif is_tf_available():
8
- from .tensorflow import *
2
+ from .pytorch import *
doctr/io/image/pytorch.py CHANGED
@@ -95,4 +95,4 @@ def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -
95
95
 
96
96
  def get_img_shape(img: torch.Tensor) -> tuple[int, int]:
97
97
  """Get the shape of an image"""
98
- return img.shape[-2:]
98
+ return img.shape[-2:] # type: ignore[return-value]
doctr/models/_utils.py CHANGED
@@ -63,7 +63,7 @@ def estimate_orientation(
63
63
  thresh = img.astype(np.uint8)
64
64
 
65
65
  page_orientation, orientation_confidence = general_page_orientation or (None, 0.0)
66
- if page_orientation and orientation_confidence >= min_confidence:
66
+ if page_orientation is not None and orientation_confidence >= min_confidence:
67
67
  # We rotate the image to the general orientation which improves the detection
68
68
  # No expand needed bitmap is already padded
69
69
  thresh = rotate_image(thresh, -page_orientation)
@@ -100,7 +100,7 @@ def estimate_orientation(
100
100
  estimated_angle = -round(median) if abs(median) != 0 else 0
101
101
 
102
102
  # combine with the general orientation and the estimated angle
103
- if page_orientation and orientation_confidence >= min_confidence:
103
+ if page_orientation is not None and orientation_confidence >= min_confidence:
104
104
  # special case where the estimated angle is mostly wrong:
105
105
  # case 1: - and + swapped
106
106
  # case 2: estimated angle is completely wrong
@@ -184,7 +184,7 @@ def invert_data_structure(
184
184
  dictionary of list when x is a list of dictionaries or a list of dictionaries when x is dictionary of lists
185
185
  """
186
186
  if isinstance(x, dict):
187
- assert len({len(v) for v in x.values()}) == 1, "All the lists in the dictionnary should have the same length."
187
+ assert len({len(v) for v in x.values()}) == 1, "All the lists in the dictionary should have the same length."
188
188
  return [dict(zip(x, t)) for t in zip(*x.values())]
189
189
  elif isinstance(x, list):
190
190
  return {k: [dic[k] for dic in x] for k in x[0]}
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -14,7 +14,7 @@ from torch import nn
14
14
 
15
15
  from doctr.datasets import VOCABS
16
16
 
17
- from ..resnet.pytorch import ResNet
17
+ from ..resnet import ResNet
18
18
 
19
19
  __all__ = ["magc_resnet31"]
20
20
 
@@ -72,7 +72,7 @@ class MAGC(nn.Module):
72
72
  def forward(self, inputs: torch.Tensor) -> torch.Tensor:
73
73
  batch, _, height, width = inputs.size()
74
74
  # (N * headers, C / headers, H , W)
75
- x = inputs.view(batch * self.headers, self.single_header_inplanes, height, width)
75
+ x = inputs.contiguous().view(batch * self.headers, self.single_header_inplanes, height, width)
76
76
  shortcut = x
77
77
  # (N * headers, C / headers, H * W)
78
78
  shortcut = shortcut.view(batch * self.headers, self.single_header_inplanes, height * width)
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -35,7 +35,7 @@ class OrientationPredictor(nn.Module):
35
35
  @torch.inference_mode()
36
36
  def forward(
37
37
  self,
38
- inputs: list[np.ndarray | torch.Tensor],
38
+ inputs: list[np.ndarray],
39
39
  ) -> list[list[int] | list[float]]:
40
40
  # Dimension check
41
41
  if any(input.ndim != 3 for input in inputs):
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -11,7 +11,7 @@ from torch import nn
11
11
 
12
12
  from doctr.datasets import VOCABS
13
13
 
14
- from ...modules.layers.pytorch import FASTConvLayer
14
+ from ...modules.layers import FASTConvLayer
15
15
  from ...utils import conv_sequence_pt, load_pretrained_params
16
16
 
17
17
  __all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -1,4 +1 @@
1
- from doctr.file_utils import is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
1
+ from .pytorch import *
@@ -1,4 +1 @@
1
- from doctr.file_utils import is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
1
+ from .pytorch import *
@@ -433,7 +433,7 @@ class LePEAttention(nn.Module):
433
433
  Returns:
434
434
  A float tensor of shape (b, h, w, c).
435
435
  """
436
- b_merged = int(img_splits_hw.shape[0] / (h * w / h_sp / w_sp))
436
+ b_merged = img_splits_hw.shape[0] // ((h * w) // (h_sp * w_sp))
437
437
  img = img_splits_hw.view(b_merged, h // h_sp, w // w_sp, h_sp, w_sp, -1)
438
438
  # contiguous() required to ensure the tensor has a contiguous memory layout
439
439
  # after permute, allowing the subsequent view operation to work correctly.
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -11,9 +11,9 @@ from torch import nn
11
11
 
12
12
  from doctr.datasets import VOCABS
13
13
  from doctr.models.modules.transformer import EncoderBlock
14
- from doctr.models.modules.vision_transformer.pytorch import PatchEmbedding
14
+ from doctr.models.modules.vision_transformer import PatchEmbedding
15
15
 
16
- from ...utils.pytorch import load_pretrained_params
16
+ from ...utils import load_pretrained_params
17
17
 
18
18
  __all__ = ["vit_s", "vit_b"]
19
19
 
@@ -5,7 +5,7 @@
5
5
 
6
6
  from typing import Any
7
7
 
8
- from doctr.file_utils import is_tf_available, is_torch_available
8
+ from doctr.models.utils import _CompiledModule
9
9
 
10
10
  from .. import classification
11
11
  from ..preprocessor import PreProcessor
@@ -30,11 +30,10 @@ ARCHS: list[str] = [
30
30
  "vgg16_bn_r",
31
31
  "vit_s",
32
32
  "vit_b",
33
+ "vip_tiny",
34
+ "vip_base",
33
35
  ]
34
36
 
35
- if is_torch_available():
36
- ARCHS.extend(["vip_tiny", "vip_base"])
37
-
38
37
  ORIENTATION_ARCHS: list[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
39
38
 
40
39
 
@@ -52,12 +51,8 @@ def _orientation_predictor(
52
51
  # Load directly classifier from backbone
53
52
  _model = classification.__dict__[arch](pretrained=pretrained)
54
53
  else:
55
- allowed_archs = [classification.MobileNetV3]
56
- if is_torch_available():
57
- # Adding the type for torch compiled models to the allowed architectures
58
- from doctr.models.utils import _CompiledModule
59
-
60
- allowed_archs.append(_CompiledModule)
54
+ # Adding the type for torch compiled models to the allowed architectures
55
+ allowed_archs = [classification.MobileNetV3, _CompiledModule]
61
56
 
62
57
  if not isinstance(arch, tuple(allowed_archs)):
63
58
  raise ValueError(f"unknown architecture: {type(arch)}")
@@ -66,7 +61,7 @@ def _orientation_predictor(
66
61
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
67
62
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
68
63
  kwargs["batch_size"] = kwargs.get("batch_size", 128 if model_type == "crop" else 4)
69
- input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]
64
+ input_shape = _model.cfg["input_shape"][1:]
70
65
  predictor = OrientationPredictor(
71
66
  PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
72
67
  )
@@ -1,7 +1,2 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
1
  from .base import *
3
-
4
- if is_torch_available():
5
- from .pytorch import *
6
- elif is_tf_available():
7
- from .tensorflow import *
2
+ from .pytorch import *
@@ -53,7 +53,7 @@ class DetectionPostProcessor(NestedObject):
53
53
 
54
54
  else:
55
55
  mask: np.ndarray = np.zeros((h, w), np.int32)
56
- cv2.fillPoly(mask, [points.astype(np.int32)], 1.0) # type: ignore[call-overload]
56
+ cv2.fillPoly(mask, [points.astype(np.int32)], 1.0)
57
57
  product = pred * mask
58
58
  return np.sum(product) / np.count_nonzero(product)
59
59
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -224,7 +224,7 @@ class _DBNet:
224
224
  padded_polygon: np.ndarray = np.array(padding.Execute(distance)[0])
225
225
 
226
226
  # Fill the mask with 1 on the new padded polygon
227
- cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) # type: ignore[call-overload]
227
+ cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
228
228
 
229
229
  # Get min/max to recover polygon after distance computation
230
230
  xmin = padded_polygon[:, 0].min()
@@ -269,7 +269,6 @@ class _DBNet:
269
269
  self,
270
270
  target: list[dict[str, np.ndarray]],
271
271
  output_shape: tuple[int, int, int],
272
- channels_last: bool = True,
273
272
  ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
274
273
  if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
275
274
  raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
@@ -280,10 +279,8 @@ class _DBNet:
280
279
 
281
280
  h: int
282
281
  w: int
283
- if channels_last:
284
- h, w, num_classes = output_shape
285
- else:
286
- num_classes, h, w = output_shape
282
+
283
+ num_classes, h, w = output_shape
287
284
  target_shape = (len(target), num_classes, h, w)
288
285
 
289
286
  seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
@@ -343,17 +340,12 @@ class _DBNet:
343
340
  if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
344
341
  seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
345
342
  continue
346
- cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
343
+ cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
347
344
 
348
345
  # Draw on both thresh map and thresh mask
349
346
  poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map(
350
347
  poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx]
351
348
  )
352
- if channels_last:
353
- seg_target = seg_target.transpose((0, 2, 3, 1))
354
- seg_mask = seg_mask.transpose((0, 2, 3, 1))
355
- thresh_target = thresh_target.transpose((0, 2, 3, 1))
356
- thresh_mask = thresh_mask.transpose((0, 2, 3, 1))
357
349
 
358
350
  thresh_target = thresh_target.astype(input_dtype) * (self.thresh_max - self.thresh_min) + self.thresh_min
359
351
 
@@ -215,7 +215,7 @@ class DBNet(_DBNet, nn.Module):
215
215
 
216
216
  if target is None or return_preds:
217
217
  # Disable for torch.compile compatibility
218
- @torch.compiler.disable # type: ignore[attr-defined]
218
+ @torch.compiler.disable
219
219
  def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
220
220
  return [
221
221
  dict(zip(self.class_names, preds))
@@ -261,7 +261,7 @@ class DBNet(_DBNet, nn.Module):
261
261
  prob_map = torch.sigmoid(out_map)
262
262
  thresh_map = torch.sigmoid(thresh_map)
263
263
 
264
- targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
264
+ targets = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
265
265
 
266
266
  seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
267
267
  seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
@@ -285,7 +285,7 @@ class DBNet(_DBNet, nn.Module):
285
285
  dice_map = torch.softmax(out_map, dim=1)
286
286
  else:
287
287
  # compute binary map instead
288
- dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment]
288
+ dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
289
289
  # Class reduced
290
290
  inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
291
291
  cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -153,14 +153,12 @@ class _FAST(BaseModel):
153
153
  self,
154
154
  target: list[dict[str, np.ndarray]],
155
155
  output_shape: tuple[int, int, int],
156
- channels_last: bool = True,
157
156
  ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
158
157
  """Build the target, and it's mask to be used from loss computation.
159
158
 
160
159
  Args:
161
160
  target: target coming from dataset
162
161
  output_shape: shape of the output of the model without batch_size
163
- channels_last: whether channels are last or not
164
162
 
165
163
  Returns:
166
164
  the new formatted target, mask and shrunken text kernel
@@ -172,10 +170,8 @@ class _FAST(BaseModel):
172
170
 
173
171
  h: int
174
172
  w: int
175
- if channels_last:
176
- h, w, num_classes = output_shape
177
- else:
178
- num_classes, h, w = output_shape
173
+
174
+ num_classes, h, w = output_shape
179
175
  target_shape = (len(target), num_classes, h, w)
180
176
 
181
177
  seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
@@ -235,14 +231,8 @@ class _FAST(BaseModel):
235
231
  if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
236
232
  seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
237
233
  continue
238
- cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
234
+ cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
239
235
  # draw the original polygon on the segmentation target
240
- cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0) # type: ignore[call-overload]
241
-
242
- # Don't forget to switch back to channel last if Tensorflow is used
243
- if channels_last:
244
- seg_target = seg_target.transpose((0, 2, 3, 1))
245
- seg_mask = seg_mask.transpose((0, 2, 3, 1))
246
- shrunken_kernel = shrunken_kernel.transpose((0, 2, 3, 1))
236
+ cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0)
247
237
 
248
238
  return seg_target, seg_mask, shrunken_kernel
@@ -206,7 +206,7 @@ class FAST(_FAST, nn.Module):
206
206
 
207
207
  if target is None or return_preds:
208
208
  # Disable for torch.compile compatibility
209
- @torch.compiler.disable # type: ignore[attr-defined]
209
+ @torch.compiler.disable
210
210
  def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
211
211
  return [
212
212
  dict(zip(self.class_names, preds))
@@ -238,7 +238,7 @@ class FAST(_FAST, nn.Module):
238
238
  Returns:
239
239
  A loss tensor
240
240
  """
241
- targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
241
+ targets = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
242
242
 
243
243
  seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
244
244
  shrunken_kernel = torch.from_numpy(targets[2]).to(out_map.device)
@@ -303,7 +303,7 @@ def reparameterize(model: FAST | nn.Module) -> FAST:
303
303
 
304
304
  for module in model.modules():
305
305
  if hasattr(module, "reparameterize_layer"):
306
- module.reparameterize_layer()
306
+ module.reparameterize_layer() # type: ignore[operator]
307
307
 
308
308
  for name, child in model.named_children():
309
309
  if isinstance(child, nn.BatchNorm2d):
@@ -315,7 +315,7 @@ def reparameterize(model: FAST | nn.Module) -> FAST:
315
315
 
316
316
  factor = child.weight / torch.sqrt(child.running_var + child.eps) # type: ignore
317
317
  last_conv.weight = nn.Parameter(conv_w * factor.reshape([last_conv.out_channels, 1, 1, 1]))
318
- last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias)
318
+ last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias) # type: ignore[operator]
319
319
  model._modules[last_conv_name] = last_conv # type: ignore[index]
320
320
  model._modules[name] = nn.Identity()
321
321
  last_conv = None
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *