python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +1 -5
  3. doctr/datasets/coco_text.py +139 -0
  4. doctr/datasets/cord.py +2 -1
  5. doctr/datasets/datasets/__init__.py +1 -6
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +2 -2
  8. doctr/datasets/generator/__init__.py +1 -6
  9. doctr/datasets/ic03.py +1 -1
  10. doctr/datasets/ic13.py +2 -1
  11. doctr/datasets/iiit5k.py +4 -1
  12. doctr/datasets/imgur5k.py +9 -2
  13. doctr/datasets/ocr.py +1 -1
  14. doctr/datasets/recognition.py +1 -1
  15. doctr/datasets/svhn.py +1 -1
  16. doctr/datasets/svt.py +2 -2
  17. doctr/datasets/synthtext.py +15 -2
  18. doctr/datasets/utils.py +7 -6
  19. doctr/datasets/vocabs.py +1100 -54
  20. doctr/file_utils.py +2 -92
  21. doctr/io/elements.py +37 -3
  22. doctr/io/image/__init__.py +1 -7
  23. doctr/io/image/pytorch.py +1 -1
  24. doctr/models/_utils.py +4 -4
  25. doctr/models/classification/__init__.py +1 -0
  26. doctr/models/classification/magc_resnet/__init__.py +1 -6
  27. doctr/models/classification/magc_resnet/pytorch.py +3 -4
  28. doctr/models/classification/mobilenet/__init__.py +1 -6
  29. doctr/models/classification/mobilenet/pytorch.py +15 -1
  30. doctr/models/classification/predictor/__init__.py +1 -6
  31. doctr/models/classification/predictor/pytorch.py +2 -2
  32. doctr/models/classification/resnet/__init__.py +1 -6
  33. doctr/models/classification/resnet/pytorch.py +26 -3
  34. doctr/models/classification/textnet/__init__.py +1 -6
  35. doctr/models/classification/textnet/pytorch.py +11 -2
  36. doctr/models/classification/vgg/__init__.py +1 -6
  37. doctr/models/classification/vgg/pytorch.py +16 -1
  38. doctr/models/classification/vip/__init__.py +1 -0
  39. doctr/models/classification/vip/layers/__init__.py +1 -0
  40. doctr/models/classification/vip/layers/pytorch.py +615 -0
  41. doctr/models/classification/vip/pytorch.py +505 -0
  42. doctr/models/classification/vit/__init__.py +1 -6
  43. doctr/models/classification/vit/pytorch.py +12 -3
  44. doctr/models/classification/zoo.py +7 -8
  45. doctr/models/detection/_utils/__init__.py +1 -6
  46. doctr/models/detection/core.py +1 -1
  47. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  48. doctr/models/detection/differentiable_binarization/base.py +7 -16
  49. doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
  50. doctr/models/detection/fast/__init__.py +1 -6
  51. doctr/models/detection/fast/base.py +6 -17
  52. doctr/models/detection/fast/pytorch.py +17 -8
  53. doctr/models/detection/linknet/__init__.py +1 -6
  54. doctr/models/detection/linknet/base.py +5 -15
  55. doctr/models/detection/linknet/pytorch.py +12 -3
  56. doctr/models/detection/predictor/__init__.py +1 -6
  57. doctr/models/detection/predictor/pytorch.py +1 -1
  58. doctr/models/detection/zoo.py +15 -32
  59. doctr/models/factory/hub.py +9 -22
  60. doctr/models/kie_predictor/__init__.py +1 -6
  61. doctr/models/kie_predictor/pytorch.py +3 -7
  62. doctr/models/modules/layers/__init__.py +1 -6
  63. doctr/models/modules/layers/pytorch.py +52 -4
  64. doctr/models/modules/transformer/__init__.py +1 -6
  65. doctr/models/modules/transformer/pytorch.py +2 -2
  66. doctr/models/modules/vision_transformer/__init__.py +1 -6
  67. doctr/models/predictor/__init__.py +1 -6
  68. doctr/models/predictor/base.py +3 -8
  69. doctr/models/predictor/pytorch.py +3 -6
  70. doctr/models/preprocessor/__init__.py +1 -6
  71. doctr/models/preprocessor/pytorch.py +27 -32
  72. doctr/models/recognition/__init__.py +1 -0
  73. doctr/models/recognition/crnn/__init__.py +1 -6
  74. doctr/models/recognition/crnn/pytorch.py +16 -7
  75. doctr/models/recognition/master/__init__.py +1 -6
  76. doctr/models/recognition/master/pytorch.py +15 -6
  77. doctr/models/recognition/parseq/__init__.py +1 -6
  78. doctr/models/recognition/parseq/pytorch.py +26 -8
  79. doctr/models/recognition/predictor/__init__.py +1 -6
  80. doctr/models/recognition/predictor/_utils.py +100 -47
  81. doctr/models/recognition/predictor/pytorch.py +4 -5
  82. doctr/models/recognition/sar/__init__.py +1 -6
  83. doctr/models/recognition/sar/pytorch.py +13 -4
  84. doctr/models/recognition/utils.py +56 -47
  85. doctr/models/recognition/viptr/__init__.py +1 -0
  86. doctr/models/recognition/viptr/pytorch.py +277 -0
  87. doctr/models/recognition/vitstr/__init__.py +1 -6
  88. doctr/models/recognition/vitstr/pytorch.py +13 -4
  89. doctr/models/recognition/zoo.py +13 -8
  90. doctr/models/utils/__init__.py +1 -6
  91. doctr/models/utils/pytorch.py +29 -19
  92. doctr/transforms/functional/__init__.py +1 -6
  93. doctr/transforms/functional/pytorch.py +4 -4
  94. doctr/transforms/modules/__init__.py +1 -7
  95. doctr/transforms/modules/base.py +26 -92
  96. doctr/transforms/modules/pytorch.py +28 -26
  97. doctr/utils/data.py +1 -1
  98. doctr/utils/geometry.py +7 -11
  99. doctr/utils/visualization.py +1 -1
  100. doctr/version.py +1 -1
  101. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
  102. python_doctr-1.0.0.dist-info/RECORD +149 -0
  103. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
  104. doctr/datasets/datasets/tensorflow.py +0 -59
  105. doctr/datasets/generator/tensorflow.py +0 -58
  106. doctr/datasets/loader.py +0 -94
  107. doctr/io/image/tensorflow.py +0 -101
  108. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  109. doctr/models/classification/mobilenet/tensorflow.py +0 -433
  110. doctr/models/classification/predictor/tensorflow.py +0 -60
  111. doctr/models/classification/resnet/tensorflow.py +0 -397
  112. doctr/models/classification/textnet/tensorflow.py +0 -266
  113. doctr/models/classification/vgg/tensorflow.py +0 -116
  114. doctr/models/classification/vit/tensorflow.py +0 -192
  115. doctr/models/detection/_utils/tensorflow.py +0 -34
  116. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
  117. doctr/models/detection/fast/tensorflow.py +0 -419
  118. doctr/models/detection/linknet/tensorflow.py +0 -369
  119. doctr/models/detection/predictor/tensorflow.py +0 -70
  120. doctr/models/kie_predictor/tensorflow.py +0 -187
  121. doctr/models/modules/layers/tensorflow.py +0 -171
  122. doctr/models/modules/transformer/tensorflow.py +0 -235
  123. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  124. doctr/models/predictor/tensorflow.py +0 -155
  125. doctr/models/preprocessor/tensorflow.py +0 -122
  126. doctr/models/recognition/crnn/tensorflow.py +0 -308
  127. doctr/models/recognition/master/tensorflow.py +0 -313
  128. doctr/models/recognition/parseq/tensorflow.py +0 -508
  129. doctr/models/recognition/predictor/tensorflow.py +0 -79
  130. doctr/models/recognition/sar/tensorflow.py +0 -416
  131. doctr/models/recognition/vitstr/tensorflow.py +0 -278
  132. doctr/models/utils/tensorflow.py +0 -182
  133. doctr/transforms/functional/tensorflow.py +0 -254
  134. doctr/transforms/modules/tensorflow.py +0 -562
  135. python_doctr-0.11.0.dist-info/RECORD +0 -173
  136. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
  137. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  138. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -1,369 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
