keras-nightly 3.12.0.dev2025082003__py3-none-any.whl → 3.12.0.dev2025082203__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/quantizers/__init__.py +1 -0
- keras/quantizers/__init__.py +1 -0
- keras/src/applications/convnext.py +20 -20
- keras/src/applications/densenet.py +21 -21
- keras/src/applications/efficientnet.py +16 -16
- keras/src/applications/efficientnet_v2.py +28 -28
- keras/src/applications/inception_resnet_v2.py +7 -7
- keras/src/applications/inception_v3.py +5 -5
- keras/src/applications/mobilenet_v2.py +13 -20
- keras/src/applications/mobilenet_v3.py +15 -15
- keras/src/applications/nasnet.py +7 -8
- keras/src/applications/resnet.py +32 -32
- keras/src/applications/xception.py +10 -10
- keras/src/backend/common/dtypes.py +3 -3
- keras/src/backend/common/variables.py +3 -1
- keras/src/backend/jax/export.py +1 -1
- keras/src/backend/jax/trainer.py +1 -1
- keras/src/backend/openvino/numpy.py +1 -1
- keras/src/backend/tensorflow/rnn.py +1 -1
- keras/src/backend/tensorflow/trainer.py +19 -1
- keras/src/backend/torch/core.py +6 -9
- keras/src/backend/torch/trainer.py +1 -1
- keras/src/callbacks/backup_and_restore.py +2 -2
- keras/src/callbacks/csv_logger.py +1 -1
- keras/src/callbacks/model_checkpoint.py +1 -1
- keras/src/callbacks/tensorboard.py +6 -6
- keras/src/datasets/boston_housing.py +1 -1
- keras/src/datasets/california_housing.py +1 -1
- keras/src/datasets/cifar10.py +1 -1
- keras/src/datasets/cifar100.py +2 -2
- keras/src/datasets/imdb.py +2 -2
- keras/src/datasets/mnist.py +1 -1
- keras/src/datasets/reuters.py +2 -2
- keras/src/dtype_policies/dtype_policy.py +1 -1
- keras/src/dtype_policies/dtype_policy_map.py +1 -1
- keras/src/export/tf2onnx_lib.py +1 -3
- keras/src/layers/attention/attention.py +2 -0
- keras/src/layers/core/lambda_layer.py +9 -8
- keras/src/layers/input_spec.py +6 -6
- keras/src/layers/layer.py +1 -1
- keras/src/layers/preprocessing/category_encoding.py +3 -3
- keras/src/layers/preprocessing/data_layer.py +159 -0
- keras/src/layers/preprocessing/discretization.py +3 -3
- keras/src/layers/preprocessing/feature_space.py +4 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +7 -4
- keras/src/layers/preprocessing/image_preprocessing/auto_contrast.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/base_image_preprocessing_layer.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/center_crop.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/cut_mix.py +6 -3
- keras/src/layers/preprocessing/image_preprocessing/equalization.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/mix_up.py +7 -4
- keras/src/layers/preprocessing/image_preprocessing/rand_augment.py +3 -1
- keras/src/layers/preprocessing/image_preprocessing/random_brightness.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_crop.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_erasing.py +6 -3
- keras/src/layers/preprocessing/image_preprocessing/random_flip.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_hue.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_invert.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_perspective.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_posterization.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_rotation.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_saturation.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_shear.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_translation.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/random_zoom.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/solarization.py +3 -0
- keras/src/layers/preprocessing/mel_spectrogram.py +29 -25
- keras/src/layers/preprocessing/normalization.py +5 -2
- keras/src/layers/preprocessing/rescaling.py +3 -3
- keras/src/layers/rnn/bidirectional.py +4 -4
- keras/src/legacy/backend.py +9 -23
- keras/src/legacy/preprocessing/image.py +11 -22
- keras/src/legacy/preprocessing/text.py +1 -1
- keras/src/legacy/saving/legacy_h5_format.py +7 -2
- keras/src/legacy/saving/saving_utils.py +0 -12
- keras/src/legacy/saving/serialization.py +0 -14
- keras/src/models/functional.py +2 -2
- keras/src/models/model.py +21 -3
- keras/src/ops/function.py +1 -1
- keras/src/ops/numpy.py +5 -5
- keras/src/ops/operation.py +3 -2
- keras/src/optimizers/base_optimizer.py +3 -4
- keras/src/quantizers/gptq.py +350 -0
- keras/src/quantizers/gptq_config.py +169 -0
- keras/src/quantizers/gptq_core.py +335 -0
- keras/src/quantizers/gptq_quant.py +133 -0
- keras/src/saving/file_editor.py +22 -20
- keras/src/saving/object_registration.py +1 -1
- keras/src/saving/saving_api.py +4 -1
- keras/src/saving/saving_lib.py +4 -4
- keras/src/saving/serialization_lib.py +9 -11
- keras/src/trainers/compile_utils.py +1 -1
- keras/src/trainers/data_adapters/array_data_adapter.py +9 -3
- keras/src/trainers/data_adapters/data_adapter_utils.py +15 -5
- keras/src/trainers/data_adapters/generator_data_adapter.py +2 -0
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +8 -2
- keras/src/trainers/data_adapters/tf_dataset_adapter.py +4 -2
- keras/src/trainers/data_adapters/torch_data_loader_adapter.py +3 -1
- keras/src/tree/dmtree_impl.py +19 -3
- keras/src/tree/optree_impl.py +3 -3
- keras/src/tree/tree_api.py +5 -2
- keras/src/utils/file_utils.py +13 -5
- keras/src/utils/io_utils.py +1 -1
- keras/src/utils/model_visualization.py +1 -1
- keras/src/utils/progbar.py +5 -5
- keras/src/utils/summary_utils.py +4 -4
- keras/src/utils/torch_utils.py +4 -4
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025082003.dist-info → keras_nightly-3.12.0.dev2025082203.dist-info}/METADATA +1 -1
- {keras_nightly-3.12.0.dev2025082003.dist-info → keras_nightly-3.12.0.dev2025082203.dist-info}/RECORD +121 -117
- keras/src/layers/preprocessing/tf_data_layer.py +0 -78
- {keras_nightly-3.12.0.dev2025082003.dist-info → keras_nightly-3.12.0.dev2025082203.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025082003.dist-info → keras_nightly-3.12.0.dev2025082203.dist-info}/top_level.txt +0 -0
keras/src/datasets/reuters.py
CHANGED
@@ -87,7 +87,7 @@ def load_data(
|
|
87
87
|
)
|
88
88
|
path = get_file(
|
89
89
|
fname=path,
|
90
|
-
origin=origin_folder
|
90
|
+
origin=f"{origin_folder}reuters.npz",
|
91
91
|
file_hash=( # noqa: E501
|
92
92
|
"d6586e694ee56d7a4e65172e12b3e987c03096cb01eab99753921ef915959916"
|
93
93
|
),
|
@@ -156,7 +156,7 @@ def get_word_index(path="reuters_word_index.json"):
|
|
156
156
|
)
|
157
157
|
path = get_file(
|
158
158
|
path,
|
159
|
-
origin=origin_folder
|
159
|
+
origin=f"{origin_folder}reuters_word_index.json",
|
160
160
|
file_hash="4d44cc38712099c9e383dc6e5f11a921",
|
161
161
|
)
|
162
162
|
with open(path) as f:
|
keras/src/export/tf2onnx_lib.py
CHANGED
@@ -157,9 +157,7 @@ def patch_tf2onnx():
|
|
157
157
|
):
|
158
158
|
a = copy.deepcopy(a)
|
159
159
|
tensor_name = (
|
160
|
-
self.name.strip()
|
161
|
-
+ "_"
|
162
|
-
+ str(external_tensor_storage.name_counter)
|
160
|
+
f"{self.name.strip()}_{external_tensor_storage.name_counter}"
|
163
161
|
)
|
164
162
|
for c in '~"#%&*:<>?/\\{|}':
|
165
163
|
tensor_name = tensor_name.replace(c, "_")
|
@@ -176,6 +176,8 @@ class Attention(Layer):
|
|
176
176
|
# Bias so padding positions do not contribute to attention
|
177
177
|
# distribution. Note 65504. is the max float16 value.
|
178
178
|
max_value = 65504.0 if scores.dtype == "float16" else 1.0e9
|
179
|
+
if len(padding_mask.shape) == 2:
|
180
|
+
padding_mask = ops.expand_dims(padding_mask, axis=-2)
|
179
181
|
scores -= max_value * ops.cast(padding_mask, dtype=scores.dtype)
|
180
182
|
|
181
183
|
weights = ops.softmax(scores, axis=-1)
|
@@ -167,14 +167,15 @@ class Lambda(Layer):
|
|
167
167
|
)
|
168
168
|
|
169
169
|
@staticmethod
|
170
|
-
def _raise_for_lambda_deserialization(
|
170
|
+
def _raise_for_lambda_deserialization(safe_mode):
|
171
171
|
if safe_mode:
|
172
172
|
raise ValueError(
|
173
|
-
|
174
|
-
"
|
175
|
-
"
|
176
|
-
"
|
177
|
-
"
|
173
|
+
"Requested the deserialization of a `Lambda` layer whose "
|
174
|
+
"`function` is a Python lambda. This carries a potential risk "
|
175
|
+
"of arbitrary code execution and thus it is disallowed by "
|
176
|
+
"default. If you trust the source of the artifact, you can "
|
177
|
+
"override this error by passing `safe_mode=False` to the "
|
178
|
+
"loading function, or calling "
|
178
179
|
"`keras.config.enable_unsafe_deserialization()."
|
179
180
|
)
|
180
181
|
|
@@ -187,7 +188,7 @@ class Lambda(Layer):
|
|
187
188
|
and "class_name" in fn_config
|
188
189
|
and fn_config["class_name"] == "__lambda__"
|
189
190
|
):
|
190
|
-
cls._raise_for_lambda_deserialization(
|
191
|
+
cls._raise_for_lambda_deserialization(safe_mode)
|
191
192
|
inner_config = fn_config["config"]
|
192
193
|
fn = python_utils.func_load(
|
193
194
|
inner_config["code"],
|
@@ -206,7 +207,7 @@ class Lambda(Layer):
|
|
206
207
|
and "class_name" in fn_config
|
207
208
|
and fn_config["class_name"] == "__lambda__"
|
208
209
|
):
|
209
|
-
cls._raise_for_lambda_deserialization(
|
210
|
+
cls._raise_for_lambda_deserialization(safe_mode)
|
210
211
|
inner_config = fn_config["config"]
|
211
212
|
fn = python_utils.func_load(
|
212
213
|
inner_config["code"],
|
keras/src/layers/input_spec.py
CHANGED
@@ -94,12 +94,12 @@ class InputSpec:
|
|
94
94
|
|
95
95
|
def __repr__(self):
|
96
96
|
spec = [
|
97
|
-
("dtype=
|
98
|
-
("shape=
|
99
|
-
("ndim=
|
100
|
-
("max_ndim=
|
101
|
-
("min_ndim=
|
102
|
-
("axes=
|
97
|
+
(f"dtype={str(self.dtype)}") if self.dtype else "",
|
98
|
+
(f"shape={str(self.shape)}") if self.shape else "",
|
99
|
+
(f"ndim={str(self.ndim)}") if self.ndim else "",
|
100
|
+
(f"max_ndim={str(self.max_ndim)}") if self.max_ndim else "",
|
101
|
+
(f"min_ndim={str(self.min_ndim)}") if self.min_ndim else "",
|
102
|
+
(f"axes={str(self.axes)}") if self.axes else "",
|
103
103
|
]
|
104
104
|
return f"InputSpec({', '.join(x for x in spec if x)})"
|
105
105
|
|
keras/src/layers/layer.py
CHANGED
@@ -1337,7 +1337,7 @@ class Layer(BackendLayer, Operation):
|
|
1337
1337
|
else:
|
1338
1338
|
attr_name = str(attr)
|
1339
1339
|
attr_type = "attribute"
|
1340
|
-
msg = " "
|
1340
|
+
msg = f" {msg}" if msg is not None else ""
|
1341
1341
|
return NotImplementedError(
|
1342
1342
|
f"Layer {self.__class__.__name__} does not have a `{attr_name}` "
|
1343
1343
|
f"{attr_type} implemented.{msg}"
|
@@ -1,12 +1,12 @@
|
|
1
1
|
from keras.src.api_export import keras_export
|
2
2
|
from keras.src.backend import KerasTensor
|
3
|
-
from keras.src.layers.preprocessing.
|
3
|
+
from keras.src.layers.preprocessing.data_layer import DataLayer
|
4
4
|
from keras.src.utils import backend_utils
|
5
5
|
from keras.src.utils import numerical_utils
|
6
6
|
|
7
7
|
|
8
8
|
@keras_export("keras.layers.CategoryEncoding")
|
9
|
-
class CategoryEncoding(
|
9
|
+
class CategoryEncoding(DataLayer):
|
10
10
|
"""A preprocessing layer which encodes integer features.
|
11
11
|
|
12
12
|
This layer provides options for condensing data into a categorical encoding
|
@@ -15,7 +15,7 @@ class CategoryEncoding(TFDataLayer):
|
|
15
15
|
inputs. For integer inputs where the total number of tokens is not known,
|
16
16
|
use `keras.layers.IntegerLookup` instead.
|
17
17
|
|
18
|
-
**Note:** This layer is safe to use inside a `tf.data` pipeline
|
18
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
19
19
|
(independently of which backend you're using).
|
20
20
|
|
21
21
|
Examples:
|
@@ -0,0 +1,159 @@
|
|
1
|
+
import keras.src.backend
|
2
|
+
from keras.src import tree
|
3
|
+
from keras.src.layers.layer import Layer
|
4
|
+
from keras.src.random.seed_generator import SeedGenerator
|
5
|
+
from keras.src.utils import backend_utils
|
6
|
+
from keras.src.utils import jax_utils
|
7
|
+
from keras.src.utils import tracking
|
8
|
+
|
9
|
+
|
10
|
+
class DataLayer(Layer):
|
11
|
+
"""Layer designed for safe use in `tf.data` or `grain` pipeline.
|
12
|
+
|
13
|
+
This layer overrides the `__call__` method to ensure that the correct
|
14
|
+
backend is used and that computation is performed on the CPU.
|
15
|
+
|
16
|
+
The `call()` method in subclasses should use `self.backend` ops. If
|
17
|
+
randomness is needed, define both `seed` and `generator` in `__init__` and
|
18
|
+
retrieve the running seed using `self._get_seed_generator()`. If the layer
|
19
|
+
has weights in `__init__` or `build()`, use `convert_weight()` to ensure
|
20
|
+
they are in the correct backend.
|
21
|
+
|
22
|
+
**Note:** This layer and its subclasses only support a single input tensor.
|
23
|
+
|
24
|
+
Examples:
|
25
|
+
|
26
|
+
**Custom `DataLayer` subclass:**
|
27
|
+
|
28
|
+
```python
|
29
|
+
from keras.src.layers.preprocessing.data_layer import DataLayer
|
30
|
+
from keras.src.random import SeedGenerator
|
31
|
+
|
32
|
+
|
33
|
+
class BiasedRandomRGBToHSVLayer(DataLayer):
|
34
|
+
def __init__(self, seed=None, **kwargs):
|
35
|
+
super().__init__(**kwargs)
|
36
|
+
self.probability_bias = ops.convert_to_tensor(0.01)
|
37
|
+
self.seed = seed
|
38
|
+
self.generator = SeedGenerator(seed)
|
39
|
+
|
40
|
+
def call(self, inputs):
|
41
|
+
images_shape = self.backend.shape(inputs)
|
42
|
+
batch_size = 1 if len(images_shape) == 3 else images_shape[0]
|
43
|
+
seed = self._get_seed_generator(self.backend._backend)
|
44
|
+
|
45
|
+
probability = self.backend.random.uniform(
|
46
|
+
shape=(batch_size,),
|
47
|
+
minval=0.0,
|
48
|
+
maxval=1.0,
|
49
|
+
seed=seed,
|
50
|
+
)
|
51
|
+
probability = self.backend.numpy.add(
|
52
|
+
probability, self.convert_weight(self.probability_bias)
|
53
|
+
)
|
54
|
+
hsv_images = self.backend.image.rgb_to_hsv(inputs)
|
55
|
+
return self.backend.numpy.where(
|
56
|
+
probability[:, None, None, None] > 0.5,
|
57
|
+
hsv_images,
|
58
|
+
inputs,
|
59
|
+
)
|
60
|
+
|
61
|
+
def compute_output_shape(self, input_shape):
|
62
|
+
return input_shape
|
63
|
+
```
|
64
|
+
|
65
|
+
**Using as a regular Keras layer:**
|
66
|
+
|
67
|
+
```python
|
68
|
+
import numpy as np
|
69
|
+
|
70
|
+
x = np.random.uniform(size=(1, 16, 16, 3)).astype("float32")
|
71
|
+
print(BiasedRandomRGBToHSVLayer()(x).shape) # (1, 16, 16, 3)
|
72
|
+
```
|
73
|
+
|
74
|
+
**Using in a `tf.data` pipeline:**
|
75
|
+
|
76
|
+
```python
|
77
|
+
import tensorflow as tf
|
78
|
+
|
79
|
+
tf_ds = tf.data.Dataset.from_tensors(x)
|
80
|
+
tf_ds = tf_ds.map(BiasedRandomRGBToHSVLayer())
|
81
|
+
print([x.shape for x in tf_ds]) # [(1, 16, 16, 3)]
|
82
|
+
```
|
83
|
+
|
84
|
+
**Using in a `grain` pipeline:**
|
85
|
+
|
86
|
+
```python
|
87
|
+
import grain
|
88
|
+
|
89
|
+
grain_ds = grain.MapDataset.source([x])
|
90
|
+
grain_ds = grain_ds.map(BiasedRandomRGBToHSVLayer())
|
91
|
+
print([x.shape for x in grain_ds]) # [(1, 16, 16, 3)]
|
92
|
+
"""
|
93
|
+
|
94
|
+
def __init__(self, **kwargs):
|
95
|
+
super().__init__(**kwargs)
|
96
|
+
self.backend = backend_utils.DynamicBackend()
|
97
|
+
self._allow_non_tensor_positional_args = True
|
98
|
+
|
99
|
+
def __call__(self, inputs, **kwargs):
|
100
|
+
sample_input = tree.flatten(inputs)[0]
|
101
|
+
if (
|
102
|
+
not isinstance(sample_input, keras.KerasTensor)
|
103
|
+
and backend_utils.in_tf_graph()
|
104
|
+
and not jax_utils.is_in_jax_tracing_scope(sample_input)
|
105
|
+
):
|
106
|
+
# We're in a TF graph, e.g. a tf.data pipeline.
|
107
|
+
self.backend.set_backend("tensorflow")
|
108
|
+
inputs = tree.map_structure(
|
109
|
+
lambda x: self.backend.convert_to_tensor(
|
110
|
+
x, dtype=self.compute_dtype
|
111
|
+
),
|
112
|
+
inputs,
|
113
|
+
)
|
114
|
+
switch_convert_input_args = False
|
115
|
+
if self._convert_input_args:
|
116
|
+
self._convert_input_args = False
|
117
|
+
switch_convert_input_args = True
|
118
|
+
try:
|
119
|
+
outputs = super().__call__(inputs, **kwargs)
|
120
|
+
finally:
|
121
|
+
self.backend.reset()
|
122
|
+
if switch_convert_input_args:
|
123
|
+
self._convert_input_args = True
|
124
|
+
return outputs
|
125
|
+
elif (
|
126
|
+
not isinstance(sample_input, keras.KerasTensor)
|
127
|
+
and backend_utils.in_grain_data_pipeline()
|
128
|
+
):
|
129
|
+
# We're in a Grain data pipeline. Force computation and data
|
130
|
+
# placement to CPU.
|
131
|
+
with keras.src.backend.device_scope("cpu"):
|
132
|
+
return super().__call__(inputs, **kwargs)
|
133
|
+
else:
|
134
|
+
return super().__call__(inputs, **kwargs)
|
135
|
+
|
136
|
+
@tracking.no_automatic_dependency_tracking
|
137
|
+
def _get_seed_generator(self, backend=None):
|
138
|
+
if not hasattr(self, "seed") or not hasattr(self, "generator"):
|
139
|
+
raise ValueError(
|
140
|
+
"The `seed` and `generator` variable must be set in the "
|
141
|
+
"`__init__` method before calling `_get_seed_generator()`."
|
142
|
+
)
|
143
|
+
if backend is None or backend == keras.backend.backend():
|
144
|
+
return self.generator
|
145
|
+
if not hasattr(self, "_backend_generators"):
|
146
|
+
self._backend_generators = {}
|
147
|
+
if backend in self._backend_generators:
|
148
|
+
return self._backend_generators[backend]
|
149
|
+
seed_generator = SeedGenerator(self.seed, backend=self.backend)
|
150
|
+
self._backend_generators[backend] = seed_generator
|
151
|
+
return seed_generator
|
152
|
+
|
153
|
+
def convert_weight(self, weight):
|
154
|
+
"""Convert the weight if it is from the a different backend."""
|
155
|
+
if self.backend.name == keras.backend.backend():
|
156
|
+
return weight
|
157
|
+
else:
|
158
|
+
weight = keras.ops.convert_to_numpy(weight)
|
159
|
+
return self.backend.convert_to_tensor(weight)
|
@@ -2,21 +2,21 @@ import numpy as np
|
|
2
2
|
|
3
3
|
from keras.src import backend
|
4
4
|
from keras.src.api_export import keras_export
|
5
|
-
from keras.src.layers.preprocessing.
|
5
|
+
from keras.src.layers.preprocessing.data_layer import DataLayer
|
6
6
|
from keras.src.utils import argument_validation
|
7
7
|
from keras.src.utils import numerical_utils
|
8
8
|
from keras.src.utils.module_utils import tensorflow as tf
|
9
9
|
|
10
10
|
|
11
11
|
@keras_export("keras.layers.Discretization")
|
12
|
-
class Discretization(
|
12
|
+
class Discretization(DataLayer):
|
13
13
|
"""A preprocessing layer which buckets continuous features by ranges.
|
14
14
|
|
15
15
|
This layer will place each element of its input data into one of several
|
16
16
|
contiguous ranges and output an integer index indicating which range each
|
17
17
|
element was placed in.
|
18
18
|
|
19
|
-
**Note:** This layer is safe to use inside a `tf.data` pipeline
|
19
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
20
20
|
(independently of which backend you're using).
|
21
21
|
|
22
22
|
Input shape:
|
@@ -3,7 +3,7 @@ from keras.src import layers
|
|
3
3
|
from keras.src import tree
|
4
4
|
from keras.src.api_export import keras_export
|
5
5
|
from keras.src.layers.layer import Layer
|
6
|
-
from keras.src.layers.preprocessing.
|
6
|
+
from keras.src.layers.preprocessing.data_layer import DataLayer
|
7
7
|
from keras.src.saving import saving_lib
|
8
8
|
from keras.src.saving import serialization_lib
|
9
9
|
from keras.src.saving.keras_saveable import KerasSaveable
|
@@ -723,7 +723,7 @@ class FeatureSpace(Layer):
|
|
723
723
|
data[name] = tf.expand_dims(x, -1)
|
724
724
|
|
725
725
|
with backend_utils.TFGraphScope():
|
726
|
-
# This scope is to make sure that inner
|
726
|
+
# This scope is to make sure that inner DataLayers
|
727
727
|
# will not convert outputs back to backend-native --
|
728
728
|
# they should be TF tensors throughout
|
729
729
|
preprocessed_data = self._preprocess_features(data)
|
@@ -808,7 +808,7 @@ class FeatureSpace(Layer):
|
|
808
808
|
return
|
809
809
|
|
810
810
|
|
811
|
-
class TFDConcat(
|
811
|
+
class TFDConcat(DataLayer):
|
812
812
|
def __init__(self, axis, **kwargs):
|
813
813
|
super().__init__(**kwargs)
|
814
814
|
self.axis = axis
|
@@ -817,6 +817,6 @@ class TFDConcat(TFDataLayer):
|
|
817
817
|
return self.backend.numpy.concatenate(xs, axis=self.axis)
|
818
818
|
|
819
819
|
|
820
|
-
class TFDIdentity(
|
820
|
+
class TFDIdentity(DataLayer):
|
821
821
|
def call(self, x):
|
822
822
|
return x
|
@@ -43,6 +43,13 @@ class AugMix(BaseImagePreprocessingLayer):
|
|
43
43
|
in num_chains different ways, with each chain consisting of
|
44
44
|
chain_depth augmentations.
|
45
45
|
|
46
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
47
|
+
(independently of which backend you're using).
|
48
|
+
|
49
|
+
References:
|
50
|
+
- [AugMix paper](https://arxiv.org/pdf/1912.02781)
|
51
|
+
- [Official Code](https://github.com/google-research/augmix)
|
52
|
+
|
46
53
|
Args:
|
47
54
|
value_range: the range of values the incoming images will have.
|
48
55
|
Represented as a two number tuple written (low, high).
|
@@ -64,10 +71,6 @@ class AugMix(BaseImagePreprocessingLayer):
|
|
64
71
|
interpolation: The interpolation method to use for resizing operations.
|
65
72
|
Options include `"nearest"`, `"bilinear"`. Default is `"bilinear"`.
|
66
73
|
seed: Integer. Used to create a random seed.
|
67
|
-
|
68
|
-
References:
|
69
|
-
- [AugMix paper](https://arxiv.org/pdf/1912.02781)
|
70
|
-
- [Official Code](https://github.com/google-research/augmix)
|
71
74
|
"""
|
72
75
|
|
73
76
|
_USE_BASE_FACTOR = False
|
@@ -17,6 +17,9 @@ class AutoContrast(BaseImagePreprocessingLayer):
|
|
17
17
|
|
18
18
|
This layer is active at both training and inference time.
|
19
19
|
|
20
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
21
|
+
(independently of which backend you're using).
|
22
|
+
|
20
23
|
Args:
|
21
24
|
value_range: Range of values the incoming images will have.
|
22
25
|
Represented as a two number tuple written `(low, high)`.
|
@@ -1,13 +1,13 @@
|
|
1
1
|
import math
|
2
2
|
|
3
3
|
from keras.src.backend import config as backend_config
|
4
|
+
from keras.src.layers.preprocessing.data_layer import DataLayer
|
4
5
|
from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.validation import ( # noqa: E501
|
5
6
|
densify_bounding_boxes,
|
6
7
|
)
|
7
|
-
from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
|
8
8
|
|
9
9
|
|
10
|
-
class BaseImagePreprocessingLayer(
|
10
|
+
class BaseImagePreprocessingLayer(DataLayer):
|
11
11
|
_USE_BASE_FACTOR = True
|
12
12
|
_FACTOR_BOUNDS = (-1, 1)
|
13
13
|
|
@@ -36,7 +36,7 @@ class CenterCrop(BaseImagePreprocessingLayer):
|
|
36
36
|
If the input height/width is even and the target height/width is odd (or
|
37
37
|
inversely), the input image is left-padded by 1 pixel.
|
38
38
|
|
39
|
-
**Note:** This layer is safe to use inside a `tf.data` pipeline
|
39
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
40
40
|
(independently of which backend you're using).
|
41
41
|
|
42
42
|
Args:
|
@@ -13,6 +13,12 @@ class CutMix(BaseImagePreprocessingLayer):
|
|
13
13
|
between two images in the dataset, while the labels are also mixed
|
14
14
|
proportionally to the area of the patches.
|
15
15
|
|
16
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
17
|
+
(independently of which backend you're using).
|
18
|
+
|
19
|
+
References:
|
20
|
+
- [CutMix paper]( https://arxiv.org/abs/1905.04899).
|
21
|
+
|
16
22
|
Args:
|
17
23
|
factor: A single float or a tuple of two floats between 0 and 1.
|
18
24
|
If a tuple of numbers is passed, a `factor` is sampled
|
@@ -23,9 +29,6 @@ class CutMix(BaseImagePreprocessingLayer):
|
|
23
29
|
in patch sizes, leading to more diverse and larger mixed patches.
|
24
30
|
Defaults to 1.
|
25
31
|
seed: Integer. Used to create a random seed.
|
26
|
-
|
27
|
-
References:
|
28
|
-
- [CutMix paper]( https://arxiv.org/abs/1905.04899).
|
29
32
|
"""
|
30
33
|
|
31
34
|
_USE_BASE_FACTOR = False
|
@@ -18,7 +18,7 @@ class Equalization(BaseImagePreprocessingLayer):
|
|
18
18
|
equalization independently on each color channel. At inference time,
|
19
19
|
the equalization is consistently applied.
|
20
20
|
|
21
|
-
**Note:** This layer is safe to use inside a `tf.data` pipeline
|
21
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
22
22
|
(independently of which backend you're using).
|
23
23
|
|
24
24
|
Args:
|
@@ -8,6 +8,9 @@ from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing
|
|
8
8
|
class MaxNumBoundingBoxes(BaseImagePreprocessingLayer):
|
9
9
|
"""Ensure the maximum number of bounding boxes.
|
10
10
|
|
11
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
12
|
+
(independently of which backend you're using).
|
13
|
+
|
11
14
|
Args:
|
12
15
|
max_number: Desired output number of bounding boxes.
|
13
16
|
padding_value: The padding value of the `boxes` and `labels` in
|
@@ -11,6 +11,13 @@ from keras.src.utils import backend_utils
|
|
11
11
|
class MixUp(BaseImagePreprocessingLayer):
|
12
12
|
"""MixUp implements the MixUp data augmentation technique.
|
13
13
|
|
14
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
15
|
+
(independently of which backend you're using).
|
16
|
+
|
17
|
+
References:
|
18
|
+
- [MixUp paper](https://arxiv.org/abs/1710.09412).
|
19
|
+
- [MixUp for Object Detection paper](https://arxiv.org/pdf/1902.04103).
|
20
|
+
|
14
21
|
Args:
|
15
22
|
alpha: Float between 0 and 1. Controls the blending strength.
|
16
23
|
Smaller values mean less mixing, while larger values allow
|
@@ -18,10 +25,6 @@ class MixUp(BaseImagePreprocessingLayer):
|
|
18
25
|
recommended for ImageNet1k classification.
|
19
26
|
seed: Integer. Used to create a random seed.
|
20
27
|
|
21
|
-
References:
|
22
|
-
- [MixUp paper](https://arxiv.org/abs/1710.09412).
|
23
|
-
- [MixUp for Object Detection paper](https://arxiv.org/pdf/1902.04103).
|
24
|
-
|
25
28
|
Example:
|
26
29
|
```python
|
27
30
|
(images, labels), _ = keras.datasets.cifar10.load_data()
|
@@ -15,6 +15,9 @@ class RandAugment(BaseImagePreprocessingLayer):
|
|
15
15
|
policy implemented by this layer has been benchmarked extensively and is
|
16
16
|
effective on a wide variety of datasets.
|
17
17
|
|
18
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
19
|
+
(independently of which backend you're using).
|
20
|
+
|
18
21
|
References:
|
19
22
|
- [RandAugment](https://arxiv.org/abs/1909.13719)
|
20
23
|
|
@@ -29,7 +32,6 @@ class RandAugment(BaseImagePreprocessingLayer):
|
|
29
32
|
interpolation: The interpolation method to use for resizing operations.
|
30
33
|
Options include `nearest`, `bilinear`. Default is `bilinear`.
|
31
34
|
seed: Integer. Used to create a random seed.
|
32
|
-
|
33
35
|
"""
|
34
36
|
|
35
37
|
_USE_BASE_FACTOR = False
|
@@ -13,7 +13,7 @@ class RandomBrightness(BaseImagePreprocessingLayer):
|
|
13
13
|
images. At inference time, the output will be identical to the input.
|
14
14
|
Call the layer with `training=True` to adjust the brightness of the input.
|
15
15
|
|
16
|
-
**Note:** This layer is safe to use inside a `tf.data` pipeline
|
16
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
17
17
|
(independently of which backend you're using).
|
18
18
|
|
19
19
|
Args:
|
@@ -13,6 +13,9 @@ class RandomColorDegeneration(BaseImagePreprocessingLayer):
|
|
13
13
|
color. It then takes a weighted average between original image and the
|
14
14
|
degenerated image. This makes colors appear more dull.
|
15
15
|
|
16
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
17
|
+
(independently of which backend you're using).
|
18
|
+
|
16
19
|
Args:
|
17
20
|
factor: A tuple of two floats or a single float.
|
18
21
|
`factor` controls the extent to which the
|
@@ -16,6 +16,9 @@ class RandomColorJitter(BaseImagePreprocessingLayer):
|
|
16
16
|
and hue image processing operation sequentially and randomly on the
|
17
17
|
input.
|
18
18
|
|
19
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
20
|
+
(independently of which backend you're using).
|
21
|
+
|
19
22
|
Args:
|
20
23
|
value_range: the range of values the incoming images will have.
|
21
24
|
Represented as a two number tuple written [low, high].
|
@@ -21,7 +21,7 @@ class RandomContrast(BaseImagePreprocessingLayer):
|
|
21
21
|
in integer or floating point dtype.
|
22
22
|
By default, the layer will output floats.
|
23
23
|
|
24
|
-
**Note:** This layer is safe to use inside a `tf.data` pipeline
|
24
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
25
25
|
(independently of which backend you're using).
|
26
26
|
|
27
27
|
Input shape:
|
@@ -30,7 +30,7 @@ class RandomCrop(BaseImagePreprocessingLayer):
|
|
30
30
|
of integer or floating point dtype. By default, the layer will output
|
31
31
|
floats.
|
32
32
|
|
33
|
-
**Note:** This layer is safe to use inside a `tf.data` pipeline
|
33
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
34
34
|
(independently of which backend you're using).
|
35
35
|
|
36
36
|
Input shape:
|
@@ -14,6 +14,9 @@ class RandomElasticTransform(BaseImagePreprocessingLayer):
|
|
14
14
|
distortion is controlled by the `scale` parameter, while the `factor`
|
15
15
|
determines the probability of applying the transformation.
|
16
16
|
|
17
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
18
|
+
(independently of which backend you're using).
|
19
|
+
|
17
20
|
Args:
|
18
21
|
factor: A single float or a tuple of two floats.
|
19
22
|
`factor` controls the probability of applying the transformation.
|
@@ -13,6 +13,12 @@ class RandomErasing(BaseImagePreprocessingLayer):
|
|
13
13
|
an image are erased (replaced by a constant value or noise)
|
14
14
|
during training to improve generalization.
|
15
15
|
|
16
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
17
|
+
(independently of which backend you're using).
|
18
|
+
|
19
|
+
References:
|
20
|
+
- [Random Erasing paper](https://arxiv.org/abs/1708.04896).
|
21
|
+
|
16
22
|
Args:
|
17
23
|
factor: A single float or a tuple of two floats.
|
18
24
|
`factor` controls the probability of applying the transformation.
|
@@ -35,9 +41,6 @@ class RandomErasing(BaseImagePreprocessingLayer):
|
|
35
41
|
typically either `[0, 1]` or `[0, 255]` depending on how your
|
36
42
|
preprocessing pipeline is set up.
|
37
43
|
seed: Integer. Used to create a random seed.
|
38
|
-
|
39
|
-
References:
|
40
|
-
- [Random Erasing paper](https://arxiv.org/abs/1708.04896).
|
41
44
|
"""
|
42
45
|
|
43
46
|
_USE_BASE_FACTOR = False
|
@@ -27,7 +27,7 @@ class RandomFlip(BaseImagePreprocessingLayer):
|
|
27
27
|
of integer or floating point dtype.
|
28
28
|
By default, the layer will output floats.
|
29
29
|
|
30
|
-
**Note:** This layer is safe to use inside a `tf.data` pipeline
|
30
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
31
31
|
(independently of which backend you're using).
|
32
32
|
|
33
33
|
Input shape:
|
@@ -13,6 +13,9 @@ class RandomGaussianBlur(BaseImagePreprocessingLayer):
|
|
13
13
|
randomly selected degree of blurring, controlled by the `factor` and
|
14
14
|
`sigma` arguments.
|
15
15
|
|
16
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
17
|
+
(independently of which backend you're using).
|
18
|
+
|
16
19
|
Args:
|
17
20
|
factor: A single float or a tuple of two floats.
|
18
21
|
`factor` controls the extent to which the image hue is impacted.
|
@@ -19,7 +19,7 @@ class RandomGrayscale(BaseImagePreprocessingLayer):
|
|
19
19
|
image using standard RGB to grayscale conversion coefficients. Images
|
20
20
|
that are not selected for conversion remain unchanged.
|
21
21
|
|
22
|
-
**Note:** This layer is safe to use inside a `tf.data` pipeline
|
22
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
23
23
|
(independently of which backend you're using).
|
24
24
|
|
25
25
|
Args:
|