python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +1 -5
  3. doctr/datasets/coco_text.py +139 -0
  4. doctr/datasets/cord.py +2 -1
  5. doctr/datasets/datasets/__init__.py +1 -6
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +2 -2
  8. doctr/datasets/generator/__init__.py +1 -6
  9. doctr/datasets/ic03.py +1 -1
  10. doctr/datasets/ic13.py +2 -1
  11. doctr/datasets/iiit5k.py +4 -1
  12. doctr/datasets/imgur5k.py +9 -2
  13. doctr/datasets/ocr.py +1 -1
  14. doctr/datasets/recognition.py +1 -1
  15. doctr/datasets/svhn.py +1 -1
  16. doctr/datasets/svt.py +2 -2
  17. doctr/datasets/synthtext.py +15 -2
  18. doctr/datasets/utils.py +7 -6
  19. doctr/datasets/vocabs.py +1100 -54
  20. doctr/file_utils.py +2 -92
  21. doctr/io/elements.py +37 -3
  22. doctr/io/image/__init__.py +1 -7
  23. doctr/io/image/pytorch.py +1 -1
  24. doctr/models/_utils.py +4 -4
  25. doctr/models/classification/__init__.py +1 -0
  26. doctr/models/classification/magc_resnet/__init__.py +1 -6
  27. doctr/models/classification/magc_resnet/pytorch.py +3 -4
  28. doctr/models/classification/mobilenet/__init__.py +1 -6
  29. doctr/models/classification/mobilenet/pytorch.py +15 -1
  30. doctr/models/classification/predictor/__init__.py +1 -6
  31. doctr/models/classification/predictor/pytorch.py +2 -2
  32. doctr/models/classification/resnet/__init__.py +1 -6
  33. doctr/models/classification/resnet/pytorch.py +26 -3
  34. doctr/models/classification/textnet/__init__.py +1 -6
  35. doctr/models/classification/textnet/pytorch.py +11 -2
  36. doctr/models/classification/vgg/__init__.py +1 -6
  37. doctr/models/classification/vgg/pytorch.py +16 -1
  38. doctr/models/classification/vip/__init__.py +1 -0
  39. doctr/models/classification/vip/layers/__init__.py +1 -0
  40. doctr/models/classification/vip/layers/pytorch.py +615 -0
  41. doctr/models/classification/vip/pytorch.py +505 -0
  42. doctr/models/classification/vit/__init__.py +1 -6
  43. doctr/models/classification/vit/pytorch.py +12 -3
  44. doctr/models/classification/zoo.py +7 -8
  45. doctr/models/detection/_utils/__init__.py +1 -6
  46. doctr/models/detection/core.py +1 -1
  47. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  48. doctr/models/detection/differentiable_binarization/base.py +7 -16
  49. doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
  50. doctr/models/detection/fast/__init__.py +1 -6
  51. doctr/models/detection/fast/base.py +6 -17
  52. doctr/models/detection/fast/pytorch.py +17 -8
  53. doctr/models/detection/linknet/__init__.py +1 -6
  54. doctr/models/detection/linknet/base.py +5 -15
  55. doctr/models/detection/linknet/pytorch.py +12 -3
  56. doctr/models/detection/predictor/__init__.py +1 -6
  57. doctr/models/detection/predictor/pytorch.py +1 -1
  58. doctr/models/detection/zoo.py +15 -32
  59. doctr/models/factory/hub.py +9 -22
  60. doctr/models/kie_predictor/__init__.py +1 -6
  61. doctr/models/kie_predictor/pytorch.py +3 -7
  62. doctr/models/modules/layers/__init__.py +1 -6
  63. doctr/models/modules/layers/pytorch.py +52 -4
  64. doctr/models/modules/transformer/__init__.py +1 -6
  65. doctr/models/modules/transformer/pytorch.py +2 -2
  66. doctr/models/modules/vision_transformer/__init__.py +1 -6
  67. doctr/models/predictor/__init__.py +1 -6
  68. doctr/models/predictor/base.py +3 -8
  69. doctr/models/predictor/pytorch.py +3 -6
  70. doctr/models/preprocessor/__init__.py +1 -6
  71. doctr/models/preprocessor/pytorch.py +27 -32
  72. doctr/models/recognition/__init__.py +1 -0
  73. doctr/models/recognition/crnn/__init__.py +1 -6
  74. doctr/models/recognition/crnn/pytorch.py +16 -7
  75. doctr/models/recognition/master/__init__.py +1 -6
  76. doctr/models/recognition/master/pytorch.py +15 -6
  77. doctr/models/recognition/parseq/__init__.py +1 -6
  78. doctr/models/recognition/parseq/pytorch.py +26 -8
  79. doctr/models/recognition/predictor/__init__.py +1 -6
  80. doctr/models/recognition/predictor/_utils.py +100 -47
  81. doctr/models/recognition/predictor/pytorch.py +4 -5
  82. doctr/models/recognition/sar/__init__.py +1 -6
  83. doctr/models/recognition/sar/pytorch.py +13 -4
  84. doctr/models/recognition/utils.py +56 -47
  85. doctr/models/recognition/viptr/__init__.py +1 -0
  86. doctr/models/recognition/viptr/pytorch.py +277 -0
  87. doctr/models/recognition/vitstr/__init__.py +1 -6
  88. doctr/models/recognition/vitstr/pytorch.py +13 -4
  89. doctr/models/recognition/zoo.py +13 -8
  90. doctr/models/utils/__init__.py +1 -6
  91. doctr/models/utils/pytorch.py +29 -19
  92. doctr/transforms/functional/__init__.py +1 -6
  93. doctr/transforms/functional/pytorch.py +4 -4
  94. doctr/transforms/modules/__init__.py +1 -7
  95. doctr/transforms/modules/base.py +26 -92
  96. doctr/transforms/modules/pytorch.py +28 -26
  97. doctr/utils/data.py +1 -1
  98. doctr/utils/geometry.py +7 -11
  99. doctr/utils/visualization.py +1 -1
  100. doctr/version.py +1 -1
  101. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
  102. python_doctr-1.0.0.dist-info/RECORD +149 -0
  103. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
  104. doctr/datasets/datasets/tensorflow.py +0 -59
  105. doctr/datasets/generator/tensorflow.py +0 -58
  106. doctr/datasets/loader.py +0 -94
  107. doctr/io/image/tensorflow.py +0 -101
  108. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  109. doctr/models/classification/mobilenet/tensorflow.py +0 -433
  110. doctr/models/classification/predictor/tensorflow.py +0 -60
  111. doctr/models/classification/resnet/tensorflow.py +0 -397
  112. doctr/models/classification/textnet/tensorflow.py +0 -266
  113. doctr/models/classification/vgg/tensorflow.py +0 -116
  114. doctr/models/classification/vit/tensorflow.py +0 -192
  115. doctr/models/detection/_utils/tensorflow.py +0 -34
  116. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
  117. doctr/models/detection/fast/tensorflow.py +0 -419
  118. doctr/models/detection/linknet/tensorflow.py +0 -369
  119. doctr/models/detection/predictor/tensorflow.py +0 -70
  120. doctr/models/kie_predictor/tensorflow.py +0 -187
  121. doctr/models/modules/layers/tensorflow.py +0 -171
  122. doctr/models/modules/transformer/tensorflow.py +0 -235
  123. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  124. doctr/models/predictor/tensorflow.py +0 -155
  125. doctr/models/preprocessor/tensorflow.py +0 -122
  126. doctr/models/recognition/crnn/tensorflow.py +0 -308
  127. doctr/models/recognition/master/tensorflow.py +0 -313
  128. doctr/models/recognition/parseq/tensorflow.py +0 -508
  129. doctr/models/recognition/predictor/tensorflow.py +0 -79
  130. doctr/models/recognition/sar/tensorflow.py +0 -416
  131. doctr/models/recognition/vitstr/tensorflow.py +0 -278
  132. doctr/models/utils/tensorflow.py +0 -182
  133. doctr/transforms/functional/tensorflow.py +0 -254
  134. doctr/transforms/modules/tensorflow.py +0 -562
  135. python_doctr-0.11.0.dist-info/RECORD +0 -173
  136. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
  137. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  138. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -1,397 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- from collections.abc import Callable
