python-doctr 0.12.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/contrib/artefacts.py +1 -1
  3. doctr/contrib/base.py +1 -1
  4. doctr/datasets/__init__.py +0 -5
  5. doctr/datasets/coco_text.py +1 -1
  6. doctr/datasets/cord.py +1 -1
  7. doctr/datasets/datasets/__init__.py +1 -6
  8. doctr/datasets/datasets/base.py +1 -1
  9. doctr/datasets/datasets/pytorch.py +3 -3
  10. doctr/datasets/detection.py +1 -1
  11. doctr/datasets/doc_artefacts.py +1 -1
  12. doctr/datasets/funsd.py +1 -1
  13. doctr/datasets/generator/__init__.py +1 -6
  14. doctr/datasets/generator/base.py +1 -1
  15. doctr/datasets/generator/pytorch.py +1 -1
  16. doctr/datasets/ic03.py +1 -1
  17. doctr/datasets/ic13.py +1 -1
  18. doctr/datasets/iiit5k.py +1 -1
  19. doctr/datasets/iiithws.py +1 -1
  20. doctr/datasets/imgur5k.py +1 -1
  21. doctr/datasets/mjsynth.py +1 -1
  22. doctr/datasets/ocr.py +1 -1
  23. doctr/datasets/orientation.py +1 -1
  24. doctr/datasets/recognition.py +1 -1
  25. doctr/datasets/sroie.py +1 -1
  26. doctr/datasets/svhn.py +1 -1
  27. doctr/datasets/svt.py +1 -1
  28. doctr/datasets/synthtext.py +1 -1
  29. doctr/datasets/utils.py +1 -1
  30. doctr/datasets/vocabs.py +1 -3
  31. doctr/datasets/wildreceipt.py +1 -1
  32. doctr/file_utils.py +3 -102
  33. doctr/io/elements.py +1 -1
  34. doctr/io/html.py +1 -1
  35. doctr/io/image/__init__.py +1 -7
  36. doctr/io/image/base.py +1 -1
  37. doctr/io/image/pytorch.py +2 -2
  38. doctr/io/pdf.py +1 -1
  39. doctr/io/reader.py +1 -1
  40. doctr/models/_utils.py +56 -18
  41. doctr/models/builder.py +1 -1
  42. doctr/models/classification/magc_resnet/__init__.py +1 -6
  43. doctr/models/classification/magc_resnet/pytorch.py +3 -3
  44. doctr/models/classification/mobilenet/__init__.py +1 -6
  45. doctr/models/classification/mobilenet/pytorch.py +1 -1
  46. doctr/models/classification/predictor/__init__.py +1 -6
  47. doctr/models/classification/predictor/pytorch.py +2 -2
  48. doctr/models/classification/resnet/__init__.py +1 -6
  49. doctr/models/classification/resnet/pytorch.py +1 -1
  50. doctr/models/classification/textnet/__init__.py +1 -6
  51. doctr/models/classification/textnet/pytorch.py +2 -2
  52. doctr/models/classification/vgg/__init__.py +1 -6
  53. doctr/models/classification/vgg/pytorch.py +1 -1
  54. doctr/models/classification/vip/__init__.py +1 -4
  55. doctr/models/classification/vip/layers/__init__.py +1 -4
  56. doctr/models/classification/vip/layers/pytorch.py +2 -2
  57. doctr/models/classification/vip/pytorch.py +1 -1
  58. doctr/models/classification/vit/__init__.py +1 -6
  59. doctr/models/classification/vit/pytorch.py +3 -3
  60. doctr/models/classification/zoo.py +7 -12
  61. doctr/models/core.py +1 -1
  62. doctr/models/detection/_utils/__init__.py +1 -6
  63. doctr/models/detection/_utils/base.py +1 -1
  64. doctr/models/detection/_utils/pytorch.py +1 -1
  65. doctr/models/detection/core.py +2 -2
  66. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  67. doctr/models/detection/differentiable_binarization/base.py +5 -13
  68. doctr/models/detection/differentiable_binarization/pytorch.py +4 -4
  69. doctr/models/detection/fast/__init__.py +1 -6
  70. doctr/models/detection/fast/base.py +5 -15
  71. doctr/models/detection/fast/pytorch.py +5 -5
  72. doctr/models/detection/linknet/__init__.py +1 -6
  73. doctr/models/detection/linknet/base.py +4 -13
  74. doctr/models/detection/linknet/pytorch.py +3 -3
  75. doctr/models/detection/predictor/__init__.py +1 -6
  76. doctr/models/detection/predictor/pytorch.py +2 -2
  77. doctr/models/detection/zoo.py +16 -33
  78. doctr/models/factory/hub.py +26 -34
  79. doctr/models/kie_predictor/__init__.py +1 -6
  80. doctr/models/kie_predictor/base.py +1 -1
  81. doctr/models/kie_predictor/pytorch.py +3 -7
  82. doctr/models/modules/layers/__init__.py +1 -6
  83. doctr/models/modules/layers/pytorch.py +4 -4
  84. doctr/models/modules/transformer/__init__.py +1 -6
  85. doctr/models/modules/transformer/pytorch.py +3 -3
  86. doctr/models/modules/vision_transformer/__init__.py +1 -6
  87. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  88. doctr/models/predictor/__init__.py +1 -6
  89. doctr/models/predictor/base.py +4 -9
  90. doctr/models/predictor/pytorch.py +3 -6
  91. doctr/models/preprocessor/__init__.py +1 -6
  92. doctr/models/preprocessor/pytorch.py +28 -33
  93. doctr/models/recognition/core.py +1 -1
  94. doctr/models/recognition/crnn/__init__.py +1 -6
  95. doctr/models/recognition/crnn/pytorch.py +7 -7
  96. doctr/models/recognition/master/__init__.py +1 -6
  97. doctr/models/recognition/master/base.py +1 -1
  98. doctr/models/recognition/master/pytorch.py +6 -6
  99. doctr/models/recognition/parseq/__init__.py +1 -6
  100. doctr/models/recognition/parseq/base.py +1 -1
  101. doctr/models/recognition/parseq/pytorch.py +6 -6
  102. doctr/models/recognition/predictor/__init__.py +1 -6
  103. doctr/models/recognition/predictor/_utils.py +8 -17
  104. doctr/models/recognition/predictor/pytorch.py +2 -3
  105. doctr/models/recognition/sar/__init__.py +1 -6
  106. doctr/models/recognition/sar/pytorch.py +4 -4
  107. doctr/models/recognition/utils.py +1 -1
  108. doctr/models/recognition/viptr/__init__.py +1 -4
  109. doctr/models/recognition/viptr/pytorch.py +4 -4
  110. doctr/models/recognition/vitstr/__init__.py +1 -6
  111. doctr/models/recognition/vitstr/base.py +1 -1
  112. doctr/models/recognition/vitstr/pytorch.py +4 -4
  113. doctr/models/recognition/zoo.py +14 -14
  114. doctr/models/utils/__init__.py +1 -6
  115. doctr/models/utils/pytorch.py +3 -2
  116. doctr/models/zoo.py +1 -1
  117. doctr/transforms/functional/__init__.py +1 -6
  118. doctr/transforms/functional/base.py +3 -2
  119. doctr/transforms/functional/pytorch.py +5 -5
  120. doctr/transforms/modules/__init__.py +1 -7
  121. doctr/transforms/modules/base.py +28 -94
  122. doctr/transforms/modules/pytorch.py +29 -27
  123. doctr/utils/common_types.py +1 -1
  124. doctr/utils/data.py +1 -2
  125. doctr/utils/fonts.py +1 -1
  126. doctr/utils/geometry.py +7 -11
  127. doctr/utils/metrics.py +1 -1
  128. doctr/utils/multithreading.py +1 -1
  129. doctr/utils/reconstitution.py +1 -1
  130. doctr/utils/repr.py +1 -1
  131. doctr/utils/visualization.py +2 -2
  132. doctr/version.py +1 -1
  133. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/METADATA +30 -80
  134. python_doctr-1.0.1.dist-info/RECORD +149 -0
  135. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/WHEEL +1 -1
  136. doctr/datasets/datasets/tensorflow.py +0 -59
  137. doctr/datasets/generator/tensorflow.py +0 -58
  138. doctr/datasets/loader.py +0 -94
  139. doctr/io/image/tensorflow.py +0 -101
  140. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  141. doctr/models/classification/mobilenet/tensorflow.py +0 -442
  142. doctr/models/classification/predictor/tensorflow.py +0 -60
  143. doctr/models/classification/resnet/tensorflow.py +0 -418
  144. doctr/models/classification/textnet/tensorflow.py +0 -275
  145. doctr/models/classification/vgg/tensorflow.py +0 -125
  146. doctr/models/classification/vit/tensorflow.py +0 -201
  147. doctr/models/detection/_utils/tensorflow.py +0 -34
  148. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
  149. doctr/models/detection/fast/tensorflow.py +0 -427
  150. doctr/models/detection/linknet/tensorflow.py +0 -377
  151. doctr/models/detection/predictor/tensorflow.py +0 -70
  152. doctr/models/kie_predictor/tensorflow.py +0 -187
  153. doctr/models/modules/layers/tensorflow.py +0 -171
  154. doctr/models/modules/transformer/tensorflow.py +0 -235
  155. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  156. doctr/models/predictor/tensorflow.py +0 -155
  157. doctr/models/preprocessor/tensorflow.py +0 -122
  158. doctr/models/recognition/crnn/tensorflow.py +0 -317
  159. doctr/models/recognition/master/tensorflow.py +0 -320
  160. doctr/models/recognition/parseq/tensorflow.py +0 -516
  161. doctr/models/recognition/predictor/tensorflow.py +0 -79
  162. doctr/models/recognition/sar/tensorflow.py +0 -423
  163. doctr/models/recognition/vitstr/tensorflow.py +0 -285
  164. doctr/models/utils/tensorflow.py +0 -189
  165. doctr/transforms/functional/tensorflow.py +0 -254
  166. doctr/transforms/modules/tensorflow.py +0 -562
  167. python_doctr-0.12.0.dist-info/RECORD +0 -180
  168. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/licenses/LICENSE +0 -0
  169. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/top_level.txt +0 -0
  170. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/zip-safe +0 -0
