python-doctr 0.12.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/contrib/artefacts.py +1 -1
- doctr/contrib/base.py +1 -1
- doctr/datasets/__init__.py +0 -5
- doctr/datasets/coco_text.py +1 -1
- doctr/datasets/cord.py +1 -1
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/base.py +1 -1
- doctr/datasets/datasets/pytorch.py +3 -3
- doctr/datasets/detection.py +1 -1
- doctr/datasets/doc_artefacts.py +1 -1
- doctr/datasets/funsd.py +1 -1
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/generator/base.py +1 -1
- doctr/datasets/generator/pytorch.py +1 -1
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +1 -1
- doctr/datasets/iiit5k.py +1 -1
- doctr/datasets/iiithws.py +1 -1
- doctr/datasets/imgur5k.py +1 -1
- doctr/datasets/mjsynth.py +1 -1
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/orientation.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/sroie.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +1 -1
- doctr/datasets/synthtext.py +1 -1
- doctr/datasets/utils.py +1 -1
- doctr/datasets/vocabs.py +1 -3
- doctr/datasets/wildreceipt.py +1 -1
- doctr/file_utils.py +3 -102
- doctr/io/elements.py +1 -1
- doctr/io/html.py +1 -1
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/base.py +1 -1
- doctr/io/image/pytorch.py +2 -2
- doctr/io/pdf.py +1 -1
- doctr/io/reader.py +1 -1
- doctr/models/_utils.py +56 -18
- doctr/models/builder.py +1 -1
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +3 -3
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/mobilenet/pytorch.py +1 -1
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +2 -2
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/resnet/pytorch.py +1 -1
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +2 -2
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vgg/pytorch.py +1 -1
- doctr/models/classification/vip/__init__.py +1 -4
- doctr/models/classification/vip/layers/__init__.py +1 -4
- doctr/models/classification/vip/layers/pytorch.py +2 -2
- doctr/models/classification/vip/pytorch.py +1 -1
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +3 -3
- doctr/models/classification/zoo.py +7 -12
- doctr/models/core.py +1 -1
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/_utils/base.py +1 -1
- doctr/models/detection/_utils/pytorch.py +1 -1
- doctr/models/detection/core.py +2 -2
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +5 -13
- doctr/models/detection/differentiable_binarization/pytorch.py +4 -4
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +5 -15
- doctr/models/detection/fast/pytorch.py +5 -5
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +4 -13
- doctr/models/detection/linknet/pytorch.py +3 -3
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +2 -2
- doctr/models/detection/zoo.py +16 -33
- doctr/models/factory/hub.py +26 -34
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/base.py +1 -1
- doctr/models/kie_predictor/pytorch.py +3 -7
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +4 -4
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +3 -3
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/modules/vision_transformer/pytorch.py +1 -1
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +4 -9
- doctr/models/predictor/pytorch.py +3 -6
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +28 -33
- doctr/models/recognition/core.py +1 -1
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +7 -7
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/base.py +1 -1
- doctr/models/recognition/master/pytorch.py +6 -6
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/base.py +1 -1
- doctr/models/recognition/parseq/pytorch.py +6 -6
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +8 -17
- doctr/models/recognition/predictor/pytorch.py +2 -3
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +4 -4
- doctr/models/recognition/utils.py +1 -1
- doctr/models/recognition/viptr/__init__.py +1 -4
- doctr/models/recognition/viptr/pytorch.py +4 -4
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/base.py +1 -1
- doctr/models/recognition/vitstr/pytorch.py +4 -4
- doctr/models/recognition/zoo.py +14 -14
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +3 -2
- doctr/models/zoo.py +1 -1
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/base.py +3 -2
- doctr/transforms/functional/pytorch.py +5 -5
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +28 -94
- doctr/transforms/modules/pytorch.py +29 -27
- doctr/utils/common_types.py +1 -1
- doctr/utils/data.py +1 -2
- doctr/utils/fonts.py +1 -1
- doctr/utils/geometry.py +7 -11
- doctr/utils/metrics.py +1 -1
- doctr/utils/multithreading.py +1 -1
- doctr/utils/reconstitution.py +1 -1
- doctr/utils/repr.py +1 -1
- doctr/utils/visualization.py +2 -2
- doctr/version.py +1 -1
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/METADATA +30 -80
- python_doctr-1.0.1.dist-info/RECORD +149 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/WHEEL +1 -1
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -442
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -418
- doctr/models/classification/textnet/tensorflow.py +0 -275
- doctr/models/classification/vgg/tensorflow.py +0 -125
- doctr/models/classification/vit/tensorflow.py +0 -201
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
- doctr/models/detection/fast/tensorflow.py +0 -427
- doctr/models/detection/linknet/tensorflow.py +0 -377
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -317
- doctr/models/recognition/master/tensorflow.py +0 -320
- doctr/models/recognition/parseq/tensorflow.py +0 -516
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -423
- doctr/models/recognition/vitstr/tensorflow.py +0 -285
- doctr/models/utils/tensorflow.py +0 -189
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.12.0.dist-info/RECORD +0 -180
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/top_level.txt +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/zip-safe +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -11,7 +11,7 @@ from torch import nn
|
|
|
11
11
|
|
|
12
12
|
from doctr.datasets import VOCABS
|
|
13
13
|
|
|
14
|
-
from ...modules.layers
|
|
14
|
+
from ...modules.layers import FASTConvLayer
|
|
15
15
|
from ...utils import conv_sequence_pt, load_pretrained_params
|
|
16
16
|
|
|
17
17
|
__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -433,7 +433,7 @@ class LePEAttention(nn.Module):
|
|
|
433
433
|
Returns:
|
|
434
434
|
A float tensor of shape (b, h, w, c).
|
|
435
435
|
"""
|
|
436
|
-
b_merged =
|
|
436
|
+
b_merged = img_splits_hw.shape[0] // ((h * w) // (h_sp * w_sp))
|
|
437
437
|
img = img_splits_hw.view(b_merged, h // h_sp, w // w_sp, h_sp, w_sp, -1)
|
|
438
438
|
# contiguous() required to ensure the tensor has a contiguous memory layout
|
|
439
439
|
# after permute, allowing the subsequent view operation to work correctly.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -11,9 +11,9 @@ from torch import nn
|
|
|
11
11
|
|
|
12
12
|
from doctr.datasets import VOCABS
|
|
13
13
|
from doctr.models.modules.transformer import EncoderBlock
|
|
14
|
-
from doctr.models.modules.vision_transformer
|
|
14
|
+
from doctr.models.modules.vision_transformer import PatchEmbedding
|
|
15
15
|
|
|
16
|
-
from ...utils
|
|
16
|
+
from ...utils import load_pretrained_params
|
|
17
17
|
|
|
18
18
|
__all__ = ["vit_s", "vit_b"]
|
|
19
19
|
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
-
from doctr.
|
|
8
|
+
from doctr.models.utils import _CompiledModule
|
|
9
9
|
|
|
10
10
|
from .. import classification
|
|
11
11
|
from ..preprocessor import PreProcessor
|
|
@@ -30,11 +30,10 @@ ARCHS: list[str] = [
|
|
|
30
30
|
"vgg16_bn_r",
|
|
31
31
|
"vit_s",
|
|
32
32
|
"vit_b",
|
|
33
|
+
"vip_tiny",
|
|
34
|
+
"vip_base",
|
|
33
35
|
]
|
|
34
36
|
|
|
35
|
-
if is_torch_available():
|
|
36
|
-
ARCHS.extend(["vip_tiny", "vip_base"])
|
|
37
|
-
|
|
38
37
|
ORIENTATION_ARCHS: list[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
|
|
39
38
|
|
|
40
39
|
|
|
@@ -52,12 +51,8 @@ def _orientation_predictor(
|
|
|
52
51
|
# Load directly classifier from backbone
|
|
53
52
|
_model = classification.__dict__[arch](pretrained=pretrained)
|
|
54
53
|
else:
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
# Adding the type for torch compiled models to the allowed architectures
|
|
58
|
-
from doctr.models.utils import _CompiledModule
|
|
59
|
-
|
|
60
|
-
allowed_archs.append(_CompiledModule)
|
|
54
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
55
|
+
allowed_archs = [classification.MobileNetV3, _CompiledModule]
|
|
61
56
|
|
|
62
57
|
if not isinstance(arch, tuple(allowed_archs)):
|
|
63
58
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
@@ -66,7 +61,7 @@ def _orientation_predictor(
|
|
|
66
61
|
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
|
|
67
62
|
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
68
63
|
kwargs["batch_size"] = kwargs.get("batch_size", 128 if model_type == "crop" else 4)
|
|
69
|
-
input_shape = _model.cfg["input_shape"][
|
|
64
|
+
input_shape = _model.cfg["input_shape"][1:]
|
|
70
65
|
predictor = OrientationPredictor(
|
|
71
66
|
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
|
|
72
67
|
)
|
doctr/models/core.py
CHANGED
doctr/models/detection/core.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -53,7 +53,7 @@ class DetectionPostProcessor(NestedObject):
|
|
|
53
53
|
|
|
54
54
|
else:
|
|
55
55
|
mask: np.ndarray = np.zeros((h, w), np.int32)
|
|
56
|
-
cv2.fillPoly(mask, [points.astype(np.int32)], 1.0)
|
|
56
|
+
cv2.fillPoly(mask, [points.astype(np.int32)], 1.0)
|
|
57
57
|
product = pred * mask
|
|
58
58
|
return np.sum(product) / np.count_nonzero(product)
|
|
59
59
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -224,7 +224,7 @@ class _DBNet:
|
|
|
224
224
|
padded_polygon: np.ndarray = np.array(padding.Execute(distance)[0])
|
|
225
225
|
|
|
226
226
|
# Fill the mask with 1 on the new padded polygon
|
|
227
|
-
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
|
|
227
|
+
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
|
|
228
228
|
|
|
229
229
|
# Get min/max to recover polygon after distance computation
|
|
230
230
|
xmin = padded_polygon[:, 0].min()
|
|
@@ -269,7 +269,6 @@ class _DBNet:
|
|
|
269
269
|
self,
|
|
270
270
|
target: list[dict[str, np.ndarray]],
|
|
271
271
|
output_shape: tuple[int, int, int],
|
|
272
|
-
channels_last: bool = True,
|
|
273
272
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
274
273
|
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
|
|
275
274
|
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
|
|
@@ -280,10 +279,8 @@ class _DBNet:
|
|
|
280
279
|
|
|
281
280
|
h: int
|
|
282
281
|
w: int
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
else:
|
|
286
|
-
num_classes, h, w = output_shape
|
|
282
|
+
|
|
283
|
+
num_classes, h, w = output_shape
|
|
287
284
|
target_shape = (len(target), num_classes, h, w)
|
|
288
285
|
|
|
289
286
|
seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
|
|
@@ -343,17 +340,12 @@ class _DBNet:
|
|
|
343
340
|
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
|
|
344
341
|
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
|
|
345
342
|
continue
|
|
346
|
-
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
343
|
+
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
347
344
|
|
|
348
345
|
# Draw on both thresh map and thresh mask
|
|
349
346
|
poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map(
|
|
350
347
|
poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx]
|
|
351
348
|
)
|
|
352
|
-
if channels_last:
|
|
353
|
-
seg_target = seg_target.transpose((0, 2, 3, 1))
|
|
354
|
-
seg_mask = seg_mask.transpose((0, 2, 3, 1))
|
|
355
|
-
thresh_target = thresh_target.transpose((0, 2, 3, 1))
|
|
356
|
-
thresh_mask = thresh_mask.transpose((0, 2, 3, 1))
|
|
357
349
|
|
|
358
350
|
thresh_target = thresh_target.astype(input_dtype) * (self.thresh_max - self.thresh_min) + self.thresh_min
|
|
359
351
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -215,7 +215,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
215
215
|
|
|
216
216
|
if target is None or return_preds:
|
|
217
217
|
# Disable for torch.compile compatibility
|
|
218
|
-
@torch.compiler.disable
|
|
218
|
+
@torch.compiler.disable
|
|
219
219
|
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
220
220
|
return [
|
|
221
221
|
dict(zip(self.class_names, preds))
|
|
@@ -261,7 +261,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
261
261
|
prob_map = torch.sigmoid(out_map)
|
|
262
262
|
thresh_map = torch.sigmoid(thresh_map)
|
|
263
263
|
|
|
264
|
-
targets = self.build_target(target, out_map.shape[1:]
|
|
264
|
+
targets = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
|
|
265
265
|
|
|
266
266
|
seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
|
|
267
267
|
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
|
|
@@ -285,7 +285,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
285
285
|
dice_map = torch.softmax(out_map, dim=1)
|
|
286
286
|
else:
|
|
287
287
|
# compute binary map instead
|
|
288
|
-
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
|
|
288
|
+
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
|
|
289
289
|
# Class reduced
|
|
290
290
|
inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
|
|
291
291
|
cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -153,14 +153,12 @@ class _FAST(BaseModel):
|
|
|
153
153
|
self,
|
|
154
154
|
target: list[dict[str, np.ndarray]],
|
|
155
155
|
output_shape: tuple[int, int, int],
|
|
156
|
-
channels_last: bool = True,
|
|
157
156
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
158
157
|
"""Build the target, and it's mask to be used from loss computation.
|
|
159
158
|
|
|
160
159
|
Args:
|
|
161
160
|
target: target coming from dataset
|
|
162
161
|
output_shape: shape of the output of the model without batch_size
|
|
163
|
-
channels_last: whether channels are last or not
|
|
164
162
|
|
|
165
163
|
Returns:
|
|
166
164
|
the new formatted target, mask and shrunken text kernel
|
|
@@ -172,10 +170,8 @@ class _FAST(BaseModel):
|
|
|
172
170
|
|
|
173
171
|
h: int
|
|
174
172
|
w: int
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
else:
|
|
178
|
-
num_classes, h, w = output_shape
|
|
173
|
+
|
|
174
|
+
num_classes, h, w = output_shape
|
|
179
175
|
target_shape = (len(target), num_classes, h, w)
|
|
180
176
|
|
|
181
177
|
seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
|
|
@@ -235,14 +231,8 @@ class _FAST(BaseModel):
|
|
|
235
231
|
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
|
|
236
232
|
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
|
|
237
233
|
continue
|
|
238
|
-
cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
234
|
+
cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
239
235
|
# draw the original polygon on the segmentation target
|
|
240
|
-
cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0)
|
|
241
|
-
|
|
242
|
-
# Don't forget to switch back to channel last if Tensorflow is used
|
|
243
|
-
if channels_last:
|
|
244
|
-
seg_target = seg_target.transpose((0, 2, 3, 1))
|
|
245
|
-
seg_mask = seg_mask.transpose((0, 2, 3, 1))
|
|
246
|
-
shrunken_kernel = shrunken_kernel.transpose((0, 2, 3, 1))
|
|
236
|
+
cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0)
|
|
247
237
|
|
|
248
238
|
return seg_target, seg_mask, shrunken_kernel
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -206,7 +206,7 @@ class FAST(_FAST, nn.Module):
|
|
|
206
206
|
|
|
207
207
|
if target is None or return_preds:
|
|
208
208
|
# Disable for torch.compile compatibility
|
|
209
|
-
@torch.compiler.disable
|
|
209
|
+
@torch.compiler.disable
|
|
210
210
|
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
211
211
|
return [
|
|
212
212
|
dict(zip(self.class_names, preds))
|
|
@@ -238,7 +238,7 @@ class FAST(_FAST, nn.Module):
|
|
|
238
238
|
Returns:
|
|
239
239
|
A loss tensor
|
|
240
240
|
"""
|
|
241
|
-
targets = self.build_target(target, out_map.shape[1:]
|
|
241
|
+
targets = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
|
|
242
242
|
|
|
243
243
|
seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
|
|
244
244
|
shrunken_kernel = torch.from_numpy(targets[2]).to(out_map.device)
|
|
@@ -303,7 +303,7 @@ def reparameterize(model: FAST | nn.Module) -> FAST:
|
|
|
303
303
|
|
|
304
304
|
for module in model.modules():
|
|
305
305
|
if hasattr(module, "reparameterize_layer"):
|
|
306
|
-
module.reparameterize_layer()
|
|
306
|
+
module.reparameterize_layer() # type: ignore[operator]
|
|
307
307
|
|
|
308
308
|
for name, child in model.named_children():
|
|
309
309
|
if isinstance(child, nn.BatchNorm2d):
|
|
@@ -315,7 +315,7 @@ def reparameterize(model: FAST | nn.Module) -> FAST:
|
|
|
315
315
|
|
|
316
316
|
factor = child.weight / torch.sqrt(child.running_var + child.eps) # type: ignore
|
|
317
317
|
last_conv.weight = nn.Parameter(conv_w * factor.reshape([last_conv.out_channels, 1, 1, 1]))
|
|
318
|
-
last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias)
|
|
318
|
+
last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias) # type: ignore[operator]
|
|
319
319
|
model._modules[last_conv_name] = last_conv # type: ignore[index]
|
|
320
320
|
model._modules[name] = nn.Identity()
|
|
321
321
|
last_conv = None
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -156,14 +156,12 @@ class _LinkNet(BaseModel):
|
|
|
156
156
|
self,
|
|
157
157
|
target: list[dict[str, np.ndarray]],
|
|
158
158
|
output_shape: tuple[int, int, int],
|
|
159
|
-
channels_last: bool = True,
|
|
160
159
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
161
160
|
"""Build the target, and it's mask to be used from loss computation.
|
|
162
161
|
|
|
163
162
|
Args:
|
|
164
163
|
target: target coming from dataset
|
|
165
164
|
output_shape: shape of the output of the model without batch_size
|
|
166
|
-
channels_last: whether channels are last or not
|
|
167
165
|
|
|
168
166
|
Returns:
|
|
169
167
|
the new formatted target and the mask
|
|
@@ -175,10 +173,8 @@ class _LinkNet(BaseModel):
|
|
|
175
173
|
|
|
176
174
|
h: int
|
|
177
175
|
w: int
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
else:
|
|
181
|
-
num_classes, h, w = output_shape
|
|
176
|
+
|
|
177
|
+
num_classes, h, w = output_shape
|
|
182
178
|
target_shape = (len(target), num_classes, h, w)
|
|
183
179
|
|
|
184
180
|
seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
|
|
@@ -237,11 +233,6 @@ class _LinkNet(BaseModel):
|
|
|
237
233
|
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
|
|
238
234
|
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
|
|
239
235
|
continue
|
|
240
|
-
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
241
|
-
|
|
242
|
-
# Don't forget to switch back to channel last if Tensorflow is used
|
|
243
|
-
if channels_last:
|
|
244
|
-
seg_target = seg_target.transpose((0, 2, 3, 1))
|
|
245
|
-
seg_mask = seg_mask.transpose((0, 2, 3, 1))
|
|
236
|
+
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
246
237
|
|
|
247
238
|
return seg_target, seg_mask
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -193,7 +193,7 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
193
193
|
|
|
194
194
|
if target is None or return_preds:
|
|
195
195
|
# Disable for torch.compile compatibility
|
|
196
|
-
@torch.compiler.disable
|
|
196
|
+
@torch.compiler.disable
|
|
197
197
|
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
198
198
|
return [
|
|
199
199
|
dict(zip(self.class_names, preds))
|
|
@@ -230,7 +230,7 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
230
230
|
Returns:
|
|
231
231
|
A loss tensor
|
|
232
232
|
"""
|
|
233
|
-
_target, _mask = self.build_target(target, out_map.shape[1:]
|
|
233
|
+
_target, _mask = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
|
|
234
234
|
|
|
235
235
|
seg_target, seg_mask = torch.from_numpy(_target).to(dtype=out_map.dtype), torch.from_numpy(_mask)
|
|
236
236
|
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -36,7 +36,7 @@ class DetectionPredictor(nn.Module):
|
|
|
36
36
|
@torch.inference_mode()
|
|
37
37
|
def forward(
|
|
38
38
|
self,
|
|
39
|
-
pages: list[np.ndarray
|
|
39
|
+
pages: list[np.ndarray],
|
|
40
40
|
return_maps: bool = False,
|
|
41
41
|
**kwargs: Any,
|
|
42
42
|
) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
|
doctr/models/detection/zoo.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
-
from doctr.
|
|
8
|
+
from doctr.models.utils import _CompiledModule
|
|
9
9
|
|
|
10
10
|
from .. import detection
|
|
11
11
|
from ..detection.fast import reparameterize
|
|
@@ -16,30 +16,17 @@ __all__ = ["detection_predictor"]
|
|
|
16
16
|
|
|
17
17
|
ARCHS: list[str]
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
]
|
|
31
|
-
elif is_torch_available():
|
|
32
|
-
ARCHS = [
|
|
33
|
-
"db_resnet34",
|
|
34
|
-
"db_resnet50",
|
|
35
|
-
"db_mobilenet_v3_large",
|
|
36
|
-
"linknet_resnet18",
|
|
37
|
-
"linknet_resnet34",
|
|
38
|
-
"linknet_resnet50",
|
|
39
|
-
"fast_tiny",
|
|
40
|
-
"fast_small",
|
|
41
|
-
"fast_base",
|
|
42
|
-
]
|
|
19
|
+
ARCHS = [
|
|
20
|
+
"db_resnet34",
|
|
21
|
+
"db_resnet50",
|
|
22
|
+
"db_mobilenet_v3_large",
|
|
23
|
+
"linknet_resnet18",
|
|
24
|
+
"linknet_resnet34",
|
|
25
|
+
"linknet_resnet50",
|
|
26
|
+
"fast_tiny",
|
|
27
|
+
"fast_small",
|
|
28
|
+
"fast_base",
|
|
29
|
+
]
|
|
43
30
|
|
|
44
31
|
|
|
45
32
|
def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor:
|
|
@@ -56,12 +43,8 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
|
|
|
56
43
|
if isinstance(_model, detection.FAST):
|
|
57
44
|
_model = reparameterize(_model)
|
|
58
45
|
else:
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
# Adding the type for torch compiled models to the allowed architectures
|
|
62
|
-
from doctr.models.utils import _CompiledModule
|
|
63
|
-
|
|
64
|
-
allowed_archs.append(_CompiledModule)
|
|
46
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
47
|
+
allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST, _CompiledModule]
|
|
65
48
|
|
|
66
49
|
if not isinstance(arch, tuple(allowed_archs)):
|
|
67
50
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
@@ -76,7 +59,7 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
|
|
|
76
59
|
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
77
60
|
kwargs["batch_size"] = kwargs.get("batch_size", 2)
|
|
78
61
|
predictor = DetectionPredictor(
|
|
79
|
-
PreProcessor(_model.cfg["input_shape"][
|
|
62
|
+
PreProcessor(_model.cfg["input_shape"][1:], **kwargs),
|
|
80
63
|
_model,
|
|
81
64
|
)
|
|
82
65
|
return predictor
|