python-doctr 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +2 -1
- doctr/datasets/funsd.py +2 -2
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +4 -1
- doctr/datasets/imgur5k.py +9 -2
- doctr/datasets/loader.py +1 -1
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +2 -2
- doctr/datasets/synthtext.py +15 -2
- doctr/datasets/utils.py +7 -6
- doctr/datasets/vocabs.py +1102 -54
- doctr/file_utils.py +9 -0
- doctr/io/elements.py +37 -3
- doctr/models/_utils.py +1 -1
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/pytorch.py +1 -2
- doctr/models/classification/magc_resnet/tensorflow.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +15 -1
- doctr/models/classification/mobilenet/tensorflow.py +11 -2
- doctr/models/classification/predictor/pytorch.py +1 -1
- doctr/models/classification/resnet/pytorch.py +26 -3
- doctr/models/classification/resnet/tensorflow.py +25 -4
- doctr/models/classification/textnet/pytorch.py +10 -1
- doctr/models/classification/textnet/tensorflow.py +11 -2
- doctr/models/classification/vgg/pytorch.py +16 -1
- doctr/models/classification/vgg/tensorflow.py +11 -2
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/pytorch.py +10 -1
- doctr/models/classification/vit/tensorflow.py +9 -0
- doctr/models/classification/zoo.py +4 -0
- doctr/models/detection/differentiable_binarization/base.py +3 -4
- doctr/models/detection/differentiable_binarization/pytorch.py +10 -1
- doctr/models/detection/differentiable_binarization/tensorflow.py +11 -4
- doctr/models/detection/fast/base.py +2 -3
- doctr/models/detection/fast/pytorch.py +13 -4
- doctr/models/detection/fast/tensorflow.py +10 -2
- doctr/models/detection/linknet/base.py +2 -3
- doctr/models/detection/linknet/pytorch.py +10 -1
- doctr/models/detection/linknet/tensorflow.py +10 -2
- doctr/models/factory/hub.py +3 -3
- doctr/models/kie_predictor/pytorch.py +1 -1
- doctr/models/kie_predictor/tensorflow.py +1 -1
- doctr/models/modules/layers/pytorch.py +49 -1
- doctr/models/predictor/pytorch.py +1 -1
- doctr/models/predictor/tensorflow.py +1 -1
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/crnn/pytorch.py +10 -1
- doctr/models/recognition/crnn/tensorflow.py +10 -1
- doctr/models/recognition/master/pytorch.py +10 -1
- doctr/models/recognition/master/tensorflow.py +10 -3
- doctr/models/recognition/parseq/pytorch.py +23 -5
- doctr/models/recognition/parseq/tensorflow.py +13 -5
- doctr/models/recognition/predictor/_utils.py +107 -45
- doctr/models/recognition/predictor/pytorch.py +3 -3
- doctr/models/recognition/predictor/tensorflow.py +3 -3
- doctr/models/recognition/sar/pytorch.py +10 -1
- doctr/models/recognition/sar/tensorflow.py +10 -3
- doctr/models/recognition/utils.py +56 -47
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/pytorch.py +10 -1
- doctr/models/recognition/vitstr/tensorflow.py +10 -3
- doctr/models/recognition/zoo.py +5 -0
- doctr/models/utils/pytorch.py +28 -18
- doctr/models/utils/tensorflow.py +15 -8
- doctr/utils/data.py +1 -1
- doctr/utils/geometry.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +19 -3
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/RECORD +82 -75
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
|
@@ -153,6 +153,15 @@ class FAST(_FAST, Model, NestedObject):
|
|
|
153
153
|
# Pooling layer as erosion reversal as described in the paper
|
|
154
154
|
self.pooling = layers.MaxPooling2D(pool_size=pooling_size // 2 + 1, strides=1, padding="same")
|
|
155
155
|
|
|
156
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
157
|
+
"""Load pretrained parameters onto the model
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
161
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
162
|
+
"""
|
|
163
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
164
|
+
|
|
156
165
|
def compute_loss(
|
|
157
166
|
self,
|
|
158
167
|
out_map: tf.Tensor,
|
|
@@ -332,8 +341,7 @@ def _fast(
|
|
|
332
341
|
# Load pretrained parameters
|
|
333
342
|
if pretrained:
|
|
334
343
|
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
335
|
-
|
|
336
|
-
model,
|
|
344
|
+
model.from_pretrained(
|
|
337
345
|
_cfg["url"],
|
|
338
346
|
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
|
|
339
347
|
)
|
|
@@ -56,9 +56,8 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
56
56
|
area = (rect[1][0] + 1) * (1 + rect[1][1])
|
|
57
57
|
length = 2 * (rect[1][0] + rect[1][1]) + 2
|
|
58
58
|
else:
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
length = poly.length
|
|
59
|
+
area = cv2.contourArea(points)
|
|
60
|
+
length = cv2.arcLength(points, closed=True)
|
|
62
61
|
distance = area * self.unclip_ratio / length # compute distance to expand polygon
|
|
63
62
|
offset = pyclipper.PyclipperOffset()
|
|
64
63
|
offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
|
@@ -160,6 +160,15 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
160
160
|
m.weight.data.fill_(1.0)
|
|
161
161
|
m.bias.data.zero_()
|
|
162
162
|
|
|
163
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
164
|
+
"""Load pretrained parameters onto the model
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
168
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
169
|
+
"""
|
|
170
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
171
|
+
|
|
163
172
|
def forward(
|
|
164
173
|
self,
|
|
165
174
|
x: torch.Tensor,
|
|
@@ -282,7 +291,7 @@ def _linknet(
|
|
|
282
291
|
_ignore_keys = (
|
|
283
292
|
ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
|
|
284
293
|
)
|
|
285
|
-
|
|
294
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
286
295
|
|
|
287
296
|
return model
|
|
288
297
|
|
|
@@ -163,6 +163,15 @@ class LinkNet(_LinkNet, Model):
|
|
|
163
163
|
assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
|
|
164
164
|
)
|
|
165
165
|
|
|
166
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
167
|
+
"""Load pretrained parameters onto the model
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
171
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
172
|
+
"""
|
|
173
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
174
|
+
|
|
166
175
|
def compute_loss(
|
|
167
176
|
self,
|
|
168
177
|
out_map: tf.Tensor,
|
|
@@ -282,8 +291,7 @@ def _linknet(
|
|
|
282
291
|
# Load pretrained parameters
|
|
283
292
|
if pretrained:
|
|
284
293
|
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
285
|
-
|
|
286
|
-
model,
|
|
294
|
+
model.from_pretrained(
|
|
287
295
|
_cfg["url"],
|
|
288
296
|
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
|
|
289
297
|
)
|
doctr/models/factory/hub.py
CHANGED
|
@@ -217,10 +217,10 @@ def from_hub(repo_id: str, **kwargs: Any):
|
|
|
217
217
|
|
|
218
218
|
# Load checkpoint
|
|
219
219
|
if is_torch_available():
|
|
220
|
-
|
|
221
|
-
model.load_state_dict(state_dict)
|
|
220
|
+
weights = hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs)
|
|
222
221
|
else: # tf
|
|
223
222
|
weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
|
|
224
|
-
|
|
223
|
+
|
|
224
|
+
model.from_pretrained(weights)
|
|
225
225
|
|
|
226
226
|
return model
|
|
@@ -173,7 +173,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
173
173
|
boxes_per_page,
|
|
174
174
|
objectness_scores_per_page,
|
|
175
175
|
text_preds_per_page,
|
|
176
|
-
origin_page_shapes,
|
|
176
|
+
origin_page_shapes,
|
|
177
177
|
crop_orientations_per_page,
|
|
178
178
|
orientations,
|
|
179
179
|
languages_dict,
|
|
@@ -171,7 +171,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
171
171
|
boxes_per_page,
|
|
172
172
|
objectness_scores_per_page,
|
|
173
173
|
text_preds_per_page,
|
|
174
|
-
origin_page_shapes,
|
|
174
|
+
origin_page_shapes,
|
|
175
175
|
crop_orientations_per_page,
|
|
176
176
|
orientations,
|
|
177
177
|
languages_dict,
|
|
@@ -8,7 +8,55 @@ import numpy as np
|
|
|
8
8
|
import torch
|
|
9
9
|
import torch.nn as nn
|
|
10
10
|
|
|
11
|
-
__all__ = ["FASTConvLayer"]
|
|
11
|
+
__all__ = ["FASTConvLayer", "DropPath", "AdaptiveAvgPool2d"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DropPath(nn.Module):
|
|
15
|
+
"""
|
|
16
|
+
DropPath (Drop Connect) layer. This is a stochastic version of the identity layer.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
# Borrowed from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
|
20
|
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
|
21
|
+
super(DropPath, self).__init__()
|
|
22
|
+
self.drop_prob = drop_prob
|
|
23
|
+
self.scale_by_keep = scale_by_keep
|
|
24
|
+
|
|
25
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
26
|
+
if self.drop_prob == 0.0 or not self.training:
|
|
27
|
+
return x
|
|
28
|
+
keep_prob = 1 - self.drop_prob
|
|
29
|
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with different dimensions
|
|
30
|
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
|
31
|
+
if keep_prob > 0.0 and self.scale_by_keep:
|
|
32
|
+
random_tensor.div_(keep_prob)
|
|
33
|
+
return x * random_tensor
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class AdaptiveAvgPool2d(nn.Module):
|
|
37
|
+
"""
|
|
38
|
+
Custom AdaptiveAvgPool2d implementation which is ONNX and `torch.compile` compatible.
|
|
39
|
+
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, output_size):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.output_size = output_size
|
|
45
|
+
|
|
46
|
+
def forward(self, x: torch.Tensor):
|
|
47
|
+
H_out, W_out = self.output_size
|
|
48
|
+
N, C, H, W = x.shape
|
|
49
|
+
|
|
50
|
+
out = torch.empty((N, C, H_out, W_out), device=x.device, dtype=x.dtype)
|
|
51
|
+
for oh in range(H_out):
|
|
52
|
+
start_h = (oh * H) // H_out
|
|
53
|
+
end_h = ((oh + 1) * H + H_out - 1) // H_out # ceil((oh+1)*H / H_out)
|
|
54
|
+
for ow in range(W_out):
|
|
55
|
+
start_w = (ow * W) // W_out
|
|
56
|
+
end_w = ((ow + 1) * W + W_out - 1) // W_out # ceil((ow+1)*W / W_out)
|
|
57
|
+
# average over the window
|
|
58
|
+
out[:, :, oh, ow] = x[:, :, start_h:end_h, start_w:end_w].mean(dim=(-2, -1))
|
|
59
|
+
return out
|
|
12
60
|
|
|
13
61
|
|
|
14
62
|
class FASTConvLayer(nn.Module):
|
|
@@ -155,6 +155,15 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
155
155
|
m.weight.data.fill_(1.0)
|
|
156
156
|
m.bias.data.zero_()
|
|
157
157
|
|
|
158
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
159
|
+
"""Load pretrained parameters onto the model
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
163
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
164
|
+
"""
|
|
165
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
166
|
+
|
|
158
167
|
def compute_loss(
|
|
159
168
|
self,
|
|
160
169
|
model_output: torch.Tensor,
|
|
@@ -254,7 +263,7 @@ def _crnn(
|
|
|
254
263
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
255
264
|
# remove the last layer weights
|
|
256
265
|
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
257
|
-
|
|
266
|
+
model.from_pretrained(_cfg["url"], ignore_keys=_ignore_keys)
|
|
258
267
|
|
|
259
268
|
return model
|
|
260
269
|
|
|
@@ -154,6 +154,15 @@ class CRNN(RecognitionModel, Model):
|
|
|
154
154
|
self.beam_width = beam_width
|
|
155
155
|
self.top_paths = top_paths
|
|
156
156
|
|
|
157
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
158
|
+
"""Load pretrained parameters onto the model
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
162
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
163
|
+
"""
|
|
164
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
165
|
+
|
|
157
166
|
def compute_loss(
|
|
158
167
|
self,
|
|
159
168
|
model_output: tf.Tensor,
|
|
@@ -243,7 +252,7 @@ def _crnn(
|
|
|
243
252
|
# Load pretrained parameters
|
|
244
253
|
if pretrained:
|
|
245
254
|
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
246
|
-
|
|
255
|
+
model.from_pretrained(_cfg["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
|
|
247
256
|
|
|
248
257
|
return model
|
|
249
258
|
|
|
@@ -151,6 +151,15 @@ class MASTER(_MASTER, nn.Module):
|
|
|
151
151
|
ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype)
|
|
152
152
|
return ce_loss.mean()
|
|
153
153
|
|
|
154
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
155
|
+
"""Load pretrained parameters onto the model
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
159
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
160
|
+
"""
|
|
161
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
162
|
+
|
|
154
163
|
def forward(
|
|
155
164
|
self,
|
|
156
165
|
x: torch.Tensor,
|
|
@@ -301,7 +310,7 @@ def _master(
|
|
|
301
310
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
302
311
|
# remove the last layer weights
|
|
303
312
|
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
304
|
-
|
|
313
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
305
314
|
|
|
306
315
|
return model
|
|
307
316
|
|
|
@@ -87,6 +87,15 @@ class MASTER(_MASTER, Model):
|
|
|
87
87
|
self.linear = layers.Dense(self.vocab_size + 3, kernel_initializer=tf.initializers.he_uniform())
|
|
88
88
|
self.postprocessor = MASTERPostProcessor(vocab=self.vocab)
|
|
89
89
|
|
|
90
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
91
|
+
"""Load pretrained parameters onto the model
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
95
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
96
|
+
"""
|
|
97
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
98
|
+
|
|
90
99
|
@tf.function
|
|
91
100
|
def make_source_and_target_mask(self, source: tf.Tensor, target: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
|
|
92
101
|
# [1, 1, 1, ..., 0, 0, 0] -> 0 is masked
|
|
@@ -287,9 +296,7 @@ def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool
|
|
|
287
296
|
# Load pretrained parameters
|
|
288
297
|
if pretrained:
|
|
289
298
|
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
290
|
-
|
|
291
|
-
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
292
|
-
)
|
|
299
|
+
model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
|
|
293
300
|
|
|
294
301
|
return model
|
|
295
302
|
|
|
@@ -76,8 +76,6 @@ class PARSeqDecoder(nn.Module):
|
|
|
76
76
|
self.cross_attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
|
|
77
77
|
self.position_feed_forward = PositionwiseFeedForward(d_model, ffd * ffd_ratio, dropout, nn.GELU())
|
|
78
78
|
|
|
79
|
-
self.attention_norm = nn.LayerNorm(d_model, eps=1e-5)
|
|
80
|
-
self.cross_attention_norm = nn.LayerNorm(d_model, eps=1e-5)
|
|
81
79
|
self.query_norm = nn.LayerNorm(d_model, eps=1e-5)
|
|
82
80
|
self.content_norm = nn.LayerNorm(d_model, eps=1e-5)
|
|
83
81
|
self.feed_forward_norm = nn.LayerNorm(d_model, eps=1e-5)
|
|
@@ -173,6 +171,26 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
173
171
|
nn.init.constant_(m.weight, 1)
|
|
174
172
|
nn.init.constant_(m.bias, 0)
|
|
175
173
|
|
|
174
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
175
|
+
"""Load pretrained parameters onto the model
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
179
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
180
|
+
"""
|
|
181
|
+
# NOTE: This is required to make the model backward compatible with already trained models docTR version <0.11.1
|
|
182
|
+
# ref.: https://github.com/mindee/doctr/issues/1911
|
|
183
|
+
if kwargs.get("ignore_keys") is None:
|
|
184
|
+
kwargs["ignore_keys"] = []
|
|
185
|
+
|
|
186
|
+
kwargs["ignore_keys"].extend([
|
|
187
|
+
"decoder.attention_norm.weight",
|
|
188
|
+
"decoder.attention_norm.bias",
|
|
189
|
+
"decoder.cross_attention_norm.weight",
|
|
190
|
+
"decoder.cross_attention_norm.bias",
|
|
191
|
+
])
|
|
192
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
193
|
+
|
|
176
194
|
def generate_permutations(self, seqlen: torch.Tensor) -> torch.Tensor:
|
|
177
195
|
# Generates permutations of the target sequence.
|
|
178
196
|
# Borrowed from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
|
|
@@ -210,7 +228,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
210
228
|
|
|
211
229
|
sos_idx = torch.zeros(len(final_perms), 1, device=seqlen.device)
|
|
212
230
|
eos_idx = torch.full((len(final_perms), 1), max_num_chars + 1, device=seqlen.device)
|
|
213
|
-
combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
|
|
231
|
+
combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
|
|
214
232
|
if len(combined) > 1:
|
|
215
233
|
combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
|
|
216
234
|
return combined
|
|
@@ -349,7 +367,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
349
367
|
# remove the [EOS] tokens for the succeeding perms
|
|
350
368
|
if i == 1:
|
|
351
369
|
gt_out = torch.where(gt_out == self.vocab_size, self.vocab_size + 2, gt_out)
|
|
352
|
-
n = (gt_out != self.vocab_size + 2).sum().item()
|
|
370
|
+
n = (gt_out != self.vocab_size + 2).sum().item() # type: ignore[attr-defined]
|
|
353
371
|
|
|
354
372
|
loss /= loss_numel
|
|
355
373
|
|
|
@@ -448,7 +466,7 @@ def _parseq(
|
|
|
448
466
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
449
467
|
# remove the last layer weights
|
|
450
468
|
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
|
|
451
|
-
|
|
469
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
452
470
|
|
|
453
471
|
return model
|
|
454
472
|
|
|
@@ -76,8 +76,6 @@ class PARSeqDecoder(layers.Layer):
|
|
|
76
76
|
d_model, ffd * ffd_ratio, dropout, layers.Activation(tf.nn.gelu)
|
|
77
77
|
)
|
|
78
78
|
|
|
79
|
-
self.attention_norm = layers.LayerNormalization(epsilon=1e-5)
|
|
80
|
-
self.cross_attention_norm = layers.LayerNormalization(epsilon=1e-5)
|
|
81
79
|
self.query_norm = layers.LayerNormalization(epsilon=1e-5)
|
|
82
80
|
self.content_norm = layers.LayerNormalization(epsilon=1e-5)
|
|
83
81
|
self.feed_forward_norm = layers.LayerNormalization(epsilon=1e-5)
|
|
@@ -165,6 +163,18 @@ class PARSeq(_PARSeq, Model):
|
|
|
165
163
|
|
|
166
164
|
self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
|
|
167
165
|
|
|
166
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
167
|
+
"""Load pretrained parameters onto the model
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
171
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
172
|
+
"""
|
|
173
|
+
# NOTE: This is required to make the model backward compatible with already trained models docTR version <0.11.1
|
|
174
|
+
# ref.: https://github.com/mindee/doctr/issues/1911
|
|
175
|
+
kwargs["skip_mismatch"] = True
|
|
176
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
177
|
+
|
|
168
178
|
def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
|
|
169
179
|
# Generates permutations of the target sequence.
|
|
170
180
|
# Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
|
|
@@ -474,9 +484,7 @@ def _parseq(
|
|
|
474
484
|
# Load pretrained parameters
|
|
475
485
|
if pretrained:
|
|
476
486
|
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
477
|
-
|
|
478
|
-
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
479
|
-
)
|
|
487
|
+
model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
|
|
480
488
|
|
|
481
489
|
return model
|
|
482
490
|
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
|
+
import math
|
|
8
|
+
|
|
7
9
|
import numpy as np
|
|
8
10
|
|
|
9
11
|
from ..utils import merge_multi_strings
|
|
@@ -15,69 +17,129 @@ def split_crops(
|
|
|
15
17
|
crops: list[np.ndarray],
|
|
16
18
|
max_ratio: float,
|
|
17
19
|
target_ratio: int,
|
|
18
|
-
|
|
20
|
+
split_overlap_ratio: float,
|
|
19
21
|
channels_last: bool = True,
|
|
20
|
-
) -> tuple[list[np.ndarray], list[int | tuple[int, int]], bool]:
|
|
21
|
-
"""
|
|
22
|
+
) -> tuple[list[np.ndarray], list[int | tuple[int, int, float]], bool]:
|
|
23
|
+
"""
|
|
24
|
+
Split crops horizontally if they exceed a given aspect ratio.
|
|
22
25
|
|
|
23
26
|
Args:
|
|
24
|
-
crops:
|
|
25
|
-
max_ratio:
|
|
26
|
-
target_ratio:
|
|
27
|
-
|
|
28
|
-
channels_last:
|
|
27
|
+
crops: List of image crops (H, W, C) if channels_last else (C, H, W).
|
|
28
|
+
max_ratio: Aspect ratio threshold above which crops are split.
|
|
29
|
+
target_ratio: Target aspect ratio after splitting (e.g., 4 for 128x32).
|
|
30
|
+
split_overlap_ratio: Desired overlap between splits (as a fraction of split width).
|
|
31
|
+
channels_last: Whether the crops are in channels-last format.
|
|
29
32
|
|
|
30
33
|
Returns:
|
|
31
|
-
|
|
34
|
+
A tuple containing:
|
|
35
|
+
- The new list of crops (possibly with splits),
|
|
36
|
+
- A mapping indicating how to reassemble predictions,
|
|
37
|
+
- A boolean indicating whether remapping is required.
|
|
32
38
|
"""
|
|
33
|
-
|
|
34
|
-
|
|
39
|
+
if split_overlap_ratio <= 0.0 or split_overlap_ratio >= 1.0:
|
|
40
|
+
raise ValueError(f"Valid range for split_overlap_ratio is (0.0, 1.0), but is: {split_overlap_ratio}")
|
|
41
|
+
|
|
42
|
+
remap_required = False
|
|
35
43
|
new_crops: list[np.ndarray] = []
|
|
44
|
+
crop_map: list[int | tuple[int, int, float]] = []
|
|
45
|
+
|
|
36
46
|
for crop in crops:
|
|
37
47
|
h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
|
|
38
48
|
aspect_ratio = w / h
|
|
49
|
+
|
|
39
50
|
if aspect_ratio > max_ratio:
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
#
|
|
46
|
-
if
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
+
split_width = max(1, math.ceil(h * target_ratio))
|
|
52
|
+
overlap_width = max(0, math.floor(split_width * split_overlap_ratio))
|
|
53
|
+
|
|
54
|
+
splits, last_overlap = _split_horizontally(crop, split_width, overlap_width, channels_last)
|
|
55
|
+
|
|
56
|
+
# Remove any empty splits
|
|
57
|
+
splits = [s for s in splits if all(dim > 0 for dim in s.shape)]
|
|
58
|
+
if splits:
|
|
59
|
+
crop_map.append((len(new_crops), len(new_crops) + len(splits), last_overlap))
|
|
60
|
+
new_crops.extend(splits)
|
|
61
|
+
remap_required = True
|
|
51
62
|
else:
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
]
|
|
56
|
-
# Avoid sending zero-sized crops
|
|
57
|
-
_crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)]
|
|
58
|
-
# Record the slice of crops
|
|
59
|
-
crop_map.append((len(new_crops), len(new_crops) + len(_crops)))
|
|
60
|
-
new_crops.extend(_crops)
|
|
61
|
-
# At least one crop will require merging
|
|
62
|
-
_remap_required = True
|
|
63
|
+
# Fallback: treat it as a single crop
|
|
64
|
+
crop_map.append(len(new_crops))
|
|
65
|
+
new_crops.append(crop)
|
|
63
66
|
else:
|
|
64
67
|
crop_map.append(len(new_crops))
|
|
65
68
|
new_crops.append(crop)
|
|
66
69
|
|
|
67
|
-
return new_crops, crop_map,
|
|
70
|
+
return new_crops, crop_map, remap_required
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _split_horizontally(
|
|
74
|
+
image: np.ndarray, split_width: int, overlap_width: int, channels_last: bool
|
|
75
|
+
) -> tuple[list[np.ndarray], float]:
|
|
76
|
+
"""
|
|
77
|
+
Horizontally split a single image with overlapping regions.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
image: The image to split (H, W, C) if channels_last else (C, H, W).
|
|
81
|
+
split_width: Width of each split.
|
|
82
|
+
overlap_width: Width of the overlapping region.
|
|
83
|
+
channels_last: Whether the image is in channels-last format.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
- A list of horizontal image slices.
|
|
87
|
+
- The actual overlap ratio of the last split.
|
|
88
|
+
"""
|
|
89
|
+
image_width = image.shape[1] if channels_last else image.shape[-1]
|
|
90
|
+
if image_width <= split_width:
|
|
91
|
+
return [image], 0.0
|
|
92
|
+
|
|
93
|
+
# Compute start columns for each split
|
|
94
|
+
step = split_width - overlap_width
|
|
95
|
+
starts = list(range(0, image_width - split_width + 1, step))
|
|
96
|
+
|
|
97
|
+
# Ensure the last patch reaches the end of the image
|
|
98
|
+
if starts[-1] + split_width < image_width:
|
|
99
|
+
starts.append(image_width - split_width)
|
|
100
|
+
|
|
101
|
+
splits = []
|
|
102
|
+
for start_col in starts:
|
|
103
|
+
end_col = start_col + split_width
|
|
104
|
+
if channels_last:
|
|
105
|
+
split = image[:, start_col:end_col, :]
|
|
106
|
+
else:
|
|
107
|
+
split = image[:, :, start_col:end_col]
|
|
108
|
+
splits.append(split)
|
|
109
|
+
|
|
110
|
+
# Calculate the last overlap ratio, if only one split no overlap
|
|
111
|
+
last_overlap = 0
|
|
112
|
+
if len(starts) > 1:
|
|
113
|
+
last_overlap = (starts[-2] + split_width) - starts[-1]
|
|
114
|
+
last_overlap_ratio = last_overlap / split_width if split_width else 0.0
|
|
115
|
+
|
|
116
|
+
return splits, last_overlap_ratio
|
|
68
117
|
|
|
69
118
|
|
|
70
119
|
def remap_preds(
|
|
71
|
-
preds: list[tuple[str, float]],
|
|
120
|
+
preds: list[tuple[str, float]],
|
|
121
|
+
crop_map: list[int | tuple[int, int, float]],
|
|
122
|
+
overlap_ratio: float,
|
|
72
123
|
) -> list[tuple[str, float]]:
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
124
|
+
"""
|
|
125
|
+
Reconstruct predictions from possibly split crops.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
preds: List of (text, confidence) tuples from each crop.
|
|
129
|
+
crop_map: Map returned by `split_crops`.
|
|
130
|
+
overlap_ratio: Overlap ratio used during splitting.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
List of merged (text, confidence) tuples corresponding to original crops.
|
|
134
|
+
"""
|
|
135
|
+
remapped = []
|
|
136
|
+
for item in crop_map:
|
|
137
|
+
if isinstance(item, int):
|
|
138
|
+
remapped.append(preds[item])
|
|
78
139
|
else:
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
140
|
+
start_idx, end_idx, last_overlap = item
|
|
141
|
+
text_parts, confidences = zip(*preds[start_idx:end_idx])
|
|
142
|
+
merged_text = merge_multi_strings(list(text_parts), overlap_ratio, last_overlap)
|
|
143
|
+
merged_conf = sum(confidences) / len(confidences) # average confidence
|
|
144
|
+
remapped.append((merged_text, merged_conf))
|
|
145
|
+
return remapped
|
|
@@ -38,7 +38,7 @@ class RecognitionPredictor(nn.Module):
|
|
|
38
38
|
self.model = model.eval()
|
|
39
39
|
self.split_wide_crops = split_wide_crops
|
|
40
40
|
self.critical_ar = 8 # Critical aspect ratio
|
|
41
|
-
self.
|
|
41
|
+
self.overlap_ratio = 0.5 # Ratio of overlap between neighboring crops
|
|
42
42
|
self.target_ar = 6 # Target aspect ratio
|
|
43
43
|
|
|
44
44
|
@torch.inference_mode()
|
|
@@ -60,7 +60,7 @@ class RecognitionPredictor(nn.Module):
|
|
|
60
60
|
crops, # type: ignore[arg-type]
|
|
61
61
|
self.critical_ar,
|
|
62
62
|
self.target_ar,
|
|
63
|
-
self.
|
|
63
|
+
self.overlap_ratio,
|
|
64
64
|
isinstance(crops[0], np.ndarray),
|
|
65
65
|
)
|
|
66
66
|
if remapped:
|
|
@@ -81,6 +81,6 @@ class RecognitionPredictor(nn.Module):
|
|
|
81
81
|
|
|
82
82
|
# Remap crops
|
|
83
83
|
if self.split_wide_crops and remapped:
|
|
84
|
-
out = remap_preds(out, crop_map, self.
|
|
84
|
+
out = remap_preds(out, crop_map, self.overlap_ratio)
|
|
85
85
|
|
|
86
86
|
return out
|
|
@@ -39,7 +39,7 @@ class RecognitionPredictor(NestedObject):
|
|
|
39
39
|
self.model = model
|
|
40
40
|
self.split_wide_crops = split_wide_crops
|
|
41
41
|
self.critical_ar = 8 # Critical aspect ratio
|
|
42
|
-
self.
|
|
42
|
+
self.overlap_ratio = 0.5 # Ratio of overlap between neighboring crops
|
|
43
43
|
self.target_ar = 6 # Target aspect ratio
|
|
44
44
|
|
|
45
45
|
def __call__(
|
|
@@ -56,7 +56,7 @@ class RecognitionPredictor(NestedObject):
|
|
|
56
56
|
# Split crops that are too wide
|
|
57
57
|
remapped = False
|
|
58
58
|
if self.split_wide_crops:
|
|
59
|
-
new_crops, crop_map, remapped = split_crops(crops, self.critical_ar, self.target_ar, self.
|
|
59
|
+
new_crops, crop_map, remapped = split_crops(crops, self.critical_ar, self.target_ar, self.overlap_ratio)
|
|
60
60
|
if remapped:
|
|
61
61
|
crops = new_crops
|
|
62
62
|
|
|
@@ -74,6 +74,6 @@ class RecognitionPredictor(NestedObject):
|
|
|
74
74
|
|
|
75
75
|
# Remap crops
|
|
76
76
|
if self.split_wide_crops and remapped:
|
|
77
|
-
out = remap_preds(out, crop_map, self.
|
|
77
|
+
out = remap_preds(out, crop_map, self.overlap_ratio)
|
|
78
78
|
|
|
79
79
|
return out
|