-
8
- from copy import deepcopy
9
- from typing import Any
10
-
11
- import numpy as np
12
- import tensorflow as tf
13
- from tensorflow.keras import Model, Sequential, layers, losses
14
-
15
- from doctr.file_utils import CLASS_NAME
16
- from doctr.models.classification import resnet18, resnet34, resnet50
17
- from doctr.models.utils import (
18
- IntermediateLayerGetter,
19
- _bf16_to_float32,
20
- _build_model,
21
- conv_sequence,
22
- load_pretrained_params,
23
- )
24
- from doctr.utils.repr import NestedObject
25
-
26
- from .base import LinkNetPostProcessor, _LinkNet
27
-
28
- __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
29
-
30
- default_cfgs: dict[str, dict[str, Any]] = {
31
- "linknet_resnet18": {
32
- "mean": (0.798, 0.785, 0.772),
33
- "std": (0.264, 0.2749, 0.287),
34
- "input_shape": (1024, 1024, 3),
35
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet18-615a82c5.weights.h5&src=0",
36
- },
37
- "linknet_resnet34": {
38
- "mean": (0.798, 0.785, 0.772),
39
- "std": (0.264, 0.2749, 0.287),
40
- "input_shape": (1024, 1024, 3),
41
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet34-9d772be5.weights.h5&src=0",
42
- },
43
- "linknet_resnet50": {
44
- "mean": (0.798, 0.785, 0.772),
45
- "std": (0.264, 0.2749, 0.287),
46
- "input_shape": (1024, 1024, 3),
47
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet50-6bf6c8b5.weights.h5&src=0",
48
- },
49
- }
50
-
51
-
52
- def decoder_block(in_chan: int, out_chan: int, stride: int, **kwargs: Any) -> Sequential:
53
- """Creates a LinkNet decoder block"""
54
- return Sequential([
55
- *conv_sequence(in_chan // 4, "relu", True, kernel_size=1, **kwargs),
56
- layers.Conv2DTranspose(
57
- filters=in_chan // 4,
58
- kernel_size=3,
59
- strides=stride,
60
- padding="same",
61
- use_bias=False,
62
- kernel_initializer="he_normal",
63
- ),
64
- layers.BatchNormalization(),
65
- layers.Activation("relu"),
66
- *conv_sequence(out_chan, "relu", True, kernel_size=1),
67
- ])
68
-
69
-
70
- class LinkNetFPN(Model, NestedObject):
71
- """LinkNet Decoder module"""
72
-
73
- def __init__(
74
- self,
75
- out_chans: int,
76
- in_shapes: list[tuple[int, ...]],
77
- ) -> None:
78
- super().__init__()
79
- self.out_chans = out_chans
80
- strides = [2] * (len(in_shapes) - 1) + [1]
81
- i_chans = [s[-1] for s in in_shapes[::-1]]
82
- o_chans = i_chans[1:] + [out_chans]
83
- self.decoders = [
84
- decoder_block(in_chan, out_chan, s, input_shape=in_shape)
85
- for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1])
86
- ]
87
-
88
- def call(self, x: list[tf.Tensor], **kwargs: Any) -> tf.Tensor:
89
- out = 0
90
- for decoder, fmap in zip(self.decoders, x[::-1]):
91
- out = decoder(out + fmap, **kwargs)
92
- return out
93
-
94
- def extra_repr(self) -> str:
95
- return f"out_chans={self.out_chans}"
96
-
97
-
98
- class LinkNet(_LinkNet, Model):
99
- """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
100
- <https://arxiv.org/pdf/1707.03718.pdf>`_.
101
-
102
- Args:
103
- feature extractor: the backbone serving as feature extractor
104
- fpn_channels: number of channels each extracted feature maps is mapped to
105
- bin_thresh: threshold for binarization of the output feature map
106
- box_thresh: minimal objectness score to consider a box
107
- assume_straight_pages: if True, fit straight bounding boxes only
108
- exportable: onnx exportable returns only logits
109
- cfg: the configuration dict of the model
110
- class_names: list of class names
111
- """
112
-
113
- _children_names: list[str] = ["feat_extractor", "fpn", "classifier", "postprocessor"]
114
-
115
- def __init__(
116
- self,
117
- feat_extractor: IntermediateLayerGetter,
118
- fpn_channels: int = 64,
119
- bin_thresh: float = 0.1,
120
- box_thresh: float = 0.1,
121
- assume_straight_pages: bool = True,
122
- exportable: bool = False,
123
- cfg: dict[str, Any] | None = None,
124
- class_names: list[str] = [CLASS_NAME],
125
- ) -> None:
126
- super().__init__(cfg=cfg)
127
-
128
- self.class_names = class_names
129
- num_classes: int = len(self.class_names)
130
-
131
- self.exportable = exportable
132
- self.assume_straight_pages = assume_straight_pages
133
-
134
- self.feat_extractor = feat_extractor
135
-
136
- self.fpn = LinkNetFPN(fpn_channels, [_shape[1:] for _shape in self.feat_extractor.output_shape])
137
- self.fpn.build(self.feat_extractor.output_shape)
138
-
139
- self.classifier = Sequential([
140
- layers.Conv2DTranspose(
141
- filters=32,
142
- kernel_size=3,
143
- strides=2,
144
- padding="same",
145
- use_bias=False,
146
- kernel_initializer="he_normal",
147
- input_shape=self.fpn.decoders[-1].output_shape[1:],
148
- ),
149
- layers.BatchNormalization(),
150
- layers.Activation("relu"),
151
- *conv_sequence(32, "relu", True, kernel_size=3, strides=1),
152
- layers.Conv2DTranspose(
153
- filters=num_classes,
154
- kernel_size=2,
155
- strides=2,
156
- padding="same",
157
- use_bias=True,
158
- kernel_initializer="he_normal",
159
- ),
160
- ])
161
-
162
- self.postprocessor = LinkNetPostProcessor(
163
- assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
164
- )
165
-
166
- def compute_loss(
167
- self,
168
- out_map: tf.Tensor,
169
- target: list[dict[str, np.ndarray]],
170
- gamma: float = 2.0,
171
- alpha: float = 0.5,
172
- eps: float = 1e-8,
173
- ) -> tf.Tensor:
174
- """Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on
175
- <https://github.com/tensorflow/addons/>`_.
176
-
177
- Args:
178
- out_map: output feature map of the model of shape N x H x W x 1
179
- target: list of dictionary where each dict has a `boxes` and a `flags` entry
180
- gamma: modulating factor in the focal loss formula
181
- alpha: balancing factor in the focal loss formula
182
- eps: epsilon factor in dice loss
183
-
184
- Returns:
185
- A loss tensor
186
- """
187
- seg_target, seg_mask = self.build_target(target, out_map.shape[1:], True)
188
- seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
189
- seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
190
- seg_mask = tf.cast(seg_mask, tf.float32)
191
-
192
- bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
193
- proba_map = tf.sigmoid(out_map)
194
-
195
- # Focal loss
196
- if gamma < 0:
197
- raise ValueError("Value of gamma should be greater than or equal to zero.")
198
- # Convert logits to prob, compute gamma factor
199
- p_t = (seg_target * proba_map) + ((1 - seg_target) * (1 - proba_map))
200
- alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha)
201
- # Unreduced loss
202
- focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
203
- # Class reduced
204
- focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3))
205
-
206
- # Compute dice loss for each class
207
- dice_map = tf.nn.softmax(out_map, axis=-1) if len(self.class_names) > 1 else proba_map
208
- # Class-reduced dice loss
209
- inter = tf.reduce_sum(seg_mask * dice_map * seg_target, axis=[0, 1, 2])
210
- cardinality = tf.reduce_sum(seg_mask * (dice_map + seg_target), axis=[0, 1, 2])
211
- dice_loss = tf.reduce_mean(1 - 2 * inter / (cardinality + eps))
212
-
213
- return focal_loss + dice_loss
214
-
215
- def call(
216
- self,
217
- x: tf.Tensor,
218
- target: list[dict[str, np.ndarray]] | None = None,
219
- return_model_output: bool = False,
220
- return_preds: bool = False,
221
- **kwargs: Any,
222
- ) -> dict[str, Any]:
223
- feat_maps = self.feat_extractor(x, **kwargs)
224
- logits = self.fpn(feat_maps, **kwargs)
225
- logits = self.classifier(logits, **kwargs)
226
-
227
- out: dict[str, tf.Tensor] = {}
228
- if self.exportable:
229
- out["logits"] = logits
230
- return out
231
-
232
- if return_model_output or target is None or return_preds:
233
- prob_map = _bf16_to_float32(tf.math.sigmoid(logits))
234
-
235
- if return_model_output:
236
- out["out_map"] = prob_map
237
-
238
- if target is None or return_preds:
239
- # Post-process boxes
240
- out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())]
241
-
242
- if target is not None:
243
- loss = self.compute_loss(logits, target)
244
- out["loss"] = loss
245
-
246
- return out
247
-
248
-
249
- def _linknet(
250
- arch: str,
251
- pretrained: bool,
252
- backbone_fn,
253
- fpn_layers: list[str],
254
- pretrained_backbone: bool = True,
255
- input_shape: tuple[int, int, int] | None = None,
256
- **kwargs: Any,
257
- ) -> LinkNet:
258
- pretrained_backbone = pretrained_backbone and not pretrained
259
-
260
- # Patch the config
261
- _cfg = deepcopy(default_cfgs[arch])
262
- _cfg["input_shape"] = input_shape or default_cfgs[arch]["input_shape"]
263
- if not kwargs.get("class_names", None):
264
- kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME])
265
- else:
266
- kwargs["class_names"] = sorted(kwargs["class_names"])
267
-
268
- # Feature extractor
269
- feat_extractor = IntermediateLayerGetter(
270
- backbone_fn(
271
- pretrained=pretrained_backbone,
272
- include_top=False,
273
- input_shape=_cfg["input_shape"],
274
- ),
275
- fpn_layers,
276
- )
277
-
278
- # Build the model
279
- model = LinkNet(feat_extractor, cfg=_cfg, **kwargs)
280
- _build_model(model)
281
-
282
- # Load pretrained parameters
283
- if pretrained:
284
- # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
285
- load_pretrained_params(
286
- model,
287
- _cfg["url"],
288
- skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
289
- )
290
-
291
- return model
292
-
293
-
294
- def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
295
- """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
296
- <https://arxiv.org/pdf/1707.03718.pdf>`_.
297
-
298
- >>> import tensorflow as tf
299
- >>> from doctr.models import linknet_resnet18
300
- >>> model = linknet_resnet18(pretrained=True)
301
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
302
- >>> out = model(input_tensor)
303
-
304
- Args:
305
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
306
- **kwargs: keyword arguments of the LinkNet architecture
307
-
308
- Returns:
309
- text detection architecture
310
- """
311
- return _linknet(
312
- "linknet_resnet18",
313
- pretrained,
314
- resnet18,
315
- ["resnet_block_1", "resnet_block_3", "resnet_block_5", "resnet_block_7"],
316
- **kwargs,
317
- )
318
-
319
-
320
- def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
321
- """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
322
- <https://arxiv.org/pdf/1707.03718.pdf>`_.
323
-
324
- >>> import tensorflow as tf
325
- >>> from doctr.models import linknet_resnet34
326
- >>> model = linknet_resnet34(pretrained=True)
327
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
328
- >>> out = model(input_tensor)
329
-
330
- Args:
331
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
332
- **kwargs: keyword arguments of the LinkNet architecture
333
-
334
- Returns:
335
- text detection architecture
336
- """
337
- return _linknet(
338
- "linknet_resnet34",
339
- pretrained,
340
- resnet34,
341
- ["resnet_block_2", "resnet_block_6", "resnet_block_12", "resnet_block_15"],
342
- **kwargs,
343
- )
344
-
345
-
346
- def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
347
- """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
348
- <https://arxiv.org/pdf/1707.03718.pdf>`_.
349
-
350
- >>> import tensorflow as tf
351
- >>> from doctr.models import linknet_resnet50
352
- >>> model = linknet_resnet50(pretrained=True)
353
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
354
- >>> out = model(input_tensor)
355
-
356
- Args:
357
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
358
- **kwargs: keyword arguments of the LinkNet architecture
359
-
360
- Returns:
361
- text detection architecture
362
- """
363
- return _linknet(
364
- "linknet_resnet50",
365
- pretrained,
366
- resnet50,
367
- ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"],
368
- **kwargs,
369
- )
@@ -1,70 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- from typing import Any
7
-
8
- import numpy as np
9
- import tensorflow as tf
10
- from tensorflow.keras import Model
11
-
12
- from doctr.models.detection._utils import _remove_padding
13
- from doctr.models.preprocessor import PreProcessor
14
- from doctr.utils.repr import NestedObject
15
-
16
- __all__ = ["DetectionPredictor"]
17
-
18
-
19
- class DetectionPredictor(NestedObject):
20
- """Implements an object able to localize text elements in a document
21
-
22
- Args:
23
- pre_processor: transform inputs for easier batched model inference
24
- model: core detection architecture
25
- """
26
-
27
- _children_names: list[str] = ["pre_processor", "model"]
28
-
29
- def __init__(
30
- self,
31
- pre_processor: PreProcessor,
32
- model: Model,
33
- ) -> None:
34
- self.pre_processor = pre_processor
35
- self.model = model
36
-
37
- def __call__(
38
- self,
39
- pages: list[np.ndarray | tf.Tensor],
40
- return_maps: bool = False,
41
- **kwargs: Any,
42
- ) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
43
- # Extract parameters from the preprocessor
44
- preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
45
- symmetric_pad = self.pre_processor.resize.symmetric_pad
46
- assume_straight_pages = self.model.assume_straight_pages
47
-
48
- # Dimension check
49
- if any(page.ndim != 3 for page in pages):
50
- raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
51
-
52
- processed_batches = self.pre_processor(pages)
53
- predicted_batches = [
54
- self.model(batch, return_preds=True, return_model_output=True, training=False, **kwargs)
55
- for batch in processed_batches
56
- ]
57
-
58
- # Remove padding from loc predictions
59
- preds = _remove_padding(
60
- pages,
61
- [pred for batch in predicted_batches for pred in batch["preds"]],
62
- preserve_aspect_ratio=preserve_aspect_ratio,
63
- symmetric_pad=symmetric_pad,
64
- assume_straight_pages=assume_straight_pages,
65
- )
66
-
67
- if return_maps:
68
- seg_maps = [pred.numpy() for batch in predicted_batches for pred in batch["out_map"]]
69
- return preds, seg_maps
70
- return preds
@@ -1,187 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- from typing import Any
7
-
8
- import numpy as np
9
- import tensorflow as tf
10
-
11
- from doctr.io.elements import Document
12
- from doctr.models._utils import get_language, invert_data_structure
13
- from doctr.models.detection.predictor import DetectionPredictor
14
- from doctr.models.recognition.predictor import RecognitionPredictor
15
- from doctr.utils.geometry import detach_scores
16
- from doctr.utils.repr import NestedObject
17
-
18
- from .base import _KIEPredictor
19
-
20
- __all__ = ["KIEPredictor"]
21
-
22
-
23
- class KIEPredictor(NestedObject, _KIEPredictor):
24
- """Implements an object able to localize and identify text elements in a set of documents
25
-
26
- Args:
27
- det_predictor: detection module
28
- reco_predictor: recognition module
29
- assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
30
- without rotated textual elements.
31
- straighten_pages: if True, estimates the page general orientation based on the median line orientation.
32
- Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
33
- accordingly. Doing so will improve performances for documents with page-uniform rotations.
34
- detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
35
- page. Doing so will slightly deteriorate the overall latency.
36
- detect_language: if True, the language prediction will be added to the predictions for each
37
- page. Doing so will slightly deteriorate the overall latency.
38
- **kwargs: keyword args of `DocumentBuilder`
39
- """
40
-
41
- _children_names = ["det_predictor", "reco_predictor", "doc_builder"]
42
-
43
- def __init__(
44
- self,
45
- det_predictor: DetectionPredictor,
46
- reco_predictor: RecognitionPredictor,
47
- assume_straight_pages: bool = True,
48
- straighten_pages: bool = False,
49
- preserve_aspect_ratio: bool = True,
50
- symmetric_pad: bool = True,
51
- detect_orientation: bool = False,
52
- detect_language: bool = False,
53
- **kwargs: Any,
54
- ) -> None:
55
- self.det_predictor = det_predictor
56
- self.reco_predictor = reco_predictor
57
- _KIEPredictor.__init__(
58
- self,
59
- assume_straight_pages,
60
- straighten_pages,
61
- preserve_aspect_ratio,
62
- symmetric_pad,
63
- detect_orientation,
64
- **kwargs,
65
- )
66
- self.detect_orientation = detect_orientation
67
- self.detect_language = detect_language
68
-
69
- def __call__(
70
- self,
71
- pages: list[np.ndarray | tf.Tensor],
72
- **kwargs: Any,
73
- ) -> Document:
74
- # Dimension check
75
- if any(page.ndim != 3 for page in pages):
76
- raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
77
-
78
- origin_page_shapes = [page.shape[:2] for page in pages]
79
-
80
- # Localize text elements
81
- loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
82
-
83
- # Detect document rotation and rotate pages
84
- seg_maps = [
85
- np.where(np.expand_dims(np.amax(out_map, axis=-1), axis=-1) > kwargs.get("bin_thresh", 0.3), 255, 0).astype(
86
- np.uint8
87
- )
88
- for out_map in out_maps
89
- ]
90
- if self.detect_orientation:
91
- general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
92
- orientations = [
93
- {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
94
- ]
95
- else:
96
- orientations = None
97
- general_pages_orientations = None
98
- origin_pages_orientations = None
99
- if self.straighten_pages:
100
- pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
101
- # update page shapes after straightening
102
- origin_page_shapes = [page.shape[:2] for page in pages]
103
-
104
- # Forward again to get predictions on straight pages
105
- loc_preds = self.det_predictor(pages, **kwargs)
106
-
107
- dict_loc_preds: dict[str, list[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
108
-
109
- # Detach objectness scores from loc_preds
110
- objectness_scores = {}
111
- for class_name, det_preds in dict_loc_preds.items():
112
- _loc_preds, _scores = detach_scores(det_preds)
113
- dict_loc_preds[class_name] = _loc_preds
114
- objectness_scores[class_name] = _scores
115
-
116
- # Apply hooks to loc_preds if any
117
- for hook in self.hooks:
118
- dict_loc_preds = hook(dict_loc_preds)
119
-
120
- # Crop images
121
- crops = {}
122
- for class_name in dict_loc_preds.keys():
123
- crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
124
- pages,
125
- dict_loc_preds[class_name],
126
- channels_last=True,
127
- assume_straight_pages=self.assume_straight_pages,
128
- assume_horizontal=self._page_orientation_disabled,
129
- )
130
-
131
- # Rectify crop orientation
132
- crop_orientations: Any = {}
133
- if not self.assume_straight_pages:
134
- for class_name in dict_loc_preds.keys():
135
- crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
136
- crops[class_name], dict_loc_preds[class_name]
137
- )
138
- crop_orientations[class_name] = [
139
- {"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations
140
- ]
141
-
142
- # Identify character sequences
143
- word_preds = {
144
- k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs)
145
- for k, crop_value in crops.items()
146
- }
147
- if not crop_orientations:
148
- crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
149
-
150
- boxes: dict = {}
151
- text_preds: dict = {}
152
- word_crop_orientations: dict = {}
153
- for class_name in dict_loc_preds.keys():
154
- boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
155
- dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
156
- )
157
-
158
- boxes_per_page: list[dict] = invert_data_structure(boxes) # type: ignore[assignment]
159
- objectness_scores_per_page: list[dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
160
- text_preds_per_page: list[dict] = invert_data_structure(text_preds) # type: ignore[assignment]
161
- crop_orientations_per_page: list[dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
162
-
163
- if self.detect_language:
164
- languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
165
- languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
166
- else:
167
- languages_dict = None
168
-
169
- out = self.doc_builder(
170
- pages,
171
- boxes_per_page,
172
- objectness_scores_per_page,
173
- text_preds_per_page,
174
- origin_page_shapes, # type: ignore[arg-type]
175
- crop_orientations_per_page,
176
- orientations,
177
- languages_dict,
178
- )
179
- return out
180
-
181
- @staticmethod
182
- def get_text(text_pred: dict) -> str:
183
- text = []
184
- for value in text_pred.values():
185
- text += [item[0] for item in value]
186
-
187
- return " ".join(text)