python-doctr 0.12.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/contrib/artefacts.py +1 -1
  3. doctr/contrib/base.py +1 -1
  4. doctr/datasets/__init__.py +0 -5
  5. doctr/datasets/coco_text.py +1 -1
  6. doctr/datasets/cord.py +1 -1
  7. doctr/datasets/datasets/__init__.py +1 -6
  8. doctr/datasets/datasets/base.py +1 -1
  9. doctr/datasets/datasets/pytorch.py +3 -3
  10. doctr/datasets/detection.py +1 -1
  11. doctr/datasets/doc_artefacts.py +1 -1
  12. doctr/datasets/funsd.py +1 -1
  13. doctr/datasets/generator/__init__.py +1 -6
  14. doctr/datasets/generator/base.py +1 -1
  15. doctr/datasets/generator/pytorch.py +1 -1
  16. doctr/datasets/ic03.py +1 -1
  17. doctr/datasets/ic13.py +1 -1
  18. doctr/datasets/iiit5k.py +1 -1
  19. doctr/datasets/iiithws.py +1 -1
  20. doctr/datasets/imgur5k.py +1 -1
  21. doctr/datasets/mjsynth.py +1 -1
  22. doctr/datasets/ocr.py +1 -1
  23. doctr/datasets/orientation.py +1 -1
  24. doctr/datasets/recognition.py +1 -1
  25. doctr/datasets/sroie.py +1 -1
  26. doctr/datasets/svhn.py +1 -1
  27. doctr/datasets/svt.py +1 -1
  28. doctr/datasets/synthtext.py +1 -1
  29. doctr/datasets/utils.py +1 -1
  30. doctr/datasets/vocabs.py +1 -3
  31. doctr/datasets/wildreceipt.py +1 -1
  32. doctr/file_utils.py +3 -102
  33. doctr/io/elements.py +1 -1
  34. doctr/io/html.py +1 -1
  35. doctr/io/image/__init__.py +1 -7
  36. doctr/io/image/base.py +1 -1
  37. doctr/io/image/pytorch.py +2 -2
  38. doctr/io/pdf.py +1 -1
  39. doctr/io/reader.py +1 -1
  40. doctr/models/_utils.py +56 -18
  41. doctr/models/builder.py +1 -1
  42. doctr/models/classification/magc_resnet/__init__.py +1 -6
  43. doctr/models/classification/magc_resnet/pytorch.py +3 -3
  44. doctr/models/classification/mobilenet/__init__.py +1 -6
  45. doctr/models/classification/mobilenet/pytorch.py +1 -1
  46. doctr/models/classification/predictor/__init__.py +1 -6
  47. doctr/models/classification/predictor/pytorch.py +2 -2
  48. doctr/models/classification/resnet/__init__.py +1 -6
  49. doctr/models/classification/resnet/pytorch.py +1 -1
  50. doctr/models/classification/textnet/__init__.py +1 -6
  51. doctr/models/classification/textnet/pytorch.py +2 -2
  52. doctr/models/classification/vgg/__init__.py +1 -6
  53. doctr/models/classification/vgg/pytorch.py +1 -1
  54. doctr/models/classification/vip/__init__.py +1 -4
  55. doctr/models/classification/vip/layers/__init__.py +1 -4
  56. doctr/models/classification/vip/layers/pytorch.py +2 -2
  57. doctr/models/classification/vip/pytorch.py +1 -1
  58. doctr/models/classification/vit/__init__.py +1 -6
  59. doctr/models/classification/vit/pytorch.py +3 -3
  60. doctr/models/classification/zoo.py +7 -12
  61. doctr/models/core.py +1 -1
  62. doctr/models/detection/_utils/__init__.py +1 -6
  63. doctr/models/detection/_utils/base.py +1 -1
  64. doctr/models/detection/_utils/pytorch.py +1 -1
  65. doctr/models/detection/core.py +2 -2
  66. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  67. doctr/models/detection/differentiable_binarization/base.py +5 -13
  68. doctr/models/detection/differentiable_binarization/pytorch.py +4 -4
  69. doctr/models/detection/fast/__init__.py +1 -6
  70. doctr/models/detection/fast/base.py +5 -15
  71. doctr/models/detection/fast/pytorch.py +5 -5
  72. doctr/models/detection/linknet/__init__.py +1 -6
  73. doctr/models/detection/linknet/base.py +4 -13
  74. doctr/models/detection/linknet/pytorch.py +3 -3
  75. doctr/models/detection/predictor/__init__.py +1 -6
  76. doctr/models/detection/predictor/pytorch.py +2 -2
  77. doctr/models/detection/zoo.py +16 -33
  78. doctr/models/factory/hub.py +26 -34
  79. doctr/models/kie_predictor/__init__.py +1 -6
  80. doctr/models/kie_predictor/base.py +1 -1
  81. doctr/models/kie_predictor/pytorch.py +3 -7
  82. doctr/models/modules/layers/__init__.py +1 -6
  83. doctr/models/modules/layers/pytorch.py +4 -4
  84. doctr/models/modules/transformer/__init__.py +1 -6
  85. doctr/models/modules/transformer/pytorch.py +3 -3
  86. doctr/models/modules/vision_transformer/__init__.py +1 -6
  87. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  88. doctr/models/predictor/__init__.py +1 -6
  89. doctr/models/predictor/base.py +4 -9
  90. doctr/models/predictor/pytorch.py +3 -6
  91. doctr/models/preprocessor/__init__.py +1 -6
  92. doctr/models/preprocessor/pytorch.py +28 -33
  93. doctr/models/recognition/core.py +1 -1
  94. doctr/models/recognition/crnn/__init__.py +1 -6
  95. doctr/models/recognition/crnn/pytorch.py +7 -7
  96. doctr/models/recognition/master/__init__.py +1 -6
  97. doctr/models/recognition/master/base.py +1 -1
  98. doctr/models/recognition/master/pytorch.py +6 -6
  99. doctr/models/recognition/parseq/__init__.py +1 -6
  100. doctr/models/recognition/parseq/base.py +1 -1
  101. doctr/models/recognition/parseq/pytorch.py +6 -6
  102. doctr/models/recognition/predictor/__init__.py +1 -6
  103. doctr/models/recognition/predictor/_utils.py +8 -17
  104. doctr/models/recognition/predictor/pytorch.py +2 -3
  105. doctr/models/recognition/sar/__init__.py +1 -6
  106. doctr/models/recognition/sar/pytorch.py +4 -4
  107. doctr/models/recognition/utils.py +1 -1
  108. doctr/models/recognition/viptr/__init__.py +1 -4
  109. doctr/models/recognition/viptr/pytorch.py +4 -4
  110. doctr/models/recognition/vitstr/__init__.py +1 -6
  111. doctr/models/recognition/vitstr/base.py +1 -1
  112. doctr/models/recognition/vitstr/pytorch.py +4 -4
  113. doctr/models/recognition/zoo.py +14 -14
  114. doctr/models/utils/__init__.py +1 -6
  115. doctr/models/utils/pytorch.py +3 -2
  116. doctr/models/zoo.py +1 -1
  117. doctr/transforms/functional/__init__.py +1 -6
  118. doctr/transforms/functional/base.py +3 -2
  119. doctr/transforms/functional/pytorch.py +5 -5
  120. doctr/transforms/modules/__init__.py +1 -7
  121. doctr/transforms/modules/base.py +28 -94
  122. doctr/transforms/modules/pytorch.py +29 -27
  123. doctr/utils/common_types.py +1 -1
  124. doctr/utils/data.py +1 -2
  125. doctr/utils/fonts.py +1 -1
  126. doctr/utils/geometry.py +7 -11
  127. doctr/utils/metrics.py +1 -1
  128. doctr/utils/multithreading.py +1 -1
  129. doctr/utils/reconstitution.py +1 -1
  130. doctr/utils/repr.py +1 -1
  131. doctr/utils/visualization.py +2 -2
  132. doctr/version.py +1 -1
  133. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/METADATA +30 -80
  134. python_doctr-1.0.1.dist-info/RECORD +149 -0
  135. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/WHEEL +1 -1
  136. doctr/datasets/datasets/tensorflow.py +0 -59
  137. doctr/datasets/generator/tensorflow.py +0 -58
  138. doctr/datasets/loader.py +0 -94
  139. doctr/io/image/tensorflow.py +0 -101
  140. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  141. doctr/models/classification/mobilenet/tensorflow.py +0 -442
  142. doctr/models/classification/predictor/tensorflow.py +0 -60
  143. doctr/models/classification/resnet/tensorflow.py +0 -418
  144. doctr/models/classification/textnet/tensorflow.py +0 -275
  145. doctr/models/classification/vgg/tensorflow.py +0 -125
  146. doctr/models/classification/vit/tensorflow.py +0 -201
  147. doctr/models/detection/_utils/tensorflow.py +0 -34
  148. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
  149. doctr/models/detection/fast/tensorflow.py +0 -427
  150. doctr/models/detection/linknet/tensorflow.py +0 -377
  151. doctr/models/detection/predictor/tensorflow.py +0 -70
  152. doctr/models/kie_predictor/tensorflow.py +0 -187
  153. doctr/models/modules/layers/tensorflow.py +0 -171
  154. doctr/models/modules/transformer/tensorflow.py +0 -235
  155. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  156. doctr/models/predictor/tensorflow.py +0 -155
  157. doctr/models/preprocessor/tensorflow.py +0 -122
  158. doctr/models/recognition/crnn/tensorflow.py +0 -317
  159. doctr/models/recognition/master/tensorflow.py +0 -320
  160. doctr/models/recognition/parseq/tensorflow.py +0 -516
  161. doctr/models/recognition/predictor/tensorflow.py +0 -79
  162. doctr/models/recognition/sar/tensorflow.py +0 -423
  163. doctr/models/recognition/vitstr/tensorflow.py +0 -285
  164. doctr/models/utils/tensorflow.py +0 -189
  165. doctr/transforms/functional/tensorflow.py +0 -254
  166. doctr/transforms/modules/tensorflow.py +0 -562
  167. python_doctr-0.12.0.dist-info/RECORD +0 -180
  168. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/licenses/LICENSE +0 -0
  169. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/top_level.txt +0 -0
  170. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/zip-safe +0 -0
