python-doctr 0.12.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/contrib/artefacts.py +1 -1
  3. doctr/contrib/base.py +1 -1
  4. doctr/datasets/__init__.py +0 -5
  5. doctr/datasets/coco_text.py +1 -1
  6. doctr/datasets/cord.py +1 -1
  7. doctr/datasets/datasets/__init__.py +1 -6
  8. doctr/datasets/datasets/base.py +1 -1
  9. doctr/datasets/datasets/pytorch.py +3 -3
  10. doctr/datasets/detection.py +1 -1
  11. doctr/datasets/doc_artefacts.py +1 -1
  12. doctr/datasets/funsd.py +1 -1
  13. doctr/datasets/generator/__init__.py +1 -6
  14. doctr/datasets/generator/base.py +1 -1
  15. doctr/datasets/generator/pytorch.py +1 -1
  16. doctr/datasets/ic03.py +1 -1
  17. doctr/datasets/ic13.py +1 -1
  18. doctr/datasets/iiit5k.py +1 -1
  19. doctr/datasets/iiithws.py +1 -1
  20. doctr/datasets/imgur5k.py +1 -1
  21. doctr/datasets/mjsynth.py +1 -1
  22. doctr/datasets/ocr.py +1 -1
  23. doctr/datasets/orientation.py +1 -1
  24. doctr/datasets/recognition.py +1 -1
  25. doctr/datasets/sroie.py +1 -1
  26. doctr/datasets/svhn.py +1 -1
  27. doctr/datasets/svt.py +1 -1
  28. doctr/datasets/synthtext.py +1 -1
  29. doctr/datasets/utils.py +1 -1
  30. doctr/datasets/vocabs.py +1 -3
  31. doctr/datasets/wildreceipt.py +1 -1
  32. doctr/file_utils.py +3 -102
  33. doctr/io/elements.py +1 -1
  34. doctr/io/html.py +1 -1
  35. doctr/io/image/__init__.py +1 -7
  36. doctr/io/image/base.py +1 -1
  37. doctr/io/image/pytorch.py +2 -2
  38. doctr/io/pdf.py +1 -1
  39. doctr/io/reader.py +1 -1
  40. doctr/models/_utils.py +56 -18
  41. doctr/models/builder.py +1 -1
  42. doctr/models/classification/magc_resnet/__init__.py +1 -6
  43. doctr/models/classification/magc_resnet/pytorch.py +3 -3
  44. doctr/models/classification/mobilenet/__init__.py +1 -6
  45. doctr/models/classification/mobilenet/pytorch.py +1 -1
  46. doctr/models/classification/predictor/__init__.py +1 -6
  47. doctr/models/classification/predictor/pytorch.py +2 -2
  48. doctr/models/classification/resnet/__init__.py +1 -6
  49. doctr/models/classification/resnet/pytorch.py +1 -1
  50. doctr/models/classification/textnet/__init__.py +1 -6
  51. doctr/models/classification/textnet/pytorch.py +2 -2
  52. doctr/models/classification/vgg/__init__.py +1 -6
  53. doctr/models/classification/vgg/pytorch.py +1 -1
  54. doctr/models/classification/vip/__init__.py +1 -4
  55. doctr/models/classification/vip/layers/__init__.py +1 -4
  56. doctr/models/classification/vip/layers/pytorch.py +2 -2
  57. doctr/models/classification/vip/pytorch.py +1 -1
  58. doctr/models/classification/vit/__init__.py +1 -6
  59. doctr/models/classification/vit/pytorch.py +3 -3
  60. doctr/models/classification/zoo.py +7 -12
  61. doctr/models/core.py +1 -1
  62. doctr/models/detection/_utils/__init__.py +1 -6
  63. doctr/models/detection/_utils/base.py +1 -1
  64. doctr/models/detection/_utils/pytorch.py +1 -1
  65. doctr/models/detection/core.py +2 -2
  66. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  67. doctr/models/detection/differentiable_binarization/base.py +5 -13
  68. doctr/models/detection/differentiable_binarization/pytorch.py +4 -4
  69. doctr/models/detection/fast/__init__.py +1 -6
  70. doctr/models/detection/fast/base.py +5 -15
  71. doctr/models/detection/fast/pytorch.py +5 -5
  72. doctr/models/detection/linknet/__init__.py +1 -6
  73. doctr/models/detection/linknet/base.py +4 -13
  74. doctr/models/detection/linknet/pytorch.py +3 -3
  75. doctr/models/detection/predictor/__init__.py +1 -6
  76. doctr/models/detection/predictor/pytorch.py +2 -2
  77. doctr/models/detection/zoo.py +16 -33
  78. doctr/models/factory/hub.py +26 -34
  79. doctr/models/kie_predictor/__init__.py +1 -6
  80. doctr/models/kie_predictor/base.py +1 -1
  81. doctr/models/kie_predictor/pytorch.py +3 -7
  82. doctr/models/modules/layers/__init__.py +1 -6
  83. doctr/models/modules/layers/pytorch.py +4 -4
  84. doctr/models/modules/transformer/__init__.py +1 -6
  85. doctr/models/modules/transformer/pytorch.py +3 -3
  86. doctr/models/modules/vision_transformer/__init__.py +1 -6
  87. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  88. doctr/models/predictor/__init__.py +1 -6
  89. doctr/models/predictor/base.py +4 -9
  90. doctr/models/predictor/pytorch.py +3 -6
  91. doctr/models/preprocessor/__init__.py +1 -6
  92. doctr/models/preprocessor/pytorch.py +28 -33
  93. doctr/models/recognition/core.py +1 -1
  94. doctr/models/recognition/crnn/__init__.py +1 -6
  95. doctr/models/recognition/crnn/pytorch.py +7 -7
  96. doctr/models/recognition/master/__init__.py +1 -6
  97. doctr/models/recognition/master/base.py +1 -1
  98. doctr/models/recognition/master/pytorch.py +6 -6
  99. doctr/models/recognition/parseq/__init__.py +1 -6
  100. doctr/models/recognition/parseq/base.py +1 -1
  101. doctr/models/recognition/parseq/pytorch.py +6 -6
  102. doctr/models/recognition/predictor/__init__.py +1 -6
  103. doctr/models/recognition/predictor/_utils.py +8 -17
  104. doctr/models/recognition/predictor/pytorch.py +2 -3
  105. doctr/models/recognition/sar/__init__.py +1 -6
  106. doctr/models/recognition/sar/pytorch.py +4 -4
  107. doctr/models/recognition/utils.py +1 -1
  108. doctr/models/recognition/viptr/__init__.py +1 -4
  109. doctr/models/recognition/viptr/pytorch.py +4 -4
  110. doctr/models/recognition/vitstr/__init__.py +1 -6
  111. doctr/models/recognition/vitstr/base.py +1 -1
  112. doctr/models/recognition/vitstr/pytorch.py +4 -4
  113. doctr/models/recognition/zoo.py +14 -14
  114. doctr/models/utils/__init__.py +1 -6
  115. doctr/models/utils/pytorch.py +3 -2
  116. doctr/models/zoo.py +1 -1
  117. doctr/transforms/functional/__init__.py +1 -6
  118. doctr/transforms/functional/base.py +3 -2
  119. doctr/transforms/functional/pytorch.py +5 -5
  120. doctr/transforms/modules/__init__.py +1 -7
  121. doctr/transforms/modules/base.py +28 -94
  122. doctr/transforms/modules/pytorch.py +29 -27
  123. doctr/utils/common_types.py +1 -1
  124. doctr/utils/data.py +1 -2
  125. doctr/utils/fonts.py +1 -1
  126. doctr/utils/geometry.py +7 -11
  127. doctr/utils/metrics.py +1 -1
  128. doctr/utils/multithreading.py +1 -1
  129. doctr/utils/reconstitution.py +1 -1
  130. doctr/utils/repr.py +1 -1
  131. doctr/utils/visualization.py +2 -2
  132. doctr/version.py +1 -1
  133. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/METADATA +30 -80
  134. python_doctr-1.0.1.dist-info/RECORD +149 -0
  135. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/WHEEL +1 -1
  136. doctr/datasets/datasets/tensorflow.py +0 -59
  137. doctr/datasets/generator/tensorflow.py +0 -58
  138. doctr/datasets/loader.py +0 -94
  139. doctr/io/image/tensorflow.py +0 -101
  140. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  141. doctr/models/classification/mobilenet/tensorflow.py +0 -442
  142. doctr/models/classification/predictor/tensorflow.py +0 -60
  143. doctr/models/classification/resnet/tensorflow.py +0 -418
  144. doctr/models/classification/textnet/tensorflow.py +0 -275
  145. doctr/models/classification/vgg/tensorflow.py +0 -125
  146. doctr/models/classification/vit/tensorflow.py +0 -201
  147. doctr/models/detection/_utils/tensorflow.py +0 -34
  148. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
  149. doctr/models/detection/fast/tensorflow.py +0 -427
  150. doctr/models/detection/linknet/tensorflow.py +0 -377
  151. doctr/models/detection/predictor/tensorflow.py +0 -70
  152. doctr/models/kie_predictor/tensorflow.py +0 -187
  153. doctr/models/modules/layers/tensorflow.py +0 -171
  154. doctr/models/modules/transformer/tensorflow.py +0 -235
  155. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  156. doctr/models/predictor/tensorflow.py +0 -155
  157. doctr/models/preprocessor/tensorflow.py +0 -122
  158. doctr/models/recognition/crnn/tensorflow.py +0 -317
  159. doctr/models/recognition/master/tensorflow.py +0 -320
  160. doctr/models/recognition/parseq/tensorflow.py +0 -516
  161. doctr/models/recognition/predictor/tensorflow.py +0 -79
  162. doctr/models/recognition/sar/tensorflow.py +0 -423
  163. doctr/models/recognition/vitstr/tensorflow.py +0 -285
  164. doctr/models/utils/tensorflow.py +0 -189
  165. doctr/transforms/functional/tensorflow.py +0 -254
  166. doctr/transforms/modules/tensorflow.py +0 -562
  167. python_doctr-0.12.0.dist-info/RECORD +0 -180
  168. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/licenses/LICENSE +0 -0
  169. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/top_level.txt +0 -0
  170. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/zip-safe +0 -0
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -11,7 +11,7 @@ from torch import nn
11
11
 
