python-doctr 0.12.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +0 -5
  3. doctr/datasets/datasets/__init__.py +1 -6
  4. doctr/datasets/datasets/pytorch.py +2 -2
  5. doctr/datasets/generator/__init__.py +1 -6
  6. doctr/datasets/vocabs.py +0 -2
  7. doctr/file_utils.py +2 -101
  8. doctr/io/image/__init__.py +1 -7
  9. doctr/io/image/pytorch.py +1 -1
  10. doctr/models/_utils.py +3 -3
  11. doctr/models/classification/magc_resnet/__init__.py +1 -6
  12. doctr/models/classification/magc_resnet/pytorch.py +2 -2
  13. doctr/models/classification/mobilenet/__init__.py +1 -6
  14. doctr/models/classification/predictor/__init__.py +1 -6
  15. doctr/models/classification/predictor/pytorch.py +1 -1
  16. doctr/models/classification/resnet/__init__.py +1 -6
  17. doctr/models/classification/textnet/__init__.py +1 -6
  18. doctr/models/classification/textnet/pytorch.py +1 -1
  19. doctr/models/classification/vgg/__init__.py +1 -6
  20. doctr/models/classification/vip/__init__.py +1 -4
  21. doctr/models/classification/vip/layers/__init__.py +1 -4
  22. doctr/models/classification/vip/layers/pytorch.py +1 -1
  23. doctr/models/classification/vit/__init__.py +1 -6
  24. doctr/models/classification/vit/pytorch.py +2 -2
  25. doctr/models/classification/zoo.py +6 -11
  26. doctr/models/detection/_utils/__init__.py +1 -6
  27. doctr/models/detection/core.py +1 -1
  28. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  29. doctr/models/detection/differentiable_binarization/base.py +4 -12
  30. doctr/models/detection/differentiable_binarization/pytorch.py +3 -3
  31. doctr/models/detection/fast/__init__.py +1 -6
  32. doctr/models/detection/fast/base.py +4 -14
  33. doctr/models/detection/fast/pytorch.py +4 -4
  34. doctr/models/detection/linknet/__init__.py +1 -6
  35. doctr/models/detection/linknet/base.py +3 -12
  36. doctr/models/detection/linknet/pytorch.py +2 -2
  37. doctr/models/detection/predictor/__init__.py +1 -6
  38. doctr/models/detection/predictor/pytorch.py +1 -1
  39. doctr/models/detection/zoo.py +15 -32
  40. doctr/models/factory/hub.py +8 -21
  41. doctr/models/kie_predictor/__init__.py +1 -6
  42. doctr/models/kie_predictor/pytorch.py +2 -6
  43. doctr/models/modules/layers/__init__.py +1 -6
  44. doctr/models/modules/layers/pytorch.py +3 -3
  45. doctr/models/modules/transformer/__init__.py +1 -6
  46. doctr/models/modules/transformer/pytorch.py +2 -2
  47. doctr/models/modules/vision_transformer/__init__.py +1 -6
  48. doctr/models/predictor/__init__.py +1 -6
  49. doctr/models/predictor/base.py +3 -8
  50. doctr/models/predictor/pytorch.py +2 -5
  51. doctr/models/preprocessor/__init__.py +1 -6
  52. doctr/models/preprocessor/pytorch.py +27 -32
  53. doctr/models/recognition/crnn/__init__.py +1 -6
  54. doctr/models/recognition/crnn/pytorch.py +6 -6
  55. doctr/models/recognition/master/__init__.py +1 -6
  56. doctr/models/recognition/master/pytorch.py +5 -5
  57. doctr/models/recognition/parseq/__init__.py +1 -6
  58. doctr/models/recognition/parseq/pytorch.py +5 -5
  59. doctr/models/recognition/predictor/__init__.py +1 -6
  60. doctr/models/recognition/predictor/_utils.py +7 -16
  61. doctr/models/recognition/predictor/pytorch.py +1 -2
  62. doctr/models/recognition/sar/__init__.py +1 -6
  63. doctr/models/recognition/sar/pytorch.py +3 -3
  64. doctr/models/recognition/viptr/__init__.py +1 -4
  65. doctr/models/recognition/viptr/pytorch.py +3 -3
  66. doctr/models/recognition/vitstr/__init__.py +1 -6
  67. doctr/models/recognition/vitstr/pytorch.py +3 -3
  68. doctr/models/recognition/zoo.py +13 -13
  69. doctr/models/utils/__init__.py +1 -6
  70. doctr/models/utils/pytorch.py +1 -1
  71. doctr/transforms/functional/__init__.py +1 -6
  72. doctr/transforms/functional/pytorch.py +4 -4
  73. doctr/transforms/modules/__init__.py +1 -7
  74. doctr/transforms/modules/base.py +26 -92
  75. doctr/transforms/modules/pytorch.py +28 -26
  76. doctr/utils/geometry.py +6 -10
  77. doctr/utils/visualization.py +1 -1
  78. doctr/version.py +1 -1
  79. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +18 -75
  80. python_doctr-1.0.0.dist-info/RECORD +149 -0
  81. doctr/datasets/datasets/tensorflow.py +0 -59
  82. doctr/datasets/generator/tensorflow.py +0 -58
  83. doctr/datasets/loader.py +0 -94
  84. doctr/io/image/tensorflow.py +0 -101
  85. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  86. doctr/models/classification/mobilenet/tensorflow.py +0 -442
  87. doctr/models/classification/predictor/tensorflow.py +0 -60
  88. doctr/models/classification/resnet/tensorflow.py +0 -418
  89. doctr/models/classification/textnet/tensorflow.py +0 -275
  90. doctr/models/classification/vgg/tensorflow.py +0 -125
  91. doctr/models/classification/vit/tensorflow.py +0 -201
  92. doctr/models/detection/_utils/tensorflow.py +0 -34
  93. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
  94. doctr/models/detection/fast/tensorflow.py +0 -427
  95. doctr/models/detection/linknet/tensorflow.py +0 -377
  96. doctr/models/detection/predictor/tensorflow.py +0 -70
  97. doctr/models/kie_predictor/tensorflow.py +0 -187
  98. doctr/models/modules/layers/tensorflow.py +0 -171
  99. doctr/models/modules/transformer/tensorflow.py +0 -235
  100. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  101. doctr/models/predictor/tensorflow.py +0 -155
  102. doctr/models/preprocessor/tensorflow.py +0 -122
  103. doctr/models/recognition/crnn/tensorflow.py +0 -317
  104. doctr/models/recognition/master/tensorflow.py +0 -320
  105. doctr/models/recognition/parseq/tensorflow.py +0 -516
  106. doctr/models/recognition/predictor/tensorflow.py +0 -79
  107. doctr/models/recognition/sar/tensorflow.py +0 -423
  108. doctr/models/recognition/vitstr/tensorflow.py +0 -285
  109. doctr/models/utils/tensorflow.py +0 -189
  110. doctr/transforms/functional/tensorflow.py +0 -254
  111. doctr/transforms/modules/tensorflow.py +0 -562
  112. python_doctr-0.12.0.dist-info/RECORD +0 -180
  113. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +0 -0
  114. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/licenses/LICENSE +0 -0
  115. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  116. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -1,418 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- import types