@@ -1,125 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- from copy import deepcopy
7
- from typing import Any
8
-
9
- from tensorflow.keras import layers
10
- from tensorflow.keras.models import Sequential
11
-
12
- from doctr.datasets import VOCABS
13
-
14
- from ...utils import _build_model, conv_sequence, load_pretrained_params
15
-
16
- __all__ = ["VGG", "vgg16_bn_r"]
17
-
18
-
19
- default_cfgs: dict[str, dict[str, Any]] = {
20
- "vgg16_bn_r": {
21
- "mean": (0.5, 0.5, 0.5),
22
- "std": (1.0, 1.0, 1.0),
23
- "input_shape": (32, 32, 3),
24
- "classes": list(VOCABS["french"]),
25
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0",
26
- },
27
- }
28
-
29
-
30
- class VGG(Sequential):
31
- """Implements the VGG architecture from `"Very Deep Convolutional Networks for Large-Scale Image Recognition"
32
- <https://arxiv.org/pdf/1409.1556.pdf>`_.
33
-
34
- Args:
35
- num_blocks: number of convolutional block in each stage
36
- planes: number of output channels in each stage
37
- rect_pools: whether pooling square kernels should be replace with rectangular ones
38
- include_top: whether the classifier head should be instantiated
39
- num_classes: number of output classes
40
- input_shape: shapes of the input tensor
41
- """
42
-
43
- def __init__(
44
- self,
45
- num_blocks: list[int],
46
- planes: list[int],
47
- rect_pools: list[bool],
48
- include_top: bool = False,
49
- num_classes: int = 1000,
50
- input_shape: tuple[int, int, int] | None = None,
51
- cfg: dict[str, Any] | None = None,
52
- ) -> None:
53
- _layers = []
54
- # Specify input_shape only for the first layer
55
- kwargs = {"input_shape": input_shape}
56
- for nb_blocks, out_chan, rect_pool in zip(num_blocks, planes, rect_pools):
57
- for _ in range(nb_blocks):
58
- _layers.extend(conv_sequence(out_chan, "relu", True, kernel_size=3, **kwargs)) # type: ignore[arg-type]
59
- kwargs = {}
60
- _layers.append(layers.MaxPooling2D((2, 1 if rect_pool else 2)))
61
-
62
- if include_top:
63
- _layers.extend([layers.GlobalAveragePooling2D(), layers.Dense(num_classes)])
64
- super().__init__(_layers)
65
- self.cfg = cfg
66
-
67
- def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
68
- """Load pretrained parameters onto the model
69
-
70
- Args:
71
- path_or_url: the path or URL to the model parameters (checkpoint)
72
- **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
73
- """
74
- load_pretrained_params(self, path_or_url, **kwargs)
75
-
76
-
77
- def _vgg(
78
- arch: str, pretrained: bool, num_blocks: list[int], planes: list[int], rect_pools: list[bool], **kwargs: Any
79
- ) -> VGG:
80
- kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
81
- kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
82
- kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
83
-
84
- _cfg = deepcopy(default_cfgs[arch])
85
- _cfg["num_classes"] = kwargs["num_classes"]
86
- _cfg["classes"] = kwargs["classes"]
87
- _cfg["input_shape"] = kwargs["input_shape"]
88
- kwargs.pop("classes")
89
-
90
- # Build the model
91
- model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs)
92
- _build_model(model)
93
-
94
- # Load pretrained parameters
95
- if pretrained:
96
- # The number of classes is not the same as the number of classes in the pretrained model =>
97
- # skip the mismatching layers for fine tuning
98
- model.from_pretrained(
99
- default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
100
- )
101
-
102
- return model
103
-
104
-
105
- def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> VGG:
106
- """VGG-16 architecture as described in `"Very Deep Convolutional Networks for Large-Scale Image Recognition"
107
- <https://arxiv.org/pdf/1409.1556.pdf>`_, modified by adding batch normalization, rectangular pooling and a simpler
108
- classification head.
109
-
110
- >>> import tensorflow as tf
111
- >>> from doctr.models import vgg16_bn_r
112
- >>> model = vgg16_bn_r(pretrained=False)
113
- >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
114
- >>> out = model(input_tensor)
115
-
116
- Args:
117
- pretrained (bool): If True, returns a model pre-trained on ImageNet
118
- **kwargs: keyword arguments of the VGG architecture
119
-
120
- Returns:
121
- VGG feature extractor
122
- """
123
- return _vgg(
124
- "vgg16_bn_r", pretrained, [2, 2, 3, 3, 3], [64, 128, 256, 512, 512], [False, False, True, True, True], **kwargs
125
- )
@@ -1,201 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- from copy import deepcopy
7
- from typing import Any
8
-
9
- import tensorflow as tf
10
- from tensorflow.keras import Sequential, layers
11
-
12
- from doctr.datasets import VOCABS
13
- from doctr.models.modules.transformer import EncoderBlock
14
- from doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding
15
- from doctr.utils.repr import NestedObject
16
-
17
- from ...utils import _build_model, load_pretrained_params
18
-
19
- __all__ = ["vit_s", "vit_b"]
20
-
21
-
22
- default_cfgs: dict[str, dict[str, Any]] = {
23
- "vit_s": {
24
- "mean": (0.694, 0.695, 0.693),
25
- "std": (0.299, 0.296, 0.301),
26
- "input_shape": (3, 32, 32),
27
- "classes": list(VOCABS["french"]),
28
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_s-69bc459e.weights.h5&src=0",
29
- },
30
- "vit_b": {
31
- "mean": (0.694, 0.695, 0.693),
32
- "std": (0.299, 0.296, 0.301),
33
- "input_shape": (32, 32, 3),
34
- "classes": list(VOCABS["french"]),
35
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_b-c64705bd.weights.h5&src=0",
36
- },
37
- }
38
-
39
-
40
- class ClassifierHead(layers.Layer, NestedObject):
41
- """Classifier head for Vision Transformer
42
-
43
- Args:
44
- num_classes: number of output classes
45
- """
46
-
47
- def __init__(self, num_classes: int) -> None:
48
- super().__init__()
49
-
50
- self.head = layers.Dense(num_classes, kernel_initializer="he_normal", name="dense")
51
-
52
- def call(self, x: tf.Tensor) -> tf.Tensor:
53
- # (batch_size, num_classes) cls token
54
- return self.head(x[:, 0])
55
-
56
-
57
- class VisionTransformer(Sequential):
58
- """VisionTransformer architecture as described in
59
- `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
60
- <https://arxiv.org/pdf/2010.11929.pdf>`_.
61
-
62
- Args:
63
- d_model: dimension of the transformer layers
64
- num_layers: number of transformer layers
65
- num_heads: number of attention heads
66
- ffd_ratio: multiplier for the hidden dimension of the feedforward layer
67
- patch_size: size of the patches
68
- input_shape: size of the input image
69
- dropout: dropout rate
70
- num_classes: number of output classes
71
- include_top: whether the classifier head should be instantiated
72
- """
73
-
74
- def __init__(
75
- self,
76
- d_model: int,
77
- num_layers: int,
78
- num_heads: int,
79
- ffd_ratio: int,
80
- patch_size: tuple[int, int] = (4, 4),
81
- input_shape: tuple[int, int, int] = (32, 32, 3),
82
- dropout: float = 0.0,
83
- num_classes: int = 1000,
84
- include_top: bool = True,
85
- cfg: dict[str, Any] | None = None,
86
- ) -> None:
87
- _layers = [
88
- PatchEmbedding(input_shape, d_model, patch_size),
89
- EncoderBlock(
90
- num_layers,
91
- num_heads,
92
- d_model,
93
- d_model * ffd_ratio,
94
- dropout,
95
- activation_fct=layers.Activation("gelu"),
96
- ),
97
- ]
98
- if include_top:
99
- _layers.append(ClassifierHead(num_classes))
100
-
101
- super().__init__(_layers)
102
- self.cfg = cfg
103
-
104
- def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
105
- """Load pretrained parameters onto the model
106
-
107
- Args:
108
- path_or_url: the path or URL to the model parameters (checkpoint)
109
- **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
110
- """
111
- load_pretrained_params(self, path_or_url, **kwargs)
112
-
113
-
114
- def _vit(
115
- arch: str,
116
- pretrained: bool,
117
- **kwargs: Any,
118
- ) -> VisionTransformer:
119
- kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
120
- kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
121
- kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
122
-
123
- _cfg = deepcopy(default_cfgs[arch])
124
- _cfg["num_classes"] = kwargs["num_classes"]
125
- _cfg["input_shape"] = kwargs["input_shape"]
126
- _cfg["classes"] = kwargs["classes"]
127
- kwargs.pop("classes")
128
-
129
- # Build the model
130
- model = VisionTransformer(cfg=_cfg, **kwargs)
131
- _build_model(model)
132
-
133
- # Load pretrained parameters
134
- if pretrained:
135
- # The number of classes is not the same as the number of classes in the pretrained model =>
136
- # skip the mismatching layers for fine tuning
137
- load_pretrained_params(
138
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
139
- )
140
-
141
- return model
142
-
143
-
144
- def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
145
- """VisionTransformer-S architecture
146
- `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
147
- <https://arxiv.org/pdf/2010.11929.pdf>`_. Patches: (H, W) -> (H/8, W/8)
148
-
149
- NOTE: unofficial config used in ViTSTR and ParSeq
150
-
151
- >>> import tensorflow as tf
152
- >>> from doctr.models import vit_s
153
- >>> model = vit_s(pretrained=False)
154
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
155
- >>> out = model(input_tensor)
156
-
157
- Args:
158
- pretrained: boolean, True if model is pretrained
159
- **kwargs: keyword arguments of the VisionTransformer architecture
160
-
161
- Returns:
162
- A feature extractor model
163
- """
164
- return _vit(
165
- "vit_s",
166
- pretrained,
167
- d_model=384,
168
- num_layers=12,
169
- num_heads=6,
170
- ffd_ratio=4,
171
- **kwargs,
172
- )
173
-
174
-
175
- def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
176
- """VisionTransformer-B architecture as described in
177
- `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
178
- <https://arxiv.org/pdf/2010.11929.pdf>`_. Patches: (H, W) -> (H/8, W/8)
179
-
180
- >>> import tensorflow as tf
181
- >>> from doctr.models import vit_b
182
- >>> model = vit_b(pretrained=False)
183
- >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
184
- >>> out = model(input_tensor)
185
-
186
- Args:
187
- pretrained: boolean, True if model is pretrained
188
- **kwargs: keyword arguments of the VisionTransformer architecture
189
-
190
- Returns:
191
- A feature extractor model
192
- """
193
- return _vit(
194
- "vit_b",
195
- pretrained,
196
- d_model=768,
197
- num_layers=12,
198
- num_heads=12,
199
- ffd_ratio=4,
200
- **kwargs,
201
- )
@@ -1,34 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- import tensorflow as tf
7
-
8
- __all__ = ["erode", "dilate"]
9
-
10
-
11
- def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
12
- """Performs erosion on a given tensor
13
-
14
- Args:
15
- x: boolean tensor of shape (N, H, W, C)
16
- kernel_size: the size of the kernel to use for erosion
17
-
18
- Returns:
19
- the eroded tensor
20
- """
21
- return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME")
22
-
23
-
24
- def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
25
- """Performs dilation on a given tensor
26
-
27
- Args:
28
- x: boolean tensor of shape (N, H, W, C)
29
- kernel_size: the size of the kernel to use for dilation
30
-
31
- Returns:
32
- the dilated tensor
33
- """
34
- return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME")