python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +17 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +17 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +14 -5
  17. doctr/datasets/ic13.py +13 -5
  18. doctr/datasets/iiit5k.py +31 -20
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +15 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +16 -5
  27. doctr/datasets/svhn.py +16 -5
  28. doctr/datasets/svt.py +14 -5
  29. doctr/datasets/synthtext.py +14 -5
  30. doctr/datasets/utils.py +37 -27
  31. doctr/datasets/vocabs.py +21 -7
  32. doctr/datasets/wildreceipt.py +25 -10
  33. doctr/file_utils.py +18 -4
  34. doctr/io/elements.py +69 -81
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +32 -50
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +21 -17
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +7 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +22 -29
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +13 -11
  52. doctr/models/classification/predictor/tensorflow.py +13 -11
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +41 -39
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +19 -20
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +18 -15
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +16 -16
  65. doctr/models/classification/zoo.py +36 -19
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +28 -37
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +36 -33
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +7 -8
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +8 -13
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +8 -5
  91. doctr/models/kie_predictor/pytorch.py +22 -19
  92. doctr/models/kie_predictor/tensorflow.py +21 -15
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -12
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +3 -4
  101. doctr/models/modules/vision_transformer/tensorflow.py +4 -4
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +52 -41
  104. doctr/models/predictor/pytorch.py +16 -13
  105. doctr/models/predictor/tensorflow.py +16 -10
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +11 -15
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +19 -29
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +21 -26
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +26 -30
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +19 -24
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +21 -24
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +13 -16
  136. doctr/models/utils/tensorflow.py +31 -30
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +21 -29
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +65 -28
  145. doctr/transforms/modules/tensorflow.py +33 -44
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +120 -64
  150. doctr/utils/metrics.py +18 -38
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +157 -75
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.9.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,10 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ from collections.abc import Callable
6
7
  from copy import deepcopy
7
- from typing import Any, Callable, Dict, List, Optional, Tuple
8
+ from typing import Any
8
9
 
9
10
  import tensorflow as tf
10
11
  from tensorflow.keras import layers
@@ -13,46 +14,46 @@ from tensorflow.keras.models import Sequential
13
14
 
14
15
  from doctr.datasets import VOCABS
15
16
 
16
- from ...utils import conv_sequence, load_pretrained_params
17
+ from ...utils import _build_model, conv_sequence, load_pretrained_params
17
18
 
18
19
  __all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"]
19
20
 
20
21
 
