keras-nightly 3.12.0.dev2025082103__py3-none-any.whl → 3.12.0.dev2025082203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. keras/_tf_keras/keras/quantizers/__init__.py +1 -0
  2. keras/quantizers/__init__.py +1 -0
  3. keras/src/applications/convnext.py +20 -20
  4. keras/src/applications/densenet.py +21 -21
  5. keras/src/applications/efficientnet.py +16 -16
  6. keras/src/applications/efficientnet_v2.py +28 -28
  7. keras/src/applications/inception_resnet_v2.py +7 -7
  8. keras/src/applications/inception_v3.py +5 -5
  9. keras/src/applications/mobilenet_v2.py +13 -20
  10. keras/src/applications/mobilenet_v3.py +15 -15
  11. keras/src/applications/nasnet.py +7 -8
  12. keras/src/applications/resnet.py +32 -32
  13. keras/src/applications/xception.py +10 -10
  14. keras/src/backend/common/dtypes.py +3 -3
  15. keras/src/backend/common/variables.py +3 -1
  16. keras/src/backend/jax/export.py +1 -1
  17. keras/src/backend/jax/trainer.py +1 -1
  18. keras/src/backend/openvino/numpy.py +1 -1
  19. keras/src/backend/tensorflow/trainer.py +19 -1
  20. keras/src/backend/torch/core.py +6 -9
  21. keras/src/backend/torch/trainer.py +1 -1
  22. keras/src/callbacks/backup_and_restore.py +2 -2
  23. keras/src/callbacks/csv_logger.py +1 -1
  24. keras/src/callbacks/model_checkpoint.py +1 -1
  25. keras/src/callbacks/tensorboard.py +6 -6
  26. keras/src/datasets/boston_housing.py +1 -1
  27. keras/src/datasets/california_housing.py +1 -1
  28. keras/src/datasets/cifar10.py +1 -1
  29. keras/src/datasets/cifar100.py +2 -2
  30. keras/src/datasets/imdb.py +2 -2
  31. keras/src/datasets/mnist.py +1 -1
  32. keras/src/datasets/reuters.py +2 -2
  33. keras/src/dtype_policies/dtype_policy.py +1 -1
  34. keras/src/dtype_policies/dtype_policy_map.py +1 -1
  35. keras/src/export/tf2onnx_lib.py +1 -3
  36. keras/src/layers/input_spec.py +6 -6
  37. keras/src/layers/layer.py +1 -1
  38. keras/src/layers/preprocessing/category_encoding.py +3 -3
  39. keras/src/layers/preprocessing/data_layer.py +159 -0
  40. keras/src/layers/preprocessing/discretization.py +3 -3
  41. keras/src/layers/preprocessing/feature_space.py +4 -4
  42. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +7 -4
  43. keras/src/layers/preprocessing/image_preprocessing/auto_contrast.py +3 -0
  44. keras/src/layers/preprocessing/image_preprocessing/base_image_preprocessing_layer.py +2 -2
  45. keras/src/layers/preprocessing/image_preprocessing/center_crop.py +1 -1
  46. keras/src/layers/preprocessing/image_preprocessing/cut_mix.py +6 -3
  47. keras/src/layers/preprocessing/image_preprocessing/equalization.py +1 -1
  48. keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box.py +3 -0
  49. keras/src/layers/preprocessing/image_preprocessing/mix_up.py +7 -4
  50. keras/src/layers/preprocessing/image_preprocessing/rand_augment.py +3 -1
  51. keras/src/layers/preprocessing/image_preprocessing/random_brightness.py +1 -1
  52. keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py +3 -0
  53. keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py +3 -0
  54. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +1 -1
  55. keras/src/layers/preprocessing/image_preprocessing/random_crop.py +1 -1
  56. keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform.py +3 -0
  57. keras/src/layers/preprocessing/image_preprocessing/random_erasing.py +6 -3
  58. keras/src/layers/preprocessing/image_preprocessing/random_flip.py +1 -1
  59. keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur.py +3 -0
  60. keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +1 -1
  61. keras/src/layers/preprocessing/image_preprocessing/random_hue.py +3 -0
  62. keras/src/layers/preprocessing/image_preprocessing/random_invert.py +3 -0
  63. keras/src/layers/preprocessing/image_preprocessing/random_perspective.py +3 -0
  64. keras/src/layers/preprocessing/image_preprocessing/random_posterization.py +3 -0
  65. keras/src/layers/preprocessing/image_preprocessing/random_rotation.py +1 -1
  66. keras/src/layers/preprocessing/image_preprocessing/random_saturation.py +3 -0
  67. keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py +3 -0
  68. keras/src/layers/preprocessing/image_preprocessing/random_shear.py +3 -0
  69. keras/src/layers/preprocessing/image_preprocessing/random_translation.py +3 -3
  70. keras/src/layers/preprocessing/image_preprocessing/random_zoom.py +3 -3
  71. keras/src/layers/preprocessing/image_preprocessing/resizing.py +3 -3
  72. keras/src/layers/preprocessing/image_preprocessing/solarization.py +3 -0
  73. keras/src/layers/preprocessing/mel_spectrogram.py +29 -25
  74. keras/src/layers/preprocessing/normalization.py +5 -2
  75. keras/src/layers/preprocessing/rescaling.py +3 -3
  76. keras/src/layers/rnn/bidirectional.py +4 -4
  77. keras/src/legacy/backend.py +9 -23
  78. keras/src/legacy/preprocessing/image.py +11 -22
  79. keras/src/legacy/preprocessing/text.py +1 -1
  80. keras/src/models/functional.py +2 -2
  81. keras/src/models/model.py +21 -3
  82. keras/src/ops/function.py +1 -1
  83. keras/src/ops/numpy.py +5 -5
  84. keras/src/ops/operation.py +3 -2
  85. keras/src/optimizers/base_optimizer.py +3 -4
  86. keras/src/quantizers/gptq.py +350 -0
  87. keras/src/quantizers/gptq_config.py +169 -0
  88. keras/src/quantizers/gptq_core.py +335 -0
  89. keras/src/quantizers/gptq_quant.py +133 -0
  90. keras/src/saving/file_editor.py +22 -20
  91. keras/src/saving/object_registration.py +1 -1
  92. keras/src/saving/saving_lib.py +4 -4
  93. keras/src/saving/serialization_lib.py +3 -5
  94. keras/src/trainers/compile_utils.py +1 -1
  95. keras/src/trainers/data_adapters/array_data_adapter.py +9 -3
  96. keras/src/trainers/data_adapters/data_adapter_utils.py +15 -5
  97. keras/src/trainers/data_adapters/generator_data_adapter.py +2 -0
  98. keras/src/trainers/data_adapters/grain_dataset_adapter.py +8 -2
  99. keras/src/trainers/data_adapters/tf_dataset_adapter.py +4 -2
  100. keras/src/trainers/data_adapters/torch_data_loader_adapter.py +3 -1
  101. keras/src/tree/dmtree_impl.py +19 -3
  102. keras/src/tree/optree_impl.py +3 -3
  103. keras/src/tree/tree_api.py +5 -2
  104. keras/src/utils/file_utils.py +13 -5
  105. keras/src/utils/io_utils.py +1 -1
  106. keras/src/utils/model_visualization.py +1 -1
  107. keras/src/utils/progbar.py +5 -5
  108. keras/src/utils/summary_utils.py +4 -4
  109. keras/src/version.py +1 -1
  110. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082203.dist-info}/METADATA +1 -1
  111. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082203.dist-info}/RECORD +113 -109
  112. keras/src/layers/preprocessing/tf_data_layer.py +0 -78
  113. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082203.dist-info}/WHEEL +0 -0
  114. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082203.dist-info}/top_level.txt +0 -0
