python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +1 -5
  3. doctr/datasets/coco_text.py +139 -0
  4. doctr/datasets/cord.py +2 -1
  5. doctr/datasets/datasets/__init__.py +1 -6
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +2 -2
  8. doctr/datasets/generator/__init__.py +1 -6
  9. doctr/datasets/ic03.py +1 -1
  10. doctr/datasets/ic13.py +2 -1
  11. doctr/datasets/iiit5k.py +4 -1
  12. doctr/datasets/imgur5k.py +9 -2
  13. doctr/datasets/ocr.py +1 -1
  14. doctr/datasets/recognition.py +1 -1
  15. doctr/datasets/svhn.py +1 -1
  16. doctr/datasets/svt.py +2 -2
  17. doctr/datasets/synthtext.py +15 -2
  18. doctr/datasets/utils.py +7 -6
  19. doctr/datasets/vocabs.py +1100 -54
  20. doctr/file_utils.py +2 -92
  21. doctr/io/elements.py +37 -3
  22. doctr/io/image/__init__.py +1 -7
  23. doctr/io/image/pytorch.py +1 -1
  24. doctr/models/_utils.py +4 -4
  25. doctr/models/classification/__init__.py +1 -0
  26. doctr/models/classification/magc_resnet/__init__.py +1 -6
  27. doctr/models/classification/magc_resnet/pytorch.py +3 -4
  28. doctr/models/classification/mobilenet/__init__.py +1 -6
  29. doctr/models/classification/mobilenet/pytorch.py +15 -1
  30. doctr/models/classification/predictor/__init__.py +1 -6
  31. doctr/models/classification/predictor/pytorch.py +2 -2
  32. doctr/models/classification/resnet/__init__.py +1 -6
  33. doctr/models/classification/resnet/pytorch.py +26 -3
  34. doctr/models/classification/textnet/__init__.py +1 -6
  35. doctr/models/classification/textnet/pytorch.py +11 -2
  36. doctr/models/classification/vgg/__init__.py +1 -6
  37. doctr/models/classification/vgg/pytorch.py +16 -1
  38. doctr/models/classification/vip/__init__.py +1 -0
  39. doctr/models/classification/vip/layers/__init__.py +1 -0
  40. doctr/models/classification/vip/layers/pytorch.py +615 -0
  41. doctr/models/classification/vip/pytorch.py +505 -0
  42. doctr/models/classification/vit/__init__.py +1 -6
  43. doctr/models/classification/vit/pytorch.py +12 -3
  44. doctr/models/classification/zoo.py +7 -8
  45. doctr/models/detection/_utils/__init__.py +1 -6
  46. doctr/models/detection/core.py +1 -1
  47. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  48. doctr/models/detection/differentiable_binarization/base.py +7 -16
  49. doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
  50. doctr/models/detection/fast/__init__.py +1 -6
  51. doctr/models/detection/fast/base.py +6 -17
  52. doctr/models/detection/fast/pytorch.py +17 -8
  53. doctr/models/detection/linknet/__init__.py +1 -6
  54. doctr/models/detection/linknet/base.py +5 -15
  55. doctr/models/detection/linknet/pytorch.py +12 -3
  56. doctr/models/detection/predictor/__init__.py +1 -6
  57. doctr/models/detection/predictor/pytorch.py +1 -1
  58. doctr/models/detection/zoo.py +15 -32
  59. doctr/models/factory/hub.py +9 -22
  60. doctr/models/kie_predictor/__init__.py +1 -6
  61. doctr/models/kie_predictor/pytorch.py +3 -7
  62. doctr/models/modules/layers/__init__.py +1 -6
  63. doctr/models/modules/layers/pytorch.py +52 -4
  64. doctr/models/modules/transformer/__init__.py +1 -6
  65. doctr/models/modules/transformer/pytorch.py +2 -2
  66. doctr/models/modules/vision_transformer/__init__.py +1 -6
  67. doctr/models/predictor/__init__.py +1 -6
  68. doctr/models/predictor/base.py +3 -8
  69. doctr/models/predictor/pytorch.py +3 -6
  70. doctr/models/preprocessor/__init__.py +1 -6
  71. doctr/models/preprocessor/pytorch.py +27 -32
  72. doctr/models/recognition/__init__.py +1 -0
  73. doctr/models/recognition/crnn/__init__.py +1 -6
  74. doctr/models/recognition/crnn/pytorch.py +16 -7
  75. doctr/models/recognition/master/__init__.py +1 -6
  76. doctr/models/recognition/master/pytorch.py +15 -6
  77. doctr/models/recognition/parseq/__init__.py +1 -6
  78. doctr/models/recognition/parseq/pytorch.py +26 -8
  79. doctr/models/recognition/predictor/__init__.py +1 -6
  80. doctr/models/recognition/predictor/_utils.py +100 -47
  81. doctr/models/recognition/predictor/pytorch.py +4 -5
  82. doctr/models/recognition/sar/__init__.py +1 -6
  83. doctr/models/recognition/sar/pytorch.py +13 -4
  84. doctr/models/recognition/utils.py +56 -47
  85. doctr/models/recognition/viptr/__init__.py +1 -0
  86. doctr/models/recognition/viptr/pytorch.py +277 -0
  87. doctr/models/recognition/vitstr/__init__.py +1 -6
  88. doctr/models/recognition/vitstr/pytorch.py +13 -4
  89. doctr/models/recognition/zoo.py +13 -8
  90. doctr/models/utils/__init__.py +1 -6
  91. doctr/models/utils/pytorch.py +29 -19
  92. doctr/transforms/functional/__init__.py +1 -6
  93. doctr/transforms/functional/pytorch.py +4 -4
  94. doctr/transforms/modules/__init__.py +1 -7
  95. doctr/transforms/modules/base.py +26 -92
  96. doctr/transforms/modules/pytorch.py +28 -26
  97. doctr/utils/data.py +1 -1
  98. doctr/utils/geometry.py +7 -11
  99. doctr/utils/visualization.py +1 -1
  100. doctr/version.py +1 -1
  101. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
  102. python_doctr-1.0.0.dist-info/RECORD +149 -0
  103. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
  104. doctr/datasets/datasets/tensorflow.py +0 -59
  105. doctr/datasets/generator/tensorflow.py +0 -58
  106. doctr/datasets/loader.py +0 -94
  107. doctr/io/image/tensorflow.py +0 -101
  108. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  109. doctr/models/classification/mobilenet/tensorflow.py +0 -433
  110. doctr/models/classification/predictor/tensorflow.py +0 -60
  111. doctr/models/classification/resnet/tensorflow.py +0 -397
  112. doctr/models/classification/textnet/tensorflow.py +0 -266
  113. doctr/models/classification/vgg/tensorflow.py +0 -116
  114. doctr/models/classification/vit/tensorflow.py +0 -192
  115. doctr/models/detection/_utils/tensorflow.py +0 -34
  116. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
  117. doctr/models/detection/fast/tensorflow.py +0 -419
  118. doctr/models/detection/linknet/tensorflow.py +0 -369
  119. doctr/models/detection/predictor/tensorflow.py +0 -70
  120. doctr/models/kie_predictor/tensorflow.py +0 -187
  121. doctr/models/modules/layers/tensorflow.py +0 -171
  122. doctr/models/modules/transformer/tensorflow.py +0 -235
  123. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  124. doctr/models/predictor/tensorflow.py +0 -155
  125. doctr/models/preprocessor/tensorflow.py +0 -122
  126. doctr/models/recognition/crnn/tensorflow.py +0 -308
  127. doctr/models/recognition/master/tensorflow.py +0 -313
  128. doctr/models/recognition/parseq/tensorflow.py +0 -508
  129. doctr/models/recognition/predictor/tensorflow.py +0 -79
  130. doctr/models/recognition/sar/tensorflow.py +0 -416
  131. doctr/models/recognition/vitstr/tensorflow.py +0 -278
  132. doctr/models/utils/tensorflow.py +0 -182
  133. doctr/transforms/functional/tensorflow.py +0 -254
  134. doctr/transforms/modules/tensorflow.py +0 -562
  135. python_doctr-0.11.0.dist-info/RECORD +0 -173
  136. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
  137. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  138. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -1,182 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- import logging
