python-doctr 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. doctr/datasets/__init__.py +1 -0
  2. doctr/datasets/coco_text.py +139 -0
  3. doctr/datasets/cord.py +2 -1
  4. doctr/datasets/funsd.py +2 -2
  5. doctr/datasets/ic03.py +1 -1
  6. doctr/datasets/ic13.py +2 -1
  7. doctr/datasets/iiit5k.py +4 -1
  8. doctr/datasets/imgur5k.py +9 -2
  9. doctr/datasets/loader.py +1 -1
  10. doctr/datasets/ocr.py +1 -1
  11. doctr/datasets/recognition.py +1 -1
  12. doctr/datasets/svhn.py +1 -1
  13. doctr/datasets/svt.py +2 -2
  14. doctr/datasets/synthtext.py +15 -2
  15. doctr/datasets/utils.py +7 -6
  16. doctr/datasets/vocabs.py +1102 -54
  17. doctr/file_utils.py +9 -0
  18. doctr/io/elements.py +37 -3
  19. doctr/models/_utils.py +1 -1
  20. doctr/models/classification/__init__.py +1 -0
  21. doctr/models/classification/magc_resnet/pytorch.py +1 -2
  22. doctr/models/classification/magc_resnet/tensorflow.py +3 -3
  23. doctr/models/classification/mobilenet/pytorch.py +15 -1
  24. doctr/models/classification/mobilenet/tensorflow.py +11 -2
  25. doctr/models/classification/predictor/pytorch.py +1 -1
  26. doctr/models/classification/resnet/pytorch.py +26 -3
  27. doctr/models/classification/resnet/tensorflow.py +25 -4
  28. doctr/models/classification/textnet/pytorch.py +10 -1
  29. doctr/models/classification/textnet/tensorflow.py +11 -2
  30. doctr/models/classification/vgg/pytorch.py +16 -1
  31. doctr/models/classification/vgg/tensorflow.py +11 -2
  32. doctr/models/classification/vip/__init__.py +4 -0
  33. doctr/models/classification/vip/layers/__init__.py +4 -0
  34. doctr/models/classification/vip/layers/pytorch.py +615 -0
  35. doctr/models/classification/vip/pytorch.py +505 -0
  36. doctr/models/classification/vit/pytorch.py +10 -1
  37. doctr/models/classification/vit/tensorflow.py +9 -0
  38. doctr/models/classification/zoo.py +4 -0
  39. doctr/models/detection/differentiable_binarization/base.py +3 -4
  40. doctr/models/detection/differentiable_binarization/pytorch.py +10 -1
  41. doctr/models/detection/differentiable_binarization/tensorflow.py +11 -4
  42. doctr/models/detection/fast/base.py +2 -3
  43. doctr/models/detection/fast/pytorch.py +13 -4
  44. doctr/models/detection/fast/tensorflow.py +10 -2
  45. doctr/models/detection/linknet/base.py +2 -3
  46. doctr/models/detection/linknet/pytorch.py +10 -1
  47. doctr/models/detection/linknet/tensorflow.py +10 -2
  48. doctr/models/factory/hub.py +3 -3
  49. doctr/models/kie_predictor/pytorch.py +1 -1
  50. doctr/models/kie_predictor/tensorflow.py +1 -1
  51. doctr/models/modules/layers/pytorch.py +49 -1
  52. doctr/models/predictor/pytorch.py +1 -1
  53. doctr/models/predictor/tensorflow.py +1 -1
  54. doctr/models/recognition/__init__.py +1 -0
  55. doctr/models/recognition/crnn/pytorch.py +10 -1
  56. doctr/models/recognition/crnn/tensorflow.py +10 -1
  57. doctr/models/recognition/master/pytorch.py +10 -1
  58. doctr/models/recognition/master/tensorflow.py +10 -3
  59. doctr/models/recognition/parseq/pytorch.py +23 -5
  60. doctr/models/recognition/parseq/tensorflow.py +13 -5
  61. doctr/models/recognition/predictor/_utils.py +107 -45
  62. doctr/models/recognition/predictor/pytorch.py +3 -3
  63. doctr/models/recognition/predictor/tensorflow.py +3 -3
  64. doctr/models/recognition/sar/pytorch.py +10 -1
  65. doctr/models/recognition/sar/tensorflow.py +10 -3
  66. doctr/models/recognition/utils.py +56 -47
  67. doctr/models/recognition/viptr/__init__.py +4 -0
  68. doctr/models/recognition/viptr/pytorch.py +277 -0
  69. doctr/models/recognition/vitstr/pytorch.py +10 -1
  70. doctr/models/recognition/vitstr/tensorflow.py +10 -3
  71. doctr/models/recognition/zoo.py +5 -0
  72. doctr/models/utils/pytorch.py +28 -18
  73. doctr/models/utils/tensorflow.py +15 -8
  74. doctr/utils/data.py +1 -1
  75. doctr/utils/geometry.py +1 -1
  76. doctr/version.py +1 -1
  77. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +19 -3
  78. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/RECORD +82 -75
  79. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  80. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  81. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  82. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
