python-doctr 0.12.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +0 -5
  3. doctr/datasets/datasets/__init__.py +1 -6
  4. doctr/datasets/datasets/pytorch.py +2 -2
  5. doctr/datasets/generator/__init__.py +1 -6
  6. doctr/datasets/vocabs.py +0 -2
  7. doctr/file_utils.py +2 -101
  8. doctr/io/image/__init__.py +1 -7
  9. doctr/io/image/pytorch.py +1 -1
  10. doctr/models/_utils.py +3 -3
  11. doctr/models/classification/magc_resnet/__init__.py +1 -6
  12. doctr/models/classification/magc_resnet/pytorch.py +2 -2
  13. doctr/models/classification/mobilenet/__init__.py +1 -6
  14. doctr/models/classification/predictor/__init__.py +1 -6
  15. doctr/models/classification/predictor/pytorch.py +1 -1
  16. doctr/models/classification/resnet/__init__.py +1 -6
  17. doctr/models/classification/textnet/__init__.py +1 -6
  18. doctr/models/classification/textnet/pytorch.py +1 -1
  19. doctr/models/classification/vgg/__init__.py +1 -6
  20. doctr/models/classification/vip/__init__.py +1 -4
  21. doctr/models/classification/vip/layers/__init__.py +1 -4
  22. doctr/models/classification/vip/layers/pytorch.py +1 -1
  23. doctr/models/classification/vit/__init__.py +1 -6
  24. doctr/models/classification/vit/pytorch.py +2 -2
  25. doctr/models/classification/zoo.py +6 -11
  26. doctr/models/detection/_utils/__init__.py +1 -6
  27. doctr/models/detection/core.py +1 -1
  28. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  29. doctr/models/detection/differentiable_binarization/base.py +4 -12
  30. doctr/models/detection/differentiable_binarization/pytorch.py +3 -3
  31. doctr/models/detection/fast/__init__.py +1 -6
  32. doctr/models/detection/fast/base.py +4 -14
  33. doctr/models/detection/fast/pytorch.py +4 -4
  34. doctr/models/detection/linknet/__init__.py +1 -6
  35. doctr/models/detection/linknet/base.py +3 -12
  36. doctr/models/detection/linknet/pytorch.py +2 -2
  37. doctr/models/detection/predictor/__init__.py +1 -6
  38. doctr/models/detection/predictor/pytorch.py +1 -1
  39. doctr/models/detection/zoo.py +15 -32
  40. doctr/models/factory/hub.py +8 -21
  41. doctr/models/kie_predictor/__init__.py +1 -6
  42. doctr/models/kie_predictor/pytorch.py +2 -6
  43. doctr/models/modules/layers/__init__.py +1 -6
  44. doctr/models/modules/layers/pytorch.py +3 -3
  45. doctr/models/modules/transformer/__init__.py +1 -6
  46. doctr/models/modules/transformer/pytorch.py +2 -2
  47. doctr/models/modules/vision_transformer/__init__.py +1 -6
  48. doctr/models/predictor/__init__.py +1 -6
  49. doctr/models/predictor/base.py +3 -8
  50. doctr/models/predictor/pytorch.py +2 -5
  51. doctr/models/preprocessor/__init__.py +1 -6
  52. doctr/models/preprocessor/pytorch.py +27 -32
  53. doctr/models/recognition/crnn/__init__.py +1 -6
  54. doctr/models/recognition/crnn/pytorch.py +6 -6
  55. doctr/models/recognition/master/__init__.py +1 -6
  56. doctr/models/recognition/master/pytorch.py +5 -5
  57. doctr/models/recognition/parseq/__init__.py +1 -6
  58. doctr/models/recognition/parseq/pytorch.py +5 -5
  59. doctr/models/recognition/predictor/__init__.py +1 -6
  60. doctr/models/recognition/predictor/_utils.py +7 -16
  61. doctr/models/recognition/predictor/pytorch.py +1 -2
  62. doctr/models/recognition/sar/__init__.py +1 -6
  63. doctr/models/recognition/sar/pytorch.py +3 -3
  64. doctr/models/recognition/viptr/__init__.py +1 -4
  65. doctr/models/recognition/viptr/pytorch.py +3 -3
  66. doctr/models/recognition/vitstr/__init__.py +1 -6
  67. doctr/models/recognition/vitstr/pytorch.py +3 -3
  68. doctr/models/recognition/zoo.py +13 -13
  69. doctr/models/utils/__init__.py +1 -6
  70. doctr/models/utils/pytorch.py +1 -1
  71. doctr/transforms/functional/__init__.py +1 -6
  72. doctr/transforms/functional/pytorch.py +4 -4
  73. doctr/transforms/modules/__init__.py +1 -7
  74. doctr/transforms/modules/base.py +26 -92
  75. doctr/transforms/modules/pytorch.py +28 -26
  76. doctr/utils/geometry.py +6 -10
  77. doctr/utils/visualization.py +1 -1
  78. doctr/version.py +1 -1
  79. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +18 -75
  80. python_doctr-1.0.0.dist-info/RECORD +149 -0
  81. doctr/datasets/datasets/tensorflow.py +0 -59
  82. doctr/datasets/generator/tensorflow.py +0 -58
  83. doctr/datasets/loader.py +0 -94
  84. doctr/io/image/tensorflow.py +0 -101
  85. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  86. doctr/models/classification/mobilenet/tensorflow.py +0 -442
  87. doctr/models/classification/predictor/tensorflow.py +0 -60
  88. doctr/models/classification/resnet/tensorflow.py +0 -418
  89. doctr/models/classification/textnet/tensorflow.py +0 -275
  90. doctr/models/classification/vgg/tensorflow.py +0 -125
  91. doctr/models/classification/vit/tensorflow.py +0 -201
  92. doctr/models/detection/_utils/tensorflow.py +0 -34
  93. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
  94. doctr/models/detection/fast/tensorflow.py +0 -427
  95. doctr/models/detection/linknet/tensorflow.py +0 -377
  96. doctr/models/detection/predictor/tensorflow.py +0 -70
  97. doctr/models/kie_predictor/tensorflow.py +0 -187
  98. doctr/models/modules/layers/tensorflow.py +0 -171
  99. doctr/models/modules/transformer/tensorflow.py +0 -235
  100. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  101. doctr/models/predictor/tensorflow.py +0 -155
  102. doctr/models/preprocessor/tensorflow.py +0 -122
  103. doctr/models/recognition/crnn/tensorflow.py +0 -317
  104. doctr/models/recognition/master/tensorflow.py +0 -320
  105. doctr/models/recognition/parseq/tensorflow.py +0 -516
  106. doctr/models/recognition/predictor/tensorflow.py +0 -79
  107. doctr/models/recognition/sar/tensorflow.py +0 -423
  108. doctr/models/recognition/vitstr/tensorflow.py +0 -285
  109. doctr/models/utils/tensorflow.py +0 -189
  110. doctr/transforms/functional/tensorflow.py +0 -254
  111. doctr/transforms/modules/tensorflow.py +0 -562
  112. python_doctr-0.12.0.dist-info/RECORD +0 -180
  113. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +0 -0
  114. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/licenses/LICENSE +0 -0
  115. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  116. {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -1,189 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- import logging
7
- from collections.abc import Callable
8
- from typing import Any
9
-
10
- import tensorflow as tf
11
- import tf2onnx
12
- import validators
13
- from tensorflow.keras import Model, layers
14
-
15
- from doctr.utils.data import download_from_url
16
-
17
- logging.getLogger("tensorflow").setLevel(logging.DEBUG)
18
-
19
-
20
- __all__ = [
21
- "load_pretrained_params",
22
- "_build_model",
23
- "conv_sequence",
24
- "IntermediateLayerGetter",
25
- "export_model_to_onnx",
26
- "_copy_tensor",
27
- "_bf16_to_float32",
28
- ]
29
-
30
-
31
- def _copy_tensor(x: tf.Tensor) -> tf.Tensor:
32
- return tf.identity(x)
33
-
34
-
35
- def _bf16_to_float32(x: tf.Tensor) -> tf.Tensor:
36
- # Convert bfloat16 to float32 for numpy compatibility
37
- return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x
38
-
39
-
40
- def _build_model(model: Model):
41
- """Build a model by calling it once with dummy input
42
-
43
- Args:
44
- model: the model to be built
45
- """
46
- model(tf.zeros((1, *model.cfg["input_shape"])), training=False)
47
-
48
-
49
- def load_pretrained_params(
50
- model: Model,
51
- path_or_url: str | None = None,
52
- hash_prefix: str | None = None,
53
- skip_mismatch: bool = False,
54
- **kwargs: Any,
55
- ) -> None:
56
- """Load a set of parameters onto a model
57
-
58
- >>> from doctr.models import load_pretrained_params
59
- >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.weights.h5")
60
-
61
- Args:
62
- model: the keras model to be loaded
63
- path_or_url: the path or URL to the model parameters (checkpoint)
64
- hash_prefix: first characters of SHA256 expected hash
65
- skip_mismatch: skip loading layers with mismatched shapes
66
- **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
67
- """
68
- if path_or_url is None:
69
- logging.warning("No model URL or Path provided, using default initialization.")
70
- return
71
-
72
- archive_path = (
73
- download_from_url(path_or_url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
74
- if validators.url(path_or_url)
75
- else path_or_url
76
- )
77
-
78
- # Load weights
79
- model.load_weights(archive_path, skip_mismatch=skip_mismatch)
80
-
81
-
82
- def conv_sequence(
83
- out_channels: int,
84
- activation: str | Callable | None = None,
85
- bn: bool = False,
86
- padding: str = "same",
87
- kernel_initializer: str = "he_normal",
88
- **kwargs: Any,
89
- ) -> list[layers.Layer]:
90
- """Builds a convolutional-based layer sequence
91
-
92
- >>> from tensorflow.keras import Sequential
93
- >>> from doctr.models import conv_sequence
94
- >>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3]))
95
-
96
- Args:
97
- out_channels: number of output channels
98
- activation: activation to be used (default: no activation)
99
- bn: should a batch normalization layer be added
100
- padding: padding scheme
101
- kernel_initializer: kernel initializer
102
- **kwargs: additional arguments to be passed to the convolutional layer
103
-
104
- Returns:
105
- list of layers
106
- """
107
- # No bias before Batch norm
108
- kwargs["use_bias"] = kwargs.get("use_bias", not bn)
109
- # Add activation directly to the conv if there is no BN
110
- kwargs["activation"] = activation if not bn else None
111
- conv_seq = [layers.Conv2D(out_channels, padding=padding, kernel_initializer=kernel_initializer, **kwargs)]
112
-
113
- if bn:
114
- conv_seq.append(layers.BatchNormalization())
115
-
116
- if (isinstance(activation, str) or callable(activation)) and bn:
117
- # Activation function can either be a string or a function ('relu' or tf.nn.relu)
118
- conv_seq.append(layers.Activation(activation))
119
-
120
- return conv_seq
121
-
122
-
123
- class IntermediateLayerGetter(Model):
124
- """Implements an intermediate layer getter
125
-
126
- >>> from tensorflow.keras.applications import ResNet50
127
- >>> from doctr.models import IntermediateLayerGetter
128
- >>> target_layers = ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"]
129
- >>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers)
130
-
131
- Args:
132
- model: the model to extract feature maps from
133
- layer_names: the list of layers to retrieve the feature map from
134
- """
135
-
136
- def __init__(self, model: Model, layer_names: list[str]) -> None:
137
- intermediate_fmaps = [model.get_layer(layer_name).get_output_at(0) for layer_name in layer_names]
138
- super().__init__(model.input, outputs=intermediate_fmaps)
139
-
140
- def __repr__(self) -> str:
141
- return f"{self.__class__.__name__}()"
142
-
143
-
144
- def export_model_to_onnx(
145
- model: Model, model_name: str, dummy_input: list[tf.TensorSpec], **kwargs: Any
146
- ) -> tuple[str, list[str]]:
147
- """Export model to ONNX format.
148
-
149
- >>> import tensorflow as tf
150
- >>> from doctr.models.classification import resnet18
151
- >>> from doctr.models.utils import export_classification_model_to_onnx
152
- >>> model = resnet18(pretrained=True, include_top=True)
153
- >>> export_model_to_onnx(model, "my_model",
154
- >>> dummy_input=[tf.TensorSpec([None, 32, 32, 3], tf.float32, name="input")])
155
-
156
- Args:
157
- model: the keras model to be exported
158
- model_name: the name for the exported model
159
- dummy_input: the dummy input to the model
160
- kwargs: additional arguments to be passed to tf2onnx
161
-
162
- Returns:
163
- the path to the exported model and a list with the output layer names
164
- """
165
- # get the users eager mode
166
- eager_mode = tf.executing_eagerly()
167
- # set eager mode to true to avoid issues with tf2onnx
168
- tf.config.run_functions_eagerly(True)
169
- large_model = kwargs.get("large_model", False)
170
- model_proto, _ = tf2onnx.convert.from_keras(
171
- model,
172
- input_signature=dummy_input,
173
- output_path=f"{model_name}.zip" if large_model else f"{model_name}.onnx",
174
- **kwargs,
175
- )
176
- # Get the output layer names
177
- output = [n.name for n in model_proto.graph.output]
178
-
179
- # reset the eager mode to the users mode
180
- tf.config.run_functions_eagerly(eager_mode)
181
-
182
- # models which are too large (weights > 2GB while converting to ONNX) needs to be handled
183
- # about an external tensor storage where the graph and weights are seperatly stored in a archive
184
- if large_model:
185
- logging.info(f"Model exported to {model_name}.zip")
186
- return f"{model_name}.zip", output
187
-
188
- logging.info(f"Model exported to {model_name}.zip")
189
- return f"{model_name}.onnx", output
@@ -1,254 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- import math
7
- import random
8
- from collections.abc import Iterable
9
- from copy import deepcopy
10
-
11
- import numpy as np
12
- import tensorflow as tf
13
-
14
- from doctr.utils.geometry import compute_expanded_shape, rotate_abs_geoms
15
-
16
- from .base import create_shadow_mask, crop_boxes
17
-
18
- __all__ = ["invert_colors", "rotate_sample", "crop_detection", "random_shadow", "rotated_img_tensor"]
19
-
20
-
21
- def invert_colors(img: tf.Tensor, min_val: float = 0.6) -> tf.Tensor:
22
- """Invert the colors of an image
23
-
24
- Args:
25
- img : tf.Tensor, the image to invert
26
- min_val : minimum value of the random shift
27
-
28
- Returns:
29
- the inverted image
30
- """
31
- out = tf.image.rgb_to_grayscale(img) # Convert to gray
32
- # Random RGB shift
33
- shift_shape = [img.shape[0], 1, 1, 3] if img.ndim == 4 else [1, 1, 3]
34
- rgb_shift = tf.random.uniform(shape=shift_shape, minval=min_val, maxval=1)
35
- # Inverse the color
36
- if out.dtype == tf.uint8:
37
- out = tf.cast(tf.cast(out, dtype=rgb_shift.dtype) * rgb_shift, dtype=tf.uint8)
38
- else:
39
- out *= tf.cast(rgb_shift, dtype=out.dtype)
40
- # Inverse the color
41
- out = 255 - out if out.dtype == tf.uint8 else 1 - out
42
- return out
43
-
44
-
45
- def rotated_img_tensor(img: tf.Tensor, angle: float, expand: bool = False) -> tf.Tensor:
46
- """Rotate image around the center, interpolation=NEAREST, pad with 0 (black)
47
-
48
- Args:
49
- img: image to rotate
50
- angle: angle in degrees. +: counter-clockwise, -: clockwise
51
- expand: whether the image should be padded before the rotation
52
-
53
- Returns:
54
- the rotated image (tensor)
55
- """
56
- # Compute the expanded padding
57
- h_crop, w_crop = 0, 0
58
- if expand:
59
- exp_h, exp_w = compute_expanded_shape(img.shape[:-1], angle)
60
- h_diff, w_diff = int(math.ceil(exp_h - img.shape[0])), int(math.ceil(exp_w - img.shape[1]))
61
- h_pad, w_pad = max(h_diff, 0), max(w_diff, 0)
62
- exp_img = tf.pad(img, tf.constant([[h_pad // 2, h_pad - h_pad // 2], [w_pad // 2, w_pad - w_pad // 2], [0, 0]]))
63
- h_crop, w_crop = int(round(max(exp_img.shape[0] - exp_h, 0))), int(round(min(exp_img.shape[1] - exp_w, 0)))
64
- else:
65
- exp_img = img
66
-
67
- # Compute the rotation matrix
68
- height, width = tf.cast(tf.shape(exp_img)[0], tf.float32), tf.cast(tf.shape(exp_img)[1], tf.float32)
69
- cos_angle, sin_angle = tf.math.cos(angle * math.pi / 180.0), tf.math.sin(angle * math.pi / 180.0)
70
- x_offset = ((width - 1) - (cos_angle * (width - 1) - sin_angle * (height - 1))) / 2.0
71
- y_offset = ((height - 1) - (sin_angle * (width - 1) + cos_angle * (height - 1))) / 2.0
72
-
73
- rotation_matrix = tf.convert_to_tensor(
74
- [cos_angle, -sin_angle, x_offset, sin_angle, cos_angle, y_offset, 0.0, 0.0],
75
- dtype=tf.float32,
76
- )
77
- # Rotate the image
78
- rotated_img = tf.squeeze(
79
- tf.raw_ops.ImageProjectiveTransformV3(
80
- images=exp_img[None], # Add a batch dimension for compatibility with ImageProjectiveTransformV3
81
- transforms=rotation_matrix[None], # Add a batch dimension for compatibility with ImageProjectiveTransformV3
82
- output_shape=tf.shape(exp_img)[:2],
83
- interpolation="NEAREST",
84
- fill_mode="CONSTANT",
85
- fill_value=tf.constant(0.0, dtype=tf.float32),
86
- )
87
- )
88
- # Crop the rest
89
- if h_crop > 0 or w_crop > 0:
90
- h_slice = slice(h_crop // 2, -h_crop // 2) if h_crop > 0 else slice(rotated_img.shape[0])
91
- w_slice = slice(-w_crop // 2, -w_crop // 2) if w_crop > 0 else slice(rotated_img.shape[1])
92
- rotated_img = rotated_img[h_slice, w_slice]
93
-
94
- return rotated_img
95
-
96
-
97
- def rotate_sample(
98
- img: tf.Tensor,
99
- geoms: np.ndarray,
100
- angle: float,
101
- expand: bool = False,
102
- ) -> tuple[tf.Tensor, np.ndarray]:
103
- """Rotate image around the center, interpolation=NEAREST, pad with 0 (black)
104
-
105
- Args:
106
- img: image to rotate
107
- geoms: array of geometries of shape (N, 4) or (N, 4, 2)
108
- angle: angle in degrees. +: counter-clockwise, -: clockwise
109
- expand: whether the image should be padded before the rotation
110
-
111
- Returns:
112
- A tuple of rotated img (tensor), rotated boxes (np array)
113
- """
114
- # Rotated the image
115
- rotated_img = rotated_img_tensor(img, angle, expand)
116
-
117
- # Get absolute coords
118
- _geoms = deepcopy(geoms)
119
- if _geoms.shape[1:] == (4,):
120
- if np.max(_geoms) <= 1:
121
- _geoms[:, [0, 2]] *= img.shape[1]
122
- _geoms[:, [1, 3]] *= img.shape[0]
123
- elif _geoms.shape[1:] == (4, 2):
124
- if np.max(_geoms) <= 1:
125
- _geoms[..., 0] *= img.shape[1]
126
- _geoms[..., 1] *= img.shape[0]
127
- else:
128
- raise AssertionError
129
-
130
- # Rotate the boxes: xmin, ymin, xmax, ymax or polygons --> (4, 2) polygon
131
- rotated_geoms: np.ndarray = rotate_abs_geoms(_geoms, angle, img.shape[:-1], expand).astype(np.float32)
132
-
133
- # Always return relative boxes to avoid label confusions when resizing is performed aferwards
134
- rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[1]
135
- rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[0]
136
-
137
- return rotated_img, np.clip(np.around(rotated_geoms, decimals=15), 0, 1)
138
-
139
-
140
- def crop_detection(
141
- img: tf.Tensor, boxes: np.ndarray, crop_box: tuple[float, float, float, float]
142
- ) -> tuple[tf.Tensor, np.ndarray]:
143
- """Crop and image and associated bboxes
144
-
145
- Args:
146
- img: image to crop
147
- boxes: array of boxes to clip, absolute (int) or relative (float)
148
- crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords.
149
-
150
- Returns:
151
- A tuple of cropped image, cropped boxes, where the image is not resized.
152
- """
153
- if any(val < 0 or val > 1 for val in crop_box):
154
- raise AssertionError("coordinates of arg `crop_box` should be relative")
155
- h, w = img.shape[:2]
156
- xmin, ymin = int(round(crop_box[0] * (w - 1))), int(round(crop_box[1] * (h - 1)))
157
- xmax, ymax = int(round(crop_box[2] * (w - 1))), int(round(crop_box[3] * (h - 1)))
158
- cropped_img = tf.image.crop_to_bounding_box(img, ymin, xmin, ymax - ymin, xmax - xmin)
159
- # Crop the box
160
- boxes = crop_boxes(boxes, crop_box if boxes.max() <= 1 else (xmin, ymin, xmax, ymax))
161
-
162
- return cropped_img, boxes
163
-
164
-
165
- def _gaussian_filter(
166
- img: tf.Tensor,
167
- kernel_size: int | Iterable[int],
168
- sigma: float,
169
- mode: str | None = None,
170
- pad_value: int = 0,
171
- ):
172
- """Apply Gaussian filter to image.
173
- Adapted from: https://github.com/tensorflow/addons/blob/master/tensorflow_addons/image/filters.py
174
-
175
- Args:
176
- img: image to filter of shape (N, H, W, C)
177
- kernel_size: kernel size of the filter
178
- sigma: standard deviation of the Gaussian filter
179
- mode: padding mode, one of "CONSTANT", "REFLECT", "SYMMETRIC"
180
- pad_value: value to pad the image with
181
-
182
- Returns:
183
- A tensor of shape (N, H, W, C)
184
- """
185
- ksize = tf.convert_to_tensor(tf.broadcast_to(kernel_size, [2]), dtype=tf.int32)
186
- sigma = tf.convert_to_tensor(tf.broadcast_to(sigma, [2]), dtype=img.dtype)
187
- assert mode in ("CONSTANT", "REFLECT", "SYMMETRIC"), "mode should be one of 'CONSTANT', 'REFLECT', 'SYMMETRIC'"
188
- mode = "CONSTANT" if mode is None else str.upper(mode)
189
- constant_values = (
190
- tf.zeros([], dtype=img.dtype) if pad_value is None else tf.convert_to_tensor(pad_value, dtype=img.dtype)
191
- )
192
-
193
- def kernel1d(ksize: tf.Tensor, sigma: tf.Tensor, dtype: tf.DType):
194
- x = tf.range(ksize, dtype=dtype)
195
- x = x - tf.cast(tf.math.floordiv(ksize, 2), dtype=dtype)
196
- x = x + tf.where(tf.math.equal(tf.math.mod(ksize, 2), 0), tf.cast(0.5, dtype), 0)
197
- g = tf.math.exp(-(tf.math.pow(x, 2) / (2 * tf.math.pow(sigma, 2))))
198
- g = g / tf.reduce_sum(g)
199
- return g
200
-
201
- def kernel2d(ksize: tf.Tensor, sigma: tf.Tensor, dtype: tf.DType):
202
- kernel_x = kernel1d(ksize[0], sigma[0], dtype)
203
- kernel_y = kernel1d(ksize[1], sigma[1], dtype)
204
- return tf.matmul(
205
- tf.expand_dims(kernel_x, axis=-1),
206
- tf.transpose(tf.expand_dims(kernel_y, axis=-1)),
207
- )
208
-
209
- g = kernel2d(ksize, sigma, img.dtype)
210
- # Pad the image
211
- height, width = ksize[0], ksize[1]
212
- paddings = [
213
- [0, 0],
214
- [(height - 1) // 2, height - 1 - (height - 1) // 2],
215
- [(width - 1) // 2, width - 1 - (width - 1) // 2],
216
- [0, 0],
217
- ]
218
- img = tf.pad(img, paddings, mode=mode, constant_values=constant_values)
219
-
220
- channel = tf.shape(img)[-1]
221
- shape = tf.concat([ksize, tf.constant([1, 1], ksize.dtype)], axis=0)
222
- g = tf.reshape(g, shape)
223
- shape = tf.concat([ksize, [channel], tf.constant([1], ksize.dtype)], axis=0)
224
- g = tf.broadcast_to(g, shape)
225
- return tf.nn.depthwise_conv2d(img, g, [1, 1, 1, 1], padding="VALID", data_format="NHWC")
226
-
227
-
228
- def random_shadow(img: tf.Tensor, opacity_range: tuple[float, float], **kwargs) -> tf.Tensor:
229
- """Apply a random shadow to a given image
230
-
231
- Args:
232
- img: image to modify
233
- opacity_range: the minimum and maximum desired opacity of the shadow
234
- **kwargs: additional arguments to pass to `create_shadow_mask`
235
-
236
- Returns:
237
- shadowed image
238
- """
239
- shadow_mask = create_shadow_mask(img.shape[:2], **kwargs)
240
-
241
- opacity = np.random.uniform(*opacity_range)
242
- shadow_tensor = 1 - tf.convert_to_tensor(shadow_mask[..., None], dtype=tf.float32)
243
-
244
- # Add some blur to make it believable
245
- k = 7 + int(2 * 4 * random.random())
246
- sigma = random.uniform(0.5, 5.0)
247
- shadow_tensor = _gaussian_filter(
248
- shadow_tensor[tf.newaxis, ...],
249
- kernel_size=k,
250
- sigma=sigma,
251
- mode="REFLECT",
252
- )
253
-
254
- return tf.squeeze(opacity * shadow_tensor * img + (1 - opacity) * img, axis=0)