12
12
  from doctr.datasets import VOCABS
13
13
 
14
- from ...modules.layers.pytorch import FASTConvLayer
14
+ from ...modules.layers import FASTConvLayer
15
15
  from ...utils import conv_sequence_pt, load_pretrained_params
16
16
 
17
17
  __all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -1,4 +1 @@
1
- from doctr.file_utils import is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
1
+ from .pytorch import *
@@ -1,4 +1 @@
1
- from doctr.file_utils import is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -433,7 +433,7 @@ class LePEAttention(nn.Module):
433
433
  Returns:
434
434
  A float tensor of shape (b, h, w, c).
435
435
  """
436
- b_merged = int(img_splits_hw.shape[0] / (h * w / h_sp / w_sp))
436
+ b_merged = img_splits_hw.shape[0] // ((h * w) // (h_sp * w_sp))
437
437
  img = img_splits_hw.view(b_merged, h // h_sp, w // w_sp, h_sp, w_sp, -1)
438
438
  # contiguous() required to ensure the tensor has a contiguous memory layout
439
439
  # after permute, allowing the subsequent view operation to work correctly.
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -11,9 +11,9 @@ from torch import nn
11
11
 
12
12
  from doctr.datasets import VOCABS
13
13
  from doctr.models.modules.transformer import EncoderBlock
14
- from doctr.models.modules.vision_transformer.pytorch import PatchEmbedding
14
+ from doctr.models.modules.vision_transformer import PatchEmbedding
15
15
 
16
- from ...utils.pytorch import load_pretrained_params
16
+ from ...utils import load_pretrained_params
17
17
 
18
18
  __all__ = ["vit_s", "vit_b"]
19
19
 
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from typing import Any
7
7
 
8
- from doctr.file_utils import is_tf_available, is_torch_available
8
+ from doctr.models.utils import _CompiledModule
9
9
 
10
10
  from .. import classification
11
11
  from ..preprocessor import PreProcessor
@@ -30,11 +30,10 @@ ARCHS: list[str] = [
30
30
  "vgg16_bn_r",
31
31
  "vit_s",
32
32
  "vit_b",
33
+ "vip_tiny",
34
+ "vip_base",
33
35
  ]
34
36
 
35
- if is_torch_available():
36
- ARCHS.extend(["vip_tiny", "vip_base"])
37
-
38
37
  ORIENTATION_ARCHS: list[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
39
38
 
40
39
 
@@ -52,12 +51,8 @@ def _orientation_predictor(
52
51
  # Load directly classifier from backbone
53
52
  _model = classification.__dict__[arch](pretrained=pretrained)
54
53
  else:
55
- allowed_archs = [classification.MobileNetV3]
56
- if is_torch_available():
57
- # Adding the type for torch compiled models to the allowed architectures
58
- from doctr.models.utils import _CompiledModule
59
-
60
- allowed_archs.append(_CompiledModule)
54
+ # Adding the type for torch compiled models to the allowed architectures
55
+ allowed_archs = [classification.MobileNetV3, _CompiledModule]
61
56
 
62
57
  if not isinstance(arch, tuple(allowed_archs)):
63
58
  raise ValueError(f"unknown architecture: {type(arch)}")
@@ -66,7 +61,7 @@ def _orientation_predictor(
66
61
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
67
62
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
68
63
  kwargs["batch_size"] = kwargs.get("batch_size", 128 if model_type == "crop" else 4)
69
- input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]
64
+ input_shape = _model.cfg["input_shape"][1:]
70
65
  predictor = OrientationPredictor(
71
66
  PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
72
67
  )
doctr/models/core.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -1,7 +1,2 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
1
  from .base import *
3
-
4
- if is_torch_available():
5
- from .pytorch import *
6
- elif is_tf_available():
7
- from .tensorflow import *
2
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -53,7 +53,7 @@ class DetectionPostProcessor(NestedObject):
53
53
 
54
54
  else:
55
55
  mask: np.ndarray = np.zeros((h, w), np.int32)
56
- cv2.fillPoly(mask, [points.astype(np.int32)], 1.0) # type: ignore[call-overload]
56
+ cv2.fillPoly(mask, [points.astype(np.int32)], 1.0)
57
57
  product = pred * mask
58
58
  return np.sum(product) / np.count_nonzero(product)
59
59
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -224,7 +224,7 @@ class _DBNet:
224
224
  padded_polygon: np.ndarray = np.array(padding.Execute(distance)[0])
225
225
 
226
226
  # Fill the mask with 1 on the new padded polygon
227
- cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) # type: ignore[call-overload]
227
+ cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
228
228
 
229
229
  # Get min/max to recover polygon after distance computation
230
230
  xmin = padded_polygon[:, 0].min()
@@ -269,7 +269,6 @@ class _DBNet:
269
269
  self,
270
270
  target: list[dict[str, np.ndarray]],
271
271
  output_shape: tuple[int, int, int],
272
- channels_last: bool = True,
273
272
  ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
274
273
  if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
275
274
  raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
@@ -280,10 +279,8 @@ class _DBNet:
280
279
 
281
280
  h: int
282
281
  w: int
283
- if channels_last:
284
- h, w, num_classes = output_shape
285
- else:
286
- num_classes, h, w = output_shape
282
+
283
+ num_classes, h, w = output_shape
287
284
  target_shape = (len(target), num_classes, h, w)
288
285
 
289
286
  seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
@@ -343,17 +340,12 @@ class _DBNet:
343
340
  if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
344
341
  seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
345
342
  continue
346
- cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
343
+ cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
347
344
 
348
345
  # Draw on both thresh map and thresh mask
349
346
  poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map(
350
347
  poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx]
351
348
  )
352
- if channels_last:
353
- seg_target = seg_target.transpose((0, 2, 3, 1))
354
- seg_mask = seg_mask.transpose((0, 2, 3, 1))
355
- thresh_target = thresh_target.transpose((0, 2, 3, 1))
356
- thresh_mask = thresh_mask.transpose((0, 2, 3, 1))
357
349
 
358
350
  thresh_target = thresh_target.astype(input_dtype) * (self.thresh_max - self.thresh_min) + self.thresh_min
359
351
 
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -215,7 +215,7 @@ class DBNet(_DBNet, nn.Module):
215
215
 
216
216
  if target is None or return_preds:
217
217
  # Disable for torch.compile compatibility
218
- @torch.compiler.disable # type: ignore[attr-defined]
218
+ @torch.compiler.disable
219
219
  def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
220
220
  return [
221
221
  dict(zip(self.class_names, preds))
@@ -261,7 +261,7 @@ class DBNet(_DBNet, nn.Module):
261
261
  prob_map = torch.sigmoid(out_map)
262
262
  thresh_map = torch.sigmoid(thresh_map)
263
263
 
264
- targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
264
+ targets = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
265
265
 
266
266
  seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
267
267
  seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
@@ -285,7 +285,7 @@ class DBNet(_DBNet, nn.Module):
285
285
  dice_map = torch.softmax(out_map, dim=1)
286
286
  else:
287
287
  # compute binary map instead
288
- dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment]
288
+ dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
289
289
  # Class reduced
290
290
  inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
291
291
  cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -153,14 +153,12 @@ class _FAST(BaseModel):
153
153
  self,
154
154
  target: list[dict[str, np.ndarray]],
155
155
  output_shape: tuple[int, int, int],
156
- channels_last: bool = True,
157
156
  ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
158
157
  """Build the target, and it's mask to be used from loss computation.
