python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +1 -5
  3. doctr/datasets/coco_text.py +139 -0
  4. doctr/datasets/cord.py +2 -1
  5. doctr/datasets/datasets/__init__.py +1 -6
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +2 -2
  8. doctr/datasets/generator/__init__.py +1 -6
  9. doctr/datasets/ic03.py +1 -1
  10. doctr/datasets/ic13.py +2 -1
  11. doctr/datasets/iiit5k.py +4 -1
  12. doctr/datasets/imgur5k.py +9 -2
  13. doctr/datasets/ocr.py +1 -1
  14. doctr/datasets/recognition.py +1 -1
  15. doctr/datasets/svhn.py +1 -1
  16. doctr/datasets/svt.py +2 -2
  17. doctr/datasets/synthtext.py +15 -2
  18. doctr/datasets/utils.py +7 -6
  19. doctr/datasets/vocabs.py +1100 -54
  20. doctr/file_utils.py +2 -92
  21. doctr/io/elements.py +37 -3
  22. doctr/io/image/__init__.py +1 -7
  23. doctr/io/image/pytorch.py +1 -1
  24. doctr/models/_utils.py +4 -4
  25. doctr/models/classification/__init__.py +1 -0
  26. doctr/models/classification/magc_resnet/__init__.py +1 -6
  27. doctr/models/classification/magc_resnet/pytorch.py +3 -4
  28. doctr/models/classification/mobilenet/__init__.py +1 -6
  29. doctr/models/classification/mobilenet/pytorch.py +15 -1
  30. doctr/models/classification/predictor/__init__.py +1 -6
  31. doctr/models/classification/predictor/pytorch.py +2 -2
  32. doctr/models/classification/resnet/__init__.py +1 -6
  33. doctr/models/classification/resnet/pytorch.py +26 -3
  34. doctr/models/classification/textnet/__init__.py +1 -6
  35. doctr/models/classification/textnet/pytorch.py +11 -2
  36. doctr/models/classification/vgg/__init__.py +1 -6
  37. doctr/models/classification/vgg/pytorch.py +16 -1
  38. doctr/models/classification/vip/__init__.py +1 -0
  39. doctr/models/classification/vip/layers/__init__.py +1 -0
  40. doctr/models/classification/vip/layers/pytorch.py +615 -0
  41. doctr/models/classification/vip/pytorch.py +505 -0
  42. doctr/models/classification/vit/__init__.py +1 -6
  43. doctr/models/classification/vit/pytorch.py +12 -3
  44. doctr/models/classification/zoo.py +7 -8
  45. doctr/models/detection/_utils/__init__.py +1 -6
  46. doctr/models/detection/core.py +1 -1
  47. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  48. doctr/models/detection/differentiable_binarization/base.py +7 -16
  49. doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
  50. doctr/models/detection/fast/__init__.py +1 -6
  51. doctr/models/detection/fast/base.py +6 -17
  52. doctr/models/detection/fast/pytorch.py +17 -8
  53. doctr/models/detection/linknet/__init__.py +1 -6
  54. doctr/models/detection/linknet/base.py +5 -15
  55. doctr/models/detection/linknet/pytorch.py +12 -3
  56. doctr/models/detection/predictor/__init__.py +1 -6
  57. doctr/models/detection/predictor/pytorch.py +1 -1
  58. doctr/models/detection/zoo.py +15 -32
  59. doctr/models/factory/hub.py +9 -22
  60. doctr/models/kie_predictor/__init__.py +1 -6
  61. doctr/models/kie_predictor/pytorch.py +3 -7
  62. doctr/models/modules/layers/__init__.py +1 -6
  63. doctr/models/modules/layers/pytorch.py +52 -4
  64. doctr/models/modules/transformer/__init__.py +1 -6
  65. doctr/models/modules/transformer/pytorch.py +2 -2
  66. doctr/models/modules/vision_transformer/__init__.py +1 -6
  67. doctr/models/predictor/__init__.py +1 -6
  68. doctr/models/predictor/base.py +3 -8
  69. doctr/models/predictor/pytorch.py +3 -6
  70. doctr/models/preprocessor/__init__.py +1 -6
  71. doctr/models/preprocessor/pytorch.py +27 -32
  72. doctr/models/recognition/__init__.py +1 -0
  73. doctr/models/recognition/crnn/__init__.py +1 -6
  74. doctr/models/recognition/crnn/pytorch.py +16 -7
  75. doctr/models/recognition/master/__init__.py +1 -6
  76. doctr/models/recognition/master/pytorch.py +15 -6
  77. doctr/models/recognition/parseq/__init__.py +1 -6
  78. doctr/models/recognition/parseq/pytorch.py +26 -8
  79. doctr/models/recognition/predictor/__init__.py +1 -6
  80. doctr/models/recognition/predictor/_utils.py +100 -47
  81. doctr/models/recognition/predictor/pytorch.py +4 -5
  82. doctr/models/recognition/sar/__init__.py +1 -6
  83. doctr/models/recognition/sar/pytorch.py +13 -4
  84. doctr/models/recognition/utils.py +56 -47
  85. doctr/models/recognition/viptr/__init__.py +1 -0
  86. doctr/models/recognition/viptr/pytorch.py +277 -0
  87. doctr/models/recognition/vitstr/__init__.py +1 -6
  88. doctr/models/recognition/vitstr/pytorch.py +13 -4
  89. doctr/models/recognition/zoo.py +13 -8
  90. doctr/models/utils/__init__.py +1 -6
  91. doctr/models/utils/pytorch.py +29 -19
  92. doctr/transforms/functional/__init__.py +1 -6
  93. doctr/transforms/functional/pytorch.py +4 -4
  94. doctr/transforms/modules/__init__.py +1 -7
  95. doctr/transforms/modules/base.py +26 -92
  96. doctr/transforms/modules/pytorch.py +28 -26
  97. doctr/utils/data.py +1 -1
  98. doctr/utils/geometry.py +7 -11
  99. doctr/utils/visualization.py +1 -1
  100. doctr/version.py +1 -1
  101. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
  102. python_doctr-1.0.0.dist-info/RECORD +149 -0
  103. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
  104. doctr/datasets/datasets/tensorflow.py +0 -59
  105. doctr/datasets/generator/tensorflow.py +0 -58
  106. doctr/datasets/loader.py +0 -94
  107. doctr/io/image/tensorflow.py +0 -101
  108. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  109. doctr/models/classification/mobilenet/tensorflow.py +0 -433
  110. doctr/models/classification/predictor/tensorflow.py +0 -60
  111. doctr/models/classification/resnet/tensorflow.py +0 -397
  112. doctr/models/classification/textnet/tensorflow.py +0 -266
  113. doctr/models/classification/vgg/tensorflow.py +0 -116
  114. doctr/models/classification/vit/tensorflow.py +0 -192
  115. doctr/models/detection/_utils/tensorflow.py +0 -34
  116. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
  117. doctr/models/detection/fast/tensorflow.py +0 -419
  118. doctr/models/detection/linknet/tensorflow.py +0 -369
  119. doctr/models/detection/predictor/tensorflow.py +0 -70
  120. doctr/models/kie_predictor/tensorflow.py +0 -187
  121. doctr/models/modules/layers/tensorflow.py +0 -171
  122. doctr/models/modules/transformer/tensorflow.py +0 -235
  123. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  124. doctr/models/predictor/tensorflow.py +0 -155
  125. doctr/models/preprocessor/tensorflow.py +0 -122
  126. doctr/models/recognition/crnn/tensorflow.py +0 -308
  127. doctr/models/recognition/master/tensorflow.py +0 -313
  128. doctr/models/recognition/parseq/tensorflow.py +0 -508
  129. doctr/models/recognition/predictor/tensorflow.py +0 -79
  130. doctr/models/recognition/sar/tensorflow.py +0 -416
  131. doctr/models/recognition/vitstr/tensorflow.py +0 -278
  132. doctr/models/utils/tensorflow.py +0 -182
  133. doctr/transforms/functional/tensorflow.py +0 -254
  134. doctr/transforms/modules/tensorflow.py +0 -562
  135. python_doctr-0.11.0.dist-info/RECORD +0 -173
  136. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
  137. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  138. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
doctr/file_utils.py CHANGED
@@ -3,93 +3,13 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- # Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
7
-
8
6
  import importlib.metadata
9
- import importlib.util
10
7
  import logging
11
- import os
12
-
13
- CLASS_NAME: str = "words"
14
8
 
9
+ __all__ = ["requires_package", "CLASS_NAME"]
15
10
 
16
- __all__ = ["is_tf_available", "is_torch_available", "requires_package", "CLASS_NAME"]
17
-
11
+ CLASS_NAME: str = "words"
18
12
  ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
19
- ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
20
-
21
- USE_TF = os.environ.get("USE_TF", "AUTO").upper()
22
- USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
23
-
24
-
25
- if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
26
- _torch_available = importlib.util.find_spec("torch") is not None
27
- if _torch_available:
28
- try:
29
- _torch_version = importlib.metadata.version("torch")
30
- logging.info(f"PyTorch version {_torch_version} available.")
31
- except importlib.metadata.PackageNotFoundError: # pragma: no cover
32
- _torch_available = False
33
- else: # pragma: no cover
34
- logging.info("Disabling PyTorch because USE_TF is set")
35
- _torch_available = False
36
-
37
- # Compatibility fix to make sure tensorflow.keras stays at Keras 2
38
- if "TF_USE_LEGACY_KERAS" not in os.environ:
39
- os.environ["TF_USE_LEGACY_KERAS"] = "1"
40
-
41
- elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
42
- raise ValueError(
43
- "docTR is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
44
- )
45
-
46
-
47
- def ensure_keras_v2() -> None: # pragma: no cover
48
- if not os.environ.get("TF_USE_LEGACY_KERAS") == "1":
49
- os.environ["TF_USE_LEGACY_KERAS"] = "1"
50
-
51
-
52
- if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
53
- _tf_available = importlib.util.find_spec("tensorflow") is not None
54
- if _tf_available:
55
- candidates = (
56
- "tensorflow",
57
- "tensorflow-cpu",
58
- "tensorflow-gpu",
59
- "tf-nightly",
60
- "tf-nightly-cpu",
61
- "tf-nightly-gpu",
62
- "intel-tensorflow",
63
- "tensorflow-rocm",
64
- "tensorflow-macos",
65
- )
66
- _tf_version = None
67
- # For the metadata, we have to look for both tensorflow and tensorflow-cpu
68
- for pkg in candidates:
69
- try:
70
- _tf_version = importlib.metadata.version(pkg)
71
- break
72
- except importlib.metadata.PackageNotFoundError:
73
- pass
74
- _tf_available = _tf_version is not None
75
- if _tf_available:
76
- if int(_tf_version.split(".")[0]) < 2: # type: ignore[union-attr] # pragma: no cover
77
- logging.info(f"TensorFlow found but with version {_tf_version}. DocTR requires version 2 minimum.")
78
- _tf_available = False
79
- else:
80
- logging.info(f"TensorFlow version {_tf_version} available.")
81
- ensure_keras_v2()
82
-
83
- else: # pragma: no cover
84
- logging.info("Disabling Tensorflow because USE_TORCH is set")
85
- _tf_available = False
86
-
87
-
88
- if not _torch_available and not _tf_available: # pragma: no cover
89
- raise ModuleNotFoundError(
90
- "DocTR requires either TensorFlow or PyTorch to be installed. Please ensure one of them"
91
- " is installed and that either USE_TF or USE_TORCH is enabled."
92
- )
93
13
 