@@ -87,7 +87,7 @@ def load_data(
87
87
  )
88
88
  path = get_file(
89
89
  fname=path,
90
- origin=origin_folder + "reuters.npz",
90
+ origin=f"{origin_folder}reuters.npz",
91
91
  file_hash=( # noqa: E501
92
92
  "d6586e694ee56d7a4e65172e12b3e987c03096cb01eab99753921ef915959916"
93
93
  ),
@@ -156,7 +156,7 @@ def get_word_index(path="reuters_word_index.json"):
156
156
  )
157
157
  path = get_file(
158
158
  path,
159
- origin=origin_folder + "reuters_word_index.json",
159
+ origin=f"{origin_folder}reuters_word_index.json",
160
160
  file_hash="4d44cc38712099c9e383dc6e5f11a921",
161
161
  )
162
162
  with open(path) as f:
@@ -3,7 +3,7 @@ from keras.src import ops
3
3
  from keras.src.api_export import keras_export
4
4
  from keras.src.backend.common import global_state
5
5
 
6
- QUANTIZATION_MODES = ("int8", "float8", "int4")
6
+ QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq")
7
7
 
8
8
 
9
9
  @keras_export(
@@ -74,7 +74,7 @@ class DTypePolicyMap(DTypePolicy, MutableMapping):
74
74
 
75
75
  @property
76
76
  def name(self):
77
- return "map_" + self.default_policy._name
77
+ return f"map_{self.default_policy._name}"
78
78
 
79
79
  @property
80
80
  def default_policy(self):
@@ -157,9 +157,7 @@ def patch_tf2onnx():
157
157
  ):
