python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/__init__.py +1 -0
  5. doctr/datasets/coco_text.py +139 -0
  6. doctr/datasets/cord.py +10 -8
  7. doctr/datasets/datasets/__init__.py +4 -4
  8. doctr/datasets/datasets/base.py +16 -16
  9. doctr/datasets/datasets/pytorch.py +12 -12
  10. doctr/datasets/datasets/tensorflow.py +10 -10
  11. doctr/datasets/detection.py +6 -9
  12. doctr/datasets/doc_artefacts.py +3 -4
  13. doctr/datasets/funsd.py +9 -8
  14. doctr/datasets/generator/__init__.py +4 -4
  15. doctr/datasets/generator/base.py +16 -17
  16. doctr/datasets/generator/pytorch.py +1 -3
  17. doctr/datasets/generator/tensorflow.py +1 -3
  18. doctr/datasets/ic03.py +5 -6
  19. doctr/datasets/ic13.py +6 -6
  20. doctr/datasets/iiit5k.py +10 -6
  21. doctr/datasets/iiithws.py +4 -5
  22. doctr/datasets/imgur5k.py +15 -7
  23. doctr/datasets/loader.py +4 -7
  24. doctr/datasets/mjsynth.py +6 -5
  25. doctr/datasets/ocr.py +3 -4
  26. doctr/datasets/orientation.py +3 -4
  27. doctr/datasets/recognition.py +4 -5
  28. doctr/datasets/sroie.py +6 -5
  29. doctr/datasets/svhn.py +7 -6
  30. doctr/datasets/svt.py +6 -7
  31. doctr/datasets/synthtext.py +19 -7
  32. doctr/datasets/utils.py +41 -35
  33. doctr/datasets/vocabs.py +1107 -49
  34. doctr/datasets/wildreceipt.py +14 -10
  35. doctr/file_utils.py +11 -7
  36. doctr/io/elements.py +96 -82
  37. doctr/io/html.py +1 -3
  38. doctr/io/image/__init__.py +3 -3
  39. doctr/io/image/base.py +2 -5
  40. doctr/io/image/pytorch.py +3 -12
  41. doctr/io/image/tensorflow.py +2 -11
  42. doctr/io/pdf.py +5 -7
  43. doctr/io/reader.py +5 -11
  44. doctr/models/_utils.py +15 -23
  45. doctr/models/builder.py +30 -48
  46. doctr/models/classification/__init__.py +1 -0
  47. doctr/models/classification/magc_resnet/__init__.py +3 -3
  48. doctr/models/classification/magc_resnet/pytorch.py +11 -15
  49. doctr/models/classification/magc_resnet/tensorflow.py +11 -14
  50. doctr/models/classification/mobilenet/__init__.py +3 -3
  51. doctr/models/classification/mobilenet/pytorch.py +20 -18
  52. doctr/models/classification/mobilenet/tensorflow.py +19 -23
  53. doctr/models/classification/predictor/__init__.py +4 -4
  54. doctr/models/classification/predictor/pytorch.py +7 -9
  55. doctr/models/classification/predictor/tensorflow.py +6 -8
  56. doctr/models/classification/resnet/__init__.py +4 -4
  57. doctr/models/classification/resnet/pytorch.py +47 -34
  58. doctr/models/classification/resnet/tensorflow.py +45 -35
  59. doctr/models/classification/textnet/__init__.py +3 -3
  60. doctr/models/classification/textnet/pytorch.py +20 -18
  61. doctr/models/classification/textnet/tensorflow.py +19 -17
  62. doctr/models/classification/vgg/__init__.py +3 -3
  63. doctr/models/classification/vgg/pytorch.py +21 -8
  64. doctr/models/classification/vgg/tensorflow.py +20 -14
  65. doctr/models/classification/vip/__init__.py +4 -0
  66. doctr/models/classification/vip/layers/__init__.py +4 -0
  67. doctr/models/classification/vip/layers/pytorch.py +615 -0
  68. doctr/models/classification/vip/pytorch.py +505 -0
  69. doctr/models/classification/vit/__init__.py +3 -3
  70. doctr/models/classification/vit/pytorch.py +18 -15
  71. doctr/models/classification/vit/tensorflow.py +15 -12
  72. doctr/models/classification/zoo.py +23 -14
  73. doctr/models/core.py +3 -3
  74. doctr/models/detection/_utils/__init__.py +4 -4
  75. doctr/models/detection/_utils/base.py +4 -7
  76. doctr/models/detection/_utils/pytorch.py +1 -5
  77. doctr/models/detection/_utils/tensorflow.py +1 -5
  78. doctr/models/detection/core.py +2 -8
  79. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  80. doctr/models/detection/differentiable_binarization/base.py +10 -21
  81. doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
  82. doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
  83. doctr/models/detection/fast/__init__.py +4 -4
  84. doctr/models/detection/fast/base.py +8 -17
  85. doctr/models/detection/fast/pytorch.py +37 -35
  86. doctr/models/detection/fast/tensorflow.py +24 -28
  87. doctr/models/detection/linknet/__init__.py +4 -4
  88. doctr/models/detection/linknet/base.py +8 -18
  89. doctr/models/detection/linknet/pytorch.py +34 -28
  90. doctr/models/detection/linknet/tensorflow.py +24 -25
  91. doctr/models/detection/predictor/__init__.py +5 -5
  92. doctr/models/detection/predictor/pytorch.py +6 -7
  93. doctr/models/detection/predictor/tensorflow.py +5 -6
  94. doctr/models/detection/zoo.py +27 -7
  95. doctr/models/factory/hub.py +6 -10
  96. doctr/models/kie_predictor/__init__.py +5 -5
  97. doctr/models/kie_predictor/base.py +4 -5
  98. doctr/models/kie_predictor/pytorch.py +19 -20
  99. doctr/models/kie_predictor/tensorflow.py +14 -15
  100. doctr/models/modules/layers/__init__.py +3 -3
  101. doctr/models/modules/layers/pytorch.py +55 -10
  102. doctr/models/modules/layers/tensorflow.py +5 -7
  103. doctr/models/modules/transformer/__init__.py +3 -3
  104. doctr/models/modules/transformer/pytorch.py +12 -13
  105. doctr/models/modules/transformer/tensorflow.py +9 -10
  106. doctr/models/modules/vision_transformer/__init__.py +3 -3
  107. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  108. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  109. doctr/models/predictor/__init__.py +5 -5
  110. doctr/models/predictor/base.py +28 -29
  111. doctr/models/predictor/pytorch.py +13 -14
  112. doctr/models/predictor/tensorflow.py +9 -10
  113. doctr/models/preprocessor/__init__.py +4 -4
  114. doctr/models/preprocessor/pytorch.py +13 -17
  115. doctr/models/preprocessor/tensorflow.py +10 -14
  116. doctr/models/recognition/__init__.py +1 -0
  117. doctr/models/recognition/core.py +3 -7
  118. doctr/models/recognition/crnn/__init__.py +4 -4
  119. doctr/models/recognition/crnn/pytorch.py +30 -29
  120. doctr/models/recognition/crnn/tensorflow.py +21 -24
  121. doctr/models/recognition/master/__init__.py +3 -3
  122. doctr/models/recognition/master/base.py +3 -7
  123. doctr/models/recognition/master/pytorch.py +32 -25
  124. doctr/models/recognition/master/tensorflow.py +22 -25
  125. doctr/models/recognition/parseq/__init__.py +3 -3
  126. doctr/models/recognition/parseq/base.py +3 -7
  127. doctr/models/recognition/parseq/pytorch.py +47 -29
  128. doctr/models/recognition/parseq/tensorflow.py +29 -27
  129. doctr/models/recognition/predictor/__init__.py +5 -5
  130. doctr/models/recognition/predictor/_utils.py +111 -52
  131. doctr/models/recognition/predictor/pytorch.py +9 -9
  132. doctr/models/recognition/predictor/tensorflow.py +8 -9
  133. doctr/models/recognition/sar/__init__.py +4 -4
  134. doctr/models/recognition/sar/pytorch.py +30 -22
  135. doctr/models/recognition/sar/tensorflow.py +22 -24
  136. doctr/models/recognition/utils.py +57 -53
  137. doctr/models/recognition/viptr/__init__.py +4 -0
  138. doctr/models/recognition/viptr/pytorch.py +277 -0
  139. doctr/models/recognition/vitstr/__init__.py +4 -4
  140. doctr/models/recognition/vitstr/base.py +3 -7
  141. doctr/models/recognition/vitstr/pytorch.py +28 -21
  142. doctr/models/recognition/vitstr/tensorflow.py +22 -23
  143. doctr/models/recognition/zoo.py +27 -11
  144. doctr/models/utils/__init__.py +4 -4
  145. doctr/models/utils/pytorch.py +41 -34
  146. doctr/models/utils/tensorflow.py +31 -23
  147. doctr/models/zoo.py +1 -5
  148. doctr/transforms/functional/__init__.py +3 -3
  149. doctr/transforms/functional/base.py +4 -11
  150. doctr/transforms/functional/pytorch.py +20 -28
  151. doctr/transforms/functional/tensorflow.py +10 -22
  152. doctr/transforms/modules/__init__.py +4 -4
  153. doctr/transforms/modules/base.py +48 -55
  154. doctr/transforms/modules/pytorch.py +58 -22
  155. doctr/transforms/modules/tensorflow.py +18 -32
  156. doctr/utils/common_types.py +8 -9
  157. doctr/utils/data.py +9 -13
  158. doctr/utils/fonts.py +2 -7
  159. doctr/utils/geometry.py +17 -48
  160. doctr/utils/metrics.py +17 -37
  161. doctr/utils/multithreading.py +4 -6
  162. doctr/utils/reconstitution.py +9 -13
  163. doctr/utils/repr.py +2 -3
  164. doctr/utils/visualization.py +16 -29
  165. doctr/version.py +1 -1
  166. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
  167. python_doctr-0.12.0.dist-info/RECORD +180 -0
  168. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  169. python_doctr-0.10.0.dist-info/RECORD +0 -173
  170. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  171. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  172. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