21
- default_cfgs: Dict[str, Dict[str, Any]] = {
22
+ default_cfgs: dict[str, dict[str, Any]] = {
22
23
  "resnet18": {
23
24
  "mean": (0.694, 0.695, 0.693),
24
25
  "std": (0.299, 0.296, 0.301),
25
26
  "input_shape": (32, 32, 3),
26
27
  "classes": list(VOCABS["french"]),
27
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/resnet18-d4634669.zip&src=0",
28
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0",
28
29
  },
29
30
  "resnet31": {
30
31
  "mean": (0.694, 0.695, 0.693),
31
32
  "std": (0.299, 0.296, 0.301),
32
33
  "input_shape": (32, 32, 3),
33
34
  "classes": list(VOCABS["french"]),
34
- "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet31-5a47a60b.zip&src=0",
35
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0",
35
36
  },
36
37
  "resnet34": {
37
38
  "mean": (0.694, 0.695, 0.693),
38
39
  "std": (0.299, 0.296, 0.301),
39
40
  "input_shape": (32, 32, 3),
40
41
  "classes": list(VOCABS["french"]),
41
- "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34-5dcc97ca.zip&src=0",
42
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0",
42
43
  },
43
44
  "resnet50": {
44
45
  "mean": (0.694, 0.695, 0.693),
45
46
  "std": (0.299, 0.296, 0.301),
46
47
  "input_shape": (32, 32, 3),
47
48
  "classes": list(VOCABS["french"]),
48
- "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet50-e75e4cdf.zip&src=0",
49
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0",
49
50
  },
50
51
  "resnet34_wide": {
51
52
  "mean": (0.694, 0.695, 0.693),
52
53
  "std": (0.299, 0.296, 0.301),
53
54
  "input_shape": (32, 32, 3),
54
55
  "classes": list(VOCABS["french"]),
55
- "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34_wide-c1271816.zip&src=0",
56
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0",
56
57
  },
57
58
  }
58
59
 
@@ -61,7 +62,6 @@ class ResnetBlock(layers.Layer):
61
62
  """Implements a resnet31 block with shortcut
62
63
 
63
64
  Args:
64
- ----
65
65
  conv_shortcut: Use of shortcut
66
66
  output_channels: number of channels to use in Conv2D
67
67
  kernel_size: size of square kernels
@@ -92,7 +92,7 @@ class ResnetBlock(layers.Layer):
92
92
  output_channels: int,
93
93
  kernel_size: int,
94
94
  strides: int = 1,
95
- ) -> List[layers.Layer]:
95
+ ) -> list[layers.Layer]:
96
96
  return [
97
97
  *conv_sequence(output_channels, "relu", bn=True, strides=strides, kernel_size=kernel_size),
98
98
  *conv_sequence(output_channels, None, bn=True, kernel_size=kernel_size),
@@ -108,8 +108,8 @@ class ResnetBlock(layers.Layer):
108
108
 
109
109
  def resnet_stage(
110
110
  num_blocks: int, out_channels: int, shortcut: bool = False, downsample: bool = False
111
- ) -> List[layers.Layer]:
112
- _layers: List[layers.Layer] = [ResnetBlock(out_channels, conv_shortcut=shortcut, strides=2 if downsample else 1)]
111
+ ) -> list[layers.Layer]:
112
+ _layers: list[layers.Layer] = [ResnetBlock(out_channels, conv_shortcut=shortcut, strides=2 if downsample else 1)]
113
113
 
114
114
  for _ in range(1, num_blocks):
115
115
  _layers.append(ResnetBlock(out_channels, conv_shortcut=False))
@@ -121,7 +121,6 @@ class ResNet(Sequential):
121
121
  """Implements a ResNet architecture
122
122
 
123
123
  Args:
124
- ----
125
124
  num_blocks: number of resnet block in each stage
126
125
  output_channels: number of channels in each stage
127
126
  stage_downsample: whether the first residual block of a stage should downsample
@@ -137,18 +136,18 @@ class ResNet(Sequential):
137
136
 
138
137
  def __init__(
139
138
  self,
140
- num_blocks: List[int],
141
- output_channels: List[int],
142
- stage_downsample: List[bool],
143
- stage_conv: List[bool],
144
- stage_pooling: List[Optional[Tuple[int, int]]],
139
+ num_blocks: list[int],
140
+ output_channels: list[int],
141
+ stage_downsample: list[bool],
142
+ stage_conv: list[bool],
143
+ stage_pooling: list[tuple[int, int] | None],
145
144
  origin_stem: bool = True,
146
145
  stem_channels: int = 64,
147
- attn_module: Optional[Callable[[int], layers.Layer]] = None,
146
+ attn_module: Callable[[int], layers.Layer] | None = None,
148
147
  include_top: bool = True,
149
148
  num_classes: int = 1000,
150
- cfg: Optional[Dict[str, Any]] = None,
151
- input_shape: Optional[Tuple[int, int, int]] = None,
149
+ cfg: dict[str, Any] | None = None,
150
+ input_shape: tuple[int, int, int] | None = None,
152
151
  ) -> None:
153
152
  inplanes = stem_channels
154
153
  if origin_stem:
@@ -188,11 +187,11 @@ class ResNet(Sequential):
188
187
  def _resnet(
189
188
  arch: str,
190
189
  pretrained: bool,
191
- num_blocks: List[int],
192
- output_channels: List[int],
193
- stage_downsample: List[bool],
194
- stage_conv: List[bool],
195
- stage_pooling: List[Optional[Tuple[int, int]]],
190
+ num_blocks: list[int],
191
+ output_channels: list[int],
192
+ stage_downsample: list[bool],
193
+ stage_conv: list[bool],
194
+ stage_pooling: list[tuple[int, int] | None],
196
195
  origin_stem: bool = True,
197
196
  **kwargs: Any,
198
197
  ) -> ResNet:
@@ -210,9 +209,15 @@ def _resnet(
210
209
  model = ResNet(
211
210
  num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, cfg=_cfg, **kwargs
212
211
  )
212
+ _build_model(model)
213
+
213
214
  # Load pretrained parameters
214
215
  if pretrained:
215
- load_pretrained_params(model, default_cfgs[arch]["url"])
216
+ # The number of classes is not the same as the number of classes in the pretrained model =>
217
+ # skip the mismatching layers for fine tuning
218
+ load_pretrained_params(
219
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
220
+ )
216
221
 
217
222
  return model
218
223
 
@@ -228,12 +233,10 @@ def resnet18(pretrained: bool = False, **kwargs: Any) -> ResNet:
228
233
  >>> out = model(input_tensor)
229
234
 
230
235
  Args:
231
- ----
232
236
  pretrained: boolean, True if model is pretrained
233
237
  **kwargs: keyword arguments of the ResNet architecture
234
238
 
235
239
  Returns:
236
- -------
237
240
  A classification model
238
241
  """
239
242
  return _resnet(
@@ -261,12 +264,10 @@ def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
261
264
  >>> out = model(input_tensor)
262
265
 
263
266
  Args:
264
- ----
265
267
  pretrained: boolean, True if model is pretrained
266
268
  **kwargs: keyword arguments of the ResNet architecture
267
269
 
268
270
  Returns:
269
- -------
270
271
  A classification model
271
272
  """
272
273
  return _resnet(
@@ -294,12 +295,10 @@ def resnet34(pretrained: bool = False, **kwargs: Any) -> ResNet:
294
295
  >>> out = model(input_tensor)
295
296
 
296
297
  Args:
297
- ----
298
298
  pretrained: boolean, True if model is pretrained
299
299
  **kwargs: keyword arguments of the ResNet architecture
300
300
 
301
301
  Returns:
302
- -------
303
302
  A classification model
304
303
  """
305
304
  return _resnet(
@@ -326,12 +325,10 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
326
325
  >>> out = model(input_tensor)
327
326
 
328
327
  Args:
329
- ----
330
328
  pretrained: boolean, True if model is pretrained
331
329
  **kwargs: keyword arguments of the ResNet architecture
332
330
 
333
331
  Returns:
334
- -------
335
332
  A classification model
336
333
  """
337
334
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs["resnet50"]["classes"]))
@@ -354,10 +351,17 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
354
351
  )
