python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +10 -8
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +9 -8
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +5 -6
- doctr/datasets/ic13.py +6 -6
- doctr/datasets/iiit5k.py +10 -6
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +15 -7
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +4 -5
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +7 -6
- doctr/datasets/svt.py +6 -7
- doctr/datasets/synthtext.py +19 -7
- doctr/datasets/utils.py +41 -35
- doctr/datasets/vocabs.py +1107 -49
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +11 -7
- doctr/io/elements.py +96 -82
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +15 -23
- doctr/models/builder.py +30 -48
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +11 -15
- doctr/models/classification/magc_resnet/tensorflow.py +11 -14
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +20 -18
- doctr/models/classification/mobilenet/tensorflow.py +19 -23
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +7 -9
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +47 -34
- doctr/models/classification/resnet/tensorflow.py +45 -35
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +20 -18
- doctr/models/classification/textnet/tensorflow.py +19 -17
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +21 -8
- doctr/models/classification/vgg/tensorflow.py +20 -14
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +18 -15
- doctr/models/classification/vit/tensorflow.py +15 -12
- doctr/models/classification/zoo.py +23 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +10 -21
- doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
- doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +8 -17
- doctr/models/detection/fast/pytorch.py +37 -35
- doctr/models/detection/fast/tensorflow.py +24 -28
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +8 -18
- doctr/models/detection/linknet/pytorch.py +34 -28
- doctr/models/detection/linknet/tensorflow.py +24 -25
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +6 -10
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +19 -20
- doctr/models/kie_predictor/tensorflow.py +14 -15
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +55 -10
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +13 -14
- doctr/models/predictor/tensorflow.py +9 -10
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +30 -29
- doctr/models/recognition/crnn/tensorflow.py +21 -24
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +32 -25
- doctr/models/recognition/master/tensorflow.py +22 -25
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +47 -29
- doctr/models/recognition/parseq/tensorflow.py +29 -27
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +111 -52
- doctr/models/recognition/predictor/pytorch.py +9 -9
- doctr/models/recognition/predictor/tensorflow.py +8 -9
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +30 -22
- doctr/models/recognition/sar/tensorflow.py +22 -24
- doctr/models/recognition/utils.py +57 -53
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +28 -21
- doctr/models/recognition/vitstr/tensorflow.py +22 -23
- doctr/models/recognition/zoo.py +27 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +41 -34
- doctr/models/utils/tensorflow.py +31 -23
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +9 -13
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +17 -48
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
- python_doctr-0.12.0.dist-info/RECORD +180 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
|
@@ -1,11 +1,13 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
|
+
import types
|
|
8
|
+
from collections.abc import Callable
|
|
7
9
|
from copy import deepcopy
|
|
8
|
-
from typing import Any
|
|
10
|
+
from typing import Any
|
|
9
11
|
|
|
10
12
|
from torch import nn
|
|
11
13
|
from torchvision.models.resnet import BasicBlock
|
|
@@ -21,7 +23,7 @@ from ...utils import conv_sequence_pt, load_pretrained_params
|
|
|
21
23
|
__all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide", "resnet_stage"]
|
|
22
24
|
|
|
23
25
|
|
|
24
|
-
default_cfgs:
|
|
26
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
25
27
|
"resnet18": {
|
|
26
28
|
"mean": (0.694, 0.695, 0.693),
|
|
27
29
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -60,9 +62,9 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
60
62
|
}
|
|
61
63
|
|
|
62
64
|
|
|
63
|
-
def resnet_stage(in_channels: int, out_channels: int, num_blocks: int, stride: int) ->
|
|
65
|
+
def resnet_stage(in_channels: int, out_channels: int, num_blocks: int, stride: int) -> list[nn.Module]:
|
|
64
66
|
"""Build a ResNet stage"""
|
|
65
|
-
_layers:
|
|
67
|
+
_layers: list[nn.Module] = []
|
|
66
68
|
|
|
67
69
|
in_chan = in_channels
|
|
68
70
|
s = stride
|
|
@@ -84,7 +86,6 @@ class ResNet(nn.Sequential):
|
|
|
84
86
|
Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
|
|
85
87
|
|
|
86
88
|
Args:
|
|
87
|
-
----
|
|
88
89
|
num_blocks: number of resnet block in each stage
|
|
89
90
|
output_channels: number of channels in each stage
|
|
90
91
|
stage_conv: whether to add a conv_sequence after each stage
|
|
@@ -98,19 +99,19 @@ class ResNet(nn.Sequential):
|
|
|
98
99
|
|
|
99
100
|
def __init__(
|
|
100
101
|
self,
|
|
101
|
-
num_blocks:
|
|
102
|
-
output_channels:
|
|
103
|
-
stage_stride:
|
|
104
|
-
stage_conv:
|
|
105
|
-
stage_pooling:
|
|
102
|
+
num_blocks: list[int],
|
|
103
|
+
output_channels: list[int],
|
|
104
|
+
stage_stride: list[int],
|
|
105
|
+
stage_conv: list[bool],
|
|
106
|
+
stage_pooling: list[tuple[int, int] | None],
|
|
106
107
|
origin_stem: bool = True,
|
|
107
108
|
stem_channels: int = 64,
|
|
108
|
-
attn_module:
|
|
109
|
+
attn_module: Callable[[int], nn.Module] | None = None,
|
|
109
110
|
include_top: bool = True,
|
|
110
111
|
num_classes: int = 1000,
|
|
111
|
-
cfg:
|
|
112
|
+
cfg: dict[str, Any] | None = None,
|
|
112
113
|
) -> None:
|
|
113
|
-
_layers:
|
|
114
|
+
_layers: list[nn.Module]
|
|
114
115
|
if origin_stem:
|
|
115
116
|
_layers = [
|
|
116
117
|
*conv_sequence_pt(3, stem_channels, True, True, kernel_size=7, padding=3, stride=2),
|
|
@@ -152,16 +153,25 @@ class ResNet(nn.Sequential):
|
|
|
152
153
|
nn.init.constant_(m.weight, 1)
|
|
153
154
|
nn.init.constant_(m.bias, 0)
|
|
154
155
|
|
|
156
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
157
|
+
"""Load pretrained parameters onto the model
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
161
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
162
|
+
"""
|
|
163
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
164
|
+
|
|
155
165
|
|
|
156
166
|
def _resnet(
|
|
157
167
|
arch: str,
|
|
158
168
|
pretrained: bool,
|
|
159
|
-
num_blocks:
|
|
160
|
-
output_channels:
|
|
161
|
-
stage_stride:
|
|
162
|
-
stage_conv:
|
|
163
|
-
stage_pooling:
|
|
164
|
-
ignore_keys:
|
|
169
|
+
num_blocks: list[int],
|
|
170
|
+
output_channels: list[int],
|
|
171
|
+
stage_stride: list[int],
|
|
172
|
+
stage_conv: list[bool],
|
|
173
|
+
stage_pooling: list[tuple[int, int] | None],
|
|
174
|
+
ignore_keys: list[str] | None = None,
|
|
165
175
|
**kwargs: Any,
|
|
166
176
|
) -> ResNet:
|
|
167
177
|
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
@@ -179,7 +189,7 @@ def _resnet(
|
|
|
179
189
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
180
190
|
# remove the last layer weights
|
|
181
191
|
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
182
|
-
|
|
192
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
183
193
|
|
|
184
194
|
return model
|
|
185
195
|
|
|
@@ -188,7 +198,7 @@ def _tv_resnet(
|
|
|
188
198
|
arch: str,
|
|
189
199
|
pretrained: bool,
|
|
190
200
|
arch_fn,
|
|
191
|
-
ignore_keys:
|
|
201
|
+
ignore_keys: list[str] | None = None,
|
|
192
202
|
**kwargs: Any,
|
|
193
203
|
) -> TVResNet:
|
|
194
204
|
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
@@ -201,12 +211,25 @@ def _tv_resnet(
|
|
|
201
211
|
|
|
202
212
|
# Build the model
|
|
203
213
|
model = arch_fn(**kwargs, weights=None)
|
|
204
|
-
|
|
214
|
+
|
|
215
|
+
# monkeypatch the model to allow for loading pretrained parameters
|
|
216
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
|
|
217
|
+
"""Load pretrained parameters onto the model
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
221
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
222
|
+
"""
|
|
223
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
224
|
+
|
|
225
|
+
# Bind method to the instance
|
|
226
|
+
model.from_pretrained = types.MethodType(from_pretrained, model)
|
|
227
|
+
|
|
205
228
|
if pretrained:
|
|
206
229
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
207
230
|
# remove the last layer weights
|
|
208
231
|
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
209
|
-
|
|
232
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
210
233
|
|
|
211
234
|
model.cfg = _cfg
|
|
212
235
|
|
|
@@ -224,12 +247,10 @@ def resnet18(pretrained: bool = False, **kwargs: Any) -> TVResNet:
|
|
|
224
247
|
>>> out = model(input_tensor)
|
|
225
248
|
|
|
226
249
|
Args:
|
|
227
|
-
----
|
|
228
250
|
pretrained: boolean, True if model is pretrained
|
|
229
251
|
**kwargs: keyword arguments of the ResNet architecture
|
|
230
252
|
|
|
231
253
|
Returns:
|
|
232
|
-
-------
|
|
233
254
|
A resnet18 model
|
|
234
255
|
"""
|
|
235
256
|
return _tv_resnet(
|
|
@@ -253,12 +274,10 @@ def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
253
274
|
>>> out = model(input_tensor)
|
|
254
275
|
|
|
255
276
|
Args:
|
|
256
|
-
----
|
|
257
277
|
pretrained: boolean, True if model is pretrained
|
|
258
278
|
**kwargs: keyword arguments of the ResNet architecture
|
|
259
279
|
|
|
260
280
|
Returns:
|
|
261
|
-
-------
|
|
262
281
|
A resnet31 model
|
|
263
282
|
"""
|
|
264
283
|
return _resnet(
|
|
@@ -287,12 +306,10 @@ def resnet34(pretrained: bool = False, **kwargs: Any) -> TVResNet:
|
|
|
287
306
|
>>> out = model(input_tensor)
|
|
288
307
|
|
|
289
308
|
Args:
|
|
290
|
-
----
|
|
291
309
|
pretrained: boolean, True if model is pretrained
|
|
292
310
|
**kwargs: keyword arguments of the ResNet architecture
|
|
293
311
|
|
|
294
312
|
Returns:
|
|
295
|
-
-------
|
|
296
313
|
A resnet34 model
|
|
297
314
|
"""
|
|
298
315
|
return _tv_resnet(
|
|
@@ -315,12 +332,10 @@ def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
315
332
|
>>> out = model(input_tensor)
|
|
316
333
|
|
|
317
334
|
Args:
|
|
318
|
-
----
|
|
319
335
|
pretrained: boolean, True if model is pretrained
|
|
320
336
|
**kwargs: keyword arguments of the ResNet architecture
|
|
321
337
|
|
|
322
338
|
Returns:
|
|
323
|
-
-------
|
|
324
339
|
A resnet34_wide model
|
|
325
340
|
"""
|
|
326
341
|
return _resnet(
|
|
@@ -349,12 +364,10 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> TVResNet:
|
|
|
349
364
|
>>> out = model(input_tensor)
|
|
350
365
|
|
|
351
366
|
Args:
|
|
352
|
-
----
|
|
353
367
|
pretrained: boolean, True if model is pretrained
|
|
354
368
|
**kwargs: keyword arguments of the ResNet architecture
|
|
355
369
|
|
|
356
370
|
Returns:
|
|
357
|
-
-------
|
|
358
371
|
A resnet50 model
|
|
359
372
|
"""
|
|
360
373
|
return _tv_resnet(
|
|
@@ -1,10 +1,12 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
+
import types
|
|
7
|
+
from collections.abc import Callable
|
|
6
8
|
from copy import deepcopy
|
|
7
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
8
10
|
|
|
9
11
|
import tensorflow as tf
|
|
10
12
|
from tensorflow.keras import layers
|
|
@@ -18,7 +20,7 @@ from ...utils import _build_model, conv_sequence, load_pretrained_params
|
|
|
18
20
|
__all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"]
|
|
19
21
|
|
|
20
22
|
|
|
21
|
-
default_cfgs:
|
|
23
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
22
24
|
"resnet18": {
|
|
23
25
|
"mean": (0.694, 0.695, 0.693),
|
|
24
26
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -61,7 +63,6 @@ class ResnetBlock(layers.Layer):
|
|
|
61
63
|
"""Implements a resnet31 block with shortcut
|
|
62
64
|
|
|
63
65
|
Args:
|
|
64
|
-
----
|
|
65
66
|
conv_shortcut: Use of shortcut
|
|
66
67
|
output_channels: number of channels to use in Conv2D
|
|
67
68
|
kernel_size: size of square kernels
|
|
@@ -92,7 +93,7 @@ class ResnetBlock(layers.Layer):
|
|
|
92
93
|
output_channels: int,
|
|
93
94
|
kernel_size: int,
|
|
94
95
|
strides: int = 1,
|
|
95
|
-
) ->
|
|
96
|
+
) -> list[layers.Layer]:
|
|
96
97
|
return [
|
|
97
98
|
*conv_sequence(output_channels, "relu", bn=True, strides=strides, kernel_size=kernel_size),
|
|
98
99
|
*conv_sequence(output_channels, None, bn=True, kernel_size=kernel_size),
|
|
@@ -108,8 +109,8 @@ class ResnetBlock(layers.Layer):
|
|
|
108
109
|
|
|
109
110
|
def resnet_stage(
|
|
110
111
|
num_blocks: int, out_channels: int, shortcut: bool = False, downsample: bool = False
|
|
111
|
-
) ->
|
|
112
|
-
_layers:
|
|
112
|
+
) -> list[layers.Layer]:
|
|
113
|
+
_layers: list[layers.Layer] = [ResnetBlock(out_channels, conv_shortcut=shortcut, strides=2 if downsample else 1)]
|
|
113
114
|
|
|
114
115
|
for _ in range(1, num_blocks):
|
|
115
116
|
_layers.append(ResnetBlock(out_channels, conv_shortcut=False))
|
|
@@ -121,7 +122,6 @@ class ResNet(Sequential):
|
|
|
121
122
|
"""Implements a ResNet architecture
|
|
122
123
|
|
|
123
124
|
Args:
|
|
124
|
-
----
|
|
125
125
|
num_blocks: number of resnet block in each stage
|
|
126
126
|
output_channels: number of channels in each stage
|
|
127
127
|
stage_downsample: whether the first residual block of a stage should downsample
|
|
@@ -137,18 +137,18 @@ class ResNet(Sequential):
|
|
|
137
137
|
|
|
138
138
|
def __init__(
|
|
139
139
|
self,
|
|
140
|
-
num_blocks:
|
|
141
|
-
output_channels:
|
|
142
|
-
stage_downsample:
|
|
143
|
-
stage_conv:
|
|
144
|
-
stage_pooling:
|
|
140
|
+
num_blocks: list[int],
|
|
141
|
+
output_channels: list[int],
|
|
142
|
+
stage_downsample: list[bool],
|
|
143
|
+
stage_conv: list[bool],
|
|
144
|
+
stage_pooling: list[tuple[int, int] | None],
|
|
145
145
|
origin_stem: bool = True,
|
|
146
146
|
stem_channels: int = 64,
|
|
147
|
-
attn_module:
|
|
147
|
+
attn_module: Callable[[int], layers.Layer] | None = None,
|
|
148
148
|
include_top: bool = True,
|
|
149
149
|
num_classes: int = 1000,
|
|
150
|
-
cfg:
|
|
151
|
-
input_shape:
|
|
150
|
+
cfg: dict[str, Any] | None = None,
|
|
151
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
152
152
|
) -> None:
|
|
153
153
|
inplanes = stem_channels
|
|
154
154
|
if origin_stem:
|
|
@@ -184,15 +184,24 @@ class ResNet(Sequential):
|
|
|
184
184
|
super().__init__(_layers)
|
|
185
185
|
self.cfg = cfg
|
|
186
186
|
|
|
187
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
188
|
+
"""Load pretrained parameters onto the model
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
192
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
193
|
+
"""
|
|
194
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
195
|
+
|
|
187
196
|
|
|
188
197
|
def _resnet(
|
|
189
198
|
arch: str,
|
|
190
199
|
pretrained: bool,
|
|
191
|
-
num_blocks:
|
|
192
|
-
output_channels:
|
|
193
|
-
stage_downsample:
|
|
194
|
-
stage_conv:
|
|
195
|
-
stage_pooling:
|
|
200
|
+
num_blocks: list[int],
|
|
201
|
+
output_channels: list[int],
|
|
202
|
+
stage_downsample: list[bool],
|
|
203
|
+
stage_conv: list[bool],
|
|
204
|
+
stage_pooling: list[tuple[int, int] | None],
|
|
196
205
|
origin_stem: bool = True,
|
|
197
206
|
**kwargs: Any,
|
|
198
207
|
) -> ResNet:
|
|
@@ -216,8 +225,8 @@ def _resnet(
|
|
|
216
225
|
if pretrained:
|
|
217
226
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
218
227
|
# skip the mismatching layers for fine tuning
|
|
219
|
-
|
|
220
|
-
|
|
228
|
+
model.from_pretrained(
|
|
229
|
+
default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
221
230
|
)
|
|
222
231
|
|
|
223
232
|
return model
|
|
@@ -234,12 +243,10 @@ def resnet18(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
234
243
|
>>> out = model(input_tensor)
|
|
235
244
|
|
|
236
245
|
Args:
|
|
237
|
-
----
|
|
238
246
|
pretrained: boolean, True if model is pretrained
|
|
239
247
|
**kwargs: keyword arguments of the ResNet architecture
|
|
240
248
|
|
|
241
249
|
Returns:
|
|
242
|
-
-------
|
|
243
250
|
A classification model
|
|
244
251
|
"""
|
|
245
252
|
return _resnet(
|
|
@@ -267,12 +274,10 @@ def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
267
274
|
>>> out = model(input_tensor)
|
|
268
275
|
|
|
269
276
|
Args:
|
|
270
|
-
----
|
|
271
277
|
pretrained: boolean, True if model is pretrained
|
|
272
278
|
**kwargs: keyword arguments of the ResNet architecture
|
|
273
279
|
|
|
274
280
|
Returns:
|
|
275
|
-
-------
|
|
276
281
|
A classification model
|
|
277
282
|
"""
|
|
278
283
|
return _resnet(
|
|
@@ -300,12 +305,10 @@ def resnet34(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
300
305
|
>>> out = model(input_tensor)
|
|
301
306
|
|
|
302
307
|
Args:
|
|
303
|
-
----
|
|
304
308
|
pretrained: boolean, True if model is pretrained
|
|
305
309
|
**kwargs: keyword arguments of the ResNet architecture
|
|
306
310
|
|
|
307
311
|
Returns:
|
|
308
|
-
-------
|
|
309
312
|
A classification model
|
|
310
313
|
"""
|
|
311
314
|
return _resnet(
|
|
@@ -332,12 +335,10 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
332
335
|
>>> out = model(input_tensor)
|
|
333
336
|
|
|
334
337
|
Args:
|
|
335
|
-
----
|
|
336
338
|
pretrained: boolean, True if model is pretrained
|
|
337
339
|
**kwargs: keyword arguments of the ResNet architecture
|
|
338
340
|
|
|
339
341
|
Returns:
|
|
340
|
-
-------
|
|
341
342
|
A classification model
|
|
342
343
|
"""
|
|
343
344
|
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs["resnet50"]["classes"]))
|
|
@@ -359,6 +360,18 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
359
360
|
classifier_activation=None,
|
|
360
361
|
)
|
|
361
362
|
|
|
363
|
+
# monkeypatch the model to allow for loading pretrained parameters
|
|
364
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
|
|
365
|
+
"""Load pretrained parameters onto the model
|
|
366
|
+
|
|
367
|
+
Args:
|
|
368
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
369
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
370
|
+
"""
|
|
371
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
372
|
+
|
|
373
|
+
model.from_pretrained = types.MethodType(from_pretrained, model) # Bind method to the instance
|
|
374
|
+
|
|
362
375
|
model.cfg = _cfg
|
|
363
376
|
_build_model(model)
|
|
364
377
|
|
|
@@ -366,8 +379,7 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
366
379
|
if pretrained:
|
|
367
380
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
368
381
|
# skip the mismatching layers for fine tuning
|
|
369
|
-
|
|
370
|
-
model,
|
|
382
|
+
model.from_pretrained(
|
|
371
383
|
default_cfgs["resnet50"]["url"],
|
|
372
384
|
skip_mismatch=kwargs["num_classes"] != len(default_cfgs["resnet50"]["classes"]),
|
|
373
385
|
)
|
|
@@ -386,12 +398,10 @@ def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
|
386
398
|
>>> out = model(input_tensor)
|
|
387
399
|
|
|
388
400
|
Args:
|
|
389
|
-
----
|
|
390
401
|
pretrained: boolean, True if model is pretrained
|
|
391
402
|
**kwargs: keyword arguments of the ResNet architecture
|
|
392
403
|
|
|
393
404
|
Returns:
|
|
394
|
-
-------
|
|
395
405
|
A classification model
|
|
396
406
|
"""
|
|
397
407
|
return _resnet(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
4
6
|
from .tensorflow import *
|
|
5
|
-
elif is_torch_available():
|
|
6
|
-
from .pytorch import * # type: ignore[assignment]
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
from copy import deepcopy
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
from torch import nn
|
|
11
11
|
|
|
@@ -16,7 +16,7 @@ from ...utils import conv_sequence_pt, load_pretrained_params
|
|
|
16
16
|
|
|
17
17
|
__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
|
|
18
18
|
|
|
19
|
-
default_cfgs:
|
|
19
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
20
20
|
"textnet_tiny": {
|
|
21
21
|
"mean": (0.694, 0.695, 0.693),
|
|
22
22
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -47,22 +47,21 @@ class TextNet(nn.Sequential):
|
|
|
47
47
|
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
|
|
48
48
|
|
|
49
49
|
Args:
|
|
50
|
-
|
|
51
|
-
stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage.
|
|
50
|
+
stages (list[dict[str, list[int]]]): list of dictionaries containing the parameters of each stage.
|
|
52
51
|
include_top (bool, optional): Whether to include the classifier head. Defaults to True.
|
|
53
52
|
num_classes (int, optional): Number of output classes. Defaults to 1000.
|
|
54
|
-
cfg (
|
|
53
|
+
cfg (dict[str, Any], optional): Additional configuration. Defaults to None.
|
|
55
54
|
"""
|
|
56
55
|
|
|
57
56
|
def __init__(
|
|
58
57
|
self,
|
|
59
|
-
stages:
|
|
60
|
-
input_shape:
|
|
58
|
+
stages: list[dict[str, list[int]]],
|
|
59
|
+
input_shape: tuple[int, int, int] = (3, 32, 32),
|
|
61
60
|
num_classes: int = 1000,
|
|
62
61
|
include_top: bool = True,
|
|
63
|
-
cfg:
|
|
62
|
+
cfg: dict[str, Any] | None = None,
|
|
64
63
|
) -> None:
|
|
65
|
-
_layers:
|
|
64
|
+
_layers: list[nn.Module] = [
|
|
66
65
|
*conv_sequence_pt(
|
|
67
66
|
in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2, padding=(1, 1)
|
|
68
67
|
),
|
|
@@ -94,11 +93,20 @@ class TextNet(nn.Sequential):
|
|
|
94
93
|
nn.init.constant_(m.weight, 1)
|
|
95
94
|
nn.init.constant_(m.bias, 0)
|
|
96
95
|
|
|
96
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
97
|
+
"""Load pretrained parameters onto the model
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
101
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
102
|
+
"""
|
|
103
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
104
|
+
|
|
97
105
|
|
|
98
106
|
def _textnet(
|
|
99
107
|
arch: str,
|
|
100
108
|
pretrained: bool,
|
|
101
|
-
ignore_keys:
|
|
109
|
+
ignore_keys: list[str] | None = None,
|
|
102
110
|
**kwargs: Any,
|
|
103
111
|
) -> TextNet:
|
|
104
112
|
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
@@ -116,7 +124,7 @@ def _textnet(
|
|
|
116
124
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
117
125
|
# remove the last layer weights
|
|
118
126
|
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
|
|
119
|
-
|
|
127
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
120
128
|
|
|
121
129
|
model.cfg = _cfg
|
|
122
130
|
|
|
@@ -135,12 +143,10 @@ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
|
135
143
|
>>> out = model(input_tensor)
|
|
136
144
|
|
|
137
145
|
Args:
|
|
138
|
-
----
|
|
139
146
|
pretrained: boolean, True if model is pretrained
|
|
140
147
|
**kwargs: keyword arguments of the TextNet architecture
|
|
141
148
|
|
|
142
149
|
Returns:
|
|
143
|
-
-------
|
|
144
150
|
A textnet tiny model
|
|
145
151
|
"""
|
|
146
152
|
return _textnet(
|
|
@@ -184,12 +190,10 @@ def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
|
184
190
|
>>> out = model(input_tensor)
|
|
185
191
|
|
|
186
192
|
Args:
|
|
187
|
-
----
|
|
188
193
|
pretrained: boolean, True if model is pretrained
|
|
189
194
|
**kwargs: keyword arguments of the TextNet architecture
|
|
190
195
|
|
|
191
196
|
Returns:
|
|
192
|
-
-------
|
|
193
197
|
A TextNet small model
|
|
194
198
|
"""
|
|
195
199
|
return _textnet(
|
|
@@ -233,12 +237,10 @@ def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
|
233
237
|
>>> out = model(input_tensor)
|
|
234
238
|
|
|
235
239
|
Args:
|
|
236
|
-
----
|
|
237
240
|
pretrained: boolean, True if model is pretrained
|
|
238
241
|
**kwargs: keyword arguments of the TextNet architecture
|
|
239
242
|
|
|
240
243
|
Returns:
|
|
241
|
-
-------
|
|
242
244
|
A TextNet base model
|
|
243
245
|
"""
|
|
244
246
|
return _textnet(
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
from copy import deepcopy
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
from tensorflow.keras import Sequential, layers
|
|
11
11
|
|
|
@@ -16,7 +16,7 @@ from ...utils import _build_model, conv_sequence, load_pretrained_params
|
|
|
16
16
|
|
|
17
17
|
__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
|
|
18
18
|
|
|
19
|
-
default_cfgs:
|
|
19
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
20
20
|
"textnet_tiny": {
|
|
21
21
|
"mean": (0.694, 0.695, 0.693),
|
|
22
22
|
"std": (0.299, 0.296, 0.301),
|
|
@@ -47,20 +47,19 @@ class TextNet(Sequential):
|
|
|
47
47
|
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
|
|
48
48
|
|
|
49
49
|
Args:
|
|
50
|
-
|
|
51
|
-
stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage.
|
|
50
|
+
stages (list[dict[str, list[int]]]): list of dictionaries containing the parameters of each stage.
|
|
52
51
|
include_top (bool, optional): Whether to include the classifier head. Defaults to True.
|
|
53
52
|
num_classes (int, optional): Number of output classes. Defaults to 1000.
|
|
54
|
-
cfg (
|
|
53
|
+
cfg (dict[str, Any], optional): Additional configuration. Defaults to None.
|
|
55
54
|
"""
|
|
56
55
|
|
|
57
56
|
def __init__(
|
|
58
57
|
self,
|
|
59
|
-
stages:
|
|
60
|
-
input_shape:
|
|
58
|
+
stages: list[dict[str, list[int]]],
|
|
59
|
+
input_shape: tuple[int, int, int] = (32, 32, 3),
|
|
61
60
|
num_classes: int = 1000,
|
|
62
61
|
include_top: bool = True,
|
|
63
|
-
cfg:
|
|
62
|
+
cfg: dict[str, Any] | None = None,
|
|
64
63
|
) -> None:
|
|
65
64
|
_layers = [
|
|
66
65
|
*conv_sequence(
|
|
@@ -93,6 +92,15 @@ class TextNet(Sequential):
|
|
|
93
92
|
super().__init__(_layers)
|
|
94
93
|
self.cfg = cfg
|
|
95
94
|
|
|
95
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
96
|
+
"""Load pretrained parameters onto the model
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
100
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
101
|
+
"""
|
|
102
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
103
|
+
|
|
96
104
|
|
|
97
105
|
def _textnet(
|
|
98
106
|
arch: str,
|
|
@@ -117,8 +125,8 @@ def _textnet(
|
|
|
117
125
|
if pretrained:
|
|
118
126
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
119
127
|
# skip the mismatching layers for fine tuning
|
|
120
|
-
|
|
121
|
-
|
|
128
|
+
model.from_pretrained(
|
|
129
|
+
default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
122
130
|
)
|
|
123
131
|
|
|
124
132
|
return model
|
|
@@ -136,12 +144,10 @@ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
|
136
144
|
>>> out = model(input_tensor)
|
|
137
145
|
|
|
138
146
|
Args:
|
|
139
|
-
----
|
|
140
147
|
pretrained: boolean, True if model is pretrained
|
|
141
148
|
**kwargs: keyword arguments of the TextNet architecture
|
|
142
149
|
|
|
143
150
|
Returns:
|
|
144
|
-
-------
|
|
145
151
|
A textnet tiny model
|
|
146
152
|
"""
|
|
147
153
|
return _textnet(
|
|
@@ -184,12 +190,10 @@ def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
|
184
190
|
>>> out = model(input_tensor)
|
|
185
191
|
|
|
186
192
|
Args:
|
|
187
|
-
----
|
|
188
193
|
pretrained: boolean, True if model is pretrained
|
|
189
194
|
**kwargs: keyword arguments of the TextNet architecture
|
|
190
195
|
|
|
191
196
|
Returns:
|
|
192
|
-
-------
|
|
193
197
|
A TextNet small model
|
|
194
198
|
"""
|
|
195
199
|
return _textnet(
|
|
@@ -232,12 +236,10 @@ def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
|
|
|
232
236
|
>>> out = model(input_tensor)
|
|
233
237
|
|
|
234
238
|
Args:
|
|
235
|
-
----
|
|
236
239
|
pretrained: boolean, True if model is pretrained
|
|
237
240
|
**kwargs: keyword arguments of the TextNet architecture
|
|
238
241
|
|
|
239
242
|
Returns:
|
|
240
|
-
-------
|
|
241
243
|
A TextNet base model
|
|
242
244
|
"""
|
|
243
245
|
return _textnet(
|