python-doctr 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +1 -1
- doctr/contrib/__init__.py +0 -0
- doctr/contrib/artefacts.py +131 -0
- doctr/contrib/base.py +105 -0
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/generator/base.py +6 -5
- doctr/datasets/imgur5k.py +1 -1
- doctr/datasets/loader.py +1 -6
- doctr/datasets/utils.py +2 -1
- doctr/datasets/vocabs.py +9 -2
- doctr/file_utils.py +26 -12
- doctr/io/elements.py +40 -6
- doctr/io/html.py +2 -2
- doctr/io/image/pytorch.py +6 -8
- doctr/io/image/tensorflow.py +1 -1
- doctr/io/pdf.py +5 -2
- doctr/io/reader.py +6 -0
- doctr/models/__init__.py +0 -1
- doctr/models/_utils.py +57 -20
- doctr/models/builder.py +71 -13
- doctr/models/classification/mobilenet/pytorch.py +45 -9
- doctr/models/classification/mobilenet/tensorflow.py +38 -7
- doctr/models/classification/predictor/pytorch.py +18 -11
- doctr/models/classification/predictor/tensorflow.py +16 -10
- doctr/models/classification/textnet/pytorch.py +3 -3
- doctr/models/classification/textnet/tensorflow.py +3 -3
- doctr/models/classification/zoo.py +39 -15
- doctr/models/detection/_utils/__init__.py +1 -0
- doctr/models/detection/_utils/base.py +66 -0
- doctr/models/detection/differentiable_binarization/base.py +4 -3
- doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
- doctr/models/detection/fast/base.py +6 -5
- doctr/models/detection/fast/pytorch.py +4 -4
- doctr/models/detection/fast/tensorflow.py +4 -4
- doctr/models/detection/linknet/base.py +4 -3
- doctr/models/detection/predictor/pytorch.py +15 -1
- doctr/models/detection/predictor/tensorflow.py +15 -1
- doctr/models/detection/zoo.py +7 -2
- doctr/models/factory/hub.py +3 -12
- doctr/models/kie_predictor/base.py +9 -3
- doctr/models/kie_predictor/pytorch.py +41 -20
- doctr/models/kie_predictor/tensorflow.py +36 -16
- doctr/models/modules/layers/pytorch.py +2 -3
- doctr/models/modules/layers/tensorflow.py +6 -8
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/predictor/base.py +77 -50
- doctr/models/predictor/pytorch.py +31 -20
- doctr/models/predictor/tensorflow.py +27 -17
- doctr/models/preprocessor/pytorch.py +4 -4
- doctr/models/preprocessor/tensorflow.py +3 -2
- doctr/models/recognition/master/pytorch.py +2 -2
- doctr/models/recognition/parseq/pytorch.py +4 -3
- doctr/models/recognition/parseq/tensorflow.py +4 -3
- doctr/models/recognition/sar/pytorch.py +7 -6
- doctr/models/recognition/sar/tensorflow.py +3 -9
- doctr/models/recognition/vitstr/pytorch.py +1 -1
- doctr/models/recognition/zoo.py +1 -1
- doctr/models/zoo.py +2 -2
- doctr/py.typed +0 -0
- doctr/transforms/functional/base.py +1 -1
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/base.py +37 -15
- doctr/transforms/modules/pytorch.py +66 -8
- doctr/transforms/modules/tensorflow.py +63 -7
- doctr/utils/fonts.py +7 -5
- doctr/utils/geometry.py +35 -12
- doctr/utils/metrics.py +33 -174
- doctr/utils/reconstitution.py +126 -0
- doctr/utils/visualization.py +5 -118
- doctr/version.py +1 -1
- {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/METADATA +84 -80
- {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/RECORD +76 -76
- {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/WHEEL +1 -1
- doctr/models/artefacts/__init__.py +0 -2
- doctr/models/artefacts/barcode.py +0 -74
- doctr/models/artefacts/face.py +0 -63
- doctr/models/obj_detection/__init__.py +0 -1
- doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
- {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.8.1.dist-info → python_doctr-0.9.0.dist-info}/zip-safe +0 -0
|
@@ -9,10 +9,10 @@ import numpy as np
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
|
|
11
11
|
from doctr.io.elements import Document
|
|
12
|
-
from doctr.models._utils import
|
|
12
|
+
from doctr.models._utils import get_language
|
|
13
13
|
from doctr.models.detection.predictor import DetectionPredictor
|
|
14
14
|
from doctr.models.recognition.predictor import RecognitionPredictor
|
|
15
|
-
from doctr.utils.geometry import
|
|
15
|
+
from doctr.utils.geometry import detach_scores
|
|
16
16
|
from doctr.utils.repr import NestedObject
|
|
17
17
|
|
|
18
18
|
from .base import _OCRPredictor
|
|
@@ -56,7 +56,13 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
56
56
|
self.det_predictor = det_predictor
|
|
57
57
|
self.reco_predictor = reco_predictor
|
|
58
58
|
_OCRPredictor.__init__(
|
|
59
|
-
self,
|
|
59
|
+
self,
|
|
60
|
+
assume_straight_pages,
|
|
61
|
+
straighten_pages,
|
|
62
|
+
preserve_aspect_ratio,
|
|
63
|
+
symmetric_pad,
|
|
64
|
+
detect_orientation,
|
|
65
|
+
**kwargs,
|
|
60
66
|
)
|
|
61
67
|
self.detect_orientation = detect_orientation
|
|
62
68
|
self.detect_language = detect_language
|
|
@@ -81,19 +87,16 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
81
87
|
for out_map in out_maps
|
|
82
88
|
]
|
|
83
89
|
if self.detect_orientation:
|
|
84
|
-
|
|
90
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
85
91
|
orientations = [
|
|
86
|
-
{"value": orientation_page, "confidence": None} for orientation_page in
|
|
92
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
87
93
|
]
|
|
88
94
|
else:
|
|
89
95
|
orientations = None
|
|
96
|
+
general_pages_orientations = None
|
|
97
|
+
origin_pages_orientations = None
|
|
90
98
|
if self.straighten_pages:
|
|
91
|
-
|
|
92
|
-
origin_page_orientations
|
|
93
|
-
if self.detect_orientation
|
|
94
|
-
else [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
95
|
-
)
|
|
96
|
-
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
99
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
97
100
|
# forward again to get predictions on straight pages
|
|
98
101
|
loc_preds_dict = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
|
|
99
102
|
|
|
@@ -101,9 +104,8 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
101
104
|
len(loc_pred) == 1 for loc_pred in loc_preds_dict
|
|
102
105
|
), "Detection Model in ocr_predictor should output only one class"
|
|
103
106
|
loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict] # type: ignore[union-attr]
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
loc_preds = self._remove_padding(pages, loc_preds)
|
|
107
|
+
# Detach objectness scores from loc_preds
|
|
108
|
+
loc_preds, objectness_scores = detach_scores(loc_preds)
|
|
107
109
|
|
|
108
110
|
# Apply hooks to loc_preds if any
|
|
109
111
|
for hook in self.hooks:
|
|
@@ -113,14 +115,20 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
113
115
|
crops, loc_preds = self._prepare_crops(
|
|
114
116
|
pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages
|
|
115
117
|
)
|
|
116
|
-
# Rectify crop orientation
|
|
118
|
+
# Rectify crop orientation and get crop orientation predictions
|
|
119
|
+
crop_orientations: Any = []
|
|
117
120
|
if not self.assume_straight_pages:
|
|
118
|
-
crops, loc_preds = self._rectify_crops(crops, loc_preds)
|
|
121
|
+
crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds)
|
|
122
|
+
crop_orientations = [
|
|
123
|
+
{"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations
|
|
124
|
+
]
|
|
119
125
|
|
|
120
126
|
# Identify character sequences
|
|
121
127
|
word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs)
|
|
128
|
+
if not crop_orientations:
|
|
129
|
+
crop_orientations = [{"value": 0, "confidence": None} for _ in word_preds]
|
|
122
130
|
|
|
123
|
-
boxes, text_preds = self._process_predictions(loc_preds, word_preds)
|
|
131
|
+
boxes, text_preds, crop_orientations = self._process_predictions(loc_preds, word_preds, crop_orientations)
|
|
124
132
|
|
|
125
133
|
if self.detect_language:
|
|
126
134
|
languages = [get_language(" ".join([item[0] for item in text_pred])) for text_pred in text_preds]
|
|
@@ -131,8 +139,10 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
131
139
|
out = self.doc_builder(
|
|
132
140
|
pages,
|
|
133
141
|
boxes,
|
|
142
|
+
objectness_scores,
|
|
134
143
|
text_preds,
|
|
135
144
|
origin_page_shapes, # type: ignore[arg-type]
|
|
145
|
+
crop_orientations,
|
|
136
146
|
orientations,
|
|
137
147
|
languages_dict,
|
|
138
148
|
)
|
|
@@ -79,7 +79,7 @@ class PreProcessor(nn.Module):
|
|
|
79
79
|
else:
|
|
80
80
|
x = x.to(dtype=torch.float32) # type: ignore[union-attr]
|
|
81
81
|
|
|
82
|
-
return x
|
|
82
|
+
return x
|
|
83
83
|
|
|
84
84
|
def __call__(self, x: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]) -> List[torch.Tensor]:
|
|
85
85
|
"""Prepare document data for model forwarding
|
|
@@ -103,7 +103,7 @@ class PreProcessor(nn.Module):
|
|
|
103
103
|
elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
|
|
104
104
|
raise TypeError("unsupported data type for torch.Tensor")
|
|
105
105
|
# Resizing
|
|
106
|
-
if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]:
|
|
106
|
+
if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]:
|
|
107
107
|
x = F.resize(
|
|
108
108
|
x, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
|
|
109
109
|
)
|
|
@@ -118,11 +118,11 @@ class PreProcessor(nn.Module):
|
|
|
118
118
|
# Sample transform (to tensor, resize)
|
|
119
119
|
samples = list(multithread_exec(self.sample_transforms, x))
|
|
120
120
|
# Batching
|
|
121
|
-
batches = self.batch_inputs(samples)
|
|
121
|
+
batches = self.batch_inputs(samples)
|
|
122
122
|
else:
|
|
123
123
|
raise TypeError(f"invalid input type: {type(x)}")
|
|
124
124
|
|
|
125
125
|
# Batch transforms (normalize)
|
|
126
126
|
batches = list(multithread_exec(self.normalize, batches))
|
|
127
127
|
|
|
128
|
-
return batches
|
|
128
|
+
return batches
|
|
@@ -41,6 +41,7 @@ class PreProcessor(NestedObject):
|
|
|
41
41
|
self.resize = Resize(output_size, **kwargs)
|
|
42
42
|
# Perform the division by 255 at the same time
|
|
43
43
|
self.normalize = Normalize(mean, std)
|
|
44
|
+
self._runs_on_cuda = tf.test.is_gpu_available()
|
|
44
45
|
|
|
45
46
|
def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]:
|
|
46
47
|
"""Gather samples into batches for inference purposes
|
|
@@ -113,13 +114,13 @@ class PreProcessor(NestedObject):
|
|
|
113
114
|
|
|
114
115
|
elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, tf.Tensor)) for sample in x):
|
|
115
116
|
# Sample transform (to tensor, resize)
|
|
116
|
-
samples = list(multithread_exec(self.sample_transforms, x))
|
|
117
|
+
samples = list(multithread_exec(self.sample_transforms, x, threads=1 if self._runs_on_cuda else None))
|
|
117
118
|
# Batching
|
|
118
119
|
batches = self.batch_inputs(samples)
|
|
119
120
|
else:
|
|
120
121
|
raise TypeError(f"invalid input type: {type(x)}")
|
|
121
122
|
|
|
122
123
|
# Batch transforms (normalize)
|
|
123
|
-
batches = list(multithread_exec(self.normalize, batches))
|
|
124
|
+
batches = list(multithread_exec(self.normalize, batches, threads=1 if self._runs_on_cuda else None))
|
|
124
125
|
|
|
125
126
|
return batches
|
|
@@ -107,7 +107,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
107
107
|
# NOTE: nn.TransformerDecoder takes the inverse from this implementation
|
|
108
108
|
# [True, True, True, ..., False, False, False] -> False is masked
|
|
109
109
|
# (N, 1, 1, max_length)
|
|
110
|
-
target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
|
|
110
|
+
target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
|
|
111
111
|
target_length = target.size(1)
|
|
112
112
|
# sub mask filled diagonal with True = see and False = masked (max_length, max_length)
|
|
113
113
|
# NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14)
|
|
@@ -142,7 +142,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
142
142
|
# Input length : number of timesteps
|
|
143
143
|
input_len = model_output.shape[1]
|
|
144
144
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
145
|
-
seq_len = seq_len + 1
|
|
145
|
+
seq_len = seq_len + 1
|
|
146
146
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
147
147
|
# The "masked" first gt char is <sos>. Delete last logit of the model output.
|
|
148
148
|
cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
|
|
@@ -212,7 +212,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
212
212
|
|
|
213
213
|
sos_idx = torch.zeros(len(final_perms), 1, device=seqlen.device)
|
|
214
214
|
eos_idx = torch.full((len(final_perms), 1), max_num_chars + 1, device=seqlen.device)
|
|
215
|
-
combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
|
|
215
|
+
combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
|
|
216
216
|
if len(combined) > 1:
|
|
217
217
|
combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
|
|
218
218
|
return combined
|
|
@@ -282,7 +282,8 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
282
282
|
ys[:, i + 1] = pos_prob.squeeze().argmax(-1)
|
|
283
283
|
|
|
284
284
|
# Stop decoding if all sequences have reached the EOS token
|
|
285
|
-
|
|
285
|
+
# NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
|
|
286
|
+
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
|
|
286
287
|
break
|
|
287
288
|
|
|
288
289
|
logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
|
|
@@ -297,7 +298,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
297
298
|
|
|
298
299
|
# Create padding mask for refined target input maskes all behind EOS token as False
|
|
299
300
|
# (N, 1, 1, max_length)
|
|
300
|
-
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
|
|
301
|
+
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
|
|
301
302
|
mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
|
|
302
303
|
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
|
|
303
304
|
|
|
@@ -288,10 +288,11 @@ class PARSeq(_PARSeq, Model):
|
|
|
288
288
|
)
|
|
289
289
|
|
|
290
290
|
# Stop decoding if all sequences have reached the EOS token
|
|
291
|
-
#
|
|
291
|
+
# NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
|
|
292
292
|
if (
|
|
293
|
-
|
|
294
|
-
and
|
|
293
|
+
not self.exportable
|
|
294
|
+
and max_len is None
|
|
295
|
+
and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1))
|
|
295
296
|
):
|
|
296
297
|
break
|
|
297
298
|
|
|
@@ -125,25 +125,26 @@ class SARDecoder(nn.Module):
|
|
|
125
125
|
if t == 0:
|
|
126
126
|
# step to init the first states of the LSTMCell
|
|
127
127
|
hidden_state_init = cell_state_init = torch.zeros(
|
|
128
|
-
features.size(0), features.size(1), device=features.device
|
|
128
|
+
features.size(0), features.size(1), device=features.device, dtype=features.dtype
|
|
129
129
|
)
|
|
130
130
|
hidden_state, cell_state = hidden_state_init, cell_state_init
|
|
131
131
|
prev_symbol = holistic
|
|
132
132
|
elif t == 1:
|
|
133
133
|
# step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
|
|
134
134
|
# (N, vocab_size + 1) --> (N, embedding_units)
|
|
135
|
-
prev_symbol = torch.zeros(
|
|
135
|
+
prev_symbol = torch.zeros(
|
|
136
|
+
features.size(0), self.vocab_size + 1, device=features.device, dtype=features.dtype
|
|
137
|
+
)
|
|
136
138
|
prev_symbol = self.embed(prev_symbol)
|
|
137
139
|
else:
|
|
138
|
-
if gt is not None:
|
|
140
|
+
if gt is not None and self.training:
|
|
139
141
|
# (N, embedding_units) -2 because of <bos> and <eos> (same)
|
|
140
142
|
prev_symbol = self.embed(gt_embedding[:, t - 2])
|
|
141
143
|
else:
|
|
142
144
|
# -1 to start at timestep where prev_symbol was initialized
|
|
143
145
|
index = logits_list[t - 1].argmax(-1)
|
|
144
146
|
# update prev_symbol with ones at the index of the previous logit vector
|
|
145
|
-
|
|
146
|
-
prev_symbol = prev_symbol.scatter_(1, index.unsqueeze(1), 1)
|
|
147
|
+
prev_symbol = self.embed(self.embed_tgt(index))
|
|
147
148
|
|
|
148
149
|
# (N, C), (N, C) take the last hidden state and cell state from current timestep
|
|
149
150
|
hidden_state_init, cell_state_init = self.lstm_cell(prev_symbol, (hidden_state_init, cell_state_init))
|
|
@@ -292,7 +293,7 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
292
293
|
# Input length : number of timesteps
|
|
293
294
|
input_len = model_output.shape[1]
|
|
294
295
|
# Add one for additional <eos> token
|
|
295
|
-
seq_len = seq_len + 1
|
|
296
|
+
seq_len = seq_len + 1
|
|
296
297
|
# Compute loss
|
|
297
298
|
# (N, L, vocab_size + 1)
|
|
298
299
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
|
|
@@ -177,23 +177,17 @@ class SARDecoder(layers.Layer, NestedObject):
|
|
|
177
177
|
elif t == 1:
|
|
178
178
|
# step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
|
|
179
179
|
# (N, vocab_size + 1) --> (N, embedding_units)
|
|
180
|
-
prev_symbol = tf.zeros([features.shape[0], self.vocab_size + 1])
|
|
180
|
+
prev_symbol = tf.zeros([features.shape[0], self.vocab_size + 1], dtype=features.dtype)
|
|
181
181
|
prev_symbol = self.embed(prev_symbol, **kwargs)
|
|
182
182
|
else:
|
|
183
|
-
if gt is not None:
|
|
183
|
+
if gt is not None and kwargs.get("training", False):
|
|
184
184
|
# (N, embedding_units) -2 because of <bos> and <eos> (same)
|
|
185
185
|
prev_symbol = self.embed(gt_embedding[:, t - 2], **kwargs)
|
|
186
186
|
else:
|
|
187
187
|
# -1 to start at timestep where prev_symbol was initialized
|
|
188
188
|
index = tf.argmax(logits_list[t - 1], axis=-1)
|
|
189
189
|
# update prev_symbol with ones at the index of the previous logit vector
|
|
190
|
-
|
|
191
|
-
index = tf.ones_like(index)
|
|
192
|
-
prev_symbol = tf.scatter_nd(
|
|
193
|
-
tf.expand_dims(index, axis=1),
|
|
194
|
-
prev_symbol,
|
|
195
|
-
tf.constant([features.shape[0], features.shape[-1]], dtype=tf.int64),
|
|
196
|
-
)
|
|
190
|
+
prev_symbol = self.embed(self.embed_tgt(index, **kwargs), **kwargs)
|
|
197
191
|
|
|
198
192
|
# (N, C), (N, C) take the last hidden state and cell state from current timestep
|
|
199
193
|
_, states = self.lstm_cells(prev_symbol, states, **kwargs)
|
|
@@ -137,7 +137,7 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
137
137
|
# Input length : number of steps
|
|
138
138
|
input_len = model_output.shape[1]
|
|
139
139
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
140
|
-
seq_len = seq_len + 1
|
|
140
|
+
seq_len = seq_len + 1
|
|
141
141
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
142
142
|
# The "masked" first gt char is <sos>.
|
|
143
143
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
|
doctr/models/recognition/zoo.py
CHANGED
|
@@ -45,7 +45,7 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
|
|
|
45
45
|
|
|
46
46
|
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
|
|
47
47
|
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
48
|
-
kwargs["batch_size"] = kwargs.get("batch_size",
|
|
48
|
+
kwargs["batch_size"] = kwargs.get("batch_size", 128)
|
|
49
49
|
input_shape = _model.cfg["input_shape"][:2] if is_tf_available() else _model.cfg["input_shape"][-2:]
|
|
50
50
|
predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model)
|
|
51
51
|
|
doctr/models/zoo.py
CHANGED
|
@@ -61,7 +61,7 @@ def _predictor(
|
|
|
61
61
|
|
|
62
62
|
|
|
63
63
|
def ocr_predictor(
|
|
64
|
-
det_arch: Any = "
|
|
64
|
+
det_arch: Any = "fast_base",
|
|
65
65
|
reco_arch: Any = "crnn_vgg16_bn",
|
|
66
66
|
pretrained: bool = False,
|
|
67
67
|
pretrained_backbone: bool = True,
|
|
@@ -175,7 +175,7 @@ def _kie_predictor(
|
|
|
175
175
|
|
|
176
176
|
|
|
177
177
|
def kie_predictor(
|
|
178
|
-
det_arch: Any = "
|
|
178
|
+
det_arch: Any = "fast_base",
|
|
179
179
|
reco_arch: Any = "crnn_vgg16_bn",
|
|
180
180
|
pretrained: bool = False,
|
|
181
181
|
pretrained_backbone: bool = True,
|
doctr/py.typed
ADDED
|
File without changes
|
|
@@ -200,4 +200,4 @@ def create_shadow_mask(
|
|
|
200
200
|
mask: np.ndarray = np.zeros((*target_shape, 1), dtype=np.uint8)
|
|
201
201
|
mask = cv2.fillPoly(mask, [final_contour], (255,), lineType=cv2.LINE_AA)[..., 0]
|
|
202
202
|
|
|
203
|
-
return (mask / 255).astype(np.float32).clip(0, 1) * intensity_mask.astype(np.float32)
|
|
203
|
+
return (mask / 255).astype(np.float32).clip(0, 1) * intensity_mask.astype(np.float32)
|
|
@@ -35,9 +35,9 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor:
|
|
|
35
35
|
rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape)
|
|
36
36
|
# Inverse the color
|
|
37
37
|
if out.dtype == torch.uint8:
|
|
38
|
-
out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
|
|
38
|
+
out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8)
|
|
39
39
|
else:
|
|
40
|
-
out = out * rgb_shift.to(dtype=out.dtype)
|
|
40
|
+
out = out * rgb_shift.to(dtype=out.dtype)
|
|
41
41
|
# Inverse the color
|
|
42
42
|
out = 255 - out if out.dtype == torch.uint8 else 1 - out
|
|
43
43
|
return out
|
|
@@ -81,7 +81,7 @@ def rotate_sample(
|
|
|
81
81
|
rotated_geoms: np.ndarray = rotate_abs_geoms(
|
|
82
82
|
_geoms,
|
|
83
83
|
angle,
|
|
84
|
-
img.shape[1:],
|
|
84
|
+
img.shape[1:], # type: ignore[arg-type]
|
|
85
85
|
expand,
|
|
86
86
|
).astype(np.float32)
|
|
87
87
|
|
|
@@ -132,7 +132,7 @@ def random_shadow(img: torch.Tensor, opacity_range: Tuple[float, float], **kwarg
|
|
|
132
132
|
-------
|
|
133
133
|
shaded image
|
|
134
134
|
"""
|
|
135
|
-
shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)
|
|
135
|
+
shadow_mask = create_shadow_mask(img.shape[1:], **kwargs) # type: ignore[arg-type]
|
|
136
136
|
|
|
137
137
|
opacity = np.random.uniform(*opacity_range)
|
|
138
138
|
shadow_tensor = 1 - torch.from_numpy(shadow_mask[None, ...])
|
doctr/transforms/modules/base.py
CHANGED
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
7
|
import random
|
|
8
|
-
from typing import Any, Callable,
|
|
8
|
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
|
|
@@ -168,11 +168,11 @@ class OneOf(NestedObject):
|
|
|
168
168
|
def __init__(self, transforms: List[Callable[[Any], Any]]) -> None:
|
|
169
169
|
self.transforms = transforms
|
|
170
170
|
|
|
171
|
-
def __call__(self, img: Any) -> Any:
|
|
171
|
+
def __call__(self, img: Any, target: Optional[np.ndarray] = None) -> Union[Any, Tuple[Any, np.ndarray]]:
|
|
172
172
|
# Pick transformation
|
|
173
173
|
transfo = self.transforms[int(random.random() * len(self.transforms))]
|
|
174
174
|
# Apply
|
|
175
|
-
return transfo(img)
|
|
175
|
+
return transfo(img) if target is None else transfo(img, target) # type: ignore[call-arg]
|
|
176
176
|
|
|
177
177
|
|
|
178
178
|
class RandomApply(NestedObject):
|
|
@@ -261,17 +261,39 @@ class RandomCrop(NestedObject):
|
|
|
261
261
|
def extra_repr(self) -> str:
|
|
262
262
|
return f"scale={self.scale}, ratio={self.ratio}"
|
|
263
263
|
|
|
264
|
-
def __call__(self, img: Any, target:
|
|
264
|
+
def __call__(self, img: Any, target: np.ndarray) -> Tuple[Any, np.ndarray]:
|
|
265
265
|
scale = random.uniform(self.scale[0], self.scale[1])
|
|
266
266
|
ratio = random.uniform(self.ratio[0], self.ratio[1])
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
267
|
+
|
|
268
|
+
height, width = img.shape[:2]
|
|
269
|
+
|
|
270
|
+
# Calculate crop size
|
|
271
|
+
crop_area = scale * width * height
|
|
272
|
+
aspect_ratio = ratio * (width / height)
|
|
273
|
+
crop_width = int(round(math.sqrt(crop_area * aspect_ratio)))
|
|
274
|
+
crop_height = int(round(math.sqrt(crop_area / aspect_ratio)))
|
|
275
|
+
|
|
276
|
+
# Ensure crop size does not exceed image dimensions
|
|
277
|
+
crop_width = min(crop_width, width)
|
|
278
|
+
crop_height = min(crop_height, height)
|
|
279
|
+
|
|
280
|
+
# Randomly select crop position
|
|
281
|
+
x = random.randint(0, width - crop_width)
|
|
282
|
+
y = random.randint(0, height - crop_height)
|
|
283
|
+
|
|
284
|
+
# relative crop box
|
|
285
|
+
crop_box = (x / width, y / height, (x + crop_width) / width, (y + crop_height) / height)
|
|
286
|
+
if target.shape[1:] == (4, 2):
|
|
287
|
+
min_xy = np.min(target, axis=1)
|
|
288
|
+
max_xy = np.max(target, axis=1)
|
|
289
|
+
_target = np.concatenate((min_xy, max_xy), axis=1)
|
|
290
|
+
else:
|
|
291
|
+
_target = target
|
|
292
|
+
|
|
293
|
+
# Crop image and targets
|
|
294
|
+
croped_img, crop_boxes = F.crop_detection(img, _target, crop_box)
|
|
295
|
+
# hard fallback if no box is kept
|
|
296
|
+
if crop_boxes.shape[0] == 0:
|
|
297
|
+
return img, target
|
|
298
|
+
# clip boxes
|
|
299
|
+
return croped_img, np.clip(crop_boxes, 0, 1)
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from typing import
|
|
7
|
+
from typing import Optional, Tuple, Union
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import torch
|
|
@@ -15,7 +15,7 @@ from torchvision.transforms import transforms as T
|
|
|
15
15
|
|
|
16
16
|
from ..functional.pytorch import random_shadow
|
|
17
17
|
|
|
18
|
-
__all__ = ["Resize", "GaussianNoise", "ChannelShuffle", "RandomHorizontalFlip", "RandomShadow"]
|
|
18
|
+
__all__ = ["Resize", "GaussianNoise", "ChannelShuffle", "RandomHorizontalFlip", "RandomShadow", "RandomResize"]
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class Resize(T.Resize):
|
|
@@ -135,9 +135,9 @@ class GaussianNoise(torch.nn.Module):
|
|
|
135
135
|
# Reshape the distribution
|
|
136
136
|
noise = self.mean + 2 * self.std * torch.rand(x.shape, device=x.device) - self.std
|
|
137
137
|
if x.dtype == torch.uint8:
|
|
138
|
-
return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8)
|
|
138
|
+
return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8)
|
|
139
139
|
else:
|
|
140
|
-
return (x + noise.to(dtype=x.dtype)).clamp(0, 1)
|
|
140
|
+
return (x + noise.to(dtype=x.dtype)).clamp(0, 1)
|
|
141
141
|
|
|
142
142
|
def extra_repr(self) -> str:
|
|
143
143
|
return f"mean={self.mean}, std={self.std}"
|
|
@@ -159,13 +159,16 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip):
|
|
|
159
159
|
"""Randomly flip the input image horizontally"""
|
|
160
160
|
|
|
161
161
|
def forward(
|
|
162
|
-
self, img: Union[torch.Tensor, Image], target:
|
|
163
|
-
) -> Tuple[Union[torch.Tensor, Image],
|
|
162
|
+
self, img: Union[torch.Tensor, Image], target: np.ndarray
|
|
163
|
+
) -> Tuple[Union[torch.Tensor, Image], np.ndarray]:
|
|
164
164
|
if torch.rand(1) < self.p:
|
|
165
165
|
_img = F.hflip(img)
|
|
166
166
|
_target = target.copy()
|
|
167
167
|
# Changing the relative bbox coordinates
|
|
168
|
-
|
|
168
|
+
if target.shape[1:] == (4,):
|
|
169
|
+
_target[:, ::2] = 1 - target[:, [2, 0]]
|
|
170
|
+
else:
|
|
171
|
+
_target[..., 0] = 1 - target[..., 0]
|
|
169
172
|
return _img, _target
|
|
170
173
|
return img, target
|
|
171
174
|
|
|
@@ -199,7 +202,7 @@ class RandomShadow(torch.nn.Module):
|
|
|
199
202
|
self.opacity_range,
|
|
200
203
|
)
|
|
201
204
|
)
|
|
202
|
-
.round()
|
|
205
|
+
.round()
|
|
203
206
|
.clip(0, 255)
|
|
204
207
|
.to(dtype=torch.uint8)
|
|
205
208
|
)
|
|
@@ -210,3 +213,58 @@ class RandomShadow(torch.nn.Module):
|
|
|
210
213
|
|
|
211
214
|
def extra_repr(self) -> str:
|
|
212
215
|
return f"opacity_range={self.opacity_range}"
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class RandomResize(torch.nn.Module):
|
|
219
|
+
"""Randomly resize the input image and align corresponding targets
|
|
220
|
+
|
|
221
|
+
>>> import torch
|
|
222
|
+
>>> from doctr.transforms import RandomResize
|
|
223
|
+
>>> transfo = RandomResize((0.3, 0.9), preserve_aspect_ratio=True, symmetric_pad=True, p=0.5)
|
|
224
|
+
>>> out = transfo(torch.rand((3, 64, 64)))
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
----
|
|
228
|
+
scale_range: range of the resizing factor for width and height (independently)
|
|
229
|
+
preserve_aspect_ratio: whether to preserve the aspect ratio of the image,
|
|
230
|
+
given a float value, the aspect ratio will be preserved with this probability
|
|
231
|
+
symmetric_pad: whether to symmetrically pad the image,
|
|
232
|
+
given a float value, the symmetric padding will be applied with this probability
|
|
233
|
+
p: probability to apply the transformation
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
def __init__(
|
|
237
|
+
self,
|
|
238
|
+
scale_range: Tuple[float, float] = (0.3, 0.9),
|
|
239
|
+
preserve_aspect_ratio: Union[bool, float] = False,
|
|
240
|
+
symmetric_pad: Union[bool, float] = False,
|
|
241
|
+
p: float = 0.5,
|
|
242
|
+
) -> None:
|
|
243
|
+
super().__init__()
|
|
244
|
+
self.scale_range = scale_range
|
|
245
|
+
self.preserve_aspect_ratio = preserve_aspect_ratio
|
|
246
|
+
self.symmetric_pad = symmetric_pad
|
|
247
|
+
self.p = p
|
|
248
|
+
self._resize = Resize
|
|
249
|
+
|
|
250
|
+
def forward(self, img: torch.Tensor, target: np.ndarray) -> Tuple[torch.Tensor, np.ndarray]:
|
|
251
|
+
if torch.rand(1) < self.p:
|
|
252
|
+
scale_h = np.random.uniform(*self.scale_range)
|
|
253
|
+
scale_w = np.random.uniform(*self.scale_range)
|
|
254
|
+
new_size = (int(img.shape[-2] * scale_h), int(img.shape[-1] * scale_w))
|
|
255
|
+
|
|
256
|
+
_img, _target = self._resize(
|
|
257
|
+
new_size,
|
|
258
|
+
preserve_aspect_ratio=self.preserve_aspect_ratio
|
|
259
|
+
if isinstance(self.preserve_aspect_ratio, bool)
|
|
260
|
+
else bool(torch.rand(1) <= self.symmetric_pad),
|
|
261
|
+
symmetric_pad=self.symmetric_pad
|
|
262
|
+
if isinstance(self.symmetric_pad, bool)
|
|
263
|
+
else bool(torch.rand(1) <= self.symmetric_pad),
|
|
264
|
+
)(img, target)
|
|
265
|
+
|
|
266
|
+
return _img, _target
|
|
267
|
+
return img, target
|
|
268
|
+
|
|
269
|
+
def extra_repr(self) -> str:
|
|
270
|
+
return f"scale_range={self.scale_range}, preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}, p={self.p}" # noqa: E501
|