@@ -1,377 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
-
8
- from copy import deepcopy
9
- from typing import Any
10
-
11
- import numpy as np
12
- import tensorflow as tf
13
- from tensorflow.keras import Model, Sequential, layers, losses
14
-
15
- from doctr.file_utils import CLASS_NAME
16
- from doctr.models.classification import resnet18, resnet34, resnet50
17
- from doctr.models.utils import (
18
- IntermediateLayerGetter,
19
- _bf16_to_float32,
20
- _build_model,
21
- conv_sequence,
22
- load_pretrained_params,
23
- )
24
- from doctr.utils.repr import NestedObject
25
-
26
- from .base import LinkNetPostProcessor, _LinkNet
27
-
28
- __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
29
-
30
- default_cfgs: dict[str, dict[str, Any]] = {
31
- "linknet_resnet18": {
32
- "mean": (0.798, 0.785, 0.772),
33
- "std": (0.264, 0.2749, 0.287),
34
- "input_shape": (1024, 1024, 3),
35
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet18-615a82c5.weights.h5&src=0",
36
- },
37
- "linknet_resnet34": {
38
- "mean": (0.798, 0.785, 0.772),
39
- "std": (0.264, 0.2749, 0.287),
40
- "input_shape": (1024, 1024, 3),
41
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet34-9d772be5.weights.h5&src=0",
42
- },
43
- "linknet_resnet50": {
44
- "mean": (0.798, 0.785, 0.772),
45
- "std": (0.264, 0.2749, 0.287),
46
- "input_shape": (1024, 1024, 3),
47
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet50-6bf6c8b5.weights.h5&src=0",
48
- },
49
- }
50
-
51
-
52
- def decoder_block(in_chan: int, out_chan: int, stride: int, **kwargs: Any) -> Sequential:
53
- """Creates a LinkNet decoder block"""
54
- return Sequential([
55
- *conv_sequence(in_chan // 4, "relu", True, kernel_size=1, **kwargs),
56
- layers.Conv2DTranspose(
57
- filters=in_chan // 4,
58
- kernel_size=3,
59
- strides=stride,
60
- padding="same",
61
- use_bias=False,
62
- kernel_initializer="he_normal",
63
- ),
64
- layers.BatchNormalization(),
65
- layers.Activation("relu"),
66
- *conv_sequence(out_chan, "relu", True, kernel_size=1),
67
- ])
68
-
69
-
70
- class LinkNetFPN(Model, NestedObject):
71
- """LinkNet Decoder module"""
72
-
73
- def __init__(
74
- self,
75
- out_chans: int,
76
- in_shapes: list[tuple[int, ...]],
77
- ) -> None:
78
- super().__init__()
79
- self.out_chans = out_chans
80
- strides = [2] * (len(in_shapes) - 1) + [1]
81
- i_chans = [s[-1] for s in in_shapes[::-1]]
82
- o_chans = i_chans[1:] + [out_chans]
83
- self.decoders = [
84
- decoder_block(in_chan, out_chan, s, input_shape=in_shape)
85
- for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1])
86
- ]
87
-
88
- def call(self, x: list[tf.Tensor], **kwargs: Any) -> tf.Tensor:
89
- out = 0
90
- for decoder, fmap in zip(self.decoders, x[::-1]):
91
- out = decoder(out + fmap, **kwargs)
92
- return out
93
-
94
- def extra_repr(self) -> str:
95
- return f"out_chans={self.out_chans}"
96
-
97
-
98
- class LinkNet(_LinkNet, Model):
99
- """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
100
- <https://arxiv.org/pdf/1707.03718.pdf>`_.
101
-
102
- Args:
103
- feature extractor: the backbone serving as feature extractor
104
- fpn_channels: number of channels each extracted feature maps is mapped to
105
- bin_thresh: threshold for binarization of the output feature map
106
- box_thresh: minimal objectness score to consider a box
107
- assume_straight_pages: if True, fit straight bounding boxes only
108
- exportable: onnx exportable returns only logits
109
- cfg: the configuration dict of the model
110
- class_names: list of class names
111
- """
112
-
113
- _children_names: list[str] = ["feat_extractor", "fpn", "classifier", "postprocessor"]
114
-
115
- def __init__(
116
- self,
117
- feat_extractor: IntermediateLayerGetter,
118
- fpn_channels: int = 64,
119
- bin_thresh: float = 0.1,
120
- box_thresh: float = 0.1,
121
- assume_straight_pages: bool = True,
122
- exportable: bool = False,
123
- cfg: dict[str, Any] | None = None,
124
- class_names: list[str] = [CLASS_NAME],
125
- ) -> None:
126
- super().__init__(cfg=cfg)
127
-
128
- self.class_names = class_names
129
- num_classes: int = len(self.class_names)
130
-
131
- self.exportable = exportable
132
- self.assume_straight_pages = assume_straight_pages
133
-
134
- self.feat_extractor = feat_extractor
135
-
136
- self.fpn = LinkNetFPN(fpn_channels, [_shape[1:] for _shape in self.feat_extractor.output_shape])
137
- self.fpn.build(self.feat_extractor.output_shape)
138
-
139
- self.classifier = Sequential([
140
- layers.Conv2DTranspose(
141
- filters=32,
142
- kernel_size=3,
143
- strides=2,
144
- padding="same",
145
- use_bias=False,
146
- kernel_initializer="he_normal",
147
- input_shape=self.fpn.decoders[-1].output_shape[1:],
148
- ),
149
- layers.BatchNormalization(),
150
- layers.Activation("relu"),
151
- *conv_sequence(32, "relu", True, kernel_size=3, strides=1),
152
- layers.Conv2DTranspose(
153
- filters=num_classes,
154
- kernel_size=2,
155
- strides=2,
156
- padding="same",
157
- use_bias=True,
158
- kernel_initializer="he_normal",
159
- ),
160
- ])
161
-
162
- self.postprocessor = LinkNetPostProcessor(
163
- assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
164
- )
165
-
166
- def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
167
- """Load pretrained parameters onto the model
168
-
169
- Args:
170
- path_or_url: the path or URL to the model parameters (checkpoint)
171
- **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
172
- """
173
- load_pretrained_params(self, path_or_url, **kwargs)
174
-
175
- def compute_loss(
176
- self,
177
- out_map: tf.Tensor,
178
- target: list[dict[str, np.ndarray]],
179
- gamma: float = 2.0,
180
- alpha: float = 0.5,
181
- eps: float = 1e-8,
182
- ) -> tf.Tensor:
183
- """Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on
184
- <https://github.com/tensorflow/addons/>`_.
185
-
186
- Args:
187
- out_map: output feature map of the model of shape N x H x W x 1
188
- target: list of dictionary where each dict has a `boxes` and a `flags` entry
189
- gamma: modulating factor in the focal loss formula
190
- alpha: balancing factor in the focal loss formula
191
- eps: epsilon factor in dice loss
192
-
193
- Returns:
194
- A loss tensor
195
- """
196
- seg_target, seg_mask = self.build_target(target, out_map.shape[1:], True)
197
- seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
198
- seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
199
- seg_mask = tf.cast(seg_mask, tf.float32)
200
-
201
- bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
202
- proba_map = tf.sigmoid(out_map)
203
-
204
- # Focal loss
205
- if gamma < 0:
206
- raise ValueError("Value of gamma should be greater than or equal to zero.")
207
- # Convert logits to prob, compute gamma factor
208
- p_t = (seg_target * proba_map) + ((1 - seg_target) * (1 - proba_map))
209
- alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha)
210
- # Unreduced loss
211
- focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
212
- # Class reduced
213
- focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3))
214
-
215
- # Compute dice loss for each class
216
- dice_map = tf.nn.softmax(out_map, axis=-1) if len(self.class_names) > 1 else proba_map
217
- # Class-reduced dice loss
218
- inter = tf.reduce_sum(seg_mask * dice_map * seg_target, axis=[0, 1, 2])
219
- cardinality = tf.reduce_sum(seg_mask * (dice_map + seg_target), axis=[0, 1, 2])
220
- dice_loss = tf.reduce_mean(1 - 2 * inter / (cardinality + eps))
221
-
222
- return focal_loss + dice_loss
223
-
224
- def call(
225
- self,
226
- x: tf.Tensor,
227
- target: list[dict[str, np.ndarray]] | None = None,
228
- return_model_output: bool = False,
229
- return_preds: bool = False,
230
- **kwargs: Any,
231
- ) -> dict[str, Any]:
232
- feat_maps = self.feat_extractor(x, **kwargs)
233
- logits = self.fpn(feat_maps, **kwargs)
234
- logits = self.classifier(logits, **kwargs)
235
-
236
- out: dict[str, tf.Tensor] = {}
237
- if self.exportable:
238
- out["logits"] = logits
239
- return out
240
-
241
- if return_model_output or target is None or return_preds:
242
- prob_map = _bf16_to_float32(tf.math.sigmoid(logits))
243
-
244
- if return_model_output:
245
- out["out_map"] = prob_map
246
-
247
- if target is None or return_preds:
248
- # Post-process boxes
249
- out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())]
250
-
251
- if target is not None:
252
- loss = self.compute_loss(logits, target)
253
- out["loss"] = loss
254
-
255
- return out
256
-
257
-
258
- def _linknet(
259
- arch: str,
260
- pretrained: bool,
261
- backbone_fn,
262
- fpn_layers: list[str],
263
- pretrained_backbone: bool = True,
264
- input_shape: tuple[int, int, int] | None = None,
265
- **kwargs: Any,
266
- ) -> LinkNet:
267
- pretrained_backbone = pretrained_backbone and not pretrained
268
-
269
- # Patch the config
270
- _cfg = deepcopy(default_cfgs[arch])
271
- _cfg["input_shape"] = input_shape or default_cfgs[arch]["input_shape"]
272
- if not kwargs.get("class_names", None):
273
- kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME])
274
- else:
275
- kwargs["class_names"] = sorted(kwargs["class_names"])
276
-
277
- # Feature extractor
278
- feat_extractor = IntermediateLayerGetter(
279
- backbone_fn(
280
- pretrained=pretrained_backbone,
281
- include_top=False,
282
- input_shape=_cfg["input_shape"],
283
- ),
284
- fpn_layers,
285
- )
286
-
287
- # Build the model
288
- model = LinkNet(feat_extractor, cfg=_cfg, **kwargs)
289
- _build_model(model)
290
-
291
- # Load pretrained parameters
292
- if pretrained:
293
- # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
294
- model.from_pretrained(
295
- _cfg["url"],
296
- skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
297
- )
298
-
299
- return model
300
-
301
-
302
- def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
303
- """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
304
- <https://arxiv.org/pdf/1707.03718.pdf>`_.
305
-
306
- >>> import tensorflow as tf
307
- >>> from doctr.models import linknet_resnet18
308
- >>> model = linknet_resnet18(pretrained=True)
309
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
310
- >>> out = model(input_tensor)
311
-
312
- Args:
313
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
314
- **kwargs: keyword arguments of the LinkNet architecture
315
-
316
- Returns:
317
- text detection architecture
318
- """
319
- return _linknet(
320
- "linknet_resnet18",
321
- pretrained,
322
- resnet18,
323
- ["resnet_block_1", "resnet_block_3", "resnet_block_5", "resnet_block_7"],
324
- **kwargs,
325
- )
326
-
327
-
328
- def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
329
- """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
330
- <https://arxiv.org/pdf/1707.03718.pdf>`_.
331
-
332
- >>> import tensorflow as tf
333
- >>> from doctr.models import linknet_resnet34
334
- >>> model = linknet_resnet34(pretrained=True)
335
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
336
- >>> out = model(input_tensor)
337
-
338
- Args:
339
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
340
- **kwargs: keyword arguments of the LinkNet architecture
341
-
342
- Returns:
343
- text detection architecture
344
- """
345
- return _linknet(
346
- "linknet_resnet34",
347
- pretrained,
348
- resnet34,
349
- ["resnet_block_2", "resnet_block_6", "resnet_block_12", "resnet_block_15"],
350
- **kwargs,
351
- )
352
-
353
-
354
- def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
355
- """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
356
- <https://arxiv.org/pdf/1707.03718.pdf>`_.
357
-
358
- >>> import tensorflow as tf
359
- >>> from doctr.models import linknet_resnet50
360
- >>> model = linknet_resnet50(pretrained=True)
361
- >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
362
- >>> out = model(input_tensor)
363
-
364
- Args:
365
- pretrained (bool): If True, returns a model pre-trained on our text detection dataset
366
- **kwargs: keyword arguments of the LinkNet architecture
367
-
368
- Returns:
369
- text detection architecture
370
- """
371
- return _linknet(
372
- "linknet_resnet50",
373
- pretrained,
374
- resnet50,
375
- ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"],
376
- **kwargs,
377
- )
@@ -1,70 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- from typing import Any
7
-
8
- import numpy as np
9
- import tensorflow as tf
10
- from tensorflow.keras import Model
11
-
12
- from doctr.models.detection._utils import _remove_padding
13
- from doctr.models.preprocessor import PreProcessor
14
- from doctr.utils.repr import NestedObject
15
-
16
- __all__ = ["DetectionPredictor"]
17
-
18
-
19
- class DetectionPredictor(NestedObject):
20
- """Implements an object able to localize text elements in a document
21
-
22
- Args:
23
- pre_processor: transform inputs for easier batched model inference
24
- model: core detection architecture
25
- """
26
-
27
- _children_names: list[str] = ["pre_processor", "model"]
28
-
29
- def __init__(
30
- self,
31
- pre_processor: PreProcessor,
32
- model: Model,
33
- ) -> None:
34
- self.pre_processor = pre_processor
35
- self.model = model
36
-
37
- def __call__(
38
- self,
39
- pages: list[np.ndarray | tf.Tensor],
40
- return_maps: bool = False,
41
- **kwargs: Any,
42
- ) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
43
- # Extract parameters from the preprocessor
44
- preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
45
- symmetric_pad = self.pre_processor.resize.symmetric_pad
46
- assume_straight_pages = self.model.assume_straight_pages
47
-
48
- # Dimension check
49
- if any(page.ndim != 3 for page in pages):
50
- raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
51
-
52
- processed_batches = self.pre_processor(pages)
53
- predicted_batches = [
54
- self.model(batch, return_preds=True, return_model_output=True, training=False, **kwargs)
55
- for batch in processed_batches
56
- ]
57
-
58
- # Remove padding from loc predictions
59
- preds = _remove_padding(
60
- pages,
61
- [pred for batch in predicted_batches for pred in batch["preds"]],
62
- preserve_aspect_ratio=preserve_aspect_ratio,
63
- symmetric_pad=symmetric_pad,
64
- assume_straight_pages=assume_straight_pages,
65
- )
66
-
67
- if return_maps:
68
- seg_maps = [pred.numpy() for batch in predicted_batches for pred in batch["out_map"]]
69
- return preds, seg_maps
70
- return preds
@@ -1,187 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- from typing import Any
7
-
8
- import numpy as np
9
- import tensorflow as tf
10
-
11
- from doctr.io.elements import Document
12
- from doctr.models._utils import get_language, invert_data_structure
13
- from doctr.models.detection.predictor import DetectionPredictor
14
- from doctr.models.recognition.predictor import RecognitionPredictor
15
- from doctr.utils.geometry import detach_scores
16
- from doctr.utils.repr import NestedObject
17
-
18
- from .base import _KIEPredictor
19
-
20
- __all__ = ["KIEPredictor"]
21
-
22
-
23
- class KIEPredictor(NestedObject, _KIEPredictor):
24
- """Implements an object able to localize and identify text elements in a set of documents
25
-
26
- Args:
27
- det_predictor: detection module
28
- reco_predictor: recognition module
29
- assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
30
- without rotated textual elements.
31
- straighten_pages: if True, estimates the page general orientation based on the median line orientation.
32
- Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
33
- accordingly. Doing so will improve performances for documents with page-uniform rotations.
34
- detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
35
- page. Doing so will slightly deteriorate the overall latency.
36
- detect_language: if True, the language prediction will be added to the predictions for each
37
- page. Doing so will slightly deteriorate the overall latency.
38
- **kwargs: keyword args of `DocumentBuilder`
39
- """
40
-
41
- _children_names = ["det_predictor", "reco_predictor", "doc_builder"]
42
-
43
- def __init__(
44
- self,
45
- det_predictor: DetectionPredictor,
46
- reco_predictor: RecognitionPredictor,
47
- assume_straight_pages: bool = True,
48
- straighten_pages: bool = False,
49
- preserve_aspect_ratio: bool = True,
50
- symmetric_pad: bool = True,
51
- detect_orientation: bool = False,
52
- detect_language: bool = False,
53
- **kwargs: Any,
54
- ) -> None:
55
- self.det_predictor = det_predictor
56
- self.reco_predictor = reco_predictor
57
- _KIEPredictor.__init__(
58
- self,
59
- assume_straight_pages,
60
- straighten_pages,
61
- preserve_aspect_ratio,
62
- symmetric_pad,
63
- detect_orientation,
64
- **kwargs,
65
- )
66
- self.detect_orientation = detect_orientation
67
- self.detect_language = detect_language
68
-
69
- def __call__(
70
- self,
71
- pages: list[np.ndarray | tf.Tensor],
72
- **kwargs: Any,
73
- ) -> Document:
74
- # Dimension check
75
- if any(page.ndim != 3 for page in pages):
76
- raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
77
-
78
- origin_page_shapes = [page.shape[:2] for page in pages]
79
-
80
- # Localize text elements
81
- loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
82
-
83
- # Detect document rotation and rotate pages
84
- seg_maps = [
85
- np.where(np.expand_dims(np.amax(out_map, axis=-1), axis=-1) > kwargs.get("bin_thresh", 0.3), 255, 0).astype(
86
- np.uint8
87
- )
88
- for out_map in out_maps
89
- ]
90
- if self.detect_orientation:
91
- general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
92
- orientations = [
93
- {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
94
- ]
95
- else:
96
- orientations = None
97
- general_pages_orientations = None
98
- origin_pages_orientations = None
99
- if self.straighten_pages:
100
- pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
101
- # update page shapes after straightening
102
- origin_page_shapes = [page.shape[:2] for page in pages]
103
-
104
- # Forward again to get predictions on straight pages
105
- loc_preds = self.det_predictor(pages, **kwargs)
106
-
107
- dict_loc_preds: dict[str, list[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
108
-
109
- # Detach objectness scores from loc_preds
110
- objectness_scores = {}
111
- for class_name, det_preds in dict_loc_preds.items():
112
- _loc_preds, _scores = detach_scores(det_preds)
113
- dict_loc_preds[class_name] = _loc_preds
114
- objectness_scores[class_name] = _scores
115
-
116
- # Apply hooks to loc_preds if any
117
- for hook in self.hooks:
118
- dict_loc_preds = hook(dict_loc_preds)
119
-
120
- # Crop images
121
- crops = {}
122
- for class_name in dict_loc_preds.keys():
123
- crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
124
- pages,
125
- dict_loc_preds[class_name],
126
- channels_last=True,
127
- assume_straight_pages=self.assume_straight_pages,
128
- assume_horizontal=self._page_orientation_disabled,
129
- )
130
-
131
- # Rectify crop orientation
132
- crop_orientations: Any = {}
133
- if not self.assume_straight_pages:
134
- for class_name in dict_loc_preds.keys():
135
- crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
136
- crops[class_name], dict_loc_preds[class_name]
137
- )
138
- crop_orientations[class_name] = [
139
- {"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations
140
- ]
141
-
142
- # Identify character sequences
143
- word_preds = {
144
- k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs)
145
- for k, crop_value in crops.items()
146
- }
147
- if not crop_orientations:
148
- crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
149
-
150
- boxes: dict = {}
151
- text_preds: dict = {}
152
- word_crop_orientations: dict = {}
153
- for class_name in dict_loc_preds.keys():
154
- boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
155
- dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
156
- )
157
-
158
- boxes_per_page: list[dict] = invert_data_structure(boxes) # type: ignore[assignment]
159
- objectness_scores_per_page: list[dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
160
- text_preds_per_page: list[dict] = invert_data_structure(text_preds) # type: ignore[assignment]
161
- crop_orientations_per_page: list[dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
162
-
163
- if self.detect_language:
164
- languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
165
- languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
166
- else:
167
- languages_dict = None
168
-
169
- out = self.doc_builder(
170
- pages,
171
- boxes_per_page,
172
- objectness_scores_per_page,
173
- text_preds_per_page,
174
- origin_page_shapes,
175
- crop_orientations_per_page,
176
- orientations,
177
- languages_dict,
178
- )
179
- return out
180
-
181
- @staticmethod
182
- def get_text(text_pred: dict) -> str:
183
- text = []
184
- for value in text_pred.values():
185
- text += [item[0] for item in value]
186
-
187
- return " ".join(text)