python-doctr 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +2 -1
- doctr/datasets/funsd.py +2 -2
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +4 -1
- doctr/datasets/imgur5k.py +9 -2
- doctr/datasets/loader.py +1 -1
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +2 -2
- doctr/datasets/synthtext.py +15 -2
- doctr/datasets/utils.py +7 -6
- doctr/datasets/vocabs.py +1102 -54
- doctr/file_utils.py +9 -0
- doctr/io/elements.py +37 -3
- doctr/models/_utils.py +1 -1
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/pytorch.py +1 -2
- doctr/models/classification/magc_resnet/tensorflow.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +15 -1
- doctr/models/classification/mobilenet/tensorflow.py +11 -2
- doctr/models/classification/predictor/pytorch.py +1 -1
- doctr/models/classification/resnet/pytorch.py +26 -3
- doctr/models/classification/resnet/tensorflow.py +25 -4
- doctr/models/classification/textnet/pytorch.py +10 -1
- doctr/models/classification/textnet/tensorflow.py +11 -2
- doctr/models/classification/vgg/pytorch.py +16 -1
- doctr/models/classification/vgg/tensorflow.py +11 -2
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/pytorch.py +10 -1
- doctr/models/classification/vit/tensorflow.py +9 -0
- doctr/models/classification/zoo.py +4 -0
- doctr/models/detection/differentiable_binarization/base.py +3 -4
- doctr/models/detection/differentiable_binarization/pytorch.py +10 -1
- doctr/models/detection/differentiable_binarization/tensorflow.py +11 -4
- doctr/models/detection/fast/base.py +2 -3
- doctr/models/detection/fast/pytorch.py +13 -4
- doctr/models/detection/fast/tensorflow.py +10 -2
- doctr/models/detection/linknet/base.py +2 -3
- doctr/models/detection/linknet/pytorch.py +10 -1
- doctr/models/detection/linknet/tensorflow.py +10 -2
- doctr/models/factory/hub.py +3 -3
- doctr/models/kie_predictor/pytorch.py +1 -1
- doctr/models/kie_predictor/tensorflow.py +1 -1
- doctr/models/modules/layers/pytorch.py +49 -1
- doctr/models/predictor/pytorch.py +1 -1
- doctr/models/predictor/tensorflow.py +1 -1
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/crnn/pytorch.py +10 -1
- doctr/models/recognition/crnn/tensorflow.py +10 -1
- doctr/models/recognition/master/pytorch.py +10 -1
- doctr/models/recognition/master/tensorflow.py +10 -3
- doctr/models/recognition/parseq/pytorch.py +23 -5
- doctr/models/recognition/parseq/tensorflow.py +13 -5
- doctr/models/recognition/predictor/_utils.py +107 -45
- doctr/models/recognition/predictor/pytorch.py +3 -3
- doctr/models/recognition/predictor/tensorflow.py +3 -3
- doctr/models/recognition/sar/pytorch.py +10 -1
- doctr/models/recognition/sar/tensorflow.py +10 -3
- doctr/models/recognition/utils.py +56 -47
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/pytorch.py +10 -1
- doctr/models/recognition/vitstr/tensorflow.py +10 -3
- doctr/models/recognition/zoo.py +5 -0
- doctr/models/utils/pytorch.py +28 -18
- doctr/models/utils/tensorflow.py +15 -8
- doctr/utils/data.py +1 -1
- doctr/utils/geometry.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +19 -3
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/RECORD +82 -75
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
doctr/file_utils.py
CHANGED
|
@@ -80,6 +80,15 @@ if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VA
|
|
|
80
80
|
logging.info(f"TensorFlow version {_tf_version} available.")
|
|
81
81
|
ensure_keras_v2()
|
|
82
82
|
|
|
83
|
+
import warnings
|
|
84
|
+
|
|
85
|
+
warnings.simplefilter("always", DeprecationWarning)
|
|
86
|
+
warnings.warn(
|
|
87
|
+
"Support for TensorFlow in DocTR is deprecated and will be removed in the next major release (v1.0.0). "
|
|
88
|
+
"Please switch to the PyTorch backend.",
|
|
89
|
+
DeprecationWarning,
|
|
90
|
+
)
|
|
91
|
+
|
|
83
92
|
else: # pragma: no cover
|
|
84
93
|
logging.info("Disabling Tensorflow because USE_TORCH is set")
|
|
85
94
|
_tf_available = False
|
doctr/io/elements.py
CHANGED
|
@@ -347,7 +347,7 @@ class Page(Element):
|
|
|
347
347
|
)
|
|
348
348
|
# Create the body
|
|
349
349
|
body = SubElement(page_hocr, "body")
|
|
350
|
-
SubElement(
|
|
350
|
+
page_div = SubElement(
|
|
351
351
|
body,
|
|
352
352
|
"div",
|
|
353
353
|
attrib={
|
|
@@ -362,7 +362,7 @@ class Page(Element):
|
|
|
362
362
|
raise TypeError("XML export is only available for straight bounding boxes for now.")
|
|
363
363
|
(xmin, ymin), (xmax, ymax) = block.geometry
|
|
364
364
|
block_div = SubElement(
|
|
365
|
-
|
|
365
|
+
page_div,
|
|
366
366
|
"div",
|
|
367
367
|
attrib={
|
|
368
368
|
"class": "ocr_carea",
|
|
@@ -550,7 +550,41 @@ class KIEPage(Element):
|
|
|
550
550
|
{int(round(xmax * width))} {int(round(ymax * height))}",
|
|
551
551
|
},
|
|
552
552
|
)
|
|
553
|
-
|
|
553
|
+
# NOTE: ocr_par, ocr_line and ocrx_word are the same because the KIE predictions contain only words
|
|
554
|
+
# This is a workaround to make it PDF/A compatible
|
|
555
|
+
par_div = SubElement(
|
|
556
|
+
prediction_div,
|
|
557
|
+
"p",
|
|
558
|
+
attrib={
|
|
559
|
+
"class": "ocr_par",
|
|
560
|
+
"id": f"{class_name}_par_{prediction_count}",
|
|
561
|
+
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
|
|
562
|
+
{int(round(xmax * width))} {int(round(ymax * height))}",
|
|
563
|
+
},
|
|
564
|
+
)
|
|
565
|
+
line_span = SubElement(
|
|
566
|
+
par_div,
|
|
567
|
+
"span",
|
|
568
|
+
attrib={
|
|
569
|
+
"class": "ocr_line",
|
|
570
|
+
"id": f"{class_name}_line_{prediction_count}",
|
|
571
|
+
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
|
|
572
|
+
{int(round(xmax * width))} {int(round(ymax * height))}; \
|
|
573
|
+
baseline 0 0; x_size 0; x_descenders 0; x_ascenders 0",
|
|
574
|
+
},
|
|
575
|
+
)
|
|
576
|
+
word_div = SubElement(
|
|
577
|
+
line_span,
|
|
578
|
+
"span",
|
|
579
|
+
attrib={
|
|
580
|
+
"class": "ocrx_word",
|
|
581
|
+
"id": f"{class_name}_word_{prediction_count}",
|
|
582
|
+
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
|
|
583
|
+
{int(round(xmax * width))} {int(round(ymax * height))}; \
|
|
584
|
+
x_wconf {int(round(prediction.confidence * 100))}",
|
|
585
|
+
},
|
|
586
|
+
)
|
|
587
|
+
word_div.text = prediction.value
|
|
554
588
|
prediction_count += 1
|
|
555
589
|
|
|
556
590
|
return ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr)
|
doctr/models/_utils.py
CHANGED
|
@@ -87,7 +87,7 @@ def estimate_orientation(
|
|
|
87
87
|
|
|
88
88
|
angles = []
|
|
89
89
|
for contour in contours[:n_ct]:
|
|
90
|
-
_, (w, h), angle = cv2.minAreaRect(contour)
|
|
90
|
+
_, (w, h), angle = cv2.minAreaRect(contour)
|
|
91
91
|
if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines
|
|
92
92
|
angles.append(angle)
|
|
93
93
|
elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree
|
|
@@ -14,7 +14,6 @@ from torch import nn
|
|
|
14
14
|
|
|
15
15
|
from doctr.datasets import VOCABS
|
|
16
16
|
|
|
17
|
-
from ...utils.pytorch import load_pretrained_params
|
|
18
17
|
from ..resnet.pytorch import ResNet
|
|
19
18
|
|
|
20
19
|
__all__ = ["magc_resnet31"]
|
|
@@ -136,7 +135,7 @@ def _magc_resnet(
|
|
|
136
135
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
137
136
|
# remove the last layer weights
|
|
138
137
|
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
139
|
-
|
|
138
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
140
139
|
|
|
141
140
|
return model
|
|
142
141
|
|
|
@@ -14,7 +14,7 @@ from tensorflow.keras.models import Sequential
|
|
|
14
14
|
|
|
15
15
|
from doctr.datasets import VOCABS
|
|
16
16
|
|
|
17
|
-
from ...utils import _build_model
|
|
17
|
+
from ...utils import _build_model
|
|
18
18
|
from ..resnet.tensorflow import ResNet
|
|
19
19
|
|
|
20
20
|
__all__ = ["magc_resnet31"]
|
|
@@ -157,8 +157,8 @@ def _magc_resnet(
|
|
|
157
157
|
if pretrained:
|
|
158
158
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
159
159
|
# skip the mismatching layers for fine tuning
|
|
160
|
-
|
|
161
|
-
|
|
160
|
+
model.from_pretrained(
|
|
161
|
+
default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
162
162
|
)
|
|
163
163
|
|
|
164
164
|
return model
|
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
# Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
|
|
7
7
|
|
|
8
|
+
import types
|
|
8
9
|
from copy import deepcopy
|
|
9
10
|
from typing import Any
|
|
10
11
|
|
|
@@ -99,12 +100,25 @@ def _mobilenet_v3(
|
|
|
99
100
|
m = getattr(m, child)
|
|
100
101
|
m.stride = (2, 1)
|
|
101
102
|
|
|
103
|
+
# monkeypatch the model to allow for loading pretrained parameters
|
|
104
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
|
|
105
|
+
"""Load pretrained parameters onto the model
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
109
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
110
|
+
"""
|
|
111
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
112
|
+
|
|
113
|
+
# Bind method to the instance
|
|
114
|
+
model.from_pretrained = types.MethodType(from_pretrained, model)
|
|
115
|
+
|
|
102
116
|
# Load pretrained parameters
|
|
103
117
|
if pretrained:
|
|
104
118
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
105
119
|
# remove the last layer weights
|
|
106
120
|
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
107
|
-
|
|
121
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
108
122
|
|
|
109
123
|
model.cfg = _cfg
|
|
110
124
|
|
|
@@ -236,6 +236,15 @@ class MobileNetV3(Sequential):
|
|
|
236
236
|
super().__init__(_layers)
|
|
237
237
|
self.cfg = cfg
|
|
238
238
|
|
|
239
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
240
|
+
"""Load pretrained parameters onto the model
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
244
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
245
|
+
"""
|
|
246
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
247
|
+
|
|
239
248
|
|
|
240
249
|
def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwargs: Any) -> MobileNetV3:
|
|
241
250
|
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
@@ -300,8 +309,8 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa
|
|
|
300
309
|
if pretrained:
|
|
301
310
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
302
311
|
# skip the mismatching layers for fine tuning
|
|
303
|
-
|
|
304
|
-
|
|
312
|
+
model.from_pretrained(
|
|
313
|
+
default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
305
314
|
)
|
|
306
315
|
|
|
307
316
|
return model
|
|
@@ -50,7 +50,7 @@ class OrientationPredictor(nn.Module):
|
|
|
50
50
|
self.model, processed_batches = set_device_and_dtype(
|
|
51
51
|
self.model, processed_batches, _params.device, _params.dtype
|
|
52
52
|
)
|
|
53
|
-
predicted_batches = [self.model(batch) for batch in processed_batches]
|
|
53
|
+
predicted_batches = [self.model(batch) for batch in processed_batches]
|
|
54
54
|
# confidence
|
|
55
55
|
probs = [
|
|
56
56
|
torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches
|
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
|
+
import types
|
|
7
8
|
from collections.abc import Callable
|
|
8
9
|
from copy import deepcopy
|
|
9
10
|
from typing import Any
|
|
@@ -152,6 +153,15 @@ class ResNet(nn.Sequential):
|
|
|
152
153
|
nn.init.constant_(m.weight, 1)
|
|
153
154
|
nn.init.constant_(m.bias, 0)
|
|
154
155
|
|
|
156
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
157
|
+
"""Load pretrained parameters onto the model
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
161
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
162
|
+
"""
|
|
163
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
164
|
+
|
|
155
165
|
|
|
156
166
|
def _resnet(
|
|
157
167
|
arch: str,
|
|
@@ -179,7 +189,7 @@ def _resnet(
|
|
|
179
189
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
180
190
|
# remove the last layer weights
|
|
181
191
|
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
182
|
-
|
|
192
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
183
193
|
|
|
184
194
|
return model
|
|
185
195
|
|
|
@@ -201,12 +211,25 @@ def _tv_resnet(
|
|
|
201
211
|
|
|
202
212
|
# Build the model
|
|
203
213
|
model = arch_fn(**kwargs, weights=None)
|
|
204
|
-
|
|
214
|
+
|
|
215
|
+
# monkeypatch the model to allow for loading pretrained parameters
|
|
216
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
|
|
217
|
+
"""Load pretrained parameters onto the model
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
221
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
222
|
+
"""
|
|
223
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
224
|
+
|
|
225
|
+
# Bind method to the instance
|
|
226
|
+
model.from_pretrained = types.MethodType(from_pretrained, model)
|
|
227
|
+
|
|
205
228
|
if pretrained:
|
|
206
229
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
207
230
|
# remove the last layer weights
|
|
208
231
|
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
209
|
-
|
|
232
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
210
233
|
|
|
211
234
|
model.cfg = _cfg
|
|
212
235
|
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
+
import types
|
|
6
7
|
from collections.abc import Callable
|
|
7
8
|
from copy import deepcopy
|
|
8
9
|
from typing import Any
|
|
@@ -183,6 +184,15 @@ class ResNet(Sequential):
|
|
|
183
184
|
super().__init__(_layers)
|
|
184
185
|
self.cfg = cfg
|
|
185
186
|
|
|
187
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
188
|
+
"""Load pretrained parameters onto the model
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
192
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
193
|
+
"""
|
|
194
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
195
|
+
|
|
186
196
|
|
|
187
197
|
def _resnet(
|
|
188
198
|
arch: str,
|
|
@@ -215,8 +225,8 @@ def _resnet(
|
|
|
215
225
|
if pretrained:
|
|
216
226
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
217
227
|
# skip the mismatching layers for fine tuning
|
|
218
|
-
|
|
219
|
-
|
|
228
|
+
model.from_pretrained(
|
|
229
|
+
default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
220
230
|
)
|
|
221
231
|
|
|
222
232
|
return model
|
|
@@ -350,6 +360,18 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
350
360
|
classifier_activation=None,
|
|
351
361
|
)
|
|
352
362
|
|
|
363
|
+
# monkeypatch the model to allow for loading pretrained parameters
|
|
364
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
|
|
365
|
+
"""Load pretrained parameters onto the model
|
|
366
|
+
|
|
367
|
+
Args:
|
|
368
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
369
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
370
|
+
"""
|
|
371
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
372
|
+
|
|
373
|
+
model.from_pretrained = types.MethodType(from_pretrained, model) # Bind method to the instance
|
|
374
|
+
|
|
353
375
|
model.cfg = _cfg
|
|
354
376
|
_build_model(model)
|
|
355
377
|
|
|
@@ -357,8 +379,7 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
357
379
|
if pretrained:
|
|
358
380
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
359
381
|
# skip the mismatching layers for fine tuning
|
|
360
|
-
|
|
361
|
-
model,
|
|
382
|
+
model.from_pretrained(
|
|
362
383
|
default_cfgs["resnet50"]["url"],
|
|
363
384
|
skip_mismatch=kwargs["num_classes"] != len(default_cfgs["resnet50"]["classes"]),
|
|
364
385
|
)
|
|
@@ -93,6 +93,15 @@ class TextNet(nn.Sequential):
|
|
|
93
93
|
nn.init.constant_(m.weight, 1)
|
|
94
94
|
nn.init.constant_(m.bias, 0)
|
|
95
95
|
|
|
96
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
97
|
+
"""Load pretrained parameters onto the model
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
101
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
102
|
+
"""
|
|
103
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
104
|
+
|
|
96
105
|
|
|
97
106
|
def _textnet(
|
|
98
107
|
arch: str,
|
|
@@ -115,7 +124,7 @@ def _textnet(
|
|
|
115
124
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
116
125
|
# remove the last layer weights
|
|
117
126
|
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
118
|
-
|
|
127
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
119
128
|
|
|
120
129
|
model.cfg = _cfg
|
|
121
130
|
|
|
@@ -92,6 +92,15 @@ class TextNet(Sequential):
|
|
|
92
92
|
super().__init__(_layers)
|
|
93
93
|
self.cfg = cfg
|
|
94
94
|
|
|
95
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
96
|
+
"""Load pretrained parameters onto the model
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
100
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
101
|
+
"""
|
|
102
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
103
|
+
|
|
95
104
|
|
|
96
105
|
def _textnet(
|
|
97
106
|
arch: str,
|
|
@@ -116,8 +125,8 @@ def _textnet(
|
|
|
116
125
|
if pretrained:
|
|
117
126
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
118
127
|
# skip the mismatching layers for fine tuning
|
|
119
|
-
|
|
120
|
-
|
|
128
|
+
model.from_pretrained(
|
|
129
|
+
default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
121
130
|
)
|
|
122
131
|
|
|
123
132
|
return model
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
+
import types
|
|
6
7
|
from copy import deepcopy
|
|
7
8
|
from typing import Any
|
|
8
9
|
|
|
@@ -53,12 +54,26 @@ def _vgg(
|
|
|
53
54
|
# Patch average pool & classification head
|
|
54
55
|
model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
55
56
|
model.classifier = nn.Linear(512, kwargs["num_classes"])
|
|
57
|
+
|
|
58
|
+
# monkeypatch the model to allow for loading pretrained parameters
|
|
59
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
|
|
60
|
+
"""Load pretrained parameters onto the model
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
64
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
65
|
+
"""
|
|
66
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
67
|
+
|
|
68
|
+
# Bind method to the instance
|
|
69
|
+
model.from_pretrained = types.MethodType(from_pretrained, model)
|
|
70
|
+
|
|
56
71
|
# Load pretrained parameters
|
|
57
72
|
if pretrained:
|
|
58
73
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
59
74
|
# remove the last layer weights
|
|
60
75
|
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
61
|
-
|
|
76
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
62
77
|
|
|
63
78
|
model.cfg = _cfg
|
|
64
79
|
|
|
@@ -64,6 +64,15 @@ class VGG(Sequential):
|
|
|
64
64
|
super().__init__(_layers)
|
|
65
65
|
self.cfg = cfg
|
|
66
66
|
|
|
67
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
68
|
+
"""Load pretrained parameters onto the model
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
72
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
73
|
+
"""
|
|
74
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
75
|
+
|
|
67
76
|
|
|
68
77
|
def _vgg(
|
|
69
78
|
arch: str, pretrained: bool, num_blocks: list[int], planes: list[int], rect_pools: list[bool], **kwargs: Any
|
|
@@ -86,8 +95,8 @@ def _vgg(
|
|
|
86
95
|
if pretrained:
|
|
87
96
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
88
97
|
# skip the mismatching layers for fine tuning
|
|
89
|
-
|
|
90
|
-
|
|
98
|
+
model.from_pretrained(
|
|
99
|
+
default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
91
100
|
)
|
|
92
101
|
|
|
93
102
|
return model
|