158
158
  a = copy.deepcopy(a)
159
159
  tensor_name = (
160
- self.name.strip()
161
- + "_"
162
- + str(external_tensor_storage.name_counter)
160
+ f"{self.name.strip()}_{external_tensor_storage.name_counter}"
163
161
  )
164
162
  for c in '~"#%&*:<>?/\\{|}':
165
163
  tensor_name = tensor_name.replace(c, "_")
@@ -94,12 +94,12 @@ class InputSpec:
94
94
 
95
95
  def __repr__(self):
96
96
  spec = [
97
- ("dtype=" + str(self.dtype)) if self.dtype else "",
98
- ("shape=" + str(self.shape)) if self.shape else "",
99
- ("ndim=" + str(self.ndim)) if self.ndim else "",
100
- ("max_ndim=" + str(self.max_ndim)) if self.max_ndim else "",
101
- ("min_ndim=" + str(self.min_ndim)) if self.min_ndim else "",
102
- ("axes=" + str(self.axes)) if self.axes else "",
97
+ (f"dtype={str(self.dtype)}") if self.dtype else "",
98
+ (f"shape={str(self.shape)}") if self.shape else "",
99
+ (f"ndim={str(self.ndim)}") if self.ndim else "",
100
+ (f"max_ndim={str(self.max_ndim)}") if self.max_ndim else "",
101
+ (f"min_ndim={str(self.min_ndim)}") if self.min_ndim else "",
102
+ (f"axes={str(self.axes)}") if self.axes else "",
103
103
  ]
104
104
  return f"InputSpec({', '.join(x for x in spec if x)})"
105
105
 
keras/src/layers/layer.py CHANGED
@@ -1337,7 +1337,7 @@ class Layer(BackendLayer, Operation):
1337
1337
  else:
1338
1338
  attr_name = str(attr)
1339
1339
  attr_type = "attribute"