355
352
 
356
353
  model.cfg = _cfg
354
+ _build_model(model)
357
355
 
358
356
  # Load pretrained parameters
359
357
  if pretrained:
360
- load_pretrained_params(model, default_cfgs["resnet50"]["url"])
358
+ # The number of classes is not the same as the number of classes in the pretrained model =>
359
+ # skip the mismatching layers for fine tuning
360
+ load_pretrained_params(
361
+ model,
362
+ default_cfgs["resnet50"]["url"],
363
+ skip_mismatch=kwargs["num_classes"] != len(default_cfgs["resnet50"]["classes"]),
364
+ )
361
365
 
362
366
  return model
363
367
 
@@ -373,12 +377,10 @@ def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet:
373
377
  >>> out = model(input_tensor)
374
378
 
375
379
  Args:
376
- ----
377
380
  pretrained: boolean, True if model is pretrained
378
381
  **kwargs: keyword arguments of the ResNet architecture
379
382
 
380
383
  Returns:
381
- -------
382
384
  A classification model
383
385
  """
384
386
  return _resnet(
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
4
6
  from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
7
  from copy import deepcopy
8
- from typing import Any, Dict, List, Optional, Tuple
8
+ from typing import Any
9
9
 
10
10
  from torch import nn
11
11
 
@@ -16,7 +16,7 @@ from ...utils import conv_sequence_pt, load_pretrained_params
16
16
 
17
17
  __all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
18
18
 
19
- default_cfgs: Dict[str, Dict[str, Any]] = {
19
+ default_cfgs: dict[str, dict[str, Any]] = {
20
20
  "textnet_tiny": {
21
21
  "mean": (0.694, 0.695, 0.693),
22
22
  "std": (0.299, 0.296, 0.301),
@@ -47,22 +47,21 @@ class TextNet(nn.Sequential):
47
47
  Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
48
48
 
49
49
  Args:
50
- ----
51
- stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage.
50
+ stages (list[dict[str, list[int]]]): list of dictionaries containing the parameters of each stage.
52
51
  include_top (bool, optional): Whether to include the classifier head. Defaults to True.
53
52
  num_classes (int, optional): Number of output classes. Defaults to 1000.
54
- cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
53
+ cfg (dict[str, Any], optional): Additional configuration. Defaults to None.
55
54
  """