7
- from copy import deepcopy
8
- from typing import Any
9
-
10
- import tensorflow as tf
11
- from tensorflow.keras import layers
12
- from tensorflow.keras.applications import ResNet50
13
- from tensorflow.keras.models import Sequential
14
-
15
- from doctr.datasets import VOCABS
16
-
17
- from ...utils import _build_model, conv_sequence, load_pretrained_params
18
-
19
- __all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"]
20
-
21
-
22
- default_cfgs: dict[str, dict[str, Any]] = {
23
- "resnet18": {
24
- "mean": (0.694, 0.695, 0.693),
25
- "std": (0.299, 0.296, 0.301),
26
- "input_shape": (32, 32, 3),
27
- "classes": list(VOCABS["french"]),
28
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0",
29
- },
30
- "resnet31": {
31
- "mean": (0.694, 0.695, 0.693),
32
- "std": (0.299, 0.296, 0.301),
33
- "input_shape": (32, 32, 3),
34
- "classes": list(VOCABS["french"]),
35
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0",
36
- },
37
- "resnet34": {
38
- "mean": (0.694, 0.695, 0.693),
39
- "std": (0.299, 0.296, 0.301),
40
- "input_shape": (32, 32, 3),
41
- "classes": list(VOCABS["french"]),
42
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0",
43
- },
44
- "resnet50": {
45
- "mean": (0.694, 0.695, 0.693),
46
- "std": (0.299, 0.296, 0.301),
47
- "input_shape": (32, 32, 3),
48
- "classes": list(VOCABS["french"]),
49
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0",
50
- },
51
- "resnet34_wide": {
52
- "mean": (0.694, 0.695, 0.693),
53
- "std": (0.299, 0.296, 0.301),
54
- "input_shape": (32, 32, 3),
55
- "classes": list(VOCABS["french"]),
56
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0",
57
- },
58
- }
59
-
60
-
61
- class ResnetBlock(layers.Layer):
62
- """Implements a resnet31 block with shortcut
63
-
64
- Args:
65
- conv_shortcut: Use of shortcut
66
- output_channels: number of channels to use in Conv2D
67
- kernel_size: size of square kernels
68
- strides: strides to use in the first convolution of the block
69
- """
70
-
71
- def __init__(self, output_channels: int, conv_shortcut: bool, strides: int = 1, **kwargs) -> None:
72
- super().__init__(**kwargs)
73
- if conv_shortcut:
74
- self.shortcut = Sequential([
75
- layers.Conv2D(
76
- filters=output_channels,
77
- strides=strides,
78
- padding="same",
79
- kernel_size=1,
80
- use_bias=False,
81
- kernel_initializer="he_normal",
82
- ),
83
- layers.BatchNormalization(),
84
- ])
85
- else:
86
- self.shortcut = layers.Lambda(lambda x: x)
87
- self.conv_block = Sequential(self.conv_resnetblock(output_channels, 3, strides))
88
- self.act = layers.Activation("relu")
89
-
90
- @staticmethod
91
- def conv_resnetblock(
92
- output_channels: int,
93
- kernel_size: int,
94
- strides: int = 1,
95
- ) -> list[layers.Layer]:
96
- return [
97
- *conv_sequence(output_channels, "relu", bn=True, strides=strides, kernel_size=kernel_size),
98
- *conv_sequence(output_channels, None, bn=True, kernel_size=kernel_size),
99
- ]
100
-
101
- def call(self, inputs: tf.Tensor) -> tf.Tensor:
102
- clone = self.shortcut(inputs)
103
- conv_out = self.conv_block(inputs)
104
- out = self.act(clone + conv_out)
105
-
106
- return out
107
-
108
-
109
- def resnet_stage(
110
- num_blocks: int, out_channels: int, shortcut: bool = False, downsample: bool = False
111
- ) -> list[layers.Layer]:
112
- _layers: list[layers.Layer] = [ResnetBlock(out_channels, conv_shortcut=shortcut, strides=2 if downsample else 1)]
113
-
114
- for _ in range(1, num_blocks):
115
- _layers.append(ResnetBlock(out_channels, conv_shortcut=False))
116
-
117
- return _layers
118
-
119
-
120
- class ResNet(Sequential):
121
- """Implements a ResNet architecture
122
-
123
- Args:
124
- num_blocks: number of resnet block in each stage
125
- output_channels: number of channels in each stage
126
- stage_downsample: whether the first residual block of a stage should downsample
127
- stage_conv: whether to add a conv_sequence after each stage
128
- stage_pooling: pooling to add after each stage (if None, no pooling)
129
- origin_stem: whether to use the orginal ResNet stem or ResNet-31's
130
- stem_channels: number of output channels of the stem convolutions
131
- attn_module: attention module to use in each stage
132
- include_top: whether the classifier head should be instantiated
133
- num_classes: number of output classes
134
- input_shape: shape of inputs
135
- """
136
-
137
- def __init__(
138
- self,
139
- num_blocks: list[int],
140
- output_channels: list[int],
141
- stage_downsample: list[bool],
142
- stage_conv: list[bool],
143
- stage_pooling: list[tuple[int, int] | None],
144
- origin_stem: bool = True,
145
- stem_channels: int = 64,
146
- attn_module: Callable[[int], layers.Layer] | None = None,
147
- include_top: bool = True,
148
- num_classes: int = 1000,
149
- cfg: dict[str, Any] | None = None,
150
- input_shape: tuple[int, int, int] | None = None,
151
- ) -> None:
152
- inplanes = stem_channels
153
- if origin_stem:
154
- _layers = [
155
- *conv_sequence(inplanes, "relu", True, kernel_size=7, strides=2, input_shape=input_shape),
156
- layers.MaxPool2D(pool_size=(3, 3), strides=2, padding="same"),
157
- ]
158
- else:
159
- _layers = [
160
- *conv_sequence(inplanes // 2, "relu", True, kernel_size=3, input_shape=input_shape),
161
- *conv_sequence(inplanes, "relu", True, kernel_size=3),
162
- layers.MaxPool2D(pool_size=2, strides=2, padding="valid"),
163
- ]
164
-
165
- for n_blocks, out_chan, down, conv, pool in zip(
166
- num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling
167
- ):
168
- _layers.extend(resnet_stage(n_blocks, out_chan, out_chan != inplanes, down))
169
- if attn_module is not None:
170
- _layers.append(attn_module(out_chan))
171
- if conv:
172
- _layers.extend(conv_sequence(out_chan, activation="relu", bn=True, kernel_size=3))
173
- if pool:
174
- _layers.append(layers.MaxPool2D(pool_size=pool, strides=pool, padding="valid"))
175
- inplanes = out_chan
176
-
177
- if include_top:
178
- _layers.extend([
179
- layers.GlobalAveragePooling2D(),
180
- layers.Dense(num_classes),
181
- ])
182
-
183
- super().__init__(_layers)
184
- self.cfg = cfg
185
-
186
-
187
- def _resnet(
188
- arch: str,
189
- pretrained: bool,
190
- num_blocks: list[int],
191
- output_channels: list[int],
192
- stage_downsample: list[bool],
193
- stage_conv: list[bool],
194
- stage_pooling: list[tuple[int, int] | None],
195
- origin_stem: bool = True,
196
- **kwargs: Any,
197
- ) -> ResNet:
198
- kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
199
- kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
200
- kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
201
-
202
- _cfg = deepcopy(default_cfgs[arch])
203
- _cfg["num_classes"] = kwargs["num_classes"]
204
- _cfg["classes"] = kwargs["classes"]
205
- _cfg["input_shape"] = kwargs["input_shape"]
206
- kwargs.pop("classes")
207
-
208
- # Build the model
209
- model = ResNet(
210
- num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, cfg=_cfg, **kwargs
211
- )
212
- _build_model(model)
213
-
214
- # Load pretrained parameters
215
- if pretrained:
216
- # The number of classes is not the same as the number of classes in the pretrained model =>
217
- # skip the mismatching layers for fine tuning
218
- load_pretrained_params(
219
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
220
- )
221
-
222
- return model
223
-
224
-
225
- def resnet18(pretrained: bool = False, **kwargs: Any) -> ResNet:
226
- """Resnet-18 architecture as described in `"Deep Residual Learning for Image Recognition",
227
- <https://arxiv.org/pdf/1512.03385.pdf>`_.
228
-
229
- >>> import tensorflow as tf
230
- >>> from doctr.models import resnet18
231
- >>> model = resnet18(pretrained=False)
232
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
233
- >>> out = model(input_tensor)
234
-
235
- Args:
236
- pretrained: boolean, True if model is pretrained
237
- **kwargs: keyword arguments of the ResNet architecture
238
-
239
- Returns:
240
- A classification model
241
- """
242
- return _resnet(
243
- "resnet18",
244
- pretrained,
245
- [2, 2, 2, 2],
246
- [64, 128, 256, 512],
247
- [False, True, True, True],
248
- [False] * 4,
249
- [None] * 4,
250
- True,
251
- **kwargs,
252
- )
253
-
254
-
255
- def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
256
- """Resnet31 architecture with rectangular pooling windows as described in
257
- `"Show, Attend and Read:A Simple and Strong Baseline for Irregular Text Recognition",
258
- <https://arxiv.org/pdf/1811.00751.pdf>`_. Downsizing: (H, W) --> (H/8, W/4)
259
-
260
- >>> import tensorflow as tf
261
- >>> from doctr.models import resnet31
262
- >>> model = resnet31(pretrained=False)
263
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
264
- >>> out = model(input_tensor)
265
-
266
- Args:
267
- pretrained: boolean, True if model is pretrained
268
- **kwargs: keyword arguments of the ResNet architecture
269
-
270
- Returns:
271
- A classification model
272
- """
273
- return _resnet(
274
- "resnet31",
275
- pretrained,
276
- [1, 2, 5, 3],
277
- [256, 256, 512, 512],
278
- [False] * 4,
279
- [True] * 4,
280
- [(2, 2), (2, 1), None, None],
281
- False,
282
- stem_channels=128,
283
- **kwargs,
284
- )
285
-
286
-
287
- def resnet34(pretrained: bool = False, **kwargs: Any) -> ResNet:
288
- """Resnet-34 architecture as described in `"Deep Residual Learning for Image Recognition",
289
- <https://arxiv.org/pdf/1512.03385.pdf>`_.
290
-
291
- >>> import tensorflow as tf
292
- >>> from doctr.models import resnet34
293
- >>> model = resnet34(pretrained=False)
294
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
295
- >>> out = model(input_tensor)
296
-
297
- Args:
298
- pretrained: boolean, True if model is pretrained
299
- **kwargs: keyword arguments of the ResNet architecture
300
-
301
- Returns:
302
- A classification model
303
- """
304
- return _resnet(
305
- "resnet34",
306
- pretrained,
307
- [3, 4, 6, 3],
308
- [64, 128, 256, 512],
309
- [False, True, True, True],
310
- [False] * 4,
311
- [None] * 4,
312
- True,
313
- **kwargs,
314
- )
315
-
316
-
317
- def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
318
- """Resnet-50 architecture as described in `"Deep Residual Learning for Image Recognition",
319
- <https://arxiv.org/pdf/1512.03385.pdf>`_.
320
-
321
- >>> import tensorflow as tf
322
- >>> from doctr.models import resnet50
323
- >>> model = resnet50(pretrained=False)
324
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
325
- >>> out = model(input_tensor)
326
-
327
- Args:
328
- pretrained: boolean, True if model is pretrained
329
- **kwargs: keyword arguments of the ResNet architecture
330
-
331
- Returns:
332
- A classification model
333
- """
334
- kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs["resnet50"]["classes"]))
335
- kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs["resnet50"]["input_shape"])
336
- kwargs["classes"] = kwargs.get("classes", default_cfgs["resnet50"]["classes"])
337
-
338
- _cfg = deepcopy(default_cfgs["resnet50"])
339
- _cfg["num_classes"] = kwargs["num_classes"]
340
- _cfg["classes"] = kwargs["classes"]
341
- _cfg["input_shape"] = kwargs["input_shape"]
342
- kwargs.pop("classes")
343
-
344
- model = ResNet50(
345
- weights=None,
346
- include_top=True,
347
- pooling=True,
348
- input_shape=kwargs["input_shape"],
349
- classes=kwargs["num_classes"],
350
- classifier_activation=None,
351
- )
352
-
353
- model.cfg = _cfg
354
- _build_model(model)
355
-
356
- # Load pretrained parameters
357
- if pretrained:
358
- # The number of classes is not the same as the number of classes in the pretrained model =>
359
- # skip the mismatching layers for fine tuning
360
- load_pretrained_params(
361
- model,
362
- default_cfgs["resnet50"]["url"],
363
- skip_mismatch=kwargs["num_classes"] != len(default_cfgs["resnet50"]["classes"]),
364
- )
365
-
366
- return model
367
-
368
-
369
- def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet:
370
- """Resnet-34 architecture as described in `"Deep Residual Learning for Image Recognition",
371
- <https://arxiv.org/pdf/1512.03385.pdf>`_ with twice as many output channels for each stage.
372
-
373
- >>> import tensorflow as tf
374
- >>> from doctr.models import resnet34_wide
375
- >>> model = resnet34_wide(pretrained=False)
376
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
377
- >>> out = model(input_tensor)
378
-
379
- Args:
380
- pretrained: boolean, True if model is pretrained
381
- **kwargs: keyword arguments of the ResNet architecture
382
-
383
- Returns:
384
- A classification model
385
- """
386
- return _resnet(
387
- "resnet34_wide",
388
- pretrained,
389
- [3, 4, 6, 3],
390
- [128, 256, 512, 1024],
391
- [False, True, True, True],
392
- [False] * 4,
393
- [None] * 4,
394
- True,
395
- stem_channels=128,
396
- **kwargs,
397
- )
@@ -1,266 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
-
7
- from copy import deepcopy
8
- from typing import Any
9
-
10
- from tensorflow.keras import Sequential, layers
11
-
12
- from doctr.datasets import VOCABS
13
-
14
- from ...modules.layers.tensorflow import FASTConvLayer
15
- from ...utils import _build_model, conv_sequence, load_pretrained_params
16
-
17
- __all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
18
-
19
- default_cfgs: dict[str, dict[str, Any]] = {
20
- "textnet_tiny": {
21
- "mean": (0.694, 0.695, 0.693),
22
- "std": (0.299, 0.296, 0.301),
23
- "input_shape": (32, 32, 3),
24
- "classes": list(VOCABS["french"]),
25
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0",
26
- },
27
- "textnet_small": {
28
- "mean": (0.694, 0.695, 0.693),
29
- "std": (0.299, 0.296, 0.301),
30
- "input_shape": (32, 32, 3),
31
- "classes": list(VOCABS["french"]),
32
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0",
33
- },
34
- "textnet_base": {
35
- "mean": (0.694, 0.695, 0.693),
36
- "std": (0.299, 0.296, 0.301),
37
- "input_shape": (32, 32, 3),
38
- "classes": list(VOCABS["french"]),
39
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0",
40
- },
41
- }
42
-
43
-
44
- class TextNet(Sequential):
45
- """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
46
- Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
47
- Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
48
-
49
- Args:
50
- stages (list[dict[str, list[int]]]): list of dictionaries containing the parameters of each stage.
51
- include_top (bool, optional): Whether to include the classifier head. Defaults to True.
52
- num_classes (int, optional): Number of output classes. Defaults to 1000.
53
- cfg (dict[str, Any], optional): Additional configuration. Defaults to None.
54
- """
55
-
56
- def __init__(
57
- self,
58
- stages: list[dict[str, list[int]]],
59
- input_shape: tuple[int, int, int] = (32, 32, 3),
60
- num_classes: int = 1000,
61
- include_top: bool = True,
62
- cfg: dict[str, Any] | None = None,
63
- ) -> None:
64
- _layers = [
65
- *conv_sequence(
66
- out_channels=64, activation="relu", bn=True, kernel_size=3, strides=2, input_shape=input_shape
67
- ),
68
- *[
69
- Sequential(
70
- [
71
- FASTConvLayer(**params) # type: ignore[arg-type]
72
- for params in [{key: stage[key][i] for key in stage} for i in range(len(stage["in_channels"]))]
73
- ],
74
- name=f"stage_{i}",
75
- )
76
- for i, stage in enumerate(stages)
77
- ],
78
- ]
79
-
80
- if include_top:
81
- _layers.append(
82
- Sequential(
83
- [
84
- layers.AveragePooling2D(1),
85
- layers.Flatten(),
86
- layers.Dense(num_classes),
87
- ],
88
- name="classifier",
89
- )
90
- )
91
-
92
- super().__init__(_layers)
93
- self.cfg = cfg
94
-
95
-
96
- def _textnet(
97
- arch: str,
98
- pretrained: bool,
99
- **kwargs: Any,
100
- ) -> TextNet:
101
- kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
102
- kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
103
- kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
104
-
105
- _cfg = deepcopy(default_cfgs[arch])
106
- _cfg["num_classes"] = kwargs["num_classes"]
107
- _cfg["input_shape"] = kwargs["input_shape"]
108
- _cfg["classes"] = kwargs["classes"]
109
- kwargs.pop("classes")
110
-
111
- # Build the model
112
- model = TextNet(cfg=_cfg, **kwargs)
113
- _build_model(model)
114
-
115
- # Load pretrained parameters
116
- if pretrained:
117
- # The number of classes is not the same as the number of classes in the pretrained model =>
118
- # skip the mismatching layers for fine tuning
119
- load_pretrained_params(
120
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
121
- )
122
-
123
- return model
124
-
125
-
126
- def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
127
- """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
128
- Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
129
- Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
130
-
131
- >>> import tensorflow as tf
132
- >>> from doctr.models import textnet_tiny
133
- >>> model = textnet_tiny(pretrained=False)
134
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
135
- >>> out = model(input_tensor)
136
-
137
- Args:
138
- pretrained: boolean, True if model is pretrained
139
- **kwargs: keyword arguments of the TextNet architecture
140
-
141
- Returns:
142
- A textnet tiny model
143
- """
144
- return _textnet(
145
- "textnet_tiny",
146
- pretrained,
147
- stages=[
148
- {"in_channels": [64] * 3, "out_channels": [64] * 3, "kernel_size": [(3, 3)] * 3, "stride": [1, 2, 1]},
149
- {
150
- "in_channels": [64, 128, 128, 128],
151
- "out_channels": [128] * 4,
152
- "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1)],
153
- "stride": [2, 1, 1, 1],
154
- },
155
- {
156
- "in_channels": [128, 256, 256, 256],
157
- "out_channels": [256] * 4,
158
- "kernel_size": [(3, 3), (3, 3), (3, 1), (1, 3)],
159
- "stride": [2, 1, 1, 1],
160
- },
161
- {
162
- "in_channels": [256, 512, 512, 512],
163
- "out_channels": [512] * 4,
164
- "kernel_size": [(3, 3), (3, 1), (1, 3), (3, 3)],
165
- "stride": [2, 1, 1, 1],
166
- },
167
- ],
168
- **kwargs,
169
- )
170
-
171
-
172
- def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
173
- """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
174
- Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
175
- Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
176
-
177
- >>> import tensorflow as tf
178
- >>> from doctr.models import textnet_small
179
- >>> model = textnet_small(pretrained=False)
180
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
181
- >>> out = model(input_tensor)
182
-
183
- Args:
184
- pretrained: boolean, True if model is pretrained
185
- **kwargs: keyword arguments of the TextNet architecture
186
-
187
- Returns:
188
- A TextNet small model
189
- """
190
- return _textnet(
191
- "textnet_small",
192
- pretrained,
193
- stages=[
194
- {"in_channels": [64] * 2, "out_channels": [64] * 2, "kernel_size": [(3, 3)] * 2, "stride": [1, 2]},
195
- {
196
- "in_channels": [64, 128, 128, 128, 128, 128, 128, 128],
197
- "out_channels": [128] * 8,
198
- "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1), (1, 3), (3, 3)],
199
- "stride": [2, 1, 1, 1, 1, 1, 1, 1],
200
- },
201
- {
202
- "in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
203
- "out_channels": [256] * 8,
204
- "kernel_size": [(3, 3), (3, 3), (1, 3), (3, 1), (3, 3), (1, 3), (3, 1), (3, 3)],
205
- "stride": [2, 1, 1, 1, 1, 1, 1, 1],
206
- },
207
- {
208
- "in_channels": [256, 512, 512, 512, 512],
209
- "out_channels": [512] * 5,
210
- "kernel_size": [(3, 3), (3, 1), (1, 3), (1, 3), (3, 1)],
211
- "stride": [2, 1, 1, 1, 1],
212
- },
213
- ],
214
- **kwargs,
215
- )
216
-
217
-
218
- def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
219
- """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
220
- Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
221
- Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
222
-
223
- >>> import tensorflow as tf
224
- >>> from doctr.models import textnet_base
225
- >>> model = textnet_base(pretrained=False)
226
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
227
- >>> out = model(input_tensor)
228
-
229
- Args:
230
- pretrained: boolean, True if model is pretrained
231
- **kwargs: keyword arguments of the TextNet architecture
232
-
233
- Returns:
234
- A TextNet base model
235
- """
236
- return _textnet(
237
- "textnet_base",
238
- pretrained,
239
- stages=[
240
- {
241
- "in_channels": [64] * 10,
242
- "out_channels": [64] * 10,
243
- "kernel_size": [(3, 3), (3, 3), (3, 1), (3, 3), (3, 1), (3, 3), (3, 3), (1, 3), (3, 3), (3, 3)],
244
- "stride": [1, 2, 1, 1, 1, 1, 1, 1, 1, 1],
245
- },
246
- {
247
- "in_channels": [64, 128, 128, 128, 128, 128, 128, 128, 128, 128],
248
- "out_channels": [128] * 10,
249
- "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 3), (3, 1), (3, 1), (3, 3), (3, 3)],
250
- "stride": [2, 1, 1, 1, 1, 1, 1, 1, 1, 1],
251
- },
252
- {
253
- "in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
254
- "out_channels": [256] * 8,
255
- "kernel_size": [(3, 3), (3, 3), (3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1)],
256
- "stride": [2, 1, 1, 1, 1, 1, 1, 1],
257
- },
258
- {
259
- "in_channels": [256, 512, 512, 512, 512],
260
- "out_channels": [512] * 5,
261
- "kernel_size": [(3, 3), (1, 3), (3, 1), (3, 1), (1, 3)],
262
- "stride": [2, 1, 1, 1, 1],
263
- },
264
- ],
265
- **kwargs,
266
- )