python-doctr 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +2 -1
- doctr/datasets/funsd.py +2 -2
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +4 -1
- doctr/datasets/imgur5k.py +9 -2
- doctr/datasets/loader.py +1 -1
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +2 -2
- doctr/datasets/synthtext.py +15 -2
- doctr/datasets/utils.py +7 -6
- doctr/datasets/vocabs.py +1102 -54
- doctr/file_utils.py +9 -0
- doctr/io/elements.py +37 -3
- doctr/models/_utils.py +1 -1
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/pytorch.py +1 -2
- doctr/models/classification/magc_resnet/tensorflow.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +15 -1
- doctr/models/classification/mobilenet/tensorflow.py +11 -2
- doctr/models/classification/predictor/pytorch.py +1 -1
- doctr/models/classification/resnet/pytorch.py +26 -3
- doctr/models/classification/resnet/tensorflow.py +25 -4
- doctr/models/classification/textnet/pytorch.py +10 -1
- doctr/models/classification/textnet/tensorflow.py +11 -2
- doctr/models/classification/vgg/pytorch.py +16 -1
- doctr/models/classification/vgg/tensorflow.py +11 -2
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/pytorch.py +10 -1
- doctr/models/classification/vit/tensorflow.py +9 -0
- doctr/models/classification/zoo.py +4 -0
- doctr/models/detection/differentiable_binarization/base.py +3 -4
- doctr/models/detection/differentiable_binarization/pytorch.py +10 -1
- doctr/models/detection/differentiable_binarization/tensorflow.py +11 -4
- doctr/models/detection/fast/base.py +2 -3
- doctr/models/detection/fast/pytorch.py +13 -4
- doctr/models/detection/fast/tensorflow.py +10 -2
- doctr/models/detection/linknet/base.py +2 -3
- doctr/models/detection/linknet/pytorch.py +10 -1
- doctr/models/detection/linknet/tensorflow.py +10 -2
- doctr/models/factory/hub.py +3 -3
- doctr/models/kie_predictor/pytorch.py +1 -1
- doctr/models/kie_predictor/tensorflow.py +1 -1
- doctr/models/modules/layers/pytorch.py +49 -1
- doctr/models/predictor/pytorch.py +1 -1
- doctr/models/predictor/tensorflow.py +1 -1
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/crnn/pytorch.py +10 -1
- doctr/models/recognition/crnn/tensorflow.py +10 -1
- doctr/models/recognition/master/pytorch.py +10 -1
- doctr/models/recognition/master/tensorflow.py +10 -3
- doctr/models/recognition/parseq/pytorch.py +23 -5
- doctr/models/recognition/parseq/tensorflow.py +13 -5
- doctr/models/recognition/predictor/_utils.py +107 -45
- doctr/models/recognition/predictor/pytorch.py +3 -3
- doctr/models/recognition/predictor/tensorflow.py +3 -3
- doctr/models/recognition/sar/pytorch.py +10 -1
- doctr/models/recognition/sar/tensorflow.py +10 -3
- doctr/models/recognition/utils.py +56 -47
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/pytorch.py +10 -1
- doctr/models/recognition/vitstr/tensorflow.py +10 -3
- doctr/models/recognition/zoo.py +5 -0
- doctr/models/utils/pytorch.py +28 -18
- doctr/models/utils/tensorflow.py +15 -8
- doctr/utils/data.py +1 -1
- doctr/utils/geometry.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +19 -3
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/RECORD +82 -75
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
doctr/datasets/__init__.py
CHANGED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
|
|
14
|
+
from .datasets import AbstractDataset
|
|
15
|
+
from .utils import convert_target_to_relative, crop_bboxes_from_image
|
|
16
|
+
|
|
17
|
+
__all__ = ["COCOTEXT"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class COCOTEXT(AbstractDataset):
|
|
21
|
+
"""
|
|
22
|
+
COCO-Text dataset from `"COCO-Text: Dataset and Benchmark for Text Detection and Recognition in Natural Images"
|
|
23
|
+
<https://arxiv.org/pdf/1601.07140v2>`_ |
|
|
24
|
+
`"homepage" <https://bgshih.github.io/cocotext/>`_.
|
|
25
|
+
|
|
26
|
+
>>> # NOTE: You need to download the dataset first.
|
|
27
|
+
>>> from doctr.datasets import COCOTEXT
|
|
28
|
+
>>> train_set = COCOTEXT(train=True, img_folder="/path/to/coco_text/train2014/",
|
|
29
|
+
>>> label_path="/path/to/coco_text/cocotext.v2.json")
|
|
30
|
+
>>> img, target = train_set[0]
|
|
31
|
+
>>> test_set = COCOTEXT(train=False, img_folder="/path/to/coco_text/train2014/",
|
|
32
|
+
>>> label_path = "/path/to/coco_text/cocotext.v2.json")
|
|
33
|
+
>>> img, target = test_set[0]
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
img_folder: folder with all the images of the dataset
|
|
37
|
+
label_path: path to the annotations file of the dataset
|
|
38
|
+
train: whether the subset should be the training one
|
|
39
|
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
|
40
|
+
recognition_task: whether the dataset should be used for recognition task
|
|
41
|
+
detection_task: whether the dataset should be used for detection task
|
|
42
|
+
**kwargs: keyword arguments from `AbstractDataset`.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
img_folder: str,
|
|
48
|
+
label_path: str,
|
|
49
|
+
train: bool = True,
|
|
50
|
+
use_polygons: bool = False,
|
|
51
|
+
recognition_task: bool = False,
|
|
52
|
+
detection_task: bool = False,
|
|
53
|
+
**kwargs: Any,
|
|
54
|
+
) -> None:
|
|
55
|
+
super().__init__(
|
|
56
|
+
img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
|
|
57
|
+
)
|
|
58
|
+
# Task check
|
|
59
|
+
if recognition_task and detection_task:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
" 'recognition' and 'detection task' cannot be set to True simultaneously. "
|
|
62
|
+
+ " To get the whole dataset with boxes and labels leave both parameters to False "
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# File existence check
|
|
66
|
+
if not os.path.exists(label_path) or not os.path.exists(img_folder):
|
|
67
|
+
raise FileNotFoundError(f"unable to find {label_path if not os.path.exists(label_path) else img_folder}")
|
|
68
|
+
|
|
69
|
+
tmp_root = img_folder
|
|
70
|
+
self.train = train
|
|
71
|
+
np_dtype = np.float32
|
|
72
|
+
self.data: list[tuple[str | Path | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
|
|
73
|
+
|
|
74
|
+
with open(label_path, "r") as file:
|
|
75
|
+
data = json.load(file)
|
|
76
|
+
|
|
77
|
+
# Filter images based on the set
|
|
78
|
+
img_items = [img for img in data["imgs"].items() if (img[1]["set"] == "train") == train]
|
|
79
|
+
box: list[float] | np.ndarray
|
|
80
|
+
|
|
81
|
+
for img_id, img_info in tqdm(img_items, desc="Preparing and Loading COCOTEXT", total=len(img_items)):
|
|
82
|
+
img_path = os.path.join(img_folder, img_info["file_name"])
|
|
83
|
+
|
|
84
|
+
# File existence check
|
|
85
|
+
if not os.path.exists(img_path): # pragma: no cover
|
|
86
|
+
raise FileNotFoundError(f"Unable to locate {img_path}")
|
|
87
|
+
|
|
88
|
+
# Get annotations for the current image (only legible text)
|
|
89
|
+
annotations = [
|
|
90
|
+
ann
|
|
91
|
+
for ann in data["anns"].values()
|
|
92
|
+
if ann["image_id"] == int(img_id) and ann["legibility"] == "legible"
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
# Some images have no annotations with readable text
|
|
96
|
+
if not annotations: # pragma: no cover
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
_targets = []
|
|
100
|
+
|
|
101
|
+
for annotation in annotations:
|
|
102
|
+
x, y, w, h = annotation["bbox"]
|
|
103
|
+
if use_polygons:
|
|
104
|
+
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
|
105
|
+
box = np.array(
|
|
106
|
+
[
|
|
107
|
+
[x, y],
|
|
108
|
+
[x + w, y],
|
|
109
|
+
[x + w, y + h],
|
|
110
|
+
[x, y + h],
|
|
111
|
+
],
|
|
112
|
+
dtype=np_dtype,
|
|
113
|
+
)
|
|
114
|
+
else:
|
|
115
|
+
# (xmin, ymin, xmax, ymax) coordinates
|
|
116
|
+
box = [x, y, x + w, y + h]
|
|
117
|
+
_targets.append((annotation["utf8_string"], box))
|
|
118
|
+
text_targets, box_targets = zip(*_targets)
|
|
119
|
+
|
|
120
|
+
if recognition_task:
|
|
121
|
+
crops = crop_bboxes_from_image(
|
|
122
|
+
img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0)
|
|
123
|
+
)
|
|
124
|
+
for crop, label in zip(crops, list(text_targets)):
|
|
125
|
+
if label and " " not in label:
|
|
126
|
+
self.data.append((crop, label))
|
|
127
|
+
|
|
128
|
+
elif detection_task:
|
|
129
|
+
self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0)))
|
|
130
|
+
else:
|
|
131
|
+
self.data.append((
|
|
132
|
+
img_path,
|
|
133
|
+
dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)),
|
|
134
|
+
))
|
|
135
|
+
|
|
136
|
+
self.root = tmp_root
|
|
137
|
+
|
|
138
|
+
def extra_repr(self) -> str:
|
|
139
|
+
return f"train={self.train}"
|
doctr/datasets/cord.py
CHANGED
|
@@ -116,7 +116,8 @@ class CORD(VisionDataset):
|
|
|
116
116
|
img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0)
|
|
117
117
|
)
|
|
118
118
|
for crop, label in zip(crops, list(text_targets)):
|
|
119
|
-
|
|
119
|
+
if " " not in label:
|
|
120
|
+
self.data.append((crop, label))
|
|
120
121
|
elif detection_task:
|
|
121
122
|
self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0)))
|
|
122
123
|
else:
|
doctr/datasets/funsd.py
CHANGED
|
@@ -107,8 +107,8 @@ class FUNSD(VisionDataset):
|
|
|
107
107
|
)
|
|
108
108
|
for crop, label in zip(crops, list(text_targets)):
|
|
109
109
|
# filter labels with unknown characters
|
|
110
|
-
if not any(char in label for char in ["☑", "☐", "\uf703", "\uf702"]):
|
|
111
|
-
self.data.append((crop, label))
|
|
110
|
+
if not any(char in label for char in ["☑", "☐", "\u03bf", "\uf703", "\uf702", " "]):
|
|
111
|
+
self.data.append((crop, label.replace("–", "-")))
|
|
112
112
|
elif detection_task:
|
|
113
113
|
self.data.append((img_path, np.asarray(box_targets, dtype=np_dtype)))
|
|
114
114
|
else:
|
doctr/datasets/ic03.py
CHANGED
|
@@ -122,7 +122,7 @@ class IC03(VisionDataset):
|
|
|
122
122
|
if recognition_task:
|
|
123
123
|
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, name.text), geoms=boxes)
|
|
124
124
|
for crop, label in zip(crops, labels):
|
|
125
|
-
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
|
|
125
|
+
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0 and " " not in label:
|
|
126
126
|
self.data.append((crop, label))
|
|
127
127
|
elif detection_task:
|
|
128
128
|
self.data.append((name.text, boxes))
|
doctr/datasets/ic13.py
CHANGED
|
@@ -100,7 +100,8 @@ class IC13(AbstractDataset):
|
|
|
100
100
|
if recognition_task:
|
|
101
101
|
crops = crop_bboxes_from_image(img_path=img_path, geoms=box_targets)
|
|
102
102
|
for crop, label in zip(crops, labels):
|
|
103
|
-
|
|
103
|
+
if " " not in label:
|
|
104
|
+
self.data.append((crop, label))
|
|
104
105
|
elif detection_task:
|
|
105
106
|
self.data.append((img_path, box_targets))
|
|
106
107
|
else:
|
doctr/datasets/iiit5k.py
CHANGED
|
@@ -8,6 +8,7 @@ from typing import Any
|
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import scipy.io as sio
|
|
11
|
+
from PIL import Image
|
|
11
12
|
from tqdm import tqdm
|
|
12
13
|
|
|
13
14
|
from .datasets import VisionDataset
|
|
@@ -98,7 +99,9 @@ class IIIT5K(VisionDataset):
|
|
|
98
99
|
box_targets = [[box[0], box[1], box[0] + box[2], box[1] + box[3]] for box in box_targets]
|
|
99
100
|
|
|
100
101
|
if recognition_task:
|
|
101
|
-
|
|
102
|
+
if " " not in _raw_label:
|
|
103
|
+
with Image.open(os.path.join(tmp_root, _raw_path)) as pil_img:
|
|
104
|
+
self.data.append((np.array(pil_img.convert("RGB")), _raw_label))
|
|
102
105
|
elif detection_task:
|
|
103
106
|
self.data.append((_raw_path, np.asarray(box_targets, dtype=np_dtype)))
|
|
104
107
|
else:
|
doctr/datasets/imgur5k.py
CHANGED
|
@@ -133,7 +133,13 @@ class IMGUR5K(AbstractDataset):
|
|
|
133
133
|
img_path=os.path.join(self.root, img_name), geoms=np.asarray(box_targets, dtype=np_dtype)
|
|
134
134
|
)
|
|
135
135
|
for crop, label in zip(crops, labels):
|
|
136
|
-
if
|
|
136
|
+
if (
|
|
137
|
+
crop.shape[0] > 0
|
|
138
|
+
and crop.shape[1] > 0
|
|
139
|
+
and len(label) > 0
|
|
140
|
+
and len(label) < 30
|
|
141
|
+
and " " not in label
|
|
142
|
+
):
|
|
137
143
|
# write data to disk
|
|
138
144
|
with open(os.path.join(reco_folder_path, f"{reco_images_counter}.txt"), "w") as f:
|
|
139
145
|
f.write(label)
|
|
@@ -152,6 +158,7 @@ class IMGUR5K(AbstractDataset):
|
|
|
152
158
|
return f"train={self.train}"
|
|
153
159
|
|
|
154
160
|
def _read_from_folder(self, path: str) -> None:
|
|
155
|
-
|
|
161
|
+
img_paths = glob.glob(os.path.join(path, "*.png"))
|
|
162
|
+
for img_path in tqdm(iterable=img_paths, desc="Preparing and Loading IMGUR5K", total=len(img_paths)):
|
|
156
163
|
with open(os.path.join(path, f"{os.path.basename(img_path)[:-4]}.txt"), "r") as f:
|
|
157
164
|
self.data.append((img_path, f.read()))
|
doctr/datasets/loader.py
CHANGED
doctr/datasets/ocr.py
CHANGED
|
@@ -40,7 +40,7 @@ class OCRDataset(AbstractDataset):
|
|
|
40
40
|
super().__init__(img_folder, **kwargs)
|
|
41
41
|
|
|
42
42
|
# List images
|
|
43
|
-
self.data: list[tuple[
|
|
43
|
+
self.data: list[tuple[Path, dict[str, Any]]] = []
|
|
44
44
|
np_dtype = np.float32
|
|
45
45
|
with open(label_file, "rb") as f:
|
|
46
46
|
data = json.load(f)
|
doctr/datasets/recognition.py
CHANGED
|
@@ -23,7 +23,7 @@ class RecognitionDataset(AbstractDataset):
|
|
|
23
23
|
|
|
24
24
|
Args:
|
|
25
25
|
img_folder: path to the images folder
|
|
26
|
-
labels_path:
|
|
26
|
+
labels_path: path to the json file containing all labels (character sequences)
|
|
27
27
|
**kwargs: keyword arguments from `AbstractDataset`.
|
|
28
28
|
"""
|
|
29
29
|
|
doctr/datasets/svhn.py
CHANGED
|
@@ -129,7 +129,7 @@ class SVHN(VisionDataset):
|
|
|
129
129
|
if recognition_task:
|
|
130
130
|
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_name), geoms=box_targets)
|
|
131
131
|
for crop, label in zip(crops, label_targets):
|
|
132
|
-
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
|
|
132
|
+
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0 and " " not in label:
|
|
133
133
|
self.data.append((crop, label))
|
|
134
134
|
elif detection_task:
|
|
135
135
|
self.data.append((img_name, box_targets))
|
doctr/datasets/svt.py
CHANGED
|
@@ -35,7 +35,7 @@ class SVT(VisionDataset):
|
|
|
35
35
|
**kwargs: keyword arguments from `VisionDataset`.
|
|
36
36
|
"""
|
|
37
37
|
|
|
38
|
-
URL = "http://
|
|
38
|
+
URL = "http://www.iapr-tc11.org/dataset/SVT/svt.zip"
|
|
39
39
|
SHA256 = "63b3d55e6b6d1e036e2a844a20c034fe3af3c32e4d914d6e0c4a3cd43df3bebf"
|
|
40
40
|
|
|
41
41
|
def __init__(
|
|
@@ -113,7 +113,7 @@ class SVT(VisionDataset):
|
|
|
113
113
|
if recognition_task:
|
|
114
114
|
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, name.text), geoms=boxes)
|
|
115
115
|
for crop, label in zip(crops, labels):
|
|
116
|
-
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
|
|
116
|
+
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0 and " " not in label:
|
|
117
117
|
self.data.append((crop, label))
|
|
118
118
|
elif detection_task:
|
|
119
119
|
self.data.append((name.text, boxes))
|
doctr/datasets/synthtext.py
CHANGED
|
@@ -41,6 +41,12 @@ class SynthText(VisionDataset):
|
|
|
41
41
|
URL = "https://thor.robots.ox.ac.uk/~vgg/data/scenetext/SynthText.zip"
|
|
42
42
|
SHA256 = "28ab030485ec8df3ed612c568dd71fb2793b9afbfa3a9d9c6e792aef33265bf1"
|
|
43
43
|
|
|
44
|
+
# filter corrupted or missing images
|
|
45
|
+
BLACKLIST = (
|
|
46
|
+
"67/fruits_129_",
|
|
47
|
+
"194/window_19_",
|
|
48
|
+
)
|
|
49
|
+
|
|
44
50
|
def __init__(
|
|
45
51
|
self,
|
|
46
52
|
train: bool = True,
|
|
@@ -111,7 +117,13 @@ class SynthText(VisionDataset):
|
|
|
111
117
|
if recognition_task:
|
|
112
118
|
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_path[0]), geoms=word_boxes)
|
|
113
119
|
for crop, label in zip(crops, labels):
|
|
114
|
-
if
|
|
120
|
+
if (
|
|
121
|
+
crop.shape[0] > 0
|
|
122
|
+
and crop.shape[1] > 0
|
|
123
|
+
and len(label) > 0
|
|
124
|
+
and len(label) < 30
|
|
125
|
+
and " " not in label
|
|
126
|
+
):
|
|
115
127
|
# write data to disk
|
|
116
128
|
with open(os.path.join(reco_folder_path, f"{reco_images_counter}.txt"), "w") as f:
|
|
117
129
|
f.write(label)
|
|
@@ -132,6 +144,7 @@ class SynthText(VisionDataset):
|
|
|
132
144
|
return f"train={self.train}"
|
|
133
145
|
|
|
134
146
|
def _read_from_folder(self, path: str) -> None:
|
|
135
|
-
|
|
147
|
+
img_paths = glob.glob(os.path.join(path, "*.png"))
|
|
148
|
+
for img_path in tqdm(iterable=img_paths, desc="Preparing and Loading SynthText", total=len(img_paths)):
|
|
136
149
|
with open(os.path.join(path, f"{os.path.basename(img_path)[:-4]}.txt"), "r") as f:
|
|
137
150
|
self.data.append((img_path, f.read()))
|
doctr/datasets/utils.py
CHANGED
|
@@ -48,7 +48,7 @@ def translate(
|
|
|
48
48
|
A string translated in a given vocab
|
|
49
49
|
"""
|
|
50
50
|
if VOCABS.get(vocab_name) is None:
|
|
51
|
-
raise KeyError("output vocabulary must be in vocabs
|
|
51
|
+
raise KeyError("output vocabulary must be in vocabs dictionary")
|
|
52
52
|
|
|
53
53
|
translated = ""
|
|
54
54
|
for char in input_string:
|
|
@@ -81,11 +81,12 @@ def encode_string(
|
|
|
81
81
|
"""
|
|
82
82
|
try:
|
|
83
83
|
return list(map(vocab.index, input_string))
|
|
84
|
-
except ValueError:
|
|
84
|
+
except ValueError as e:
|
|
85
|
+
missing_chars = [char for char in input_string if char not in vocab]
|
|
85
86
|
raise ValueError(
|
|
86
|
-
f"
|
|
87
|
-
|
|
88
|
-
)
|
|
87
|
+
f"Some characters cannot be found in 'vocab': {set(missing_chars)}.\n"
|
|
88
|
+
f"Please check the input string `{input_string}` and the vocabulary `{vocab}`"
|
|
89
|
+
) from e
|
|
89
90
|
|
|
90
91
|
|
|
91
92
|
def decode_sequence(
|
|
@@ -199,7 +200,7 @@ def crop_bboxes_from_image(img_path: str | Path, geoms: np.ndarray) -> list[np.n
|
|
|
199
200
|
a list of cropped images
|
|
200
201
|
"""
|
|
201
202
|
with Image.open(img_path) as pil_img:
|
|
202
|
-
img: np.ndarray = np.
|
|
203
|
+
img: np.ndarray = np.asarray(pil_img.convert("RGB"))
|
|
203
204
|
# Polygon
|
|
204
205
|
if geoms.ndim == 3 and geoms.shape[1:] == (4, 2):
|
|
205
206
|
return extract_rcrops(img, geoms.astype(dtype=int))
|