159
158
 
160
159
  Args:
161
160
  target: target coming from dataset
162
161
  output_shape: shape of the output of the model without batch_size
163
- channels_last: whether channels are last or not
164
162
 
165
163
  Returns:
166
164
  the new formatted target, mask and shrunken text kernel
@@ -172,10 +170,8 @@ class _FAST(BaseModel):
172
170
 
173
171
  h: int
174
172
  w: int
175
- if channels_last:
176
- h, w, num_classes = output_shape
177
- else:
178
- num_classes, h, w = output_shape
173
+
174
+ num_classes, h, w = output_shape
179
175
  target_shape = (len(target), num_classes, h, w)
180
176
 
181
177
  seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
@@ -235,14 +231,8 @@ class _FAST(BaseModel):
235
231
  if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
236
232
  seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
237
233
  continue
238
- cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
234
+ cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
239
235
  # draw the original polygon on the segmentation target
240
- cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0) # type: ignore[call-overload]
241
-
242
- # Don't forget to switch back to channel last if Tensorflow is used
243
- if channels_last:
244
- seg_target = seg_target.transpose((0, 2, 3, 1))
245
- seg_mask = seg_mask.transpose((0, 2, 3, 1))
246
- shrunken_kernel = shrunken_kernel.transpose((0, 2, 3, 1))
236
+ cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0)
247
237
 