1340
- msg = " " + msg if msg is not None else ""
1340
+ msg = f" {msg}" if msg is not None else ""
1341
1341
  return NotImplementedError(
1342
1342
  f"Layer {self.__class__.__name__} does not have a `{attr_name}` "
1343
1343
  f"{attr_type} implemented.{msg}"
@@ -1,12 +1,12 @@
1
1
  from keras.src.api_export import keras_export
2
2
  from keras.src.backend import KerasTensor
3
- from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
3
+ from keras.src.layers.preprocessing.data_layer import DataLayer
4
4
  from keras.src.utils import backend_utils
5
5
  from keras.src.utils import numerical_utils
6
6
 
7
7
 
8
8
  @keras_export("keras.layers.CategoryEncoding")
9
- class CategoryEncoding(TFDataLayer):
9
+ class CategoryEncoding(DataLayer):
10
10
  """A preprocessing layer which encodes integer features.
11
11
 
12
12
  This layer provides options for condensing data into a categorical encoding
@@ -15,7 +15,7 @@ class CategoryEncoding(TFDataLayer):
15
15
  inputs. For integer inputs where the total number of tokens is not known,
16
16
  use `keras.layers.IntegerLookup` instead.
17
17
 
18
- **Note:** This layer is safe to use inside a `tf.data` pipeline
18
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
19
19
  (independently of which backend you're using).
20
20
 
21
21
  Examples:
@@ -0,0 +1,159 @@
1
+ import keras.src.backend
2
+ from keras.src import tree
3
+ from keras.src.layers.layer import Layer
4
+ from keras.src.random.seed_generator import SeedGenerator
5
+ from keras.src.utils import backend_utils
6
+ from keras.src.utils import jax_utils
7
+ from keras.src.utils import tracking
8
+
9
+
10
+ class DataLayer(Layer):
11
+ """Layer designed for safe use in `tf.data` or `grain` pipeline.
12
+
13
+ This layer overrides the `__call__` method to ensure that the correct
14
+ backend is used and that computation is performed on the CPU.
15
+
16
+ The `call()` method in subclasses should use `self.backend` ops. If
17
+ randomness is needed, define both `seed` and `generator` in `__init__` and
18
+ retrieve the running seed using `self._get_seed_generator()`. If the layer
19
+ has weights in `__init__` or `build()`, use `convert_weight()` to ensure
20
+ they are in the correct backend.
21
+
22
+ **Note:** This layer and its subclasses only support a single input tensor.
23
+
24
+ Examples:
25
+
26
+ **Custom `DataLayer` subclass:**
27
+
28
+ ```python
29
+ from keras.src.layers.preprocessing.data_layer import DataLayer
30
+ from keras.src.random import SeedGenerator
31
+
32
+
33
+ class BiasedRandomRGBToHSVLayer(DataLayer):
34
+ def __init__(self, seed=None, **kwargs):
35
+ super().__init__(**kwargs)
36
+ self.probability_bias = ops.convert_to_tensor(0.01)
37
+ self.seed = seed
38
+ self.generator = SeedGenerator(seed)
39
+
40
+ def call(self, inputs):
41
+ images_shape = self.backend.shape(inputs)
42
+ batch_size = 1 if len(images_shape) == 3 else images_shape[0]
43
+ seed = self._get_seed_generator(self.backend._backend)
44
+
45
+ probability = self.backend.random.uniform(
46
+ shape=(batch_size,),
47
+ minval=0.0,
48
+ maxval=1.0,
49
+ seed=seed,
50
+ )
51
+ probability = self.backend.numpy.add(
52
+ probability, self.convert_weight(self.probability_bias)
53
+ )
54
+ hsv_images = self.backend.image.rgb_to_hsv(inputs)
55
+ return self.backend.numpy.where(
56
+ probability[:, None, None, None] > 0.5,
57
+ hsv_images,
58
+ inputs,
59
+ )
60
+
61
+ def compute_output_shape(self, input_shape):
62
+ return input_shape
63
+ ```
64
+
65
+ **Using as a regular Keras layer:**
66
+
67
+ ```python
68
+ import numpy as np
69
+
70
+ x = np.random.uniform(size=(1, 16, 16, 3)).astype("float32")
71
+ print(BiasedRandomRGBToHSVLayer()(x).shape) # (1, 16, 16, 3)
72
+ ```
73
+
74
+ **Using in a `tf.data` pipeline:**
75
+
76
+ ```python
77
+ import tensorflow as tf
78
+
79
+ tf_ds = tf.data.Dataset.from_tensors(x)
80
+ tf_ds = tf_ds.map(BiasedRandomRGBToHSVLayer())
81
+ print([x.shape for x in tf_ds]) # [(1, 16, 16, 3)]
82
+ ```
83
+
84
+ **Using in a `grain` pipeline:**
85
+
86
+ ```python
87
+ import grain
88
+
89
+ grain_ds = grain.MapDataset.source([x])
90
+ grain_ds = grain_ds.map(BiasedRandomRGBToHSVLayer())
91
+ print([x.shape for x in grain_ds]) # [(1, 16, 16, 3)]
92
+ """
93
+
94
+ def __init__(self, **kwargs):
95
+ super().__init__(**kwargs)
96
+ self.backend = backend_utils.DynamicBackend()
97
+ self._allow_non_tensor_positional_args = True
98
+
99
+ def __call__(self, inputs, **kwargs):
100
+ sample_input = tree.flatten(inputs)[0]
101
+ if (
102
+ not isinstance(sample_input, keras.KerasTensor)
103
+ and backend_utils.in_tf_graph()
104
+ and not jax_utils.is_in_jax_tracing_scope(sample_input)
105
+ ):
106
+ # We're in a TF graph, e.g. a tf.data pipeline.
107
+ self.backend.set_backend("tensorflow")
108
+ inputs = tree.map_structure(
109
+ lambda x: self.backend.convert_to_tensor(
110
+ x, dtype=self.compute_dtype
111
+ ),
112
+ inputs,
113
+ )
114
+ switch_convert_input_args = False
115
+ if self._convert_input_args:
116
+ self._convert_input_args = False
117
+ switch_convert_input_args = True
118
+ try:
119
+ outputs = super().__call__(inputs, **kwargs)
120
+ finally:
121
+ self.backend.reset()
122
+ if switch_convert_input_args:
123
+ self._convert_input_args = True
124
+ return outputs
125
+ elif (
126
+ not isinstance(sample_input, keras.KerasTensor)
127
+ and backend_utils.in_grain_data_pipeline()
128
+ ):
129
+ # We're in a Grain data pipeline. Force computation and data
130
+ # placement to CPU.
131
+ with keras.src.backend.device_scope("cpu"):
132
+ return super().__call__(inputs, **kwargs)
133
+ else:
134
+ return super().__call__(inputs, **kwargs)
135
+
136
+ @tracking.no_automatic_dependency_tracking
137
+ def _get_seed_generator(self, backend=None):
138
+ if not hasattr(self, "seed") or not hasattr(self, "generator"):
139
+ raise ValueError(
140
+ "The `seed` and `generator` variable must be set in the "
141
+ "`__init__` method before calling `_get_seed_generator()`."
142
+ )
143
+ if backend is None or backend == keras.backend.backend():
144
+ return self.generator
145
+ if not hasattr(self, "_backend_generators"):
146
+ self._backend_generators = {}
147
+ if backend in self._backend_generators:
148
+ return self._backend_generators[backend]
149
+ seed_generator = SeedGenerator(self.seed, backend=self.backend)
150
+ self._backend_generators[backend] = seed_generator
151
+ return seed_generator
152
+
153
+ def convert_weight(self, weight):
154
+ """Convert the weight if it is from the a different backend."""
155
+ if self.backend.name == keras.backend.backend():
156
+ return weight
157
+ else:
158
+ weight = keras.ops.convert_to_numpy(weight)
159
+ return self.backend.convert_to_tensor(weight)
@@ -2,21 +2,21 @@ import numpy as np
2
2
 
3
3
  from keras.src import backend
4
4
  from keras.src.api_export import keras_export
5
- from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
5
+ from keras.src.layers.preprocessing.data_layer import DataLayer
6
6
  from keras.src.utils import argument_validation
7
7
  from keras.src.utils import numerical_utils
8
8
  from keras.src.utils.module_utils import tensorflow as tf
9
9
 
10
10
 
11
11
  @keras_export("keras.layers.Discretization")
12
- class Discretization(TFDataLayer):
12
+ class Discretization(DataLayer):
13
13
  """A preprocessing layer which buckets continuous features by ranges.
14
14
 
15
15
  This layer will place each element of its input data into one of several
16
16
  contiguous ranges and output an integer index indicating which range each
17
17
  element was placed in.
18
18
 
19
- **Note:** This layer is safe to use inside a `tf.data` pipeline
19
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
20
20
  (independently of which backend you're using).
21
21
 
22
22
  Input shape:
@@ -3,7 +3,7 @@ from keras.src import layers
3
3
  from keras.src import tree
4
4
  from keras.src.api_export import keras_export
5
5
  from keras.src.layers.layer import Layer
6
- from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
6
+ from keras.src.layers.preprocessing.data_layer import DataLayer
7
7
  from keras.src.saving import saving_lib
8
8
  from keras.src.saving import serialization_lib
9
9
  from keras.src.saving.keras_saveable import KerasSaveable
@@ -723,7 +723,7 @@ class FeatureSpace(Layer):
723
723
  data[name] = tf.expand_dims(x, -1)
724
724
 
725
725
  with backend_utils.TFGraphScope():
726
- # This scope is to make sure that inner TFDataLayers
726
+ # This scope is to make sure that inner DataLayers
727
727
  # will not convert outputs back to backend-native --
728
728
  # they should be TF tensors throughout
729
729
  preprocessed_data = self._preprocess_features(data)
@@ -808,7 +808,7 @@ class FeatureSpace(Layer):
808
808
  return
809
809
 
810
810
 
811
- class TFDConcat(TFDataLayer):
811
+ class TFDConcat(DataLayer):
812
812
  def __init__(self, axis, **kwargs):
813
813
  super().__init__(**kwargs)
814
814
  self.axis = axis
@@ -817,6 +817,6 @@ class TFDConcat(TFDataLayer):
817
817
  return self.backend.numpy.concatenate(xs, axis=self.axis)
818
818
 
819
819
 
820
- class TFDIdentity(TFDataLayer):
820
+ class TFDIdentity(DataLayer):
821
821
  def call(self, x):
822
822
  return x
@@ -43,6 +43,13 @@ class AugMix(BaseImagePreprocessingLayer):
43
43
  in num_chains different ways, with each chain consisting of
44
44
  chain_depth augmentations.
45
45
 
46
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
47
+ (independently of which backend you're using).
48
+
49
+ References:
50
+ - [AugMix paper](https://arxiv.org/pdf/1912.02781)
51
+ - [Official Code](https://github.com/google-research/augmix)
52
+
46
53
  Args:
47
54
  value_range: the range of values the incoming images will have.
48
55
  Represented as a two number tuple written (low, high).
@@ -64,10 +71,6 @@ class AugMix(BaseImagePreprocessingLayer):
64
71
  interpolation: The interpolation method to use for resizing operations.
65
72
  Options include `"nearest"`, `"bilinear"`. Default is `"bilinear"`.
66
73
  seed: Integer. Used to create a random seed.
67
-
68
- References:
69
- - [AugMix paper](https://arxiv.org/pdf/1912.02781)
70
- - [Official Code](https://github.com/google-research/augmix)
71
74
  """
72
75
 
73
76
  _USE_BASE_FACTOR = False
@@ -17,6 +17,9 @@ class AutoContrast(BaseImagePreprocessingLayer):
17
17
 
18
18
  This layer is active at both training and inference time.
19
19
 
20
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
21
+ (independently of which backend you're using).
22
+
20
23
  Args:
21
24
  value_range: Range of values the incoming images will have.
22
25
  Represented as a two number tuple written `(low, high)`.
@@ -1,13 +1,13 @@
1
1
  import math
2
2
 
3
3
  from keras.src.backend import config as backend_config
4
+ from keras.src.layers.preprocessing.data_layer import DataLayer
4
5
  from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.validation import ( # noqa: E501
5
6
  densify_bounding_boxes,
6
7
  )
7
- from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
8
8
 
9
9
 
10
- class BaseImagePreprocessingLayer(TFDataLayer):
10
+ class BaseImagePreprocessingLayer(DataLayer):
11
11
  _USE_BASE_FACTOR = True
12
12
  _FACTOR_BOUNDS = (-1, 1)
13
13
 
@@ -36,7 +36,7 @@ class CenterCrop(BaseImagePreprocessingLayer):
36
36
  If the input height/width is even and the target height/width is odd (or
37
37
  inversely), the input image is left-padded by 1 pixel.
38
38
 
39
- **Note:** This layer is safe to use inside a `tf.data` pipeline
39
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
40
40
  (independently of which backend you're using).
41
41
 
42
42
  Args:
@@ -13,6 +13,12 @@ class CutMix(BaseImagePreprocessingLayer):
13
13
  between two images in the dataset, while the labels are also mixed
14
14
  proportionally to the area of the patches.
15
15
 
16
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
17
+ (independently of which backend you're using).
18
+
19
+ References:
20
+ - [CutMix paper]( https://arxiv.org/abs/1905.04899).
21
+
16
22
  Args:
17
23
  factor: A single float or a tuple of two floats between 0 and 1.
18
24
  If a tuple of numbers is passed, a `factor` is sampled
@@ -23,9 +29,6 @@ class CutMix(BaseImagePreprocessingLayer):
23
29
  in patch sizes, leading to more diverse and larger mixed patches.
24
30
  Defaults to 1.
25
31
  seed: Integer. Used to create a random seed.
26
-
27
- References:
28
- - [CutMix paper]( https://arxiv.org/abs/1905.04899).
29
32
  """
30
33
 
31
34
  _USE_BASE_FACTOR = False
@@ -18,7 +18,7 @@ class Equalization(BaseImagePreprocessingLayer):
18
18
  equalization independently on each color channel. At inference time,
19
19
  the equalization is consistently applied.
20
20
 
21
- **Note:** This layer is safe to use inside a `tf.data` pipeline
21
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
22
22
  (independently of which backend you're using).
23
23
 
24
24
  Args:
@@ -8,6 +8,9 @@ from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing
8
8
  class MaxNumBoundingBoxes(BaseImagePreprocessingLayer):
9
9
  """Ensure the maximum number of bounding boxes.
10
10
 
11
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
12
+ (independently of which backend you're using).
13
+
11
14
  Args:
12
15
  max_number: Desired output number of bounding boxes.
13
16
  padding_value: The padding value of the `boxes` and `labels` in
@@ -11,6 +11,13 @@ from keras.src.utils import backend_utils
11
11
  class MixUp(BaseImagePreprocessingLayer):
12
12
  """MixUp implements the MixUp data augmentation technique.
13
13
 
14
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
15
+ (independently of which backend you're using).
16
+
17
+ References:
18
+ - [MixUp paper](https://arxiv.org/abs/1710.09412).
19
+ - [MixUp for Object Detection paper](https://arxiv.org/pdf/1902.04103).
20
+
14
21
  Args:
15
22
  alpha: Float between 0 and 1. Controls the blending strength.
16
23
  Smaller values mean less mixing, while larger values allow
@@ -18,10 +25,6 @@ class MixUp(BaseImagePreprocessingLayer):
18
25
  recommended for ImageNet1k classification.
19
26
  seed: Integer. Used to create a random seed.
20
27
 
21
- References:
22
- - [MixUp paper](https://arxiv.org/abs/1710.09412).
23
- - [MixUp for Object Detection paper](https://arxiv.org/pdf/1902.04103).
24
-
25
28
  Example:
26
29
  ```python
27
30
  (images, labels), _ = keras.datasets.cifar10.load_data()
@@ -15,6 +15,9 @@ class RandAugment(BaseImagePreprocessingLayer):
15
15
  policy implemented by this layer has been benchmarked extensively and is
16
16
  effective on a wide variety of datasets.
17
17
 
18
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
19
+ (independently of which backend you're using).
20
+
18
21
  References:
19
22
  - [RandAugment](https://arxiv.org/abs/1909.13719)
20
23
 
@@ -29,7 +32,6 @@ class RandAugment(BaseImagePreprocessingLayer):
29
32
  interpolation: The interpolation method to use for resizing operations.
30
33
  Options include `nearest`, `bilinear`. Default is `bilinear`.
31
34
  seed: Integer. Used to create a random seed.
32
-
33
35
  """
34
36
 
35
37
  _USE_BASE_FACTOR = False
@@ -13,7 +13,7 @@ class RandomBrightness(BaseImagePreprocessingLayer):
13
13
  images. At inference time, the output will be identical to the input.
14
14
  Call the layer with `training=True` to adjust the brightness of the input.
15
15
 
16
- **Note:** This layer is safe to use inside a `tf.data` pipeline
16
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
17
17
  (independently of which backend you're using).
18
18
 
19
19
  Args:
@@ -13,6 +13,9 @@ class RandomColorDegeneration(BaseImagePreprocessingLayer):
13
13
  color. It then takes a weighted average between original image and the
14
14
  degenerated image. This makes colors appear more dull.
15
15
 
16
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
17
+ (independently of which backend you're using).
18
+
16
19
  Args:
17
20
  factor: A tuple of two floats or a single float.
18
21
  `factor` controls the extent to which the
@@ -16,6 +16,9 @@ class RandomColorJitter(BaseImagePreprocessingLayer):
16
16
  and hue image processing operation sequentially and randomly on the
17
17
  input.
18
18
 
19
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
20
+ (independently of which backend you're using).
21
+
19
22
  Args:
20
23
  value_range: the range of values the incoming images will have.
21
24
  Represented as a two number tuple written [low, high].
@@ -21,7 +21,7 @@ class RandomContrast(BaseImagePreprocessingLayer):
21
21
  in integer or floating point dtype.
22
22
  By default, the layer will output floats.
23
23
 
24
- **Note:** This layer is safe to use inside a `tf.data` pipeline
24
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
25
25
  (independently of which backend you're using).
26
26
 
27
27
  Input shape:
@@ -30,7 +30,7 @@ class RandomCrop(BaseImagePreprocessingLayer):
30
30
  of integer or floating point dtype. By default, the layer will output
31
31
  floats.
32
32
 
33
- **Note:** This layer is safe to use inside a `tf.data` pipeline
33
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
34
34
  (independently of which backend you're using).
35
35
 
36
36
  Input shape:
@@ -14,6 +14,9 @@ class RandomElasticTransform(BaseImagePreprocessingLayer):
14
14
  distortion is controlled by the `scale` parameter, while the `factor`
15
15
  determines the probability of applying the transformation.
16
16
 
17
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
18
+ (independently of which backend you're using).
19
+
17
20
  Args:
18
21
  factor: A single float or a tuple of two floats.
19
22
  `factor` controls the probability of applying the transformation.
@@ -13,6 +13,12 @@ class RandomErasing(BaseImagePreprocessingLayer):
13
13
  an image are erased (replaced by a constant value or noise)
14
14
  during training to improve generalization.
15
15
 
16
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
17
+ (independently of which backend you're using).
18
+
19
+ References:
20
+ - [Random Erasing paper](https://arxiv.org/abs/1708.04896).
21
+
16
22
  Args:
17
23
  factor: A single float or a tuple of two floats.
18
24
  `factor` controls the probability of applying the transformation.
@@ -35,9 +41,6 @@ class RandomErasing(BaseImagePreprocessingLayer):
35
41
  typically either `[0, 1]` or `[0, 255]` depending on how your
36
42
  preprocessing pipeline is set up.
37
43
  seed: Integer. Used to create a random seed.
38
-
39
- References:
40
- - [Random Erasing paper](https://arxiv.org/abs/1708.04896).
41
44
  """
42
45
 
43
46
  _USE_BASE_FACTOR = False
@@ -27,7 +27,7 @@ class RandomFlip(BaseImagePreprocessingLayer):
27
27
  of integer or floating point dtype.
28
28
  By default, the layer will output floats.
29
29
 
30
- **Note:** This layer is safe to use inside a `tf.data` pipeline
30
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
31
31
  (independently of which backend you're using).
32
32
 
33
33
  Input shape:
@@ -13,6 +13,9 @@ class RandomGaussianBlur(BaseImagePreprocessingLayer):
13
13
  randomly selected degree of blurring, controlled by the `factor` and
14
14
  `sigma` arguments.
15
15
 
16
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
17
+ (independently of which backend you're using).
18
+
16
19
  Args:
17
20
  factor: A single float or a tuple of two floats.
18
21
  `factor` controls the extent to which the image hue is impacted.
@@ -19,7 +19,7 @@ class RandomGrayscale(BaseImagePreprocessingLayer):
19
19
  image using standard RGB to grayscale conversion coefficients. Images
20
20
  that are not selected for conversion remain unchanged.
21
21
 
22
- **Note:** This layer is safe to use inside a `tf.data` pipeline
22
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
23
23
  (independently of which backend you're using).
24
24
 
25
25
  Args:
@@ -14,6 +14,9 @@ class RandomHue(BaseImagePreprocessingLayer):
14
14
  The image hue is adjusted by converting the image(s) to HSV and rotating the
15
15
  hue channel (H) by delta. The image is then converted back to RGB.
16
16
 
17
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
18
+ (independently of which backend you're using).
19
+
17
20
  Args:
18
21
  factor: A single float or a tuple of two floats.
19
22
  `factor` controls the extent to which the
@@ -14,6 +14,9 @@ class RandomInvert(BaseImagePreprocessingLayer):
14
14
  complementary values. Images that are not selected for inversion
15
15
  remain unchanged.
16
16
 
17
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
18
+ (independently of which backend you're using).
19
+
17
20
  Args:
18
21
  factor: A single float or a tuple of two floats.
19
22
  `factor` controls the probability of inverting the image colors.
@@ -20,6 +20,9 @@ class RandomPerspective(BaseImagePreprocessingLayer):
20
20
  corner points, simulating a 3D-like transformation. The amount of distortion
21
21
  is controlled by the `factor` and `scale` parameters.
22
22
 
23
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
24
+ (independently of which backend you're using).
25
+
23
26
  Args:
24
27
  factor: A float or a tuple of two floats.
25
28
  Represents the probability of applying the perspective
@@ -8,6 +8,9 @@ from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing
8
8
  class RandomPosterization(BaseImagePreprocessingLayer):
9
9
  """Reduces the number of bits for each color channel.
10
10
 
11
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
12
+ (independently of which backend you're using).
13
+
11
14
  References:
12
15
  - [AutoAugment: Learning Augmentation Policies from Data](https://arxiv.org/abs/1805.09501)
13
16
  - [RandAugment: Practical automated data augmentation with a reduced search space](https://arxiv.org/abs/1909.13719)
@@ -23,7 +23,7 @@ class RandomRotation(BaseImagePreprocessingLayer):
23
23
  of integer or floating point dtype.
24
24
  By default, the layer will output floats.
25
25
 
26
- **Note:** This layer is safe to use inside a `tf.data` pipeline
26
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
27
27
  (independently of which backend you're using).
28
28
 
29
29
  Input shape: