python-doctr 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +1 -1
- doctr/contrib/__init__.py +0 -0
- doctr/contrib/artefacts.py +131 -0
- doctr/contrib/base.py +105 -0
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/generator/base.py +6 -5
- doctr/datasets/imgur5k.py +1 -1
- doctr/datasets/loader.py +1 -6
- doctr/datasets/utils.py +2 -1
- doctr/datasets/vocabs.py +9 -2
- doctr/file_utils.py +26 -12
- doctr/io/elements.py +40 -6
- doctr/io/html.py +2 -2
- doctr/io/image/pytorch.py +6 -8
- doctr/io/image/tensorflow.py +1 -1
- doctr/io/pdf.py +5 -2
- doctr/io/reader.py +6 -0
- doctr/models/__init__.py +0 -1
- doctr/models/_utils.py +57 -20
- doctr/models/builder.py +71 -13
- doctr/models/classification/mobilenet/pytorch.py +45 -9
- doctr/models/classification/mobilenet/tensorflow.py +38 -7
- doctr/models/classification/predictor/pytorch.py +18 -11
- doctr/models/classification/predictor/tensorflow.py +16 -10
- doctr/models/classification/textnet/pytorch.py +3 -3
- doctr/models/classification/textnet/tensorflow.py +3 -3
- doctr/models/classification/zoo.py +39 -15
- doctr/models/detection/__init__.py +1 -0
- doctr/models/detection/_utils/__init__.py +1 -0
- doctr/models/detection/_utils/base.py +66 -0
- doctr/models/detection/differentiable_binarization/base.py +4 -3
- doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
- doctr/models/detection/differentiable_binarization/tensorflow.py +14 -18
- doctr/models/detection/fast/__init__.py +6 -0
- doctr/models/detection/fast/base.py +257 -0
- doctr/models/detection/fast/pytorch.py +442 -0
- doctr/models/detection/fast/tensorflow.py +428 -0
- doctr/models/detection/linknet/base.py +4 -3
- doctr/models/detection/predictor/pytorch.py +15 -1
- doctr/models/detection/predictor/tensorflow.py +15 -1
- doctr/models/detection/zoo.py +21 -4
- doctr/models/factory/hub.py +3 -12
- doctr/models/kie_predictor/base.py +9 -3
- doctr/models/kie_predictor/pytorch.py +41 -20
- doctr/models/kie_predictor/tensorflow.py +36 -16
- doctr/models/modules/layers/pytorch.py +89 -10
- doctr/models/modules/layers/tensorflow.py +88 -10
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/predictor/base.py +77 -50
- doctr/models/predictor/pytorch.py +31 -20
- doctr/models/predictor/tensorflow.py +27 -17
- doctr/models/preprocessor/pytorch.py +4 -4
- doctr/models/preprocessor/tensorflow.py +3 -2
- doctr/models/recognition/master/pytorch.py +2 -2
- doctr/models/recognition/parseq/pytorch.py +4 -3
- doctr/models/recognition/parseq/tensorflow.py +4 -3
- doctr/models/recognition/sar/pytorch.py +7 -6
- doctr/models/recognition/sar/tensorflow.py +3 -9
- doctr/models/recognition/vitstr/pytorch.py +1 -1
- doctr/models/recognition/zoo.py +1 -1
- doctr/models/zoo.py +2 -2
- doctr/py.typed +0 -0
- doctr/transforms/functional/base.py +1 -1
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/base.py +37 -15
- doctr/transforms/modules/pytorch.py +66 -8
- doctr/transforms/modules/tensorflow.py +63 -7
- doctr/utils/fonts.py +7 -5
- doctr/utils/geometry.py +35 -12
- doctr/utils/metrics.py +33 -174
- doctr/utils/reconstitution.py +126 -0
- doctr/utils/visualization.py +5 -118
- doctr/version.py +1 -1
- {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/METADATA +96 -91
- {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/RECORD +79 -75
- {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/WHEEL +1 -1
- doctr/models/artefacts/__init__.py +0 -2
- doctr/models/artefacts/barcode.py +0 -74
- doctr/models/artefacts/face.py +0 -63
- doctr/models/obj_detection/__init__.py +0 -1
- doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
- {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.8.0.dist-info → python_doctr-0.9.0.dist-info}/zip-safe +0 -0
doctr/__init__.py
CHANGED
|
File without changes
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import cv2
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from doctr.file_utils import requires_package
|
|
12
|
+
|
|
13
|
+
from .base import _BasePredictor
|
|
14
|
+
|
|
15
|
+
__all__ = ["ArtefactDetector"]
|
|
16
|
+
|
|
17
|
+
default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
18
|
+
"yolov8_artefact": {
|
|
19
|
+
"input_shape": (3, 1024, 1024),
|
|
20
|
+
"labels": ["bar_code", "qr_code", "logo", "photo"],
|
|
21
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/yolo_artefact-f9d66f14.onnx&src=0",
|
|
22
|
+
},
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ArtefactDetector(_BasePredictor):
|
|
27
|
+
"""
|
|
28
|
+
A class to detect artefacts in images
|
|
29
|
+
|
|
30
|
+
>>> from doctr.io import DocumentFile
|
|
31
|
+
>>> from doctr.contrib.artefacts import ArtefactDetector
|
|
32
|
+
>>> doc = DocumentFile.from_images(["path/to/image.jpg"])
|
|
33
|
+
>>> detector = ArtefactDetector()
|
|
34
|
+
>>> results = detector(doc)
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
----
|
|
38
|
+
arch: the architecture to use
|
|
39
|
+
batch_size: the batch size to use
|
|
40
|
+
model_path: the path to the model to use
|
|
41
|
+
labels: the labels to use
|
|
42
|
+
input_shape: the input shape to use
|
|
43
|
+
mask_labels: the mask labels to use
|
|
44
|
+
conf_threshold: the confidence threshold to use
|
|
45
|
+
iou_threshold: the intersection over union threshold to use
|
|
46
|
+
**kwargs: additional arguments to be passed to `download_from_url`
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
arch: str = "yolov8_artefact",
|
|
52
|
+
batch_size: int = 2,
|
|
53
|
+
model_path: Optional[str] = None,
|
|
54
|
+
labels: Optional[List[str]] = None,
|
|
55
|
+
input_shape: Optional[Tuple[int, int, int]] = None,
|
|
56
|
+
conf_threshold: float = 0.5,
|
|
57
|
+
iou_threshold: float = 0.5,
|
|
58
|
+
**kwargs: Any,
|
|
59
|
+
) -> None:
|
|
60
|
+
super().__init__(batch_size=batch_size, url=default_cfgs[arch]["url"], model_path=model_path, **kwargs)
|
|
61
|
+
self.labels = labels or default_cfgs[arch]["labels"]
|
|
62
|
+
self.input_shape = input_shape or default_cfgs[arch]["input_shape"]
|
|
63
|
+
self.conf_threshold = conf_threshold
|
|
64
|
+
self.iou_threshold = iou_threshold
|
|
65
|
+
|
|
66
|
+
def preprocess(self, img: np.ndarray) -> np.ndarray:
|
|
67
|
+
return np.transpose(cv2.resize(img, (self.input_shape[2], self.input_shape[1])), (2, 0, 1)) / np.array(255.0)
|
|
68
|
+
|
|
69
|
+
def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> List[List[Dict[str, Any]]]:
|
|
70
|
+
results = []
|
|
71
|
+
|
|
72
|
+
for batch in zip(output, input_images):
|
|
73
|
+
for out, img in zip(batch[0], batch[1]):
|
|
74
|
+
org_height, org_width = img.shape[:2]
|
|
75
|
+
width_scale, height_scale = org_width / self.input_shape[2], org_height / self.input_shape[1]
|
|
76
|
+
for res in out:
|
|
77
|
+
sample_results = []
|
|
78
|
+
for row in np.transpose(np.squeeze(res)):
|
|
79
|
+
classes_scores = row[4:]
|
|
80
|
+
max_score = np.amax(classes_scores)
|
|
81
|
+
if max_score >= self.conf_threshold:
|
|
82
|
+
class_id = np.argmax(classes_scores)
|
|
83
|
+
x, y, w, h = row[0], row[1], row[2], row[3]
|
|
84
|
+
# to rescaled xmin, ymin, xmax, ymax
|
|
85
|
+
xmin = int((x - w / 2) * width_scale)
|
|
86
|
+
ymin = int((y - h / 2) * height_scale)
|
|
87
|
+
xmax = int((x + w / 2) * width_scale)
|
|
88
|
+
ymax = int((y + h / 2) * height_scale)
|
|
89
|
+
|
|
90
|
+
sample_results.append({
|
|
91
|
+
"label": self.labels[class_id],
|
|
92
|
+
"confidence": float(max_score),
|
|
93
|
+
"box": [xmin, ymin, xmax, ymax],
|
|
94
|
+
})
|
|
95
|
+
|
|
96
|
+
# Filter out overlapping boxes
|
|
97
|
+
boxes = [res["box"] for res in sample_results]
|
|
98
|
+
scores = [res["confidence"] for res in sample_results]
|
|
99
|
+
keep_indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf_threshold, self.iou_threshold) # type: ignore[arg-type]
|
|
100
|
+
sample_results = [sample_results[i] for i in keep_indices]
|
|
101
|
+
|
|
102
|
+
results.append(sample_results)
|
|
103
|
+
|
|
104
|
+
self._results = results
|
|
105
|
+
return results
|
|
106
|
+
|
|
107
|
+
def show(self, **kwargs: Any) -> None:
|
|
108
|
+
"""
|
|
109
|
+
Display the results
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
----
|
|
113
|
+
**kwargs: additional keyword arguments to be passed to `plt.show`
|
|
114
|
+
"""
|
|
115
|
+
requires_package("matplotlib", "`.show()` requires matplotlib installed")
|
|
116
|
+
import matplotlib.pyplot as plt
|
|
117
|
+
from matplotlib.patches import Rectangle
|
|
118
|
+
|
|
119
|
+
# visualize the results with matplotlib
|
|
120
|
+
if self._results and self._inputs:
|
|
121
|
+
for img, res in zip(self._inputs, self._results):
|
|
122
|
+
plt.figure(figsize=(10, 10))
|
|
123
|
+
plt.imshow(img)
|
|
124
|
+
for obj in res:
|
|
125
|
+
xmin, ymin, xmax, ymax = obj["box"]
|
|
126
|
+
label = obj["label"]
|
|
127
|
+
plt.text(xmin, ymin, f"{label} {obj['confidence']:.2f}", color="red")
|
|
128
|
+
plt.gca().add_patch(
|
|
129
|
+
Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor="red", linewidth=2)
|
|
130
|
+
)
|
|
131
|
+
plt.show(**kwargs)
|
doctr/contrib/base.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Any, List, Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from doctr.file_utils import requires_package
|
|
11
|
+
from doctr.utils.data import download_from_url
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class _BasePredictor:
|
|
15
|
+
"""
|
|
16
|
+
Base class for all predictors
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
----
|
|
20
|
+
batch_size: the batch size to use
|
|
21
|
+
url: the url to use to download a model if needed
|
|
22
|
+
model_path: the path to the model to use
|
|
23
|
+
**kwargs: additional arguments to be passed to `download_from_url`
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, batch_size: int, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs) -> None:
|
|
27
|
+
self.batch_size = batch_size
|
|
28
|
+
self.session = self._init_model(url, model_path, **kwargs)
|
|
29
|
+
|
|
30
|
+
self._inputs: List[np.ndarray] = []
|
|
31
|
+
self._results: List[Any] = []
|
|
32
|
+
|
|
33
|
+
def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs: Any) -> Any:
|
|
34
|
+
"""
|
|
35
|
+
Download the model from the given url if needed
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
----
|
|
39
|
+
url: the url to use
|
|
40
|
+
model_path: the path to the model to use
|
|
41
|
+
**kwargs: additional arguments to be passed to `download_from_url`
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
-------
|
|
45
|
+
Any: the ONNX loaded model
|
|
46
|
+
"""
|
|
47
|
+
requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.")
|
|
48
|
+
import onnxruntime as ort
|
|
49
|
+
|
|
50
|
+
if not url and not model_path:
|
|
51
|
+
raise ValueError("You must provide either a url or a model_path")
|
|
52
|
+
onnx_model_path = model_path if model_path else str(download_from_url(url, cache_subdir="models", **kwargs)) # type: ignore[arg-type]
|
|
53
|
+
return ort.InferenceSession(onnx_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
|
54
|
+
|
|
55
|
+
def preprocess(self, img: np.ndarray) -> np.ndarray:
|
|
56
|
+
"""
|
|
57
|
+
Preprocess the input image
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
----
|
|
61
|
+
img: the input image to preprocess
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
-------
|
|
65
|
+
np.ndarray: the preprocessed image
|
|
66
|
+
"""
|
|
67
|
+
raise NotImplementedError
|
|
68
|
+
|
|
69
|
+
def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> Any:
|
|
70
|
+
"""
|
|
71
|
+
Postprocess the model output
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
----
|
|
75
|
+
output: the model output to postprocess
|
|
76
|
+
input_images: the input images used to generate the output
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
-------
|
|
80
|
+
Any: the postprocessed output
|
|
81
|
+
"""
|
|
82
|
+
raise NotImplementedError
|
|
83
|
+
|
|
84
|
+
def __call__(self, inputs: List[np.ndarray]) -> Any:
|
|
85
|
+
"""
|
|
86
|
+
Call the model on the given inputs
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
----
|
|
90
|
+
inputs: the inputs to use
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
-------
|
|
94
|
+
Any: the postprocessed output
|
|
95
|
+
"""
|
|
96
|
+
self._inputs = inputs
|
|
97
|
+
model_inputs = self.session.get_inputs()
|
|
98
|
+
|
|
99
|
+
batched_inputs = [inputs[i : i + self.batch_size] for i in range(0, len(inputs), self.batch_size)]
|
|
100
|
+
processed_batches = [
|
|
101
|
+
np.array([self.preprocess(img) for img in batch], dtype=np.float32) for batch in batched_inputs
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
outputs = [self.session.run(None, {model_inputs[0].name: batch}) for batch in processed_batches]
|
|
105
|
+
return self.postprocess(outputs, batched_inputs)
|
|
@@ -50,9 +50,9 @@ class AbstractDataset(_AbstractDataset):
|
|
|
50
50
|
@staticmethod
|
|
51
51
|
def collate_fn(samples: List[Tuple[torch.Tensor, Any]]) -> Tuple[torch.Tensor, List[Any]]:
|
|
52
52
|
images, targets = zip(*samples)
|
|
53
|
-
images = torch.stack(images, dim=0)
|
|
53
|
+
images = torch.stack(images, dim=0) # type: ignore[assignment]
|
|
54
54
|
|
|
55
|
-
return images, list(targets)
|
|
55
|
+
return images, list(targets) # type: ignore[return-value]
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
|
doctr/datasets/generator/base.py
CHANGED
|
@@ -20,7 +20,7 @@ def synthesize_text_img(
|
|
|
20
20
|
font_family: Optional[str] = None,
|
|
21
21
|
background_color: Optional[Tuple[int, int, int]] = None,
|
|
22
22
|
text_color: Optional[Tuple[int, int, int]] = None,
|
|
23
|
-
) -> Image:
|
|
23
|
+
) -> Image.Image:
|
|
24
24
|
"""Generate a synthetic text image
|
|
25
25
|
|
|
26
26
|
Args:
|
|
@@ -81,7 +81,7 @@ class _CharacterGenerator(AbstractDataset):
|
|
|
81
81
|
self._data: List[Image.Image] = []
|
|
82
82
|
if cache_samples:
|
|
83
83
|
self._data = [
|
|
84
|
-
(synthesize_text_img(char, font_family=font), idx)
|
|
84
|
+
(synthesize_text_img(char, font_family=font), idx) # type: ignore[misc]
|
|
85
85
|
for idx, char in enumerate(self.vocab)
|
|
86
86
|
for font in self.font_family
|
|
87
87
|
]
|
|
@@ -93,7 +93,7 @@ class _CharacterGenerator(AbstractDataset):
|
|
|
93
93
|
# Samples are already cached
|
|
94
94
|
if len(self._data) > 0:
|
|
95
95
|
idx = index % len(self._data)
|
|
96
|
-
pil_img, target = self._data[idx]
|
|
96
|
+
pil_img, target = self._data[idx] # type: ignore[misc]
|
|
97
97
|
else:
|
|
98
98
|
target = index % len(self.vocab)
|
|
99
99
|
pil_img = synthesize_text_img(self.vocab[target], font_family=random.choice(self.font_family))
|
|
@@ -132,7 +132,8 @@ class _WordGenerator(AbstractDataset):
|
|
|
132
132
|
if cache_samples:
|
|
133
133
|
_words = [self._generate_string(*self.wordlen_range) for _ in range(num_samples)]
|
|
134
134
|
self._data = [
|
|
135
|
-
(synthesize_text_img(text, font_family=random.choice(self.font_family)), text)
|
|
135
|
+
(synthesize_text_img(text, font_family=random.choice(self.font_family)), text) # type: ignore[misc]
|
|
136
|
+
for text in _words
|
|
136
137
|
]
|
|
137
138
|
|
|
138
139
|
def _generate_string(self, min_chars: int, max_chars: int) -> str:
|
|
@@ -145,7 +146,7 @@ class _WordGenerator(AbstractDataset):
|
|
|
145
146
|
def _read_sample(self, index: int) -> Tuple[Any, str]:
|
|
146
147
|
# Samples are already cached
|
|
147
148
|
if len(self._data) > 0:
|
|
148
|
-
pil_img, target = self._data[index]
|
|
149
|
+
pil_img, target = self._data[index] # type: ignore[misc]
|
|
149
150
|
else:
|
|
150
151
|
target = self._generate_string(*self.wordlen_range)
|
|
151
152
|
pil_img = synthesize_text_img(target, font_family=random.choice(self.font_family))
|
doctr/datasets/imgur5k.py
CHANGED
|
@@ -112,7 +112,7 @@ class IMGUR5K(AbstractDataset):
|
|
|
112
112
|
if ann["word"] != "."
|
|
113
113
|
]
|
|
114
114
|
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
|
115
|
-
box_targets = [cv2.boxPoints(((box[0], box[1]), (box[2], box[3]), box[4])) for box in _boxes]
|
|
115
|
+
box_targets = [cv2.boxPoints(((box[0], box[1]), (box[2], box[3]), box[4])) for box in _boxes]
|
|
116
116
|
|
|
117
117
|
if not use_polygons:
|
|
118
118
|
# xmin, ymin, xmax, ymax
|
doctr/datasets/loader.py
CHANGED
|
@@ -9,8 +9,6 @@ from typing import Callable, Optional
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import tensorflow as tf
|
|
11
11
|
|
|
12
|
-
from doctr.utils.multithreading import multithread_exec
|
|
13
|
-
|
|
14
12
|
__all__ = ["DataLoader"]
|
|
15
13
|
|
|
16
14
|
|
|
@@ -47,7 +45,6 @@ class DataLoader:
|
|
|
47
45
|
shuffle: whether the samples should be shuffled before passing it to the iterator
|
|
48
46
|
batch_size: number of elements in each batch
|
|
49
47
|
drop_last: if `True`, drops the last batch if it isn't full
|
|
50
|
-
num_workers: number of workers to use for data loading
|
|
51
48
|
collate_fn: function to merge samples into a batch
|
|
52
49
|
"""
|
|
53
50
|
|
|
@@ -57,7 +54,6 @@ class DataLoader:
|
|
|
57
54
|
shuffle: bool = True,
|
|
58
55
|
batch_size: int = 1,
|
|
59
56
|
drop_last: bool = False,
|
|
60
|
-
num_workers: Optional[int] = None,
|
|
61
57
|
collate_fn: Optional[Callable] = None,
|
|
62
58
|
) -> None:
|
|
63
59
|
self.dataset = dataset
|
|
@@ -69,7 +65,6 @@ class DataLoader:
|
|
|
69
65
|
self.collate_fn = self.dataset.collate_fn if hasattr(self.dataset, "collate_fn") else default_collate
|
|
70
66
|
else:
|
|
71
67
|
self.collate_fn = collate_fn
|
|
72
|
-
self.num_workers = num_workers
|
|
73
68
|
self.reset()
|
|
74
69
|
|
|
75
70
|
def __len__(self) -> int:
|
|
@@ -92,7 +87,7 @@ class DataLoader:
|
|
|
92
87
|
idx = self._num_yielded * self.batch_size
|
|
93
88
|
indices = self.indices[idx : min(len(self.dataset), idx + self.batch_size)]
|
|
94
89
|
|
|
95
|
-
samples = list(
|
|
90
|
+
samples = list(map(self.dataset.__getitem__, indices))
|
|
96
91
|
|
|
97
92
|
batch_data = self.collate_fn(samples)
|
|
98
93
|
|
doctr/datasets/utils.py
CHANGED
|
@@ -186,7 +186,8 @@ def crop_bboxes_from_image(img_path: Union[str, Path], geoms: np.ndarray) -> Lis
|
|
|
186
186
|
-------
|
|
187
187
|
a list of cropped images
|
|
188
188
|
"""
|
|
189
|
-
|
|
189
|
+
with Image.open(img_path) as pil_img:
|
|
190
|
+
img: np.ndarray = np.array(pil_img.convert("RGB"))
|
|
190
191
|
# Polygon
|
|
191
192
|
if geoms.ndim == 3 and geoms.shape[1:] == (4, 2):
|
|
192
193
|
return extract_rcrops(img, geoms.astype(dtype=int))
|
doctr/datasets/vocabs.py
CHANGED
|
@@ -17,9 +17,14 @@ VOCABS: Dict[str, str] = {
|
|
|
17
17
|
"ancient_greek": "αβγδεζηθικλμνξοπρστυφχψωΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩ",
|
|
18
18
|
"arabic_letters": "ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىي",
|
|
19
19
|
"persian_letters": "پچڢڤگ",
|
|
20
|
-
"
|
|
20
|
+
"arabic_digits": "٠١٢٣٤٥٦٧٨٩",
|
|
21
21
|
"arabic_diacritics": "ًٌٍَُِّْ",
|
|
22
22
|
"arabic_punctuation": "؟؛«»—",
|
|
23
|
+
"hindi_letters": "अआइईउऊऋॠऌॡएऐओऔअंअःकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसह",
|
|
24
|
+
"hindi_digits": "०१२३४५६७८९",
|
|
25
|
+
"hindi_punctuation": "।,?!:्ॐ॰॥॰",
|
|
26
|
+
"bangla_letters": "অআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ঽািীুূৃেৈোৌ্ৎংঃঁ",
|
|
27
|
+
"bangla_digits": "০১২৩৪৫৬৭৮৯",
|
|
23
28
|
}
|
|
24
29
|
|
|
25
30
|
VOCABS["latin"] = VOCABS["digits"] + VOCABS["ascii_letters"] + VOCABS["punctuation"]
|
|
@@ -32,7 +37,7 @@ VOCABS["italian"] = VOCABS["english"] + "àèéìíîòóùúÀÈÉÌÍÎÒÓÙ
|
|
|
32
37
|
VOCABS["german"] = VOCABS["english"] + "äöüßÄÖÜẞ"
|
|
33
38
|
VOCABS["arabic"] = (
|
|
34
39
|
VOCABS["digits"]
|
|
35
|
-
+ VOCABS["
|
|
40
|
+
+ VOCABS["arabic_digits"]
|
|
36
41
|
+ VOCABS["arabic_letters"]
|
|
37
42
|
+ VOCABS["persian_letters"]
|
|
38
43
|
+ VOCABS["arabic_diacritics"]
|
|
@@ -52,6 +57,8 @@ VOCABS["vietnamese"] = (
|
|
|
52
57
|
+ "ÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ"
|
|
53
58
|
)
|
|
54
59
|
VOCABS["hebrew"] = VOCABS["english"] + "אבגדהוזחטיכלמנסעפצקרשת" + "₪"
|
|
60
|
+
VOCABS["hindi"] = VOCABS["hindi_letters"] + VOCABS["hindi_digits"] + VOCABS["hindi_punctuation"]
|
|
61
|
+
VOCABS["bangla"] = VOCABS["bangla_letters"] + VOCABS["bangla_digits"]
|
|
55
62
|
VOCABS["multilingual"] = "".join(
|
|
56
63
|
dict.fromkeys(
|
|
57
64
|
VOCABS["french"]
|
doctr/file_utils.py
CHANGED
|
@@ -5,21 +5,16 @@
|
|
|
5
5
|
|
|
6
6
|
# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
|
|
7
7
|
|
|
8
|
+
import importlib.metadata
|
|
8
9
|
import importlib.util
|
|
9
10
|
import logging
|
|
10
11
|
import os
|
|
11
|
-
import
|
|
12
|
+
from typing import Optional
|
|
12
13
|
|
|
13
14
|
CLASS_NAME: str = "words"
|
|
14
15
|
|
|
15
16
|
|
|
16
|
-
|
|
17
|
-
import importlib_metadata
|
|
18
|
-
else:
|
|
19
|
-
import importlib.metadata as importlib_metadata
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
__all__ = ["is_tf_available", "is_torch_available", "CLASS_NAME"]
|
|
17
|
+
__all__ = ["is_tf_available", "is_torch_available", "requires_package", "CLASS_NAME"]
|
|
23
18
|
|
|
24
19
|
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
|
25
20
|
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
|
@@ -32,9 +27,9 @@ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VA
|
|
|
32
27
|
_torch_available = importlib.util.find_spec("torch") is not None
|
|
33
28
|
if _torch_available:
|
|
34
29
|
try:
|
|
35
|
-
_torch_version =
|
|
30
|
+
_torch_version = importlib.metadata.version("torch")
|
|
36
31
|
logging.info(f"PyTorch version {_torch_version} available.")
|
|
37
|
-
except
|
|
32
|
+
except importlib.metadata.PackageNotFoundError: # pragma: no cover
|
|
38
33
|
_torch_available = False
|
|
39
34
|
else: # pragma: no cover
|
|
40
35
|
logging.info("Disabling PyTorch because USE_TF is set")
|
|
@@ -59,9 +54,9 @@ if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VA
|
|
|
59
54
|
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
|
60
55
|
for pkg in candidates:
|
|
61
56
|
try:
|
|
62
|
-
_tf_version =
|
|
57
|
+
_tf_version = importlib.metadata.version(pkg)
|
|
63
58
|
break
|
|
64
|
-
except
|
|
59
|
+
except importlib.metadata.PackageNotFoundError:
|
|
65
60
|
pass
|
|
66
61
|
_tf_available = _tf_version is not None
|
|
67
62
|
if _tf_available:
|
|
@@ -82,6 +77,25 @@ if not _torch_available and not _tf_available: # pragma: no cover
|
|
|
82
77
|
)
|
|
83
78
|
|
|
84
79
|
|
|
80
|
+
def requires_package(name: str, extra_message: Optional[str] = None) -> None: # pragma: no cover
|
|
81
|
+
"""
|
|
82
|
+
package requirement helper
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
----
|
|
86
|
+
name: name of the package
|
|
87
|
+
extra_message: additional message to display if the package is not found
|
|
88
|
+
"""
|
|
89
|
+
try:
|
|
90
|
+
_pkg_version = importlib.metadata.version(name)
|
|
91
|
+
logging.info(f"{name} version {_pkg_version} available.")
|
|
92
|
+
except importlib.metadata.PackageNotFoundError:
|
|
93
|
+
raise ImportError(
|
|
94
|
+
f"\n\n{extra_message if extra_message is not None else ''} "
|
|
95
|
+
f"\nPlease install it with the following command: pip install {name}\n"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
85
99
|
def is_torch_available():
|
|
86
100
|
"""Whether PyTorch is installed."""
|
|
87
101
|
return _torch_available
|
doctr/io/elements.py
CHANGED
|
@@ -12,14 +12,19 @@ from xml.etree import ElementTree as ET
|
|
|
12
12
|
from xml.etree.ElementTree import Element as ETElement
|
|
13
13
|
from xml.etree.ElementTree import SubElement
|
|
14
14
|
|
|
15
|
-
import matplotlib.pyplot as plt
|
|
16
15
|
import numpy as np
|
|
17
16
|
|
|
18
17
|
import doctr
|
|
18
|
+
from doctr.file_utils import requires_package
|
|
19
19
|
from doctr.utils.common_types import BoundingBox
|
|
20
20
|
from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox
|
|
21
|
+
from doctr.utils.reconstitution import synthesize_kie_page, synthesize_page
|
|
21
22
|
from doctr.utils.repr import NestedObject
|
|
22
|
-
|
|
23
|
+
|
|
24
|
+
try: # optional dependency for visualization
|
|
25
|
+
from doctr.utils.visualization import visualize_kie_page, visualize_page
|
|
26
|
+
except ModuleNotFoundError:
|
|
27
|
+
pass
|
|
23
28
|
|
|
24
29
|
__all__ = ["Element", "Word", "Artefact", "Line", "Prediction", "Block", "Page", "KIEPage", "Document"]
|
|
25
30
|
|
|
@@ -67,16 +72,27 @@ class Word(Element):
|
|
|
67
72
|
confidence: the confidence associated with the text prediction
|
|
68
73
|
geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
|
|
69
74
|
the page's size
|
|
75
|
+
objectness_score: the objectness score of the detection
|
|
76
|
+
crop_orientation: the general orientation of the crop in degrees and its confidence
|
|
70
77
|
"""
|
|
71
78
|
|
|
72
|
-
_exported_keys: List[str] = ["value", "confidence", "geometry"]
|
|
79
|
+
_exported_keys: List[str] = ["value", "confidence", "geometry", "objectness_score", "crop_orientation"]
|
|
73
80
|
_children_names: List[str] = []
|
|
74
81
|
|
|
75
|
-
def __init__(
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
value: str,
|
|
85
|
+
confidence: float,
|
|
86
|
+
geometry: Union[BoundingBox, np.ndarray],
|
|
87
|
+
objectness_score: float,
|
|
88
|
+
crop_orientation: Dict[str, Any],
|
|
89
|
+
) -> None:
|
|
76
90
|
super().__init__()
|
|
77
91
|
self.value = value
|
|
78
92
|
self.confidence = confidence
|
|
79
93
|
self.geometry = geometry
|
|
94
|
+
self.objectness_score = objectness_score
|
|
95
|
+
self.crop_orientation = crop_orientation
|
|
80
96
|
|
|
81
97
|
def render(self) -> str:
|
|
82
98
|
"""Renders the full text of the element"""
|
|
@@ -135,7 +151,7 @@ class Line(Element):
|
|
|
135
151
|
all words in it.
|
|
136
152
|
"""
|
|
137
153
|
|
|
138
|
-
_exported_keys: List[str] = ["geometry"]
|
|
154
|
+
_exported_keys: List[str] = ["geometry", "objectness_score"]
|
|
139
155
|
_children_names: List[str] = ["words"]
|
|
140
156
|
words: List[Word] = []
|
|
141
157
|
|
|
@@ -143,7 +159,11 @@ class Line(Element):
|
|
|
143
159
|
self,
|
|
144
160
|
words: List[Word],
|
|
145
161
|
geometry: Optional[Union[BoundingBox, np.ndarray]] = None,
|
|
162
|
+
objectness_score: Optional[float] = None,
|
|
146
163
|
) -> None:
|
|
164
|
+
# Compute the objectness score of the line
|
|
165
|
+
if objectness_score is None:
|
|
166
|
+
objectness_score = float(np.mean([w.objectness_score for w in words]))
|
|
147
167
|
# Resolve the geometry using the smallest enclosing bounding box
|
|
148
168
|
if geometry is None:
|
|
149
169
|
# Check whether this is a rotated or straight box
|
|
@@ -152,6 +172,7 @@ class Line(Element):
|
|
|
152
172
|
|
|
153
173
|
super().__init__(words=words)
|
|
154
174
|
self.geometry = geometry
|
|
175
|
+
self.objectness_score = objectness_score
|
|
155
176
|
|
|
156
177
|
def render(self) -> str:
|
|
157
178
|
"""Renders the full text of the element"""
|
|
@@ -189,7 +210,7 @@ class Block(Element):
|
|
|
189
210
|
all lines and artefacts in it.
|
|
190
211
|
"""
|
|
191
212
|
|
|
192
|
-
_exported_keys: List[str] = ["geometry"]
|
|
213
|
+
_exported_keys: List[str] = ["geometry", "objectness_score"]
|
|
193
214
|
_children_names: List[str] = ["lines", "artefacts"]
|
|
194
215
|
lines: List[Line] = []
|
|
195
216
|
artefacts: List[Artefact] = []
|
|
@@ -199,7 +220,11 @@ class Block(Element):
|
|
|
199
220
|
lines: List[Line] = [],
|
|
200
221
|
artefacts: List[Artefact] = [],
|
|
201
222
|
geometry: Optional[Union[BoundingBox, np.ndarray]] = None,
|
|
223
|
+
objectness_score: Optional[float] = None,
|
|
202
224
|
) -> None:
|
|
225
|
+
# Compute the objectness score of the line
|
|
226
|
+
if objectness_score is None:
|
|
227
|
+
objectness_score = float(np.mean([w.objectness_score for line in lines for w in line.words]))
|
|
203
228
|
# Resolve the geometry using the smallest enclosing bounding box
|
|
204
229
|
if geometry is None:
|
|
205
230
|
line_boxes = [word.geometry for line in lines for word in line.words]
|
|
@@ -211,6 +236,7 @@ class Block(Element):
|
|
|
211
236
|
|
|
212
237
|
super().__init__(lines=lines, artefacts=artefacts)
|
|
213
238
|
self.geometry = geometry
|
|
239
|
+
self.objectness_score = objectness_score
|
|
214
240
|
|
|
215
241
|
def render(self, line_break: str = "\n") -> str:
|
|
216
242
|
"""Renders the full text of the element"""
|
|
@@ -274,6 +300,10 @@ class Page(Element):
|
|
|
274
300
|
preserve_aspect_ratio: pass True if you passed True to the predictor
|
|
275
301
|
**kwargs: additional keyword arguments passed to the matplotlib.pyplot.show method
|
|
276
302
|
"""
|
|
303
|
+
requires_package("matplotlib", "`.show()` requires matplotlib & mplcursors installed")
|
|
304
|
+
requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed")
|
|
305
|
+
import matplotlib.pyplot as plt
|
|
306
|
+
|
|
277
307
|
visualize_page(self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio)
|
|
278
308
|
plt.show(**kwargs)
|
|
279
309
|
|
|
@@ -449,6 +479,10 @@ class KIEPage(Element):
|
|
|
449
479
|
preserve_aspect_ratio: pass True if you passed True to the predictor
|
|
450
480
|
**kwargs: keyword arguments passed to the matplotlib.pyplot.show method
|
|
451
481
|
"""
|
|
482
|
+
requires_package("matplotlib", "`.show()` requires matplotlib & mplcursors installed")
|
|
483
|
+
requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed")
|
|
484
|
+
import matplotlib.pyplot as plt
|
|
485
|
+
|
|
452
486
|
visualize_kie_page(
|
|
453
487
|
self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio
|
|
454
488
|
)
|
doctr/io/html.py
CHANGED
|
@@ -5,8 +5,6 @@
|
|
|
5
5
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
-
from weasyprint import HTML
|
|
9
|
-
|
|
10
8
|
__all__ = ["read_html"]
|
|
11
9
|
|
|
12
10
|
|
|
@@ -25,4 +23,6 @@ def read_html(url: str, **kwargs: Any) -> bytes:
|
|
|
25
23
|
-------
|
|
26
24
|
decoded PDF file as a bytes stream
|
|
27
25
|
"""
|
|
26
|
+
from weasyprint import HTML
|
|
27
|
+
|
|
28
28
|
return HTML(url, **kwargs).write_pdf()
|
doctr/io/image/pytorch.py
CHANGED
|
@@ -16,7 +16,7 @@ from doctr.utils.common_types import AbstractPath
|
|
|
16
16
|
__all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"]
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
def tensor_from_pil(pil_img: Image, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
|
19
|
+
def tensor_from_pil(pil_img: Image.Image, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
|
20
20
|
"""Convert a PIL Image to a PyTorch tensor
|
|
21
21
|
|
|
22
22
|
Args:
|
|
@@ -51,9 +51,8 @@ def read_img_as_tensor(img_path: AbstractPath, dtype: torch.dtype = torch.float3
|
|
|
51
51
|
if dtype not in (torch.uint8, torch.float16, torch.float32):
|
|
52
52
|
raise ValueError("insupported value for dtype")
|
|
53
53
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
return tensor_from_pil(pil_img, dtype)
|
|
54
|
+
with Image.open(img_path, mode="r") as pil_img:
|
|
55
|
+
return tensor_from_pil(pil_img.convert("RGB"), dtype)
|
|
57
56
|
|
|
58
57
|
|
|
59
58
|
def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
|
@@ -71,9 +70,8 @@ def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32)
|
|
|
71
70
|
if dtype not in (torch.uint8, torch.float16, torch.float32):
|
|
72
71
|
raise ValueError("insupported value for dtype")
|
|
73
72
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
return tensor_from_pil(pil_img, dtype)
|
|
73
|
+
with Image.open(BytesIO(img_content), mode="r") as pil_img:
|
|
74
|
+
return tensor_from_pil(pil_img.convert("RGB"), dtype)
|
|
77
75
|
|
|
78
76
|
|
|
79
77
|
def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
|
@@ -106,4 +104,4 @@ def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -
|
|
|
106
104
|
|
|
107
105
|
def get_img_shape(img: torch.Tensor) -> Tuple[int, int]:
|
|
108
106
|
"""Get the shape of an image"""
|
|
109
|
-
return img.shape[-2:]
|
|
107
|
+
return img.shape[-2:] # type: ignore[return-value]
|
doctr/io/image/tensorflow.py
CHANGED
|
@@ -15,7 +15,7 @@ from doctr.utils.common_types import AbstractPath
|
|
|
15
15
|
__all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"]
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
def tensor_from_pil(pil_img: Image, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
|
|
18
|
+
def tensor_from_pil(pil_img: Image.Image, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
|
|
19
19
|
"""Convert a PIL Image to a TensorFlow tensor
|
|
20
20
|
|
|
21
21
|
Args:
|