7
- from collections.abc import Callable
8
- from copy import deepcopy
9
- from typing import Any
10
-
11
- import tensorflow as tf
12
- from tensorflow.keras import layers
13
- from tensorflow.keras.applications import ResNet50
14
- from tensorflow.keras.models import Sequential
15
-
16
- from doctr.datasets import VOCABS
17
-
18
- from ...utils import _build_model, conv_sequence, load_pretrained_params
19
-
20
- __all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"]
21
-
22
-
23
- default_cfgs: dict[str, dict[str, Any]] = {
24
- "resnet18": {
25
- "mean": (0.694, 0.695, 0.693),
26
- "std": (0.299, 0.296, 0.301),
27
- "input_shape": (32, 32, 3),
28
- "classes": list(VOCABS["french"]),
29
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0",
30
- },
31
- "resnet31": {
32
- "mean": (0.694, 0.695, 0.693),
33
- "std": (0.299, 0.296, 0.301),
34
- "input_shape": (32, 32, 3),
35
- "classes": list(VOCABS["french"]),
36
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0",
37
- },
38
- "resnet34": {
39
- "mean": (0.694, 0.695, 0.693),
40
- "std": (0.299, 0.296, 0.301),
41
- "input_shape": (32, 32, 3),
42
- "classes": list(VOCABS["french"]),
43
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0",
44
- },
45
- "resnet50": {
46
- "mean": (0.694, 0.695, 0.693),
47
- "std": (0.299, 0.296, 0.301),
48
- "input_shape": (32, 32, 3),
49
- "classes": list(VOCABS["french"]),
50
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0",
51
- },
52
- "resnet34_wide": {
53
- "mean": (0.694, 0.695, 0.693),
54
- "std": (0.299, 0.296, 0.301),
55
- "input_shape": (32, 32, 3),
56
- "classes": list(VOCABS["french"]),
57
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0",
58
- },
59
- }
60
-
61
-
62
- class ResnetBlock(layers.Layer):
63
- """Implements a resnet31 block with shortcut
64
-
65
- Args:
66
- conv_shortcut: Use of shortcut
67
- output_channels: number of channels to use in Conv2D
68
- kernel_size: size of square kernels
69
- strides: strides to use in the first convolution of the block
70
- """
71
-
72
- def __init__(self, output_channels: int, conv_shortcut: bool, strides: int = 1, **kwargs) -> None:
73
- super().__init__(**kwargs)
74
- if conv_shortcut:
75
- self.shortcut = Sequential([
76
- layers.Conv2D(
77
- filters=output_channels,
78
- strides=strides,
79
- padding="same",
80
- kernel_size=1,
81
- use_bias=False,
82
- kernel_initializer="he_normal",
83
- ),
84
- layers.BatchNormalization(),
85
- ])
86
- else:
87
- self.shortcut = layers.Lambda(lambda x: x)
88
- self.conv_block = Sequential(self.conv_resnetblock(output_channels, 3, strides))
89
- self.act = layers.Activation("relu")
90
-
91
- @staticmethod
92
- def conv_resnetblock(
93
- output_channels: int,
94
- kernel_size: int,
95
- strides: int = 1,
96
- ) -> list[layers.Layer]:
97
- return [
98
- *conv_sequence(output_channels, "relu", bn=True, strides=strides, kernel_size=kernel_size),
99
- *conv_sequence(output_channels, None, bn=True, kernel_size=kernel_size),
100
- ]
101
-
102
- def call(self, inputs: tf.Tensor) -> tf.Tensor:
103
- clone = self.shortcut(inputs)
104
- conv_out = self.conv_block(inputs)
105
- out = self.act(clone + conv_out)
106
-
107
- return out
108
-
109
-
110
- def resnet_stage(
111
- num_blocks: int, out_channels: int, shortcut: bool = False, downsample: bool = False
112
- ) -> list[layers.Layer]:
113
- _layers: list[layers.Layer] = [ResnetBlock(out_channels, conv_shortcut=shortcut, strides=2 if downsample else 1)]
114
-
115
- for _ in range(1, num_blocks):
116
- _layers.append(ResnetBlock(out_channels, conv_shortcut=False))
117
-
118
- return _layers
119
-
120
-
121
- class ResNet(Sequential):
122
- """Implements a ResNet architecture
123
-
124
- Args:
125
- num_blocks: number of resnet block in each stage
126
- output_channels: number of channels in each stage
127
- stage_downsample: whether the first residual block of a stage should downsample
128
- stage_conv: whether to add a conv_sequence after each stage
129
- stage_pooling: pooling to add after each stage (if None, no pooling)
130
- origin_stem: whether to use the orginal ResNet stem or ResNet-31's
131
- stem_channels: number of output channels of the stem convolutions
132
- attn_module: attention module to use in each stage
133
- include_top: whether the classifier head should be instantiated
134
- num_classes: number of output classes
135
- input_shape: shape of inputs
136
- """
137
-
138
- def __init__(
139
- self,
140
- num_blocks: list[int],
141
- output_channels: list[int],
142
- stage_downsample: list[bool],
143
- stage_conv: list[bool],
144
- stage_pooling: list[tuple[int, int] | None],
145
- origin_stem: bool = True,
146
- stem_channels: int = 64,
147
- attn_module: Callable[[int], layers.Layer] | None = None,
148
- include_top: bool = True,
149
- num_classes: int = 1000,
150
- cfg: dict[str, Any] | None = None,
151
- input_shape: tuple[int, int, int] | None = None,
152
- ) -> None:
153
- inplanes = stem_channels
154
- if origin_stem:
155
- _layers = [
156
- *conv_sequence(inplanes, "relu", True, kernel_size=7, strides=2, input_shape=input_shape),
157
- layers.MaxPool2D(pool_size=(3, 3), strides=2, padding="same"),
158
- ]
159
- else:
160
- _layers = [
161
- *conv_sequence(inplanes // 2, "relu", True, kernel_size=3, input_shape=input_shape),
162
- *conv_sequence(inplanes, "relu", True, kernel_size=3),
163
- layers.MaxPool2D(pool_size=2, strides=2, padding="valid"),
164
- ]
165
-
166
- for n_blocks, out_chan, down, conv, pool in zip(
167
- num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling
168
- ):
169
- _layers.extend(resnet_stage(n_blocks, out_chan, out_chan != inplanes, down))
170
- if attn_module is not None:
171
- _layers.append(attn_module(out_chan))
172
- if conv:
173
- _layers.extend(conv_sequence(out_chan, activation="relu", bn=True, kernel_size=3))
174
- if pool:
175
- _layers.append(layers.MaxPool2D(pool_size=pool, strides=pool, padding="valid"))
176
- inplanes = out_chan
177
-
178
- if include_top:
179
- _layers.extend([
180
- layers.GlobalAveragePooling2D(),
181
- layers.Dense(num_classes),
182
- ])
183
-
184
- super().__init__(_layers)
185
- self.cfg = cfg
186
-
187
- def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
188
- """Load pretrained parameters onto the model
189
-
190
- Args:
191
- path_or_url: the path or URL to the model parameters (checkpoint)
192
- **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
193
- """
194
- load_pretrained_params(self, path_or_url, **kwargs)
195
-
196
-
197
- def _resnet(
198
- arch: str,
199
- pretrained: bool,
200
- num_blocks: list[int],
201
- output_channels: list[int],
202
- stage_downsample: list[bool],
203
- stage_conv: list[bool],
204
- stage_pooling: list[tuple[int, int] | None],
205
- origin_stem: bool = True,
206
- **kwargs: Any,
207
- ) -> ResNet:
208
- kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
209
- kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
210
- kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
211
-
212
- _cfg = deepcopy(default_cfgs[arch])
213
- _cfg["num_classes"] = kwargs["num_classes"]
214
- _cfg["classes"] = kwargs["classes"]
215
- _cfg["input_shape"] = kwargs["input_shape"]
216
- kwargs.pop("classes")
217
-
218
- # Build the model
219
- model = ResNet(
220
- num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, cfg=_cfg, **kwargs
221
- )
222
- _build_model(model)
223
-
224
- # Load pretrained parameters
225
- if pretrained:
226
- # The number of classes is not the same as the number of classes in the pretrained model =>
227
- # skip the mismatching layers for fine tuning
228
- model.from_pretrained(
229
- default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
230
- )
231
-
232
- return model
233
-
234
-
235
- def resnet18(pretrained: bool = False, **kwargs: Any) -> ResNet:
236
- """Resnet-18 architecture as described in `"Deep Residual Learning for Image Recognition",
237
- <https://arxiv.org/pdf/1512.03385.pdf>`_.
238
-
239
- >>> import tensorflow as tf
240
- >>> from doctr.models import resnet18
241
- >>> model = resnet18(pretrained=False)
242
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
243
- >>> out = model(input_tensor)
244
-
245
- Args:
246
- pretrained: boolean, True if model is pretrained
247
- **kwargs: keyword arguments of the ResNet architecture
248
-
249
- Returns:
250
- A classification model
251
- """
252
- return _resnet(
253
- "resnet18",
254
- pretrained,
255
- [2, 2, 2, 2],
256
- [64, 128, 256, 512],
257
- [False, True, True, True],
258
- [False] * 4,
259
- [None] * 4,
260
- True,
261
- **kwargs,
262
- )
263
-
264
-
265
- def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
266
- """Resnet31 architecture with rectangular pooling windows as described in
267
- `"Show, Attend and Read:A Simple and Strong Baseline for Irregular Text Recognition",
268
- <https://arxiv.org/pdf/1811.00751.pdf>`_. Downsizing: (H, W) --> (H/8, W/4)
269
-
270
- >>> import tensorflow as tf
271
- >>> from doctr.models import resnet31
272
- >>> model = resnet31(pretrained=False)
273
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
274
- >>> out = model(input_tensor)
275
-
276
- Args:
277
- pretrained: boolean, True if model is pretrained
278
- **kwargs: keyword arguments of the ResNet architecture
279
-
280
- Returns:
281
- A classification model
282
- """
283
- return _resnet(
284
- "resnet31",
285
- pretrained,
286
- [1, 2, 5, 3],
287
- [256, 256, 512, 512],
288
- [False] * 4,
289
- [True] * 4,
290
- [(2, 2), (2, 1), None, None],
291
- False,
292
- stem_channels=128,
293
- **kwargs,
294
- )
295
-
296
-
297
- def resnet34(pretrained: bool = False, **kwargs: Any) -> ResNet:
298
- """Resnet-34 architecture as described in `"Deep Residual Learning for Image Recognition",
299
- <https://arxiv.org/pdf/1512.03385.pdf>`_.
300
-
301
- >>> import tensorflow as tf
302
- >>> from doctr.models import resnet34
303
- >>> model = resnet34(pretrained=False)
304
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
305
- >>> out = model(input_tensor)
306
-
307
- Args:
308
- pretrained: boolean, True if model is pretrained
309
- **kwargs: keyword arguments of the ResNet architecture
310
-
311
- Returns:
312
- A classification model
313
- """
314
- return _resnet(
315
- "resnet34",
316
- pretrained,
317
- [3, 4, 6, 3],
318
- [64, 128, 256, 512],
319
- [False, True, True, True],
320
- [False] * 4,
321
- [None] * 4,
322
- True,
323
- **kwargs,
324
- )
325
-
326
-
327
- def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
328
- """Resnet-50 architecture as described in `"Deep Residual Learning for Image Recognition",
329
- <https://arxiv.org/pdf/1512.03385.pdf>`_.
330
-
331
- >>> import tensorflow as tf
332
- >>> from doctr.models import resnet50
333
- >>> model = resnet50(pretrained=False)
334
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
335
- >>> out = model(input_tensor)
336
-
337
- Args:
338
- pretrained: boolean, True if model is pretrained
339
- **kwargs: keyword arguments of the ResNet architecture
340
-
341
- Returns:
342
- A classification model
343
- """
344
- kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs["resnet50"]["classes"]))
345
- kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs["resnet50"]["input_shape"])
346
- kwargs["classes"] = kwargs.get("classes", default_cfgs["resnet50"]["classes"])
347
-
348
- _cfg = deepcopy(default_cfgs["resnet50"])
349
- _cfg["num_classes"] = kwargs["num_classes"]
350
- _cfg["classes"] = kwargs["classes"]
351
- _cfg["input_shape"] = kwargs["input_shape"]
352
- kwargs.pop("classes")
353
-
354
- model = ResNet50(
355
- weights=None,
356
- include_top=True,
357
- pooling=True,
358
- input_shape=kwargs["input_shape"],
359
- classes=kwargs["num_classes"],
360
- classifier_activation=None,
361
- )
362
-
363
- # monkeypatch the model to allow for loading pretrained parameters
364
- def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: # noqa: D417
365
- """Load pretrained parameters onto the model
366
-
367
- Args:
368
- path_or_url: the path or URL to the model parameters (checkpoint)
369
- **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
370
- """
371
- load_pretrained_params(self, path_or_url, **kwargs)
372
-
373
- model.from_pretrained = types.MethodType(from_pretrained, model) # Bind method to the instance
374
-
375
- model.cfg = _cfg
376
- _build_model(model)
377
-
378
- # Load pretrained parameters
379
- if pretrained:
380
- # The number of classes is not the same as the number of classes in the pretrained model =>
381
- # skip the mismatching layers for fine tuning
382
- model.from_pretrained(
383
- default_cfgs["resnet50"]["url"],
384
- skip_mismatch=kwargs["num_classes"] != len(default_cfgs["resnet50"]["classes"]),
385
- )
386
-
387
- return model
388
-
389
-
390
- def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet:
391
- """Resnet-34 architecture as described in `"Deep Residual Learning for Image Recognition",
392
- <https://arxiv.org/pdf/1512.03385.pdf>`_ with twice as many output channels for each stage.
393
-
394
- >>> import tensorflow as tf
395
- >>> from doctr.models import resnet34_wide
396
- >>> model = resnet34_wide(pretrained=False)
397
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
398
- >>> out = model(input_tensor)
399
-
400
- Args:
401
- pretrained: boolean, True if model is pretrained
402
- **kwargs: keyword arguments of the ResNet architecture
403
-
404
- Returns:
405
- A classification model
406
- """
407
- return _resnet(
408
- "resnet34_wide",
409
- pretrained,
410
- [3, 4, 6, 3],
411
- [128, 256, 512, 1024],
412
- [False, True, True, True],
413
- [False] * 4,
414
- [None] * 4,
415
- True,
416
- stem_channels=128,
417
- **kwargs,
418
- )
@@ -1,275 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
-
7
- from copy import deepcopy
8
- from typing import Any
9
-
10
- from tensorflow.keras import Sequential, layers
11
-
12
- from doctr.datasets import VOCABS
13
-
14
- from ...modules.layers.tensorflow import FASTConvLayer
15
- from ...utils import _build_model, conv_sequence, load_pretrained_params
16
-
17
- __all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
18
-
19
- default_cfgs: dict[str, dict[str, Any]] = {
20
- "textnet_tiny": {
21
- "mean": (0.694, 0.695, 0.693),
22
- "std": (0.299, 0.296, 0.301),
23
- "input_shape": (32, 32, 3),
24
- "classes": list(VOCABS["french"]),
25
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0",
26
- },
27
- "textnet_small": {
28
- "mean": (0.694, 0.695, 0.693),
29
- "std": (0.299, 0.296, 0.301),
30
- "input_shape": (32, 32, 3),
31
- "classes": list(VOCABS["french"]),
32
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0",
33
- },
34
- "textnet_base": {
35
- "mean": (0.694, 0.695, 0.693),
36
- "std": (0.299, 0.296, 0.301),
37
- "input_shape": (32, 32, 3),
38
- "classes": list(VOCABS["french"]),
39
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0",
40
- },
41
- }
42
-
43
-
44
- class TextNet(Sequential):
45
- """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
46
- Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
47
- Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
48
-
49
- Args:
50
- stages (list[dict[str, list[int]]]): list of dictionaries containing the parameters of each stage.
51
- include_top (bool, optional): Whether to include the classifier head. Defaults to True.
52
- num_classes (int, optional): Number of output classes. Defaults to 1000.
53
- cfg (dict[str, Any], optional): Additional configuration. Defaults to None.
54
- """
55
-
56
- def __init__(
57
- self,
58
- stages: list[dict[str, list[int]]],
59
- input_shape: tuple[int, int, int] = (32, 32, 3),
60
- num_classes: int = 1000,
61
- include_top: bool = True,
62
- cfg: dict[str, Any] | None = None,
63
- ) -> None:
64
- _layers = [
65
- *conv_sequence(
66
- out_channels=64, activation="relu", bn=True, kernel_size=3, strides=2, input_shape=input_shape
67
- ),
68
- *[
69
- Sequential(
70
- [
71
- FASTConvLayer(**params) # type: ignore[arg-type]
72
- for params in [{key: stage[key][i] for key in stage} for i in range(len(stage["in_channels"]))]
73
- ],
74
- name=f"stage_{i}",
75
- )
76
- for i, stage in enumerate(stages)
77
- ],
78
- ]
79
-
80
- if include_top:
81
- _layers.append(
82
- Sequential(
83
- [
84
- layers.AveragePooling2D(1),
85
- layers.Flatten(),
86
- layers.Dense(num_classes),
87
- ],
88
- name="classifier",
89
- )
90
- )
91
-
92
- super().__init__(_layers)
93
- self.cfg = cfg
94
-
95
- def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
96
- """Load pretrained parameters onto the model
97
-
98
- Args:
99
- path_or_url: the path or URL to the model parameters (checkpoint)
100
- **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
101
- """
102
- load_pretrained_params(self, path_or_url, **kwargs)
103
-
104
-
105
- def _textnet(
106
- arch: str,
107
- pretrained: bool,
108
- **kwargs: Any,
109
- ) -> TextNet:
110
- kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
111
- kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
112
- kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
113
-
114
- _cfg = deepcopy(default_cfgs[arch])
115
- _cfg["num_classes"] = kwargs["num_classes"]
116
- _cfg["input_shape"] = kwargs["input_shape"]
117
- _cfg["classes"] = kwargs["classes"]
118
- kwargs.pop("classes")
119
-
120
- # Build the model
121
- model = TextNet(cfg=_cfg, **kwargs)
122
- _build_model(model)
123
-
124
- # Load pretrained parameters
125
- if pretrained:
126
- # The number of classes is not the same as the number of classes in the pretrained model =>
127
- # skip the mismatching layers for fine tuning
128
- model.from_pretrained(
129
- default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
130
- )
131
-
132
- return model
133
-
134
-
135
- def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
136
- """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
137
- Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
138
- Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
139
-
140
- >>> import tensorflow as tf
141
- >>> from doctr.models import textnet_tiny
142
- >>> model = textnet_tiny(pretrained=False)
143
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
144
- >>> out = model(input_tensor)
145
-
146
- Args:
147
- pretrained: boolean, True if model is pretrained
148
- **kwargs: keyword arguments of the TextNet architecture
149
-
150
- Returns:
151
- A textnet tiny model
152
- """
153
- return _textnet(
154
- "textnet_tiny",
155
- pretrained,
156
- stages=[
157
- {"in_channels": [64] * 3, "out_channels": [64] * 3, "kernel_size": [(3, 3)] * 3, "stride": [1, 2, 1]},
158
- {
159
- "in_channels": [64, 128, 128, 128],
160
- "out_channels": [128] * 4,
161
- "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1)],
162
- "stride": [2, 1, 1, 1],
163
- },
164
- {
165
- "in_channels": [128, 256, 256, 256],
166
- "out_channels": [256] * 4,
167
- "kernel_size": [(3, 3), (3, 3), (3, 1), (1, 3)],
168
- "stride": [2, 1, 1, 1],
169
- },
170
- {
171
- "in_channels": [256, 512, 512, 512],
172
- "out_channels": [512] * 4,
173
- "kernel_size": [(3, 3), (3, 1), (1, 3), (3, 3)],
174
- "stride": [2, 1, 1, 1],
175
- },
176
- ],
177
- **kwargs,
178
- )
179
-
180
-
181
- def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
182
- """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
183
- Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
184
- Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
185
-
186
- >>> import tensorflow as tf
187
- >>> from doctr.models import textnet_small
188
- >>> model = textnet_small(pretrained=False)
189
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
190
- >>> out = model(input_tensor)
191
-
192
- Args:
193
- pretrained: boolean, True if model is pretrained
194
- **kwargs: keyword arguments of the TextNet architecture
195
-
196
- Returns:
197
- A TextNet small model
198
- """
199
- return _textnet(
200
- "textnet_small",
201
- pretrained,
202
- stages=[
203
- {"in_channels": [64] * 2, "out_channels": [64] * 2, "kernel_size": [(3, 3)] * 2, "stride": [1, 2]},
204
- {
205
- "in_channels": [64, 128, 128, 128, 128, 128, 128, 128],
206
- "out_channels": [128] * 8,
207
- "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1), (1, 3), (3, 3)],
208
- "stride": [2, 1, 1, 1, 1, 1, 1, 1],
209
- },
210
- {
211
- "in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
212
- "out_channels": [256] * 8,
213
- "kernel_size": [(3, 3), (3, 3), (1, 3), (3, 1), (3, 3), (1, 3), (3, 1), (3, 3)],
214
- "stride": [2, 1, 1, 1, 1, 1, 1, 1],
215
- },
216
- {
217
- "in_channels": [256, 512, 512, 512, 512],
218
- "out_channels": [512] * 5,
219
- "kernel_size": [(3, 3), (3, 1), (1, 3), (1, 3), (3, 1)],
220
- "stride": [2, 1, 1, 1, 1],
221
- },
222
- ],
223
- **kwargs,
224
- )
225
-
226
-
227
- def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
228
- """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
229
- Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
230
- Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
231
-
232
- >>> import tensorflow as tf
233
- >>> from doctr.models import textnet_base
234
- >>> model = textnet_base(pretrained=False)
235
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
236
- >>> out = model(input_tensor)
237
-
238
- Args:
239
- pretrained: boolean, True if model is pretrained
240
- **kwargs: keyword arguments of the TextNet architecture
241
-
242
- Returns:
243
- A TextNet base model
244
- """
245
- return _textnet(
246
- "textnet_base",
247
- pretrained,
248
- stages=[
249
- {
250
- "in_channels": [64] * 10,
251
- "out_channels": [64] * 10,
252
- "kernel_size": [(3, 3), (3, 3), (3, 1), (3, 3), (3, 1), (3, 3), (3, 3), (1, 3), (3, 3), (3, 3)],
253
- "stride": [1, 2, 1, 1, 1, 1, 1, 1, 1, 1],
254
- },
255
- {
256
- "in_channels": [64, 128, 128, 128, 128, 128, 128, 128, 128, 128],
257
- "out_channels": [128] * 10,
258
- "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 3), (3, 1), (3, 1), (3, 3), (3, 3)],
259
- "stride": [2, 1, 1, 1, 1, 1, 1, 1, 1, 1],
260
- },
261
- {
262
- "in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
263
- "out_channels": [256] * 8,
264
- "kernel_size": [(3, 3), (3, 3), (3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1)],
265
- "stride": [2, 1, 1, 1, 1, 1, 1, 1],
266
- },
267
- {
268
- "in_channels": [256, 512, 512, 512, 512],
269
- "out_channels": [512] * 5,
270
- "kernel_size": [(3, 3), (1, 3), (3, 1), (3, 1), (1, 3)],
271
- "stride": [2, 1, 1, 1, 1],
272
- },
273
- ],
274
- **kwargs,
275
- )