56
55
 
57
56
  def __init__(
58
57
  self,
59
- stages: List[Dict[str, List[int]]],
60
- input_shape: Tuple[int, int, int] = (3, 32, 32),
58
+ stages: list[dict[str, list[int]]],
59
+ input_shape: tuple[int, int, int] = (3, 32, 32),
61
60
  num_classes: int = 1000,
62
61
  include_top: bool = True,
63
- cfg: Optional[Dict[str, Any]] = None,
62
+ cfg: dict[str, Any] | None = None,
64
63
  ) -> None:
65
- _layers: List[nn.Module] = [
64
+ _layers: list[nn.Module] = [
66
65
  *conv_sequence_pt(
67
66
  in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2, padding=(1, 1)
68
67
  ),
@@ -98,7 +97,7 @@ class TextNet(nn.Sequential):
98
97
  def _textnet(
99
98
  arch: str,
100
99
  pretrained: bool,
101
- ignore_keys: Optional[List[str]] = None,
100
+ ignore_keys: list[str] | None = None,
102
101
  **kwargs: Any,
103
102
  ) -> TextNet:
104
103
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
@@ -135,12 +134,10 @@ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
135
134
  >>> out = model(input_tensor)
136
135
 
137
136
  Args:
138
- ----
139
137
  pretrained: boolean, True if model is pretrained
140
138
  **kwargs: keyword arguments of the TextNet architecture
141
139
 
142
140
  Returns:
143
- -------
144
141
  A textnet tiny model
145
142
  """
146
143
  return _textnet(
@@ -184,12 +181,10 @@ def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
184
181
  >>> out = model(input_tensor)
185
182
 
186
183
  Args:
187
- ----
188
184
  pretrained: boolean, True if model is pretrained
189
185
  **kwargs: keyword arguments of the TextNet architecture
190
186
 
191
187
  Returns:
192
- -------
193
188
  A TextNet small model
194
189
  """
195
190
  return _textnet(
@@ -233,12 +228,10 @@ def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
233
228
  >>> out = model(input_tensor)
234
229
 
235
230
  Args:
236
- ----
237
231
  pretrained: boolean, True if model is pretrained
238
232
  **kwargs: keyword arguments of the TextNet architecture
239
233
 
240
234
  Returns:
241
- -------
242
235
  A TextNet base model
243
236
  """
244
237
  return _textnet(
@@ -1,42 +1,42 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
7
  from copy import deepcopy
8
- from typing import Any, Dict, List, Optional, Tuple
8
+ from typing import Any
9
9
 
10
10
  from tensorflow.keras import Sequential, layers
11
11
 
12
12
  from doctr.datasets import VOCABS
13
13
 
14
14
  from ...modules.layers.tensorflow import FASTConvLayer
15
- from ...utils import conv_sequence, load_pretrained_params
15
+ from ...utils import _build_model, conv_sequence, load_pretrained_params
16
16
 
17
17
  __all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
18
18
 
19
- default_cfgs: Dict[str, Dict[str, Any]] = {
19
+ default_cfgs: dict[str, dict[str, Any]] = {
20
20
  "textnet_tiny": {
21
21
  "mean": (0.694, 0.695, 0.693),
22
22
  "std": (0.299, 0.296, 0.301),
23
23
  "input_shape": (32, 32, 3),
24
24
  "classes": list(VOCABS["french"]),
25
- "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-fe9cc245.zip&src=0",
25
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0",
26
26
  },
27
27
  "textnet_small": {
28
28
  "mean": (0.694, 0.695, 0.693),
29
29
  "std": (0.299, 0.296, 0.301),
30
30
  "input_shape": (32, 32, 3),
31
31
  "classes": list(VOCABS["french"]),
32
- "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-29c39c82.zip&src=0",
32
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0",
33
33
  },
34
34
  "textnet_base": {
35
35
  "mean": (0.694, 0.695, 0.693),
36
36
  "std": (0.299, 0.296, 0.301),
37
37
  "input_shape": (32, 32, 3),
38
38
  "classes": list(VOCABS["french"]),
39
- "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-168aa82c.zip&src=0",
39
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0",
40
40
  },
41
41
  }
42
42
 
@@ -47,20 +47,19 @@ class TextNet(Sequential):
47
47
  Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
48
48
 
49
49
  Args:
50
- ----
51
- stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage.
50
+ stages (list[dict[str, list[int]]]): list of dictionaries containing the parameters of each stage.
52
51
  include_top (bool, optional): Whether to include the classifier head. Defaults to True.
53
52
  num_classes (int, optional): Number of output classes. Defaults to 1000.
54
- cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
53
+ cfg (dict[str, Any], optional): Additional configuration. Defaults to None.
55
54
  """
56
55
 
57
56
  def __init__(
58
57
  self,
59
- stages: List[Dict[str, List[int]]],
60
- input_shape: Tuple[int, int, int] = (32, 32, 3),
58
+ stages: list[dict[str, list[int]]],
59
+ input_shape: tuple[int, int, int] = (32, 32, 3),
61
60
  num_classes: int = 1000,
62
61
  include_top: bool = True,
63
- cfg: Optional[Dict[str, Any]] = None,
62
+ cfg: dict[str, Any] | None = None,
64
63
  ) -> None:
65
64
  _layers = [
66
65
  *conv_sequence(
@@ -111,9 +110,15 @@ def _textnet(
111
110
 
112
111
  # Build the model
113
112
  model = TextNet(cfg=_cfg, **kwargs)
113
+ _build_model(model)
114
+
114
115
  # Load pretrained parameters
115
116
  if pretrained:
116
- load_pretrained_params(model, default_cfgs[arch]["url"])
117
+ # The number of classes is not the same as the number of classes in the pretrained model =>
118
+ # skip the mismatching layers for fine tuning
119
+ load_pretrained_params(
120
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
121
+ )
117
122
 
118
123
  return model
119
124
 
@@ -130,12 +135,10 @@ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
130
135
  >>> out = model(input_tensor)
131
136
 
132
137
  Args:
133
- ----
134
138
  pretrained: boolean, True if model is pretrained
135
139
  **kwargs: keyword arguments of the TextNet architecture
136
140
 
137
141
  Returns:
138
- -------
139
142
  A textnet tiny model
140
143
  """
141
144
  return _textnet(
@@ -178,12 +181,10 @@ def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
178
181
  >>> out = model(input_tensor)
179
182
 
180
183
  Args:
181
- ----
182
184
  pretrained: boolean, True if model is pretrained
183
185
  **kwargs: keyword arguments of the TextNet architecture
184
186
 
185
187
  Returns:
186
- -------
187
188
  A TextNet small model
188
189
  """
189
190
  return _textnet(
@@ -226,12 +227,10 @@ def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
226
227
  >>> out = model(input_tensor)
227
228
 
228
229
  Args:
229
- ----
230
230
  pretrained: boolean, True if model is pretrained
231
231
  **kwargs: keyword arguments of the TextNet architecture
232
232
 
233
233
  Returns:
234
- -------
235
234
  A TextNet base model
236
235
  """
237
236
  return _textnet(
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
3
+ if is_torch_available():
6
4
  from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import *
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, List, Optional
7
+ from typing import Any
8
8
 
9
9
  from torch import nn
10
10
  from torchvision.models import vgg as tv_vgg
@@ -16,7 +16,7 @@ from ...utils import load_pretrained_params
16
16
  __all__ = ["vgg16_bn_r"]
17
17
 
18
18
 
19
- default_cfgs: Dict[str, Dict[str, Any]] = {
19
+ default_cfgs: dict[str, dict[str, Any]] = {
20
20
  "vgg16_bn_r": {
21
21
  "mean": (0.694, 0.695, 0.693),
22
22
  "std": (0.299, 0.296, 0.301),
@@ -32,7 +32,7 @@ def _vgg(
32
32
  pretrained: bool,
33
33
  tv_arch: str,
34
34
  num_rect_pools: int = 3,
35
- ignore_keys: Optional[List[str]] = None,
35
+ ignore_keys: list[str] | None = None,
36
36
  **kwargs: Any,
37
37
  ) -> tv_vgg.VGG:
38
38
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
@@ -45,7 +45,7 @@ def _vgg(
45
45
 
46
46
  # Build the model
47
47
  model = tv_vgg.__dict__[tv_arch](**kwargs, weights=None)
48
- # List the MaxPool2d
48
+ # list the MaxPool2d
49
49
  pool_idcs = [idx for idx, m in enumerate(model.features) if isinstance(m, nn.MaxPool2d)]
50
50
  # Replace their kernel with rectangular ones
51
51
  for idx in pool_idcs[-num_rect_pools:]:
@@ -77,12 +77,10 @@ def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> tv_vgg.VGG:
77
77
  >>> out = model(input_tensor)
78
78
 
79
79
  Args:
80
- ----
81
80
  pretrained (bool): If True, returns a model pre-trained on ImageNet
82
81
  **kwargs: keyword arguments of the VGG architecture
83
82
 
84
83
  Returns:
85
- -------
86
84
  VGG feature extractor
87
85
  """
88
86
  return _vgg(
@@ -1,28 +1,28 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, List, Optional, Tuple
7
+ from typing import Any
8
8
 
9
9
  from tensorflow.keras import layers
10
10
  from tensorflow.keras.models import Sequential
11
11
 
12
12
  from doctr.datasets import VOCABS
13
13
 
14
- from ...utils import conv_sequence, load_pretrained_params
14
+ from ...utils import _build_model, conv_sequence, load_pretrained_params
15
15
 
16
16
  __all__ = ["VGG", "vgg16_bn_r"]
17
17
 
18
18
 
19
- default_cfgs: Dict[str, Dict[str, Any]] = {
19
+ default_cfgs: dict[str, dict[str, Any]] = {
20
20
  "vgg16_bn_r": {
21
21
  "mean": (0.5, 0.5, 0.5),
22
22
  "std": (1.0, 1.0, 1.0),
23
23
  "input_shape": (32, 32, 3),
24
24
  "classes": list(VOCABS["french"]),
25
- "url": "https://doctr-static.mindee.com/models?id=v0.4.1/vgg16_bn_r-c5836cea.zip&src=0",
25
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0",
26
26
  },
27
27
  }
28
28
 
@@ -32,7 +32,6 @@ class VGG(Sequential):
32
32
  <https://arxiv.org/pdf/1409.1556.pdf>`_.
33
33
 
34
34
  Args:
35
- ----
36
35
  num_blocks: number of convolutional block in each stage
37
36
  planes: number of output channels in each stage
38
37
  rect_pools: whether pooling square kernels should be replace with rectangular ones
@@ -43,13 +42,13 @@ class VGG(Sequential):
43
42
 
44
43
  def __init__(
45
44
  self,
46
- num_blocks: List[int],
47
- planes: List[int],
48
- rect_pools: List[bool],
45
+ num_blocks: list[int],
46
+ planes: list[int],
47
+ rect_pools: list[bool],
49
48
  include_top: bool = False,
50
49
  num_classes: int = 1000,
51
- input_shape: Optional[Tuple[int, int, int]] = None,
52
- cfg: Optional[Dict[str, Any]] = None,
50
+ input_shape: tuple[int, int, int] | None = None,
51
+ cfg: dict[str, Any] | None = None,
53
52
  ) -> None:
54
53
  _layers = []
55
54
  # Specify input_shape only for the first layer
@@ -67,7 +66,7 @@ class VGG(Sequential):
67
66
 
68
67
 
69
68
  def _vgg(
70
- arch: str, pretrained: bool, num_blocks: List[int], planes: List[int], rect_pools: List[bool], **kwargs: Any
69
+ arch: str, pretrained: bool, num_blocks: list[int], planes: list[int], rect_pools: list[bool], **kwargs: Any
71
70
  ) -> VGG:
72
71
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
73
72
  kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
@@ -81,9 +80,15 @@ def _vgg(
81
80
 
82
81
  # Build the model
83
82
  model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs)
83
+ _build_model(model)
84
+
84
85
  # Load pretrained parameters
85
86
  if pretrained:
86
- load_pretrained_params(model, default_cfgs[arch]["url"])
87
+ # The number of classes is not the same as the number of classes in the pretrained model =>
88
+ # skip the mismatching layers for fine tuning
89
+ load_pretrained_params(
90
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
91
+ )
87
92
 
88
93
  return model
89
94
 
@@ -100,12 +105,10 @@ def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> VGG:
100
105
  >>> out = model(input_tensor)
101
106
 
102
107
  Args:
103
- ----
104
108
  pretrained (bool): If True, returns a model pre-trained on ImageNet
105
109
  **kwargs: keyword arguments of the VGG architecture
106
110
 
107
111
  Returns:
108
- -------
109
112
  VGG feature extractor
110
113
  """
111
114
  return _vgg(
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
4
6
  from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]