248
238
  return seg_target, seg_mask, shrunken_kernel
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -206,7 +206,7 @@ class FAST(_FAST, nn.Module):
206
206
 
207
207
  if target is None or return_preds:
208
208
  # Disable for torch.compile compatibility
209
- @torch.compiler.disable # type: ignore[attr-defined]
209
+ @torch.compiler.disable
210
210
  def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
211
211
  return [
212
212
  dict(zip(self.class_names, preds))
@@ -238,7 +238,7 @@ class FAST(_FAST, nn.Module):
238
238
  Returns:
239
239
  A loss tensor
240
240
  """
241
- targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
241
+ targets = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
242
242
 
243
243
  seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
244
244
  shrunken_kernel = torch.from_numpy(targets[2]).to(out_map.device)
@@ -303,7 +303,7 @@ def reparameterize(model: FAST | nn.Module) -> FAST:
303
303
 
304
304
  for module in model.modules():
305
305
  if hasattr(module, "reparameterize_layer"):
306
- module.reparameterize_layer()
306
+ module.reparameterize_layer() # type: ignore[operator]
307
307
 
308
308
  for name, child in model.named_children():
309
309
  if isinstance(child, nn.BatchNorm2d):
@@ -315,7 +315,7 @@ def reparameterize(model: FAST | nn.Module) -> FAST:
315
315
 
316
316
  factor = child.weight / torch.sqrt(child.running_var + child.eps) # type: ignore
317
317
  last_conv.weight = nn.Parameter(conv_w * factor.reshape([last_conv.out_channels, 1, 1, 1]))
318
- last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias)
318
+ last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias) # type: ignore[operator]
319
319
  model._modules[last_conv_name] = last_conv # type: ignore[index]
320
320
  model._modules[name] = nn.Identity()
321
321
  last_conv = None
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -156,14 +156,12 @@ class _LinkNet(BaseModel):
156
156
  self,
157
157
  target: list[dict[str, np.ndarray]],
158
158
  output_shape: tuple[int, int, int],
159
- channels_last: bool = True,
160
159
  ) -> tuple[np.ndarray, np.ndarray]:
161
160
  """Build the target, and it's mask to be used from loss computation.
162
161
 
163
162
  Args:
164
163
  target: target coming from dataset
165
164
  output_shape: shape of the output of the model without batch_size
166
- channels_last: whether channels are last or not
167
165
 
168
166
  Returns:
169
167
  the new formatted target and the mask
@@ -175,10 +173,8 @@ class _LinkNet(BaseModel):
175
173
 
176
174
  h: int
177
175
  w: int
178
- if channels_last:
179
- h, w, num_classes = output_shape
180
- else:
181
- num_classes, h, w = output_shape
176
+
177
+ num_classes, h, w = output_shape
182
178
  target_shape = (len(target), num_classes, h, w)
183
179
 
184
180
  seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
@@ -237,11 +233,6 @@ class _LinkNet(BaseModel):
237
233
  if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
238
234
  seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
239
235
  continue
240
- cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
241
-
242
- # Don't forget to switch back to channel last if Tensorflow is used
243
- if channels_last:
244
- seg_target = seg_target.transpose((0, 2, 3, 1))
245
- seg_mask = seg_mask.transpose((0, 2, 3, 1))
236
+ cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
246
237
 
247
238
  return seg_target, seg_mask
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -193,7 +193,7 @@ class LinkNet(nn.Module, _LinkNet):
193
193
 
194
194
  if target is None or return_preds:
195
195
  # Disable for torch.compile compatibility
196
- @torch.compiler.disable # type: ignore[attr-defined]
196
+ @torch.compiler.disable
197
197
  def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
198
198
  return [
199
199
  dict(zip(self.class_names, preds))
@@ -230,7 +230,7 @@ class LinkNet(nn.Module, _LinkNet):
230
230
  Returns:
231
231
  A loss tensor
232
232
  """
233
- _target, _mask = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
233
+ _target, _mask = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
234
234
 
235
235
  seg_target, seg_mask = torch.from_numpy(_target).to(dtype=out_map.dtype), torch.from_numpy(_mask)
236
236
  seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -36,7 +36,7 @@ class DetectionPredictor(nn.Module):
36
36
  @torch.inference_mode()
37
37
  def forward(
38
38
  self,
39
- pages: list[np.ndarray | torch.Tensor],
39
+ pages: list[np.ndarray],
40
40
  return_maps: bool = False,
41
41
  **kwargs: Any,
42
42
  ) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from typing import Any
7
7
 
8
- from doctr.file_utils import is_tf_available, is_torch_available
8
+ from doctr.models.utils import _CompiledModule
9
9
 
10
10
  from .. import detection
11
11
  from ..detection.fast import reparameterize
@@ -16,30 +16,17 @@ __all__ = ["detection_predictor"]
16
16
 
17
17
  ARCHS: list[str]
18
18
 
19
-
20
- if is_tf_available():
21
- ARCHS = [
22
- "db_resnet50",
23
- "db_mobilenet_v3_large",
24
- "linknet_resnet18",
25
- "linknet_resnet34",
26
- "linknet_resnet50",
27
- "fast_tiny",
28
- "fast_small",
29
- "fast_base",
30
- ]
31
- elif is_torch_available():
32
- ARCHS = [
33
- "db_resnet34",
34
- "db_resnet50",
35
- "db_mobilenet_v3_large",
36
- "linknet_resnet18",
37
- "linknet_resnet34",
38
- "linknet_resnet50",
39
- "fast_tiny",
40
- "fast_small",
41
- "fast_base",
42
- ]
19
+ ARCHS = [
20
+ "db_resnet34",
21
+ "db_resnet50",
22
+ "db_mobilenet_v3_large",
23
+ "linknet_resnet18",
24
+ "linknet_resnet34",
25
+ "linknet_resnet50",
26
+ "fast_tiny",
27
+ "fast_small",
28
+ "fast_base",
29
+ ]
43
30
 
44
31
 
45
32
  def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor:
@@ -56,12 +43,8 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
56
43
  if isinstance(_model, detection.FAST):
57
44
  _model = reparameterize(_model)
58
45
  else:
59
- allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST]
60
- if is_torch_available():
61
- # Adding the type for torch compiled models to the allowed architectures
62
- from doctr.models.utils import _CompiledModule
63
-
64
- allowed_archs.append(_CompiledModule)
46
+ # Adding the type for torch compiled models to the allowed architectures
47
+ allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST, _CompiledModule]
65
48
 
66
49
  if not isinstance(arch, tuple(allowed_archs)):
67
50
  raise ValueError(f"unknown architecture: {type(arch)}")
@@ -76,7 +59,7 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
76
59
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
77
60
  kwargs["batch_size"] = kwargs.get("batch_size", 2)
78
61
  predictor = DetectionPredictor(
79
- PreProcessor(_model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:], **kwargs),
62
+ PreProcessor(_model.cfg["input_shape"][1:], **kwargs),
80
63
  _model,
81
64
  )
82
65
  return predictor