python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/datasets/__init__.py +1 -5
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +2 -1
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/funsd.py +2 -2
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +4 -1
- doctr/datasets/imgur5k.py +9 -2
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +2 -2
- doctr/datasets/synthtext.py +15 -2
- doctr/datasets/utils.py +7 -6
- doctr/datasets/vocabs.py +1100 -54
- doctr/file_utils.py +2 -92
- doctr/io/elements.py +37 -3
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/pytorch.py +1 -1
- doctr/models/_utils.py +4 -4
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +3 -4
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/mobilenet/pytorch.py +15 -1
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +2 -2
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/resnet/pytorch.py +26 -3
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +11 -2
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vgg/pytorch.py +16 -1
- doctr/models/classification/vip/__init__.py +1 -0
- doctr/models/classification/vip/layers/__init__.py +1 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +12 -3
- doctr/models/classification/zoo.py +7 -8
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/core.py +1 -1
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +7 -16
- doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +6 -17
- doctr/models/detection/fast/pytorch.py +17 -8
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +5 -15
- doctr/models/detection/linknet/pytorch.py +12 -3
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +1 -1
- doctr/models/detection/zoo.py +15 -32
- doctr/models/factory/hub.py +9 -22
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/pytorch.py +3 -7
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +52 -4
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +3 -8
- doctr/models/predictor/pytorch.py +3 -6
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +27 -32
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +16 -7
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/pytorch.py +15 -6
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/pytorch.py +26 -8
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +100 -47
- doctr/models/recognition/predictor/pytorch.py +4 -5
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +13 -4
- doctr/models/recognition/utils.py +56 -47
- doctr/models/recognition/viptr/__init__.py +1 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/pytorch.py +13 -4
- doctr/models/recognition/zoo.py +13 -8
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +29 -19
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +26 -92
- doctr/transforms/modules/pytorch.py +28 -26
- doctr/utils/data.py +1 -1
- doctr/utils/geometry.py +7 -11
- doctr/utils/visualization.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
- python_doctr-1.0.0.dist-info/RECORD +149 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -433
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -397
- doctr/models/classification/textnet/tensorflow.py +0 -266
- doctr/models/classification/vgg/tensorflow.py +0 -116
- doctr/models/classification/vit/tensorflow.py +0 -192
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
- doctr/models/detection/fast/tensorflow.py +0 -419
- doctr/models/detection/linknet/tensorflow.py +0 -369
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -308
- doctr/models/recognition/master/tensorflow.py +0 -313
- doctr/models/recognition/parseq/tensorflow.py +0 -508
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -416
- doctr/models/recognition/vitstr/tensorflow.py +0 -278
- doctr/models/utils/tensorflow.py +0 -182
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.11.0.dist-info/RECORD +0 -173
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
doctr/file_utils.py
CHANGED
|
@@ -3,93 +3,13 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
|
|
7
|
-
|
|
8
6
|
import importlib.metadata
|
|
9
|
-
import importlib.util
|
|
10
7
|
import logging
|
|
11
|
-
import os
|
|
12
|
-
|
|
13
|
-
CLASS_NAME: str = "words"
|
|
14
8
|
|
|
9
|
+
__all__ = ["requires_package", "CLASS_NAME"]
|
|
15
10
|
|
|
16
|
-
|
|
17
|
-
|
|
11
|
+
CLASS_NAME: str = "words"
|
|
18
12
|
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
|
19
|
-
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
|
20
|
-
|
|
21
|
-
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
|
22
|
-
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
|
26
|
-
_torch_available = importlib.util.find_spec("torch") is not None
|
|
27
|
-
if _torch_available:
|
|
28
|
-
try:
|
|
29
|
-
_torch_version = importlib.metadata.version("torch")
|
|
30
|
-
logging.info(f"PyTorch version {_torch_version} available.")
|
|
31
|
-
except importlib.metadata.PackageNotFoundError: # pragma: no cover
|
|
32
|
-
_torch_available = False
|
|
33
|
-
else: # pragma: no cover
|
|
34
|
-
logging.info("Disabling PyTorch because USE_TF is set")
|
|
35
|
-
_torch_available = False
|
|
36
|
-
|
|
37
|
-
# Compatibility fix to make sure tensorflow.keras stays at Keras 2
|
|
38
|
-
if "TF_USE_LEGACY_KERAS" not in os.environ:
|
|
39
|
-
os.environ["TF_USE_LEGACY_KERAS"] = "1"
|
|
40
|
-
|
|
41
|
-
elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
|
|
42
|
-
raise ValueError(
|
|
43
|
-
"docTR is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def ensure_keras_v2() -> None: # pragma: no cover
|
|
48
|
-
if not os.environ.get("TF_USE_LEGACY_KERAS") == "1":
|
|
49
|
-
os.environ["TF_USE_LEGACY_KERAS"] = "1"
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
|
53
|
-
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
|
54
|
-
if _tf_available:
|
|
55
|
-
candidates = (
|
|
56
|
-
"tensorflow",
|
|
57
|
-
"tensorflow-cpu",
|
|
58
|
-
"tensorflow-gpu",
|
|
59
|
-
"tf-nightly",
|
|
60
|
-
"tf-nightly-cpu",
|
|
61
|
-
"tf-nightly-gpu",
|
|
62
|
-
"intel-tensorflow",
|
|
63
|
-
"tensorflow-rocm",
|
|
64
|
-
"tensorflow-macos",
|
|
65
|
-
)
|
|
66
|
-
_tf_version = None
|
|
67
|
-
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
|
68
|
-
for pkg in candidates:
|
|
69
|
-
try:
|
|
70
|
-
_tf_version = importlib.metadata.version(pkg)
|
|
71
|
-
break
|
|
72
|
-
except importlib.metadata.PackageNotFoundError:
|
|
73
|
-
pass
|
|
74
|
-
_tf_available = _tf_version is not None
|
|
75
|
-
if _tf_available:
|
|
76
|
-
if int(_tf_version.split(".")[0]) < 2: # type: ignore[union-attr] # pragma: no cover
|
|
77
|
-
logging.info(f"TensorFlow found but with version {_tf_version}. DocTR requires version 2 minimum.")
|
|
78
|
-
_tf_available = False
|
|
79
|
-
else:
|
|
80
|
-
logging.info(f"TensorFlow version {_tf_version} available.")
|
|
81
|
-
ensure_keras_v2()
|
|
82
|
-
|
|
83
|
-
else: # pragma: no cover
|
|
84
|
-
logging.info("Disabling Tensorflow because USE_TORCH is set")
|
|
85
|
-
_tf_available = False
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
if not _torch_available and not _tf_available: # pragma: no cover
|
|
89
|
-
raise ModuleNotFoundError(
|
|
90
|
-
"DocTR requires either TensorFlow or PyTorch to be installed. Please ensure one of them"
|
|
91
|
-
" is installed and that either USE_TF or USE_TORCH is enabled."
|
|
92
|
-
)
|
|
93
13
|
|
|
94
14
|
|
|
95
15
|
def requires_package(name: str, extra_message: str | None = None) -> None: # pragma: no cover
|
|
@@ -108,13 +28,3 @@ def requires_package(name: str, extra_message: str | None = None) -> None: # pr
|
|
|
108
28
|
f"\n\n{extra_message if extra_message is not None else ''} "
|
|
109
29
|
f"\nPlease install it with the following command: pip install {name}\n"
|
|
110
30
|
)
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
def is_torch_available():
|
|
114
|
-
"""Whether PyTorch is installed."""
|
|
115
|
-
return _torch_available
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
def is_tf_available():
|
|
119
|
-
"""Whether TensorFlow is installed."""
|
|
120
|
-
return _tf_available
|
doctr/io/elements.py
CHANGED
|
@@ -347,7 +347,7 @@ class Page(Element):
|
|
|
347
347
|
)
|
|
348
348
|
# Create the body
|
|
349
349
|
body = SubElement(page_hocr, "body")
|
|
350
|
-
SubElement(
|
|
350
|
+
page_div = SubElement(
|
|
351
351
|
body,
|
|
352
352
|
"div",
|
|
353
353
|
attrib={
|
|
@@ -362,7 +362,7 @@ class Page(Element):
|
|
|
362
362
|
raise TypeError("XML export is only available for straight bounding boxes for now.")
|
|
363
363
|
(xmin, ymin), (xmax, ymax) = block.geometry
|
|
364
364
|
block_div = SubElement(
|
|
365
|
-
|
|
365
|
+
page_div,
|
|
366
366
|
"div",
|
|
367
367
|
attrib={
|
|
368
368
|
"class": "ocr_carea",
|
|
@@ -550,7 +550,41 @@ class KIEPage(Element):
|
|
|
550
550
|
{int(round(xmax * width))} {int(round(ymax * height))}",
|
|
551
551
|
},
|
|
552
552
|
)
|
|
553
|
-
|
|
553
|
+
# NOTE: ocr_par, ocr_line and ocrx_word are the same because the KIE predictions contain only words
|
|
554
|
+
# This is a workaround to make it PDF/A compatible
|
|
555
|
+
par_div = SubElement(
|
|
556
|
+
prediction_div,
|
|
557
|
+
"p",
|
|
558
|
+
attrib={
|
|
559
|
+
"class": "ocr_par",
|
|
560
|
+
"id": f"{class_name}_par_{prediction_count}",
|
|
561
|
+
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
|
|
562
|
+
{int(round(xmax * width))} {int(round(ymax * height))}",
|
|
563
|
+
},
|
|
564
|
+
)
|
|
565
|
+
line_span = SubElement(
|
|
566
|
+
par_div,
|
|
567
|
+
"span",
|
|
568
|
+
attrib={
|
|
569
|
+
"class": "ocr_line",
|
|
570
|
+
"id": f"{class_name}_line_{prediction_count}",
|
|
571
|
+
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
|
|
572
|
+
{int(round(xmax * width))} {int(round(ymax * height))}; \
|
|
573
|
+
baseline 0 0; x_size 0; x_descenders 0; x_ascenders 0",
|
|
574
|
+
},
|
|
575
|
+
)
|
|
576
|
+
word_div = SubElement(
|
|
577
|
+
line_span,
|
|
578
|
+
"span",
|
|
579
|
+
attrib={
|
|
580
|
+
"class": "ocrx_word",
|
|
581
|
+
"id": f"{class_name}_word_{prediction_count}",
|
|
582
|
+
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
|
|
583
|
+
{int(round(xmax * width))} {int(round(ymax * height))}; \
|
|
584
|
+
x_wconf {int(round(prediction.confidence * 100))}",
|
|
585
|
+
},
|
|
586
|
+
)
|
|
587
|
+
word_div.text = prediction.value
|
|
554
588
|
prediction_count += 1
|
|
555
589
|
|
|
556
590
|
return ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr)
|
doctr/io/image/__init__.py
CHANGED
doctr/io/image/pytorch.py
CHANGED
doctr/models/_utils.py
CHANGED
|
@@ -63,7 +63,7 @@ def estimate_orientation(
|
|
|
63
63
|
thresh = img.astype(np.uint8)
|
|
64
64
|
|
|
65
65
|
page_orientation, orientation_confidence = general_page_orientation or (None, 0.0)
|
|
66
|
-
if page_orientation and orientation_confidence >= min_confidence:
|
|
66
|
+
if page_orientation is not None and orientation_confidence >= min_confidence:
|
|
67
67
|
# We rotate the image to the general orientation which improves the detection
|
|
68
68
|
# No expand needed bitmap is already padded
|
|
69
69
|
thresh = rotate_image(thresh, -page_orientation)
|
|
@@ -87,7 +87,7 @@ def estimate_orientation(
|
|
|
87
87
|
|
|
88
88
|
angles = []
|
|
89
89
|
for contour in contours[:n_ct]:
|
|
90
|
-
_, (w, h), angle = cv2.minAreaRect(contour)
|
|
90
|
+
_, (w, h), angle = cv2.minAreaRect(contour)
|
|
91
91
|
if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines
|
|
92
92
|
angles.append(angle)
|
|
93
93
|
elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree
|
|
@@ -100,7 +100,7 @@ def estimate_orientation(
|
|
|
100
100
|
estimated_angle = -round(median) if abs(median) != 0 else 0
|
|
101
101
|
|
|
102
102
|
# combine with the general orientation and the estimated angle
|
|
103
|
-
if page_orientation and orientation_confidence >= min_confidence:
|
|
103
|
+
if page_orientation is not None and orientation_confidence >= min_confidence:
|
|
104
104
|
# special case where the estimated angle is mostly wrong:
|
|
105
105
|
# case 1: - and + swapped
|
|
106
106
|
# case 2: estimated angle is completely wrong
|
|
@@ -184,7 +184,7 @@ def invert_data_structure(
|
|
|
184
184
|
dictionary of list when x is a list of dictionaries or a list of dictionaries when x is dictionary of lists
|
|
185
185
|
"""
|
|
186
186
|
if isinstance(x, dict):
|
|
187
|
-
assert len({len(v) for v in x.values()}) == 1, "All the lists in the
|
|
187
|
+
assert len({len(v) for v in x.values()}) == 1, "All the lists in the dictionary should have the same length."
|
|
188
188
|
return [dict(zip(x, t)) for t in zip(*x.values())]
|
|
189
189
|
elif isinstance(x, list):
|
|
190
190
|
return {k: [dic[k] for dic in x] for k in x[0]}
|
|
@@ -14,8 +14,7 @@ from torch import nn
|
|
|
14
14
|
|
|
15
15
|
from doctr.datasets import VOCABS
|
|
16
16
|
|
|
17
|
-
from
|
|
18
|
-
from ..resnet.pytorch import ResNet
|
|
17
|
+
from ..resnet import ResNet
|
|
19
18
|
|
|
20
19
|
__all__ = ["magc_resnet31"]
|
|
21
20
|
|
|
@@ -73,7 +72,7 @@ class MAGC(nn.Module):
|
|
|
73
72
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
74
73
|
batch, _, height, width = inputs.size()
|
|
75
74
|
# (N * headers, C / headers, H , W)
|
|
76
|
-
x = inputs.view(batch * self.headers, self.single_header_inplanes, height, width)
|
|
75
|
+
x = inputs.contiguous().view(batch * self.headers, self.single_header_inplanes, height, width)
|
|
77
76
|
shortcut = x
|
|
78
77
|
# (N * headers, C / headers, H * W)
|
|
79
78
|
shortcut = shortcut.view(batch * self.headers, self.single_header_inplanes, height * width)
|
|
@@ -136,7 +135,7 @@ def _magc_resnet(
|
|
|
136
135
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
137
136
|
# remove the last layer weights
|
|
138
137
|
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
139
|
-
|
|
138
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
140
139
|
|
|
141
140
|
return model
|
|
142
141
|
|
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
# Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
|
|
7
7
|
|
|
8
|
+
import types
|
|
8
9
|
from copy import deepcopy
|
|
9
10
|
from typing import Any
|
|
10
11
|
|
|
@@ -99,12 +100,25 @@ def _mobilenet_v3(
|
|
|
99
100
|
m = getattr(m, child)
|
|
100
101
|
m.stride = (2, 1)
|
|
101
102
|
|
|
103
|
+
# monkeypatch the model to allow for loading pretrained parameters
|
|
104
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
|
|
105
|
+
"""Load pretrained parameters onto the model
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
109
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
110
|
+
"""
|
|
111
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
112
|
+
|
|
113
|
+
# Bind method to the instance
|
|
114
|
+
model.from_pretrained = types.MethodType(from_pretrained, model)
|
|
115
|
+
|
|
102
116
|
# Load pretrained parameters
|
|
103
117
|
if pretrained:
|
|
104
118
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
105
119
|
# remove the last layer weights
|
|
106
120
|
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
107
|
-
|
|
121
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
108
122
|
|
|
109
123
|
model.cfg = _cfg
|
|
110
124
|
|
|
@@ -35,7 +35,7 @@ class OrientationPredictor(nn.Module):
|
|
|
35
35
|
@torch.inference_mode()
|
|
36
36
|
def forward(
|
|
37
37
|
self,
|
|
38
|
-
inputs: list[np.ndarray
|
|
38
|
+
inputs: list[np.ndarray],
|
|
39
39
|
) -> list[list[int] | list[float]]:
|
|
40
40
|
# Dimension check
|
|
41
41
|
if any(input.ndim != 3 for input in inputs):
|
|
@@ -50,7 +50,7 @@ class OrientationPredictor(nn.Module):
|
|
|
50
50
|
self.model, processed_batches = set_device_and_dtype(
|
|
51
51
|
self.model, processed_batches, _params.device, _params.dtype
|
|
52
52
|
)
|
|
53
|
-
predicted_batches = [self.model(batch) for batch in processed_batches]
|
|
53
|
+
predicted_batches = [self.model(batch) for batch in processed_batches]
|
|
54
54
|
# confidence
|
|
55
55
|
probs = [
|
|
56
56
|
torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches
|
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
|
+
import types
|
|
7
8
|
from collections.abc import Callable
|
|
8
9
|
from copy import deepcopy
|
|
9
10
|
from typing import Any
|
|
@@ -152,6 +153,15 @@ class ResNet(nn.Sequential):
|
|
|
152
153
|
nn.init.constant_(m.weight, 1)
|
|
153
154
|
nn.init.constant_(m.bias, 0)
|
|
154
155
|
|
|
156
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
157
|
+
"""Load pretrained parameters onto the model
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
161
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
162
|
+
"""
|
|
163
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
164
|
+
|
|
155
165
|
|
|
156
166
|
def _resnet(
|
|
157
167
|
arch: str,
|
|
@@ -179,7 +189,7 @@ def _resnet(
|
|
|
179
189
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
180
190
|
# remove the last layer weights
|
|
181
191
|
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
182
|
-
|
|
192
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
183
193
|
|
|
184
194
|
return model
|
|
185
195
|
|
|
@@ -201,12 +211,25 @@ def _tv_resnet(
|
|
|
201
211
|
|
|
202
212
|
# Build the model
|
|
203
213
|
model = arch_fn(**kwargs, weights=None)
|
|
204
|
-
|
|
214
|
+
|
|
215
|
+
# monkeypatch the model to allow for loading pretrained parameters
|
|
216
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
|
|
217
|
+
"""Load pretrained parameters onto the model
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
221
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
222
|
+
"""
|
|
223
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
224
|
+
|
|
225
|
+
# Bind method to the instance
|
|
226
|
+
model.from_pretrained = types.MethodType(from_pretrained, model)
|
|
227
|
+
|
|
205
228
|
if pretrained:
|
|
206
229
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
207
230
|
# remove the last layer weights
|
|
208
231
|
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
209
|
-
|
|
232
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
210
233
|
|
|
211
234
|
model.cfg = _cfg
|
|
212
235
|
|
|
@@ -11,7 +11,7 @@ from torch import nn
|
|
|
11
11
|
|
|
12
12
|
from doctr.datasets import VOCABS
|
|
13
13
|
|
|
14
|
-
from ...modules.layers
|
|
14
|
+
from ...modules.layers import FASTConvLayer
|
|
15
15
|
from ...utils import conv_sequence_pt, load_pretrained_params
|
|
16
16
|
|
|
17
17
|
__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
|
|
@@ -93,6 +93,15 @@ class TextNet(nn.Sequential):
|
|
|
93
93
|
nn.init.constant_(m.weight, 1)
|
|
94
94
|
nn.init.constant_(m.bias, 0)
|
|
95
95
|
|
|
96
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
97
|
+
"""Load pretrained parameters onto the model
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
101
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
102
|
+
"""
|
|
103
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
104
|
+
|
|
96
105
|
|
|
97
106
|
def _textnet(
|
|
98
107
|
arch: str,
|
|
@@ -115,7 +124,7 @@ def _textnet(
|
|
|
115
124
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
116
125
|
# remove the last layer weights
|
|
117
126
|
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
118
|
-
|
|
127
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
119
128
|
|
|
120
129
|
model.cfg = _cfg
|
|
121
130
|
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
+
import types
|
|
6
7
|
from copy import deepcopy
|
|
7
8
|
from typing import Any
|
|
8
9
|
|
|
@@ -53,12 +54,26 @@ def _vgg(
|
|
|
53
54
|
# Patch average pool & classification head
|
|
54
55
|
model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
55
56
|
model.classifier = nn.Linear(512, kwargs["num_classes"])
|
|
57
|
+
|
|
58
|
+
# monkeypatch the model to allow for loading pretrained parameters
|
|
59
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
|
|
60
|
+
"""Load pretrained parameters onto the model
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
64
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
65
|
+
"""
|
|
66
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
67
|
+
|
|
68
|
+
# Bind method to the instance
|
|
69
|
+
model.from_pretrained = types.MethodType(from_pretrained, model)
|
|
70
|
+
|
|
56
71
|
# Load pretrained parameters
|
|
57
72
|
if pretrained:
|
|
58
73
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
59
74
|
# remove the last layer weights
|
|
60
75
|
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
61
|
-
|
|
76
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
62
77
|
|
|
63
78
|
model.cfg = _cfg
|
|
64
79
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .pytorch import *
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .pytorch import *
|