@@ -1,11 +1,13 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
+ import types
8
+ from collections.abc import Callable
7
9
  from copy import deepcopy
8
- from typing import Any, Callable, Dict, List, Optional, Tuple
10
+ from typing import Any
9
11
 
10
12
  from torch import nn
11
13
  from torchvision.models.resnet import BasicBlock
@@ -21,7 +23,7 @@ from ...utils import conv_sequence_pt, load_pretrained_params
21
23
  __all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide", "resnet_stage"]
22
24
 
23
25
 
24
- default_cfgs: Dict[str, Dict[str, Any]] = {
26
+ default_cfgs: dict[str, dict[str, Any]] = {
25
27
  "resnet18": {
26
28
  "mean": (0.694, 0.695, 0.693),
27
29
  "std": (0.299, 0.296, 0.301),
@@ -60,9 +62,9 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
60
62
  }
61
63
 
62
64
 
63
- def resnet_stage(in_channels: int, out_channels: int, num_blocks: int, stride: int) -> List[nn.Module]:
65
+ def resnet_stage(in_channels: int, out_channels: int, num_blocks: int, stride: int) -> list[nn.Module]:
64
66
  """Build a ResNet stage"""
65
- _layers: List[nn.Module] = []
67
+ _layers: list[nn.Module] = []
66
68
 
67
69
  in_chan = in_channels
68
70
  s = stride
@@ -84,7 +86,6 @@ class ResNet(nn.Sequential):
84
86
  Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
85
87
 
86
88
  Args:
87
- ----
88
89
  num_blocks: number of resnet block in each stage
89
90
  output_channels: number of channels in each stage
90
91
  stage_conv: whether to add a conv_sequence after each stage
@@ -98,19 +99,19 @@ class ResNet(nn.Sequential):
98
99
 
99
100
  def __init__(
100
101
  self,
101
- num_blocks: List[int],
102
- output_channels: List[int],
103
- stage_stride: List[int],
104
- stage_conv: List[bool],
105
- stage_pooling: List[Optional[Tuple[int, int]]],
102
+ num_blocks: list[int],
103
+ output_channels: list[int],
104
+ stage_stride: list[int],
105
+ stage_conv: list[bool],
106
+ stage_pooling: list[tuple[int, int] | None],
106
107
  origin_stem: bool = True,
107
108
  stem_channels: int = 64,
108
- attn_module: Optional[Callable[[int], nn.Module]] = None,
109
+ attn_module: Callable[[int], nn.Module] | None = None,
109
110
  include_top: bool = True,
110
111
  num_classes: int = 1000,
111
- cfg: Optional[Dict[str, Any]] = None,
112
+ cfg: dict[str, Any] | None = None,
112
113
  ) -> None:
113
- _layers: List[nn.Module]
114
+ _layers: list[nn.Module]
114
115
  if origin_stem:
115
116
  _layers = [
116
117
  *conv_sequence_pt(3, stem_channels, True, True, kernel_size=7, padding=3, stride=2),
@@ -152,16 +153,25 @@ class ResNet(nn.Sequential):
152
153
  nn.init.constant_(m.weight, 1)
153
154
  nn.init.constant_(m.bias, 0)
154
155
 
156
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
157
+ """Load pretrained parameters onto the model
158
+
159
+ Args:
160
+ path_or_url: the path or URL to the model parameters (checkpoint)
161
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
162
+ """
163
+ load_pretrained_params(self, path_or_url, **kwargs)
164
+
155
165
 
156
166
  def _resnet(
157
167
  arch: str,
158
168
  pretrained: bool,
159
- num_blocks: List[int],
160
- output_channels: List[int],
161
- stage_stride: List[int],
162
- stage_conv: List[bool],
163
- stage_pooling: List[Optional[Tuple[int, int]]],
164
- ignore_keys: Optional[List[str]] = None,
169
+ num_blocks: list[int],
170
+ output_channels: list[int],
171
+ stage_stride: list[int],
172
+ stage_conv: list[bool],
173
+ stage_pooling: list[tuple[int, int] | None],
174
+ ignore_keys: list[str] | None = None,
165
175
  **kwargs: Any,
166
176
  ) -> ResNet:
167
177
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
@@ -179,7 +189,7 @@ def _resnet(
179
189
  # The number of classes is not the same as the number of classes in the pretrained model =>
180
190
  # remove the last layer weights
181
191
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
182
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
192
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
183
193
 
184
194
  return model
185
195
 
@@ -188,7 +198,7 @@ def _tv_resnet(
188
198
  arch: str,
189
199
  pretrained: bool,
190
200
  arch_fn,
191
- ignore_keys: Optional[List[str]] = None,
201
+ ignore_keys: list[str] | None = None,
192
202
  **kwargs: Any,
193
203
  ) -> TVResNet:
194
204
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
@@ -201,12 +211,25 @@ def _tv_resnet(
201
211
 
202
212
  # Build the model
203
213
  model = arch_fn(**kwargs, weights=None)
204
- # Load pretrained parameters
214
+
215
+ # monkeypatch the model to allow for loading pretrained parameters
216
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
217
+ """Load pretrained parameters onto the model
218
+
219
+ Args:
220
+ path_or_url: the path or URL to the model parameters (checkpoint)
221
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
222
+ """
223
+ load_pretrained_params(self, path_or_url, **kwargs)
224
+
225
+ # Bind method to the instance
226
+ model.from_pretrained = types.MethodType(from_pretrained, model)
227
+
205
228
  if pretrained:
206
229
  # The number of classes is not the same as the number of classes in the pretrained model =>
207
230
  # remove the last layer weights
208
231
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
209
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
232
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
210
233
 
211
234
  model.cfg = _cfg
212
235
 
@@ -224,12 +247,10 @@ def resnet18(pretrained: bool = False, **kwargs: Any) -> TVResNet:
224
247
  >>> out = model(input_tensor)
225
248
 
226
249
  Args:
227
- ----
228
250
  pretrained: boolean, True if model is pretrained
229
251
  **kwargs: keyword arguments of the ResNet architecture
230
252
 
231
253
  Returns:
232
- -------
233
254
  A resnet18 model
234
255
  """
235
256
  return _tv_resnet(
@@ -253,12 +274,10 @@ def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
253
274
  >>> out = model(input_tensor)
254
275
 
255
276
  Args:
256
- ----
257
277
  pretrained: boolean, True if model is pretrained
258
278
  **kwargs: keyword arguments of the ResNet architecture
259
279
 
260
280
  Returns:
261
- -------
262
281
  A resnet31 model
263
282
  """
264
283
  return _resnet(
@@ -287,12 +306,10 @@ def resnet34(pretrained: bool = False, **kwargs: Any) -> TVResNet:
287
306
  >>> out = model(input_tensor)
288
307
 
289
308
  Args:
290
- ----
291
309
  pretrained: boolean, True if model is pretrained
292
310
  **kwargs: keyword arguments of the ResNet architecture
293
311
 
294
312
  Returns:
295
- -------
296
313
  A resnet34 model
297
314
  """
298
315
  return _tv_resnet(
@@ -315,12 +332,10 @@ def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet:
315
332
  >>> out = model(input_tensor)
316
333
 
317
334
  Args:
318
- ----
319
335
  pretrained: boolean, True if model is pretrained
320
336
  **kwargs: keyword arguments of the ResNet architecture
321
337
 
322
338
  Returns:
323
- -------
324
339
  A resnet34_wide model
325
340
  """
326
341
  return _resnet(
@@ -349,12 +364,10 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> TVResNet:
349
364
  >>> out = model(input_tensor)
350
365
 
351
366
  Args:
352
- ----
353
367
  pretrained: boolean, True if model is pretrained
354
368
  **kwargs: keyword arguments of the ResNet architecture
355
369
 
356
370
  Returns:
357
- -------
358
371
  A resnet50 model
359
372
  """
360
373
  return _tv_resnet(
@@ -1,10 +1,12 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ import types
7
+ from collections.abc import Callable
6
8
  from copy import deepcopy
7
- from typing import Any, Callable, Dict, List, Optional, Tuple
9
+ from typing import Any
8
10
 
9
11
  import tensorflow as tf
10
12
  from tensorflow.keras import layers
@@ -18,7 +20,7 @@ from ...utils import _build_model, conv_sequence, load_pretrained_params
18
20
  __all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"]
19
21
 
20
22
 
21
- default_cfgs: Dict[str, Dict[str, Any]] = {
23
+ default_cfgs: dict[str, dict[str, Any]] = {
22
24
  "resnet18": {
23
25
  "mean": (0.694, 0.695, 0.693),
24
26
  "std": (0.299, 0.296, 0.301),
@@ -61,7 +63,6 @@ class ResnetBlock(layers.Layer):
61
63
  """Implements a resnet31 block with shortcut
62
64
 
63
65
  Args:
64
- ----
65
66
  conv_shortcut: Use of shortcut
66
67
  output_channels: number of channels to use in Conv2D
67
68
  kernel_size: size of square kernels
@@ -92,7 +93,7 @@ class ResnetBlock(layers.Layer):
92
93
  output_channels: int,
93
94
  kernel_size: int,
94
95
  strides: int = 1,
95
- ) -> List[layers.Layer]:
96
+ ) -> list[layers.Layer]:
96
97
  return [
97
98
  *conv_sequence(output_channels, "relu", bn=True, strides=strides, kernel_size=kernel_size),
98
99
  *conv_sequence(output_channels, None, bn=True, kernel_size=kernel_size),
@@ -108,8 +109,8 @@ class ResnetBlock(layers.Layer):
108
109
 
109
110
  def resnet_stage(
110
111
  num_blocks: int, out_channels: int, shortcut: bool = False, downsample: bool = False
111
- ) -> List[layers.Layer]:
112
- _layers: List[layers.Layer] = [ResnetBlock(out_channels, conv_shortcut=shortcut, strides=2 if downsample else 1)]
112
+ ) -> list[layers.Layer]:
113
+ _layers: list[layers.Layer] = [ResnetBlock(out_channels, conv_shortcut=shortcut, strides=2 if downsample else 1)]
113
114
 
114
115
  for _ in range(1, num_blocks):
115
116
  _layers.append(ResnetBlock(out_channels, conv_shortcut=False))
@@ -121,7 +122,6 @@ class ResNet(Sequential):
121
122
  """Implements a ResNet architecture
122
123
 
123
124
  Args:
124
- ----
125
125
  num_blocks: number of resnet block in each stage
126
126
  output_channels: number of channels in each stage
127
127
  stage_downsample: whether the first residual block of a stage should downsample
@@ -137,18 +137,18 @@ class ResNet(Sequential):
137
137
 
138
138
  def __init__(
139
139
  self,
140
- num_blocks: List[int],
141
- output_channels: List[int],
142
- stage_downsample: List[bool],
143
- stage_conv: List[bool],
144
- stage_pooling: List[Optional[Tuple[int, int]]],
140
+ num_blocks: list[int],
141
+ output_channels: list[int],
142
+ stage_downsample: list[bool],
143
+ stage_conv: list[bool],
144
+ stage_pooling: list[tuple[int, int] | None],
145
145
  origin_stem: bool = True,
146
146
  stem_channels: int = 64,
147
- attn_module: Optional[Callable[[int], layers.Layer]] = None,
147
+ attn_module: Callable[[int], layers.Layer] | None = None,
148
148
  include_top: bool = True,
149
149
  num_classes: int = 1000,
150
- cfg: Optional[Dict[str, Any]] = None,
151
- input_shape: Optional[Tuple[int, int, int]] = None,
150
+ cfg: dict[str, Any] | None = None,
151
+ input_shape: tuple[int, int, int] | None = None,
152
152
  ) -> None:
153
153
  inplanes = stem_channels
154
154
  if origin_stem:
@@ -184,15 +184,24 @@ class ResNet(Sequential):
184
184
  super().__init__(_layers)
185
185
  self.cfg = cfg
186
186
 
187
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
188
+ """Load pretrained parameters onto the model
189
+
190
+ Args:
191
+ path_or_url: the path or URL to the model parameters (checkpoint)
192
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
193
+ """
194
+ load_pretrained_params(self, path_or_url, **kwargs)
195
+
187
196
 
188
197
  def _resnet(
189
198
  arch: str,
190
199
  pretrained: bool,
191
- num_blocks: List[int],
192
- output_channels: List[int],
193
- stage_downsample: List[bool],
194
- stage_conv: List[bool],
195
- stage_pooling: List[Optional[Tuple[int, int]]],
200
+ num_blocks: list[int],
201
+ output_channels: list[int],
202
+ stage_downsample: list[bool],
203
+ stage_conv: list[bool],
204
+ stage_pooling: list[tuple[int, int] | None],
196
205
  origin_stem: bool = True,
197
206
  **kwargs: Any,
198
207
  ) -> ResNet:
@@ -216,8 +225,8 @@ def _resnet(
216
225
  if pretrained:
217
226
  # The number of classes is not the same as the number of classes in the pretrained model =>
218
227
  # skip the mismatching layers for fine tuning
219
- load_pretrained_params(
220
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
228
+ model.from_pretrained(
229
+ default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
221
230
  )
222
231
 
223
232
  return model
@@ -234,12 +243,10 @@ def resnet18(pretrained: bool = False, **kwargs: Any) -> ResNet:
234
243
  >>> out = model(input_tensor)
235
244
 
236
245
  Args:
237
- ----
238
246
  pretrained: boolean, True if model is pretrained
239
247
  **kwargs: keyword arguments of the ResNet architecture
240
248
 
241
249
  Returns:
242
- -------
243
250
  A classification model
244
251
  """
245
252
  return _resnet(
@@ -267,12 +274,10 @@ def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
267
274
  >>> out = model(input_tensor)
268
275
 
269
276
  Args:
270
- ----
271
277
  pretrained: boolean, True if model is pretrained
272
278
  **kwargs: keyword arguments of the ResNet architecture
273
279
 
274
280
  Returns:
275
- -------
276
281
  A classification model
277
282
  """
278
283
  return _resnet(
@@ -300,12 +305,10 @@ def resnet34(pretrained: bool = False, **kwargs: Any) -> ResNet:
300
305
  >>> out = model(input_tensor)
301
306
 
302
307
  Args:
303
- ----
304
308
  pretrained: boolean, True if model is pretrained
305
309
  **kwargs: keyword arguments of the ResNet architecture
306
310
 
307
311
  Returns:
308
- -------
309
312
  A classification model
310
313
  """
311
314
  return _resnet(
@@ -332,12 +335,10 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
332
335
  >>> out = model(input_tensor)
333
336
 
334
337
  Args:
335
- ----
336
338
  pretrained: boolean, True if model is pretrained
337
339
  **kwargs: keyword arguments of the ResNet architecture
338
340
 
339
341
  Returns:
340
- -------
341
342
  A classification model
342
343
  """
343
344
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs["resnet50"]["classes"]))
@@ -359,6 +360,18 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
359
360
  classifier_activation=None,
360
361
  )
361
362
 
363
+ # monkeypatch the model to allow for loading pretrained parameters
364
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
365
+ """Load pretrained parameters onto the model
366
+
367
+ Args:
368
+ path_or_url: the path or URL to the model parameters (checkpoint)
369
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
370
+ """
371
+ load_pretrained_params(self, path_or_url, **kwargs)
372
+
373
+ model.from_pretrained = types.MethodType(from_pretrained, model) # Bind method to the instance
374
+
362
375
  model.cfg = _cfg
363
376
  _build_model(model)
364
377
 
@@ -366,8 +379,7 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
366
379
  if pretrained:
367
380
  # The number of classes is not the same as the number of classes in the pretrained model =>
368
381
  # skip the mismatching layers for fine tuning
369
- load_pretrained_params(
370
- model,
382
+ model.from_pretrained(
371
383
  default_cfgs["resnet50"]["url"],
372
384
  skip_mismatch=kwargs["num_classes"] != len(default_cfgs["resnet50"]["classes"]),
373
385
  )
@@ -386,12 +398,10 @@ def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet:
386
398
  >>> out = model(input_tensor)
387
399
 
388
400
  Args:
389
- ----
390
401
  pretrained: boolean, True if model is pretrained
391
402
  **kwargs: keyword arguments of the ResNet architecture
392
403
 
393
404
  Returns:
394
- -------
395
405
  A classification model
396
406
  """
397
407
  return _resnet(
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
4
6
  from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
7
  from copy import deepcopy
8
- from typing import Any, Dict, List, Optional, Tuple
8
+ from typing import Any
9
9
 
10
10
  from torch import nn
11
11
 
@@ -16,7 +16,7 @@ from ...utils import conv_sequence_pt, load_pretrained_params
16
16
 
17
17
  __all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
18
18
 
19
- default_cfgs: Dict[str, Dict[str, Any]] = {
19
+ default_cfgs: dict[str, dict[str, Any]] = {
20
20
  "textnet_tiny": {
21
21
  "mean": (0.694, 0.695, 0.693),
22
22
  "std": (0.299, 0.296, 0.301),
@@ -47,22 +47,21 @@ class TextNet(nn.Sequential):
47
47
  Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
48
48
 
49
49
  Args:
50
- ----
51
- stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage.
50
+ stages (list[dict[str, list[int]]]): list of dictionaries containing the parameters of each stage.
52
51
  include_top (bool, optional): Whether to include the classifier head. Defaults to True.
53
52
  num_classes (int, optional): Number of output classes. Defaults to 1000.
54
- cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
53
+ cfg (dict[str, Any], optional): Additional configuration. Defaults to None.
55
54
  """
56
55
 
57
56
  def __init__(
58
57
  self,
59
- stages: List[Dict[str, List[int]]],
60
- input_shape: Tuple[int, int, int] = (3, 32, 32),
58
+ stages: list[dict[str, list[int]]],
59
+ input_shape: tuple[int, int, int] = (3, 32, 32),
61
60
  num_classes: int = 1000,
62
61
  include_top: bool = True,
63
- cfg: Optional[Dict[str, Any]] = None,
62
+ cfg: dict[str, Any] | None = None,
64
63
  ) -> None:
65
- _layers: List[nn.Module] = [
64
+ _layers: list[nn.Module] = [
66
65
  *conv_sequence_pt(
67
66
  in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2, padding=(1, 1)
68
67
  ),
@@ -94,11 +93,20 @@ class TextNet(nn.Sequential):
94
93
  nn.init.constant_(m.weight, 1)
95
94
  nn.init.constant_(m.bias, 0)
96
95
 
96
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
97
+ """Load pretrained parameters onto the model
98
+
99
+ Args:
100
+ path_or_url: the path or URL to the model parameters (checkpoint)
101
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
102
+ """
103
+ load_pretrained_params(self, path_or_url, **kwargs)
104
+
97
105
 
98
106
  def _textnet(
99
107
  arch: str,
100
108
  pretrained: bool,
101
- ignore_keys: Optional[List[str]] = None,
109
+ ignore_keys: list[str] | None = None,
102
110
  **kwargs: Any,
103
111
  ) -> TextNet:
104
112
  kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
@@ -116,7 +124,7 @@ def _textnet(
116
124
  # The number of classes is not the same as the number of classes in the pretrained model =>
117
125
  # remove the last layer weights
118
126
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
119
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
127
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
120
128
 
121
129
  model.cfg = _cfg
122
130
 
@@ -135,12 +143,10 @@ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
135
143
  >>> out = model(input_tensor)
136
144
 
137
145
  Args:
138
- ----
139
146
  pretrained: boolean, True if model is pretrained
140
147
  **kwargs: keyword arguments of the TextNet architecture
141
148
 
142
149
  Returns:
143
- -------
144
150
  A textnet tiny model
145
151
  """
146
152
  return _textnet(
@@ -184,12 +190,10 @@ def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
184
190
  >>> out = model(input_tensor)
185
191
 
186
192
  Args:
187
- ----
188
193
  pretrained: boolean, True if model is pretrained
189
194
  **kwargs: keyword arguments of the TextNet architecture
190
195
 
191
196
  Returns:
192
- -------
193
197
  A TextNet small model
194
198
  """
195
199
  return _textnet(
@@ -233,12 +237,10 @@ def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
233
237
  >>> out = model(input_tensor)
234
238
 
235
239
  Args:
236
- ----
237
240
  pretrained: boolean, True if model is pretrained
238
241
  **kwargs: keyword arguments of the TextNet architecture
239
242
 
240
243
  Returns:
241
- -------
242
244
  A TextNet base model
243
245
  """
244
246
  return _textnet(
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
7
  from copy import deepcopy
8
- from typing import Any, Dict, List, Optional, Tuple
8
+ from typing import Any
9
9
 
10
10
  from tensorflow.keras import Sequential, layers
11
11
 
@@ -16,7 +16,7 @@ from ...utils import _build_model, conv_sequence, load_pretrained_params
16
16
 
17
17
  __all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
18
18
 
19
- default_cfgs: Dict[str, Dict[str, Any]] = {
19
+ default_cfgs: dict[str, dict[str, Any]] = {
20
20
  "textnet_tiny": {
21
21
  "mean": (0.694, 0.695, 0.693),
22
22
  "std": (0.299, 0.296, 0.301),
@@ -47,20 +47,19 @@ class TextNet(Sequential):
47
47
  Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
48
48
 
49
49
  Args:
50
- ----
51
- stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage.
50
+ stages (list[dict[str, list[int]]]): list of dictionaries containing the parameters of each stage.
52
51
  include_top (bool, optional): Whether to include the classifier head. Defaults to True.
53
52
  num_classes (int, optional): Number of output classes. Defaults to 1000.
54
- cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
53
+ cfg (dict[str, Any], optional): Additional configuration. Defaults to None.
55
54
  """
56
55
 
57
56
  def __init__(
58
57
  self,
59
- stages: List[Dict[str, List[int]]],
60
- input_shape: Tuple[int, int, int] = (32, 32, 3),
58
+ stages: list[dict[str, list[int]]],
59
+ input_shape: tuple[int, int, int] = (32, 32, 3),
61
60
  num_classes: int = 1000,
62
61
  include_top: bool = True,
63
- cfg: Optional[Dict[str, Any]] = None,
62
+ cfg: dict[str, Any] | None = None,
64
63
  ) -> None:
65
64
  _layers = [
66
65
  *conv_sequence(
@@ -93,6 +92,15 @@ class TextNet(Sequential):
93
92
  super().__init__(_layers)
94
93
  self.cfg = cfg
95
94
 
95
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
96
+ """Load pretrained parameters onto the model
97
+
98
+ Args:
99
+ path_or_url: the path or URL to the model parameters (checkpoint)
100
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
101
+ """
102
+ load_pretrained_params(self, path_or_url, **kwargs)
103
+
96
104
 
97
105
  def _textnet(
98
106
  arch: str,
@@ -117,8 +125,8 @@ def _textnet(
117
125
  if pretrained:
118
126
  # The number of classes is not the same as the number of classes in the pretrained model =>
119
127
  # skip the mismatching layers for fine tuning
120
- load_pretrained_params(
121
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
128
+ model.from_pretrained(
129
+ default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
122
130
  )
123
131
 
124
132
  return model
@@ -136,12 +144,10 @@ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
136
144
  >>> out = model(input_tensor)
137
145
 
138
146
  Args:
139
- ----
140
147
  pretrained: boolean, True if model is pretrained
141
148
  **kwargs: keyword arguments of the TextNet architecture
142
149
 
143
150
  Returns:
144
- -------
145
151
  A textnet tiny model
146
152
  """
147
153
  return _textnet(
@@ -184,12 +190,10 @@ def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
184
190
  >>> out = model(input_tensor)
185
191
 
186
192
  Args:
187
- ----
188
193
  pretrained: boolean, True if model is pretrained
189
194
  **kwargs: keyword arguments of the TextNet architecture
190
195
 
191
196
  Returns:
192
- -------
193
197
  A TextNet small model
194
198
  """
195
199
  return _textnet(
@@ -232,12 +236,10 @@ def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
232
236
  >>> out = model(input_tensor)
233
237
 
234
238
  Args:
235
- ----
236
239
  pretrained: boolean, True if model is pretrained
237
240
  **kwargs: keyword arguments of the TextNet architecture
238
241
 
239
242
  Returns:
240
- -------
241
243
  A TextNet base model
242
244
  """
243
245
  return _textnet(
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
3
+ if is_torch_available():
6
4
  from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import *