94
14
 
95
15
  def requires_package(name: str, extra_message: str | None = None) -> None: # pragma: no cover
@@ -108,13 +28,3 @@ def requires_package(name: str, extra_message: str | None = None) -> None: # pr
108
28
  f"\n\n{extra_message if extra_message is not None else ''} "
109
29
  f"\nPlease install it with the following command: pip install {name}\n"
110
30
  )
111
-
112
-
113
- def is_torch_available():
114
- """Whether PyTorch is installed."""
115
- return _torch_available
116
-
117
-
118
- def is_tf_available():
119
- """Whether TensorFlow is installed."""
120
- return _tf_available
doctr/io/elements.py CHANGED
@@ -347,7 +347,7 @@ class Page(Element):
347
347
  )
348
348
  # Create the body
349
349
  body = SubElement(page_hocr, "body")
350
- SubElement(
350
+ page_div = SubElement(
351
351
  body,
352
352
  "div",
353
353
  attrib={
@@ -362,7 +362,7 @@ class Page(Element):
362
362
  raise TypeError("XML export is only available for straight bounding boxes for now.")
363
363
  (xmin, ymin), (xmax, ymax) = block.geometry
364
364
  block_div = SubElement(
365
- body,
365
+ page_div,
366
366
  "div",
367
367
  attrib={
368
368
  "class": "ocr_carea",
@@ -550,7 +550,41 @@ class KIEPage(Element):
550
550
  {int(round(xmax * width))} {int(round(ymax * height))}",
551
551
  },
552
552
  )
553
- prediction_div.text = prediction.value
553
+ # NOTE: ocr_par, ocr_line and ocrx_word are the same because the KIE predictions contain only words
554
+ # This is a workaround to make it PDF/A compatible
555
+ par_div = SubElement(
556
+ prediction_div,
557
+ "p",
558
+ attrib={
559
+ "class": "ocr_par",
560
+ "id": f"{class_name}_par_{prediction_count}",
561
+ "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
562
+ {int(round(xmax * width))} {int(round(ymax * height))}",
563
+ },
564
+ )
565
+ line_span = SubElement(
566
+ par_div,
567
+ "span",
568
+ attrib={
569
+ "class": "ocr_line",
570
+ "id": f"{class_name}_line_{prediction_count}",
571
+ "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
572
+ {int(round(xmax * width))} {int(round(ymax * height))}; \
573
+ baseline 0 0; x_size 0; x_descenders 0; x_ascenders 0",
574
+ },
575
+ )
576
+ word_div = SubElement(
577
+ line_span,
578
+ "span",
579
+ attrib={
580
+ "class": "ocrx_word",
581
+ "id": f"{class_name}_word_{prediction_count}",
582
+ "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
583
+ {int(round(xmax * width))} {int(round(ymax * height))}; \
584
+ x_wconf {int(round(prediction.confidence * 100))}",
585
+ },
586
+ )
587
+ word_div.text = prediction.value
554
588
  prediction_count += 1
555
589
 
556
590
  return ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr)
@@ -1,8 +1,2 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
1
  from .base import *
4
-
5
- if is_torch_available():
6
- from .pytorch import *
7
- elif is_tf_available():
8
- from .tensorflow import *
2
+ from .pytorch import *
doctr/io/image/pytorch.py CHANGED
@@ -95,4 +95,4 @@ def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -
95
95
 
96
96
  def get_img_shape(img: torch.Tensor) -> tuple[int, int]:
97
97
  """Get the shape of an image"""
98
- return img.shape[-2:]
98
+ return img.shape[-2:] # type: ignore[return-value]
doctr/models/_utils.py CHANGED
@@ -63,7 +63,7 @@ def estimate_orientation(
63
63
  thresh = img.astype(np.uint8)
64
64
 
65
65
  page_orientation, orientation_confidence = general_page_orientation or (None, 0.0)
66
- if page_orientation and orientation_confidence >= min_confidence:
66
+ if page_orientation is not None and orientation_confidence >= min_confidence:
67
67
  # We rotate the image to the general orientation which improves the detection
68
68
  # No expand needed bitmap is already padded
69
69
  thresh = rotate_image(thresh, -page_orientation)
@@ -87,7 +87,7 @@ def estimate_orientation(
87
87
 
88
88
  angles = []
89
89
  for contour in contours[:n_ct]:
90
- _, (w, h), angle = cv2.minAreaRect(contour) # type: ignore[assignment]
90
+ _, (w, h), angle = cv2.minAreaRect(contour)
91
91
  if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines
92
92
  angles.append(angle)
93
93
  elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree
@@ -100,7 +100,7 @@ def estimate_orientation(
100
100
  estimated_angle = -round(median) if abs(median) != 0 else 0
101
101
 
102
102
  # combine with the general orientation and the estimated angle
103
- if page_orientation and orientation_confidence >= min_confidence:
103
+ if page_orientation is not None and orientation_confidence >= min_confidence:
104
104
  # special case where the estimated angle is mostly wrong:
105
105
  # case 1: - and + swapped
106
106
  # case 2: estimated angle is completely wrong
@@ -184,7 +184,7 @@ def invert_data_structure(
184
184
  dictionary of list when x is a list of dictionaries or a list of dictionaries when x is dictionary of lists
185
185
  """
186
186
  if isinstance(x, dict):
187
- assert len({len(v) for v in x.values()}) == 1, "All the lists in the dictionnary should have the same length."
187
+ assert len({len(v) for v in x.values()}) == 1, "All the lists in the dictionary should have the same length."
188
188
  return [dict(zip(x, t)) for t in zip(*x.values())]
189
189
  elif isinstance(x, list):
190
190
  return {k: [dic[k] for dic in x] for k in x[0]}
@@ -4,4 +4,5 @@ from .vgg import *
4
4
  from .magc_resnet import *
5
5
  from .vit import *
6
6
  from .textnet import *
7
+ from .vip import *
7
8
  from .zoo import *
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -14,8 +14,7 @@ from torch import nn
14
14
 
15
15
  from doctr.datasets import VOCABS
16
16
 
17
- from ...utils.pytorch import load_pretrained_params
18
- from ..resnet.pytorch import ResNet
17
+ from ..resnet import ResNet
19
18
 
20
19
  __all__ = ["magc_resnet31"]
21
20
 
@@ -73,7 +72,7 @@ class MAGC(nn.Module):
73
72
  def forward(self, inputs: torch.Tensor) -> torch.Tensor:
74
73
  batch, _, height, width = inputs.size()
75
74
  # (N * headers, C / headers, H , W)
76
- x = inputs.view(batch * self.headers, self.single_header_inplanes, height, width)
75
+ x = inputs.contiguous().view(batch * self.headers, self.single_header_inplanes, height, width)
77
76
  shortcut = x
78
77
  # (N * headers, C / headers, H * W)
79
78
  shortcut = shortcut.view(batch * self.headers, self.single_header_inplanes, height * width)
@@ -136,7 +135,7 @@ def _magc_resnet(
136
135
  # The number of classes is not the same as the number of classes in the pretrained model =>
137
136
  # remove the last layer weights
138
137
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
139
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
138
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
140
139
 
141
140
  return model
142
141
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -5,6 +5,7 @@
5
5
 
6
6
  # Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
7
7
 
8
+ import types
8
9
  from copy import deepcopy
9
10
  from typing import Any
10
11
 
@@ -99,12 +100,25 @@ def _mobilenet_v3(
99
100
  m = getattr(m, child)
100
101
  m.stride = (2, 1)
101
102
 
103
+ # monkeypatch the model to allow for loading pretrained parameters
104
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
105
+ """Load pretrained parameters onto the model
106
+
107
+ Args:
108
+ path_or_url: the path or URL to the model parameters (checkpoint)
109
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
110
+ """
111
+ load_pretrained_params(self, path_or_url, **kwargs)
112
+
113
+ # Bind method to the instance
114
+ model.from_pretrained = types.MethodType(from_pretrained, model)
115
+
102
116
  # Load pretrained parameters
103
117
  if pretrained:
104
118
  # The number of classes is not the same as the number of classes in the pretrained model =>
105
119
  # remove the last layer weights
106
120
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
107
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
121
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
108
122
 
109
123
  model.cfg = _cfg
110
124
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -35,7 +35,7 @@ class OrientationPredictor(nn.Module):
35
35
  @torch.inference_mode()
36
36
  def forward(
37
37
  self,
38
- inputs: list[np.ndarray | torch.Tensor],
38
+ inputs: list[np.ndarray],
39
39
  ) -> list[list[int] | list[float]]:
40
40
  # Dimension check
41
41
  if any(input.ndim != 3 for input in inputs):
@@ -50,7 +50,7 @@ class OrientationPredictor(nn.Module):
50
50
  self.model, processed_batches = set_device_and_dtype(
51
51
  self.model, processed_batches, _params.device, _params.dtype
52
52
  )
53
- predicted_batches = [self.model(batch) for batch in processed_batches] # type: ignore[misc]
53
+ predicted_batches = [self.model(batch) for batch in processed_batches]
54
54
  # confidence
55
55
  probs = [
56
56
  torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -4,6 +4,7 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
+ import types
7
8
  from collections.abc import Callable
8
9
  from copy import deepcopy
9
10
  from typing import Any
@@ -152,6 +153,15 @@ class ResNet(nn.Sequential):
152
153
  nn.init.constant_(m.weight, 1)
153
154
  nn.init.constant_(m.bias, 0)
154
155
 
156
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
157
+ """Load pretrained parameters onto the model
158
+
159
+ Args:
160
+ path_or_url: the path or URL to the model parameters (checkpoint)
161
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
162
+ """
163
+ load_pretrained_params(self, path_or_url, **kwargs)
164
+
155
165
 
156
166
  def _resnet(
157
167
  arch: str,
@@ -179,7 +189,7 @@ def _resnet(
179
189
  # The number of classes is not the same as the number of classes in the pretrained model =>
180
190
  # remove the last layer weights
181
191
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
182
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
192
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
183
193
 
184
194
  return model
185
195
 
@@ -201,12 +211,25 @@ def _tv_resnet(
201
211
 
202
212
  # Build the model
203
213
  model = arch_fn(**kwargs, weights=None)
204
- # Load pretrained parameters
214
+
215
+ # monkeypatch the model to allow for loading pretrained parameters
216
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
217
+ """Load pretrained parameters onto the model
218
+
219
+ Args:
220
+ path_or_url: the path or URL to the model parameters (checkpoint)
221
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
222
+ """
223
+ load_pretrained_params(self, path_or_url, **kwargs)
224
+
225
+ # Bind method to the instance
226
+ model.from_pretrained = types.MethodType(from_pretrained, model)
227
+
205
228
  if pretrained:
206
229
  # The number of classes is not the same as the number of classes in the pretrained model =>
207
230
  # remove the last layer weights
208
231
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
209
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
232
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
210
233
 
211
234
  model.cfg = _cfg
212
235
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -11,7 +11,7 @@ from torch import nn
11
11
 
12
12
  from doctr.datasets import VOCABS
13
13
 
14
- from ...modules.layers.pytorch import FASTConvLayer
14
+ from ...modules.layers import FASTConvLayer
15
15
  from ...utils import conv_sequence_pt, load_pretrained_params
16
16
 
17
17
  __all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
@@ -93,6 +93,15 @@ class TextNet(nn.Sequential):
93
93
  nn.init.constant_(m.weight, 1)
94
94
  nn.init.constant_(m.bias, 0)
95
95
 
96
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
97
+ """Load pretrained parameters onto the model
98
+
99
+ Args:
100
+ path_or_url: the path or URL to the model parameters (checkpoint)
101
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
102
+ """
103
+ load_pretrained_params(self, path_or_url, **kwargs)
104
+
96
105
 
97
106
  def _textnet(
98
107
  arch: str,
@@ -115,7 +124,7 @@ def _textnet(
115
124
  # The number of classes is not the same as the number of classes in the pretrained model =>
116
125
  # remove the last layer weights
117
126
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
118
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
127
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
119
128
 
120
129
  model.cfg = _cfg
121
130
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -3,6 +3,7 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ import types
6
7
  from copy import deepcopy
7
8
  from typing import Any
8
9
 
@@ -53,12 +54,26 @@ def _vgg(
53
54
  # Patch average pool & classification head
54
55
  model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
55
56
  model.classifier = nn.Linear(512, kwargs["num_classes"])
57
+
58
+ # monkeypatch the model to allow for loading pretrained parameters
59
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
60
+ """Load pretrained parameters onto the model
61
+
62
+ Args:
63
+ path_or_url: the path or URL to the model parameters (checkpoint)
64
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
65
+ """
66
+ load_pretrained_params(self, path_or_url, **kwargs)
67
+
68
+ # Bind method to the instance
69
+ model.from_pretrained = types.MethodType(from_pretrained, model)
70
+
56
71
  # Load pretrained parameters
57
72
  if pretrained:
58
73
  # The number of classes is not the same as the number of classes in the pretrained model =>
59
74
  # remove the last layer weights
60
75
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
61
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
76
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
62
77
 
63
78
  model.cfg = _cfg
64
79
 
@@ -0,0 +1 @@
1
+ from .pytorch import *
@@ -0,0 +1 @@
1
+ from .pytorch import *