doctr/file_utils.py CHANGED
@@ -80,6 +80,15 @@ if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VA
80
80
  logging.info(f"TensorFlow version {_tf_version} available.")
81
81
  ensure_keras_v2()
82
82
 
83
+ import warnings
84
+
85
+ warnings.simplefilter("always", DeprecationWarning)
86
+ warnings.warn(
87
+ "Support for TensorFlow in DocTR is deprecated and will be removed in the next major release (v1.0.0). "
88
+ "Please switch to the PyTorch backend.",
89
+ DeprecationWarning,
90
+ )
91
+
83
92
  else: # pragma: no cover
84
93
  logging.info("Disabling Tensorflow because USE_TORCH is set")
85
94
  _tf_available = False
doctr/io/elements.py CHANGED
@@ -347,7 +347,7 @@ class Page(Element):
347
347
  )
348
348
  # Create the body
349
349
  body = SubElement(page_hocr, "body")
350
- SubElement(
350
+ page_div = SubElement(
351
351
  body,
352
352
  "div",
353
353
  attrib={
@@ -362,7 +362,7 @@ class Page(Element):
362
362
  raise TypeError("XML export is only available for straight bounding boxes for now.")
363
363
  (xmin, ymin), (xmax, ymax) = block.geometry
364
364
  block_div = SubElement(
365
- body,
365
+ page_div,
366
366
  "div",
367
367
  attrib={
368
368
  "class": "ocr_carea",
@@ -550,7 +550,41 @@ class KIEPage(Element):
550
550
  {int(round(xmax * width))} {int(round(ymax * height))}",
551
551
  },
552
552
  )
553
- prediction_div.text = prediction.value
553
+ # NOTE: ocr_par, ocr_line and ocrx_word are the same because the KIE predictions contain only words
554
+ # This is a workaround to make it PDF/A compatible
555
+ par_div = SubElement(
556
+ prediction_div,
557
+ "p",
558
+ attrib={
559
+ "class": "ocr_par",
560
+ "id": f"{class_name}_par_{prediction_count}",
561
+ "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
562
+ {int(round(xmax * width))} {int(round(ymax * height))}",
563
+ },
564
+ )
565
+ line_span = SubElement(
566
+ par_div,
567
+ "span",
568
+ attrib={
569
+ "class": "ocr_line",
570
+ "id": f"{class_name}_line_{prediction_count}",
571
+ "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
572
+ {int(round(xmax * width))} {int(round(ymax * height))}; \
573
+ baseline 0 0; x_size 0; x_descenders 0; x_ascenders 0",
574
+ },
575
+ )
576
+ word_div = SubElement(
577
+ line_span,
578
+ "span",
579
+ attrib={
580
+ "class": "ocrx_word",
581
+ "id": f"{class_name}_word_{prediction_count}",
582
+ "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
583
+ {int(round(xmax * width))} {int(round(ymax * height))}; \
584
+ x_wconf {int(round(prediction.confidence * 100))}",
585
+ },
586
+ )
587
+ word_div.text = prediction.value
554
588
  prediction_count += 1
555
589
 
556
590
  return ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr)
doctr/models/_utils.py CHANGED
@@ -87,7 +87,7 @@ def estimate_orientation(
87
87
 
88
88
  angles = []
89
89
  for contour in contours[:n_ct]:
90
- _, (w, h), angle = cv2.minAreaRect(contour) # type: ignore[assignment]
90
+ _, (w, h), angle = cv2.minAreaRect(contour)
91
91
  if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines
92
92
  angles.append(angle)
93
93
  elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree
@@ -4,4 +4,5 @@ from .vgg import *
4
4
  from .magc_resnet import *
5
5
  from .vit import *
6
6
  from .textnet import *
7
+ from .vip import *
7
8
  from .zoo import *
@@ -14,7 +14,6 @@ from torch import nn
14
14
 
15
15
  from doctr.datasets import VOCABS
16
16
 
17
- from ...utils.pytorch import load_pretrained_params
18
17
  from ..resnet.pytorch import ResNet
19
18
 
20
19
  __all__ = ["magc_resnet31"]
@@ -136,7 +135,7 @@ def _magc_resnet(
136
135
  # The number of classes is not the same as the number of classes in the pretrained model =>
137
136
  # remove the last layer weights
138
137
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
139
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
138
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
140
139
 
141
140
  return model
142
141
 
@@ -14,7 +14,7 @@ from tensorflow.keras.models import Sequential
14
14
 
15
15
  from doctr.datasets import VOCABS
16
16
 
17
- from ...utils import _build_model, load_pretrained_params
17
+ from ...utils import _build_model
18
18
  from ..resnet.tensorflow import ResNet
19
19
 
20
20
  __all__ = ["magc_resnet31"]
@@ -157,8 +157,8 @@ def _magc_resnet(
157
157
  if pretrained:
158
158
  # The number of classes is not the same as the number of classes in the pretrained model =>
159
159
  # skip the mismatching layers for fine tuning
160
- load_pretrained_params(
161
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
160
+ model.from_pretrained(
161
+ default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
162
162
  )
163
163
 
164
164
  return model
@@ -5,6 +5,7 @@
5
5
 
6
6
  # Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
7
7
 
8
+ import types
8
9
  from copy import deepcopy
9
10
  from typing import Any
10
11
 
@@ -99,12 +100,25 @@ def _mobilenet_v3(
99
100
  m = getattr(m, child)
100
101
  m.stride = (2, 1)
101
102
 
103
+ # monkeypatch the model to allow for loading pretrained parameters
104
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
105
+ """Load pretrained parameters onto the model
106
+
107
+ Args:
108
+ path_or_url: the path or URL to the model parameters (checkpoint)
109
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
110
+ """
111
+ load_pretrained_params(self, path_or_url, **kwargs)
112
+
113
+ # Bind method to the instance
114
+ model.from_pretrained = types.MethodType(from_pretrained, model)
115
+
102
116
  # Load pretrained parameters
103
117
  if pretrained:
104
118
  # The number of classes is not the same as the number of classes in the pretrained model =>
105
119
  # remove the last layer weights
106
120
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
107
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
121
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
108
122
 
109
123
  model.cfg = _cfg
110
124
 
@@ -236,6 +236,15 @@ class MobileNetV3(Sequential):
236
236
  super().__init__(_layers)
237
237
  self.cfg = cfg
238
238
 
239
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
240
+ """Load pretrained parameters onto the model
241
+
242
+ Args:
243
+ path_or_url: the path or URL to the model parameters (checkpoint)
244
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
245
+ """
246
+ load_pretrained_params(self, path_or_url, **kwargs)
247
+
239
248
 
240
249
  def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwargs: Any) -> MobileNetV3:
241
250
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
@@ -300,8 +309,8 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa
300
309
  if pretrained:
301
310
  # The number of classes is not the same as the number of classes in the pretrained model =>
302
311
  # skip the mismatching layers for fine tuning
303
- load_pretrained_params(
304
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
312
+ model.from_pretrained(
313
+ default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
305
314
  )
306
315
 
307
316
  return model
@@ -50,7 +50,7 @@ class OrientationPredictor(nn.Module):
50
50
  self.model, processed_batches = set_device_and_dtype(
51
51
  self.model, processed_batches, _params.device, _params.dtype
52
52
  )
53
- predicted_batches = [self.model(batch) for batch in processed_batches] # type: ignore[misc]
53
+ predicted_batches = [self.model(batch) for batch in processed_batches]
54
54
  # confidence
55
55
  probs = [
56
56
  torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches
@@ -4,6 +4,7 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
+ import types
7
8
  from collections.abc import Callable
8
9
  from copy import deepcopy
9
10
  from typing import Any
@@ -152,6 +153,15 @@ class ResNet(nn.Sequential):
152
153
  nn.init.constant_(m.weight, 1)
153
154
  nn.init.constant_(m.bias, 0)
154
155
 
156
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
157
+ """Load pretrained parameters onto the model
158
+
159
+ Args:
160
+ path_or_url: the path or URL to the model parameters (checkpoint)
161
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
162
+ """
163
+ load_pretrained_params(self, path_or_url, **kwargs)
164
+
155
165
 
156
166
  def _resnet(
157
167
  arch: str,
@@ -179,7 +189,7 @@ def _resnet(
179
189
  # The number of classes is not the same as the number of classes in the pretrained model =>
180
190
  # remove the last layer weights
181
191
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
182
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
192
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
183
193
 
184
194
  return model
185
195
 
@@ -201,12 +211,25 @@ def _tv_resnet(
201
211
 
202
212
  # Build the model
203
213
  model = arch_fn(**kwargs, weights=None)
204
- # Load pretrained parameters
214
+
215
+ # monkeypatch the model to allow for loading pretrained parameters
216
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
217
+ """Load pretrained parameters onto the model
218
+
219
+ Args:
220
+ path_or_url: the path or URL to the model parameters (checkpoint)
221
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
222
+ """
223
+ load_pretrained_params(self, path_or_url, **kwargs)
224
+
225
+ # Bind method to the instance
226
+ model.from_pretrained = types.MethodType(from_pretrained, model)
227
+
205
228
  if pretrained:
206
229
  # The number of classes is not the same as the number of classes in the pretrained model =>
207
230
  # remove the last layer weights
208
231
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
209
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
232
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
210
233
 
211
234
  model.cfg = _cfg
212
235
 
@@ -3,6 +3,7 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ import types
6
7
  from collections.abc import Callable
7
8
  from copy import deepcopy
8
9
  from typing import Any
@@ -183,6 +184,15 @@ class ResNet(Sequential):
183
184
  super().__init__(_layers)
184
185
  self.cfg = cfg
185
186
 
187
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
188
+ """Load pretrained parameters onto the model
189
+
190
+ Args:
191
+ path_or_url: the path or URL to the model parameters (checkpoint)
192
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
193
+ """
194
+ load_pretrained_params(self, path_or_url, **kwargs)
195
+
186
196
 
187
197
  def _resnet(
188
198
  arch: str,
@@ -215,8 +225,8 @@ def _resnet(
215
225
  if pretrained:
216
226
  # The number of classes is not the same as the number of classes in the pretrained model =>
217
227
  # skip the mismatching layers for fine tuning
218
- load_pretrained_params(
219
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
228
+ model.from_pretrained(
229
+ default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
220
230
  )
221
231
 
222
232
  return model
@@ -350,6 +360,18 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
350
360
  classifier_activation=None,
351
361
  )
352
362
 
363
+ # monkeypatch the model to allow for loading pretrained parameters
364
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
365
+ """Load pretrained parameters onto the model
366
+
367
+ Args:
368
+ path_or_url: the path or URL to the model parameters (checkpoint)
369
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
370
+ """
371
+ load_pretrained_params(self, path_or_url, **kwargs)
372
+
373
+ model.from_pretrained = types.MethodType(from_pretrained, model) # Bind method to the instance
374
+
353
375
  model.cfg = _cfg
354
376
  _build_model(model)
355
377
 
@@ -357,8 +379,7 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
357
379
  if pretrained:
358
380
  # The number of classes is not the same as the number of classes in the pretrained model =>
359
381
  # skip the mismatching layers for fine tuning
360
- load_pretrained_params(
361
- model,
382
+ model.from_pretrained(
362
383
  default_cfgs["resnet50"]["url"],
363
384
  skip_mismatch=kwargs["num_classes"] != len(default_cfgs["resnet50"]["classes"]),
364
385
  )
@@ -93,6 +93,15 @@ class TextNet(nn.Sequential):
93
93
  nn.init.constant_(m.weight, 1)
94
94
  nn.init.constant_(m.bias, 0)
95
95
 
96
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
97
+ """Load pretrained parameters onto the model
98
+
99
+ Args:
100
+ path_or_url: the path or URL to the model parameters (checkpoint)
101
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
102
+ """
103
+ load_pretrained_params(self, path_or_url, **kwargs)
104
+
96
105
 
97
106
  def _textnet(
98
107
  arch: str,
@@ -115,7 +124,7 @@ def _textnet(
115
124
  # The number of classes is not the same as the number of classes in the pretrained model =>
116
125
  # remove the last layer weights
117
126
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
118
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
127
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
119
128
 
120
129
  model.cfg = _cfg
121
130
 
@@ -92,6 +92,15 @@ class TextNet(Sequential):
92
92
  super().__init__(_layers)
93
93
  self.cfg = cfg
94
94
 
95
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
96
+ """Load pretrained parameters onto the model
97
+
98
+ Args:
99
+ path_or_url: the path or URL to the model parameters (checkpoint)
100
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
101
+ """
102
+ load_pretrained_params(self, path_or_url, **kwargs)
103
+
95
104
 
96
105
  def _textnet(
97
106
  arch: str,
@@ -116,8 +125,8 @@ def _textnet(
116
125
  if pretrained:
117
126
  # The number of classes is not the same as the number of classes in the pretrained model =>
118
127
  # skip the mismatching layers for fine tuning
119
- load_pretrained_params(
120
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
128
+ model.from_pretrained(
129
+ default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
121
130
  )
122
131
 
123
132
  return model
@@ -3,6 +3,7 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ import types
6
7
  from copy import deepcopy
7
8
  from typing import Any
8
9
 
@@ -53,12 +54,26 @@ def _vgg(
53
54
  # Patch average pool & classification head
54
55
  model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
55
56
  model.classifier = nn.Linear(512, kwargs["num_classes"])
57
+
58
+ # monkeypatch the model to allow for loading pretrained parameters
59
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
60
+ """Load pretrained parameters onto the model
61
+
62
+ Args:
63
+ path_or_url: the path or URL to the model parameters (checkpoint)
64
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
65
+ """
66
+ load_pretrained_params(self, path_or_url, **kwargs)
67
+
68
+ # Bind method to the instance
69
+ model.from_pretrained = types.MethodType(from_pretrained, model)
70
+
56
71
  # Load pretrained parameters
57
72
  if pretrained:
58
73
  # The number of classes is not the same as the number of classes in the pretrained model =>
59
74
  # remove the last layer weights
60
75
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
61
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
76
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
62
77
 
63
78
  model.cfg = _cfg
64
79
 
@@ -64,6 +64,15 @@ class VGG(Sequential):
64
64
  super().__init__(_layers)
65
65
  self.cfg = cfg
66
66
 
67
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
68
+ """Load pretrained parameters onto the model
69
+
70
+ Args:
71
+ path_or_url: the path or URL to the model parameters (checkpoint)
72
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
73
+ """
74
+ load_pretrained_params(self, path_or_url, **kwargs)
75
+
67
76
 
68
77
  def _vgg(
69
78
  arch: str, pretrained: bool, num_blocks: list[int], planes: list[int], rect_pools: list[bool], **kwargs: Any
@@ -86,8 +95,8 @@ def _vgg(
86
95
  if pretrained:
87
96
  # The number of classes is not the same as the number of classes in the pretrained model =>
88
97
  # skip the mismatching layers for fine tuning
89
- load_pretrained_params(
90
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
98
+ model.from_pretrained(
99
+ default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
91
100
  )
92
101
 
93
102
  return model
@@ -0,0 +1,4 @@
1
+ from doctr.file_utils import is_torch_available
2
+
3
+ if is_torch_available():
4
+ from .pytorch import *
@@ -0,0 +1,4 @@
1
+ from doctr.file_utils import is_torch_available
2
+
3
+ if is_torch_available():
4
+ from .pytorch import *