7
- from collections.abc import Callable
8
- from typing import Any
9
-
10
- import tensorflow as tf
11
- import tf2onnx
12
- from tensorflow.keras import Model, layers
13
-
14
- from doctr.utils.data import download_from_url
15
-
16
- logging.getLogger("tensorflow").setLevel(logging.DEBUG)
17
-
18
-
19
- __all__ = [
20
- "load_pretrained_params",
21
- "_build_model",
22
- "conv_sequence",
23
- "IntermediateLayerGetter",
24
- "export_model_to_onnx",
25
- "_copy_tensor",
26
- "_bf16_to_float32",
27
- ]
28
-
29
-
30
- def _copy_tensor(x: tf.Tensor) -> tf.Tensor:
31
- return tf.identity(x)
32
-
33
-
34
- def _bf16_to_float32(x: tf.Tensor) -> tf.Tensor:
35
- # Convert bfloat16 to float32 for numpy compatibility
36
- return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x
37
-
38
-
39
- def _build_model(model: Model):
40
- """Build a model by calling it once with dummy input
41
-
42
- Args:
43
- model: the model to be built
44
- """
45
- model(tf.zeros((1, *model.cfg["input_shape"])), training=False)
46
-
47
-
48
- def load_pretrained_params(
49
- model: Model,
50
- url: str | None = None,
51
- hash_prefix: str | None = None,
52
- skip_mismatch: bool = False,
53
- **kwargs: Any,
54
- ) -> None:
55
- """Load a set of parameters onto a model
56
-
57
- >>> from doctr.models import load_pretrained_params
58
- >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.weights.h5")
59
-
60
- Args:
61
- model: the keras model to be loaded
62
- url: URL of the zipped set of parameters
63
- hash_prefix: first characters of SHA256 expected hash
64
- skip_mismatch: skip loading layers with mismatched shapes
65
- **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
66
- """
67
- if url is None:
68
- logging.warning("Invalid model URL, using default initialization.")
69
- else:
70
- archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
71
- # Load weights
72
- model.load_weights(archive_path, skip_mismatch=skip_mismatch)
73
-
74
-
75
- def conv_sequence(
76
- out_channels: int,
77
- activation: str | Callable | None = None,
78
- bn: bool = False,
79
- padding: str = "same",
80
- kernel_initializer: str = "he_normal",
81
- **kwargs: Any,
82
- ) -> list[layers.Layer]:
83
- """Builds a convolutional-based layer sequence
84
-
85
- >>> from tensorflow.keras import Sequential
86
- >>> from doctr.models import conv_sequence
87
- >>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3]))
88
-
89
- Args:
90
- out_channels: number of output channels
91
- activation: activation to be used (default: no activation)
92
- bn: should a batch normalization layer be added
93
- padding: padding scheme
94
- kernel_initializer: kernel initializer
95
- **kwargs: additional arguments to be passed to the convolutional layer
96
-
97
- Returns:
98
- list of layers
99
- """
100
- # No bias before Batch norm
101
- kwargs["use_bias"] = kwargs.get("use_bias", not bn)
102
- # Add activation directly to the conv if there is no BN
103
- kwargs["activation"] = activation if not bn else None
104
- conv_seq = [layers.Conv2D(out_channels, padding=padding, kernel_initializer=kernel_initializer, **kwargs)]
105
-
106
- if bn:
107
- conv_seq.append(layers.BatchNormalization())
108
-
109
- if (isinstance(activation, str) or callable(activation)) and bn:
110
- # Activation function can either be a string or a function ('relu' or tf.nn.relu)
111
- conv_seq.append(layers.Activation(activation))
112
-
113
- return conv_seq
114
-
115
-
116
- class IntermediateLayerGetter(Model):
117
- """Implements an intermediate layer getter
118
-
119
- >>> from tensorflow.keras.applications import ResNet50
120
- >>> from doctr.models import IntermediateLayerGetter
121
- >>> target_layers = ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"]
122
- >>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers)
123
-
124
- Args:
125
- model: the model to extract feature maps from
126
- layer_names: the list of layers to retrieve the feature map from
127
- """
128
-
129
- def __init__(self, model: Model, layer_names: list[str]) -> None:
130
- intermediate_fmaps = [model.get_layer(layer_name).get_output_at(0) for layer_name in layer_names]
131
- super().__init__(model.input, outputs=intermediate_fmaps)
132
-
133
- def __repr__(self) -> str:
134
- return f"{self.__class__.__name__}()"
135
-
136
-
137
- def export_model_to_onnx(
138
- model: Model, model_name: str, dummy_input: list[tf.TensorSpec], **kwargs: Any
139
- ) -> tuple[str, list[str]]:
140
- """Export model to ONNX format.
141
-
142
- >>> import tensorflow as tf
143
- >>> from doctr.models.classification import resnet18
144
- >>> from doctr.models.utils import export_classification_model_to_onnx
145
- >>> model = resnet18(pretrained=True, include_top=True)
146
- >>> export_model_to_onnx(model, "my_model",
147
- >>> dummy_input=[tf.TensorSpec([None, 32, 32, 3], tf.float32, name="input")])
148
-
149
- Args:
150
- model: the keras model to be exported
151
- model_name: the name for the exported model
152
- dummy_input: the dummy input to the model
153
- kwargs: additional arguments to be passed to tf2onnx
154
-
155
- Returns:
156
- the path to the exported model and a list with the output layer names
157
- """
158
- # get the users eager mode
159
- eager_mode = tf.executing_eagerly()
160
- # set eager mode to true to avoid issues with tf2onnx
161
- tf.config.run_functions_eagerly(True)
162
- large_model = kwargs.get("large_model", False)
163
- model_proto, _ = tf2onnx.convert.from_keras(
164
- model,
165
- input_signature=dummy_input,
166
- output_path=f"{model_name}.zip" if large_model else f"{model_name}.onnx",
167
- **kwargs,
168
- )
169
- # Get the output layer names
170
- output = [n.name for n in model_proto.graph.output]
171
-
172
- # reset the eager mode to the users mode
173
- tf.config.run_functions_eagerly(eager_mode)
174
-
175
- # models which are too large (weights > 2GB while converting to ONNX) needs to be handled
176
- # about an external tensor storage where the graph and weights are seperatly stored in a archive
177
- if large_model:
178
- logging.info(f"Model exported to {model_name}.zip")
179
- return f"{model_name}.zip", output
180
-
181
- logging.info(f"Model exported to {model_name}.zip")
182
- return f"{model_name}.onnx", output
@@ -1,254 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- import math
7
- import random
8
- from collections.abc import Iterable
9
- from copy import deepcopy
10
-
11
- import numpy as np
12
- import tensorflow as tf
13
-
14
- from doctr.utils.geometry import compute_expanded_shape, rotate_abs_geoms
15
-
16
- from .base import create_shadow_mask, crop_boxes
17
-
18
- __all__ = ["invert_colors", "rotate_sample", "crop_detection", "random_shadow", "rotated_img_tensor"]
19
-
20
-
21
- def invert_colors(img: tf.Tensor, min_val: float = 0.6) -> tf.Tensor:
22
- """Invert the colors of an image
23
-
24
- Args:
25
- img : tf.Tensor, the image to invert
26
- min_val : minimum value of the random shift
27
-
28
- Returns:
29
- the inverted image
30
- """
31
- out = tf.image.rgb_to_grayscale(img) # Convert to gray
32
- # Random RGB shift
33
- shift_shape = [img.shape[0], 1, 1, 3] if img.ndim == 4 else [1, 1, 3]
34
- rgb_shift = tf.random.uniform(shape=shift_shape, minval=min_val, maxval=1)
35
- # Inverse the color
36
- if out.dtype == tf.uint8:
37
- out = tf.cast(tf.cast(out, dtype=rgb_shift.dtype) * rgb_shift, dtype=tf.uint8)
38
- else:
39
- out *= tf.cast(rgb_shift, dtype=out.dtype)
40
- # Inverse the color
41
- out = 255 - out if out.dtype == tf.uint8 else 1 - out
42
- return out
43
-
44
-
45
- def rotated_img_tensor(img: tf.Tensor, angle: float, expand: bool = False) -> tf.Tensor:
46
- """Rotate image around the center, interpolation=NEAREST, pad with 0 (black)
47
-
48
- Args:
49
- img: image to rotate
50
- angle: angle in degrees. +: counter-clockwise, -: clockwise
51
- expand: whether the image should be padded before the rotation
52
-
53
- Returns:
54
- the rotated image (tensor)
55
- """
56
- # Compute the expanded padding
57
- h_crop, w_crop = 0, 0
58
- if expand:
59
- exp_h, exp_w = compute_expanded_shape(img.shape[:-1], angle)
60
- h_diff, w_diff = int(math.ceil(exp_h - img.shape[0])), int(math.ceil(exp_w - img.shape[1]))
61
- h_pad, w_pad = max(h_diff, 0), max(w_diff, 0)
62
- exp_img = tf.pad(img, tf.constant([[h_pad // 2, h_pad - h_pad // 2], [w_pad // 2, w_pad - w_pad // 2], [0, 0]]))
63
- h_crop, w_crop = int(round(max(exp_img.shape[0] - exp_h, 0))), int(round(min(exp_img.shape[1] - exp_w, 0)))
64
- else:
65
- exp_img = img
66
-
67
- # Compute the rotation matrix
68
- height, width = tf.cast(tf.shape(exp_img)[0], tf.float32), tf.cast(tf.shape(exp_img)[1], tf.float32)
69
- cos_angle, sin_angle = tf.math.cos(angle * math.pi / 180.0), tf.math.sin(angle * math.pi / 180.0)
70
- x_offset = ((width - 1) - (cos_angle * (width - 1) - sin_angle * (height - 1))) / 2.0
71
- y_offset = ((height - 1) - (sin_angle * (width - 1) + cos_angle * (height - 1))) / 2.0
72
-
73
- rotation_matrix = tf.convert_to_tensor(
74
- [cos_angle, -sin_angle, x_offset, sin_angle, cos_angle, y_offset, 0.0, 0.0],
75
- dtype=tf.float32,
76
- )
77
- # Rotate the image
78
- rotated_img = tf.squeeze(
79
- tf.raw_ops.ImageProjectiveTransformV3(
80
- images=exp_img[None], # Add a batch dimension for compatibility with ImageProjectiveTransformV3
81
- transforms=rotation_matrix[None], # Add a batch dimension for compatibility with ImageProjectiveTransformV3
82
- output_shape=tf.shape(exp_img)[:2],
83
- interpolation="NEAREST",
84
- fill_mode="CONSTANT",
85
- fill_value=tf.constant(0.0, dtype=tf.float32),
86
- )
87
- )
88
- # Crop the rest
89
- if h_crop > 0 or w_crop > 0:
90
- h_slice = slice(h_crop // 2, -h_crop // 2) if h_crop > 0 else slice(rotated_img.shape[0])
91
- w_slice = slice(-w_crop // 2, -w_crop // 2) if w_crop > 0 else slice(rotated_img.shape[1])
92
- rotated_img = rotated_img[h_slice, w_slice]
93
-
94
- return rotated_img
95
-
96
-
97
- def rotate_sample(
98
- img: tf.Tensor,
99
- geoms: np.ndarray,
100
- angle: float,
101
- expand: bool = False,
102
- ) -> tuple[tf.Tensor, np.ndarray]:
103
- """Rotate image around the center, interpolation=NEAREST, pad with 0 (black)
104
-
105
- Args:
106
- img: image to rotate
107
- geoms: array of geometries of shape (N, 4) or (N, 4, 2)
108
- angle: angle in degrees. +: counter-clockwise, -: clockwise
109
- expand: whether the image should be padded before the rotation
110
-
111
- Returns:
112
- A tuple of rotated img (tensor), rotated boxes (np array)
113
- """
114
- # Rotated the image
115
- rotated_img = rotated_img_tensor(img, angle, expand)
116
-
117
- # Get absolute coords
118
- _geoms = deepcopy(geoms)
119
- if _geoms.shape[1:] == (4,):
120
- if np.max(_geoms) <= 1:
121
- _geoms[:, [0, 2]] *= img.shape[1]
122
- _geoms[:, [1, 3]] *= img.shape[0]
123
- elif _geoms.shape[1:] == (4, 2):
124
- if np.max(_geoms) <= 1:
125
- _geoms[..., 0] *= img.shape[1]
126
- _geoms[..., 1] *= img.shape[0]
127
- else:
128
- raise AssertionError
129
-
130
- # Rotate the boxes: xmin, ymin, xmax, ymax or polygons --> (4, 2) polygon
131
- rotated_geoms: np.ndarray = rotate_abs_geoms(_geoms, angle, img.shape[:-1], expand).astype(np.float32)
132
-
133
- # Always return relative boxes to avoid label confusions when resizing is performed aferwards
134
- rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[1]
135
- rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[0]
136
-
137
- return rotated_img, np.clip(np.around(rotated_geoms, decimals=15), 0, 1)
138
-
139
-
140
- def crop_detection(
141
- img: tf.Tensor, boxes: np.ndarray, crop_box: tuple[float, float, float, float]
142
- ) -> tuple[tf.Tensor, np.ndarray]:
143
- """Crop and image and associated bboxes
144
-
145
- Args:
146
- img: image to crop
147
- boxes: array of boxes to clip, absolute (int) or relative (float)
148
- crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords.
149
-
150
- Returns:
151
- A tuple of cropped image, cropped boxes, where the image is not resized.
152
- """
153
- if any(val < 0 or val > 1 for val in crop_box):
154
- raise AssertionError("coordinates of arg `crop_box` should be relative")
155
- h, w = img.shape[:2]
156
- xmin, ymin = int(round(crop_box[0] * (w - 1))), int(round(crop_box[1] * (h - 1)))
157
- xmax, ymax = int(round(crop_box[2] * (w - 1))), int(round(crop_box[3] * (h - 1)))
158
- cropped_img = tf.image.crop_to_bounding_box(img, ymin, xmin, ymax - ymin, xmax - xmin)
159
- # Crop the box
160
- boxes = crop_boxes(boxes, crop_box if boxes.max() <= 1 else (xmin, ymin, xmax, ymax))
161
-
162
- return cropped_img, boxes
163
-
164
-
165
- def _gaussian_filter(
166
- img: tf.Tensor,
167
- kernel_size: int | Iterable[int],
168
- sigma: float,
169
- mode: str | None = None,
170
- pad_value: int = 0,
171
- ):
172
- """Apply Gaussian filter to image.
173
- Adapted from: https://github.com/tensorflow/addons/blob/master/tensorflow_addons/image/filters.py
174
-
175
- Args:
176
- img: image to filter of shape (N, H, W, C)
177
- kernel_size: kernel size of the filter
178
- sigma: standard deviation of the Gaussian filter
179
- mode: padding mode, one of "CONSTANT", "REFLECT", "SYMMETRIC"
180
- pad_value: value to pad the image with
181
-
182
- Returns:
183
- A tensor of shape (N, H, W, C)
184
- """
185
- ksize = tf.convert_to_tensor(tf.broadcast_to(kernel_size, [2]), dtype=tf.int32)
186
- sigma = tf.convert_to_tensor(tf.broadcast_to(sigma, [2]), dtype=img.dtype)
187
- assert mode in ("CONSTANT", "REFLECT", "SYMMETRIC"), "mode should be one of 'CONSTANT', 'REFLECT', 'SYMMETRIC'"
188
- mode = "CONSTANT" if mode is None else str.upper(mode)
189
- constant_values = (
190
- tf.zeros([], dtype=img.dtype) if pad_value is None else tf.convert_to_tensor(pad_value, dtype=img.dtype)
191
- )
192
-
193
- def kernel1d(ksize: tf.Tensor, sigma: tf.Tensor, dtype: tf.DType):
194
- x = tf.range(ksize, dtype=dtype)
195
- x = x - tf.cast(tf.math.floordiv(ksize, 2), dtype=dtype)
196
- x = x + tf.where(tf.math.equal(tf.math.mod(ksize, 2), 0), tf.cast(0.5, dtype), 0)
197
- g = tf.math.exp(-(tf.math.pow(x, 2) / (2 * tf.math.pow(sigma, 2))))
198
- g = g / tf.reduce_sum(g)
199
- return g
200
-
201
- def kernel2d(ksize: tf.Tensor, sigma: tf.Tensor, dtype: tf.DType):
202
- kernel_x = kernel1d(ksize[0], sigma[0], dtype)
203
- kernel_y = kernel1d(ksize[1], sigma[1], dtype)
204
- return tf.matmul(
205
- tf.expand_dims(kernel_x, axis=-1),
206
- tf.transpose(tf.expand_dims(kernel_y, axis=-1)),
207
- )
208
-
209
- g = kernel2d(ksize, sigma, img.dtype)
210
- # Pad the image
211
- height, width = ksize[0], ksize[1]
212
- paddings = [
213
- [0, 0],
214
- [(height - 1) // 2, height - 1 - (height - 1) // 2],
215
- [(width - 1) // 2, width - 1 - (width - 1) // 2],
216
- [0, 0],
217
- ]
218
- img = tf.pad(img, paddings, mode=mode, constant_values=constant_values)
219
-
220
- channel = tf.shape(img)[-1]
221
- shape = tf.concat([ksize, tf.constant([1, 1], ksize.dtype)], axis=0)
222
- g = tf.reshape(g, shape)
223
- shape = tf.concat([ksize, [channel], tf.constant([1], ksize.dtype)], axis=0)
224
- g = tf.broadcast_to(g, shape)
225
- return tf.nn.depthwise_conv2d(img, g, [1, 1, 1, 1], padding="VALID", data_format="NHWC")
226
-
227
-
228
- def random_shadow(img: tf.Tensor, opacity_range: tuple[float, float], **kwargs) -> tf.Tensor:
229
- """Apply a random shadow to a given image
230
-
231
- Args:
232
- img: image to modify
233
- opacity_range: the minimum and maximum desired opacity of the shadow
234
- **kwargs: additional arguments to pass to `create_shadow_mask`
235
-
236
- Returns:
237
- shadowed image
238
- """
239
- shadow_mask = create_shadow_mask(img.shape[:2], **kwargs)
240
-
241
- opacity = np.random.uniform(*opacity_range)
242
- shadow_tensor = 1 - tf.convert_to_tensor(shadow_mask[..., None], dtype=tf.float32)
243
-
244
- # Add some blur to make it believable
245
- k = 7 + int(2 * 4 * random.random())
246
- sigma = random.uniform(0.5, 5.0)
247
- shadow_tensor = _gaussian_filter(
248
- shadow_tensor[tf.newaxis, ...],
249
- kernel_size=k,
250
- sigma=sigma,
251
- mode="REFLECT",
252
- )
253
-
254
- return tf.squeeze(opacity * shadow_tensor * img + (1 - opacity) * img, axis=0)