keras-nightly 3.12.0.dev2025082003__py3-none-any.whl → 3.12.0.dev2025082203__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/quantizers/__init__.py +1 -0
- keras/quantizers/__init__.py +1 -0
- keras/src/applications/convnext.py +20 -20
- keras/src/applications/densenet.py +21 -21
- keras/src/applications/efficientnet.py +16 -16
- keras/src/applications/efficientnet_v2.py +28 -28
- keras/src/applications/inception_resnet_v2.py +7 -7
- keras/src/applications/inception_v3.py +5 -5
- keras/src/applications/mobilenet_v2.py +13 -20
- keras/src/applications/mobilenet_v3.py +15 -15
- keras/src/applications/nasnet.py +7 -8
- keras/src/applications/resnet.py +32 -32
- keras/src/applications/xception.py +10 -10
- keras/src/backend/common/dtypes.py +3 -3
- keras/src/backend/common/variables.py +3 -1
- keras/src/backend/jax/export.py +1 -1
- keras/src/backend/jax/trainer.py +1 -1
- keras/src/backend/openvino/numpy.py +1 -1
- keras/src/backend/tensorflow/rnn.py +1 -1
- keras/src/backend/tensorflow/trainer.py +19 -1
- keras/src/backend/torch/core.py +6 -9
- keras/src/backend/torch/trainer.py +1 -1
- keras/src/callbacks/backup_and_restore.py +2 -2
- keras/src/callbacks/csv_logger.py +1 -1
- keras/src/callbacks/model_checkpoint.py +1 -1
- keras/src/callbacks/tensorboard.py +6 -6
- keras/src/datasets/boston_housing.py +1 -1
- keras/src/datasets/california_housing.py +1 -1
- keras/src/datasets/cifar10.py +1 -1
- keras/src/datasets/cifar100.py +2 -2
- keras/src/datasets/imdb.py +2 -2
- keras/src/datasets/mnist.py +1 -1
- keras/src/datasets/reuters.py +2 -2
- keras/src/dtype_policies/dtype_policy.py +1 -1
- keras/src/dtype_policies/dtype_policy_map.py +1 -1
- keras/src/export/tf2onnx_lib.py +1 -3
- keras/src/layers/attention/attention.py +2 -0
- keras/src/layers/core/lambda_layer.py +9 -8
- keras/src/layers/input_spec.py +6 -6
- keras/src/layers/layer.py +1 -1
- keras/src/layers/preprocessing/category_encoding.py +3 -3
- keras/src/layers/preprocessing/data_layer.py +159 -0
- keras/src/layers/preprocessing/discretization.py +3 -3
- keras/src/layers/preprocessing/feature_space.py +4 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +7 -4
- keras/src/layers/preprocessing/image_preprocessing/auto_contrast.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/base_image_preprocessing_layer.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/center_crop.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/cut_mix.py +6 -3
- keras/src/layers/preprocessing/image_preprocessing/equalization.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/mix_up.py +7 -4
- keras/src/layers/preprocessing/image_preprocessing/rand_augment.py +3 -1
- keras/src/layers/preprocessing/image_preprocessing/random_brightness.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_crop.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_erasing.py +6 -3
- keras/src/layers/preprocessing/image_preprocessing/random_flip.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_hue.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_invert.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_perspective.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_posterization.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_rotation.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_saturation.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_shear.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_translation.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/random_zoom.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/solarization.py +3 -0
- keras/src/layers/preprocessing/mel_spectrogram.py +29 -25
- keras/src/layers/preprocessing/normalization.py +5 -2
- keras/src/layers/preprocessing/rescaling.py +3 -3
- keras/src/layers/rnn/bidirectional.py +4 -4
- keras/src/legacy/backend.py +9 -23
- keras/src/legacy/preprocessing/image.py +11 -22
- keras/src/legacy/preprocessing/text.py +1 -1
- keras/src/legacy/saving/legacy_h5_format.py +7 -2
- keras/src/legacy/saving/saving_utils.py +0 -12
- keras/src/legacy/saving/serialization.py +0 -14
- keras/src/models/functional.py +2 -2
- keras/src/models/model.py +21 -3
- keras/src/ops/function.py +1 -1
- keras/src/ops/numpy.py +5 -5
- keras/src/ops/operation.py +3 -2
- keras/src/optimizers/base_optimizer.py +3 -4
- keras/src/quantizers/gptq.py +350 -0
- keras/src/quantizers/gptq_config.py +169 -0
- keras/src/quantizers/gptq_core.py +335 -0
- keras/src/quantizers/gptq_quant.py +133 -0
- keras/src/saving/file_editor.py +22 -20
- keras/src/saving/object_registration.py +1 -1
- keras/src/saving/saving_api.py +4 -1
- keras/src/saving/saving_lib.py +4 -4
- keras/src/saving/serialization_lib.py +9 -11
- keras/src/trainers/compile_utils.py +1 -1
- keras/src/trainers/data_adapters/array_data_adapter.py +9 -3
- keras/src/trainers/data_adapters/data_adapter_utils.py +15 -5
- keras/src/trainers/data_adapters/generator_data_adapter.py +2 -0
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +8 -2
- keras/src/trainers/data_adapters/tf_dataset_adapter.py +4 -2
- keras/src/trainers/data_adapters/torch_data_loader_adapter.py +3 -1
- keras/src/tree/dmtree_impl.py +19 -3
- keras/src/tree/optree_impl.py +3 -3
- keras/src/tree/tree_api.py +5 -2
- keras/src/utils/file_utils.py +13 -5
- keras/src/utils/io_utils.py +1 -1
- keras/src/utils/model_visualization.py +1 -1
- keras/src/utils/progbar.py +5 -5
- keras/src/utils/summary_utils.py +4 -4
- keras/src/utils/torch_utils.py +4 -4
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025082003.dist-info → keras_nightly-3.12.0.dev2025082203.dist-info}/METADATA +1 -1
- {keras_nightly-3.12.0.dev2025082003.dist-info → keras_nightly-3.12.0.dev2025082203.dist-info}/RECORD +121 -117
- keras/src/layers/preprocessing/tf_data_layer.py +0 -78
- {keras_nightly-3.12.0.dev2025082003.dist-info → keras_nightly-3.12.0.dev2025082203.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025082003.dist-info → keras_nightly-3.12.0.dev2025082203.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,9 @@ class RandomHue(BaseImagePreprocessingLayer):
|
|
14
14
|
The image hue is adjusted by converting the image(s) to HSV and rotating the
|
15
15
|
hue channel (H) by delta. The image is then converted back to RGB.
|
16
16
|
|
17
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
18
|
+
(independently of which backend you're using).
|
19
|
+
|
17
20
|
Args:
|
18
21
|
factor: A single float or a tuple of two floats.
|
19
22
|
`factor` controls the extent to which the
|
@@ -14,6 +14,9 @@ class RandomInvert(BaseImagePreprocessingLayer):
|
|
14
14
|
complementary values. Images that are not selected for inversion
|
15
15
|
remain unchanged.
|
16
16
|
|
17
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
18
|
+
(independently of which backend you're using).
|
19
|
+
|
17
20
|
Args:
|
18
21
|
factor: A single float or a tuple of two floats.
|
19
22
|
`factor` controls the probability of inverting the image colors.
|
@@ -20,6 +20,9 @@ class RandomPerspective(BaseImagePreprocessingLayer):
|
|
20
20
|
corner points, simulating a 3D-like transformation. The amount of distortion
|
21
21
|
is controlled by the `factor` and `scale` parameters.
|
22
22
|
|
23
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
24
|
+
(independently of which backend you're using).
|
25
|
+
|
23
26
|
Args:
|
24
27
|
factor: A float or a tuple of two floats.
|
25
28
|
Represents the probability of applying the perspective
|
@@ -8,6 +8,9 @@ from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing
|
|
8
8
|
class RandomPosterization(BaseImagePreprocessingLayer):
|
9
9
|
"""Reduces the number of bits for each color channel.
|
10
10
|
|
11
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
12
|
+
(independently of which backend you're using).
|
13
|
+
|
11
14
|
References:
|
12
15
|
- [AutoAugment: Learning Augmentation Policies from Data](https://arxiv.org/abs/1805.09501)
|
13
16
|
- [RandAugment: Practical automated data augmentation with a reduced search space](https://arxiv.org/abs/1909.13719)
|
@@ -23,7 +23,7 @@ class RandomRotation(BaseImagePreprocessingLayer):
|
|
23
23
|
of integer or floating point dtype.
|
24
24
|
By default, the layer will output floats.
|
25
25
|
|
26
|
-
**Note:** This layer is safe to use inside a `tf.data` pipeline
|
26
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
27
27
|
(independently of which backend you're using).
|
28
28
|
|
29
29
|
Input shape:
|
@@ -13,6 +13,9 @@ class RandomSaturation(BaseImagePreprocessingLayer):
|
|
13
13
|
This layer will randomly increase/reduce the saturation for the input RGB
|
14
14
|
images.
|
15
15
|
|
16
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
17
|
+
(independently of which backend you're using).
|
18
|
+
|
16
19
|
Args:
|
17
20
|
factor: A tuple of two floats or a single float.
|
18
21
|
`factor` controls the extent to which the image saturation
|
@@ -13,6 +13,9 @@ class RandomSharpness(BaseImagePreprocessingLayer):
|
|
13
13
|
original image and the processed image. This operation adjusts the clarity
|
14
14
|
of the edges in an image, ranging from blurred to enhanced sharpness.
|
15
15
|
|
16
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
17
|
+
(independently of which backend you're using).
|
18
|
+
|
16
19
|
Args:
|
17
20
|
factor: A tuple of two floats or a single float.
|
18
21
|
`factor` controls the extent to which the image sharpness
|
@@ -23,6 +23,9 @@ class RandomShear(BaseImagePreprocessingLayer):
|
|
23
23
|
regions created during the transformation are filled according to the
|
24
24
|
`fill_mode` and `fill_value` parameters.
|
25
25
|
|
26
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
27
|
+
(independently of which backend you're using).
|
28
|
+
|
26
29
|
Args:
|
27
30
|
x_factor: A tuple of two floats. For each augmented image, a value
|
28
31
|
is sampled from the provided range. If a float is passed, the
|
@@ -23,6 +23,9 @@ class RandomTranslation(BaseImagePreprocessingLayer):
|
|
23
23
|
of integer or floating point dtype. By default, the layer will output
|
24
24
|
floats.
|
25
25
|
|
26
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
27
|
+
(independently of which backend you're using).
|
28
|
+
|
26
29
|
Input shape:
|
27
30
|
3D (unbatched) or 4D (batched) tensor with shape:
|
28
31
|
`(..., height, width, channels)`, in `"channels_last"` format,
|
@@ -34,9 +37,6 @@ class RandomTranslation(BaseImagePreprocessingLayer):
|
|
34
37
|
or `(..., channels, target_height, target_width)`,
|
35
38
|
in `"channels_first"` format.
|
36
39
|
|
37
|
-
**Note:** This layer is safe to use inside a `tf.data` pipeline
|
38
|
-
(independently of which backend you're using).
|
39
|
-
|
40
40
|
Args:
|
41
41
|
height_factor: a float represented as fraction of value, or a tuple of
|
42
42
|
size 2 representing lower and upper bound for shifting vertically. A
|
@@ -24,6 +24,9 @@ class RandomZoom(BaseImagePreprocessingLayer):
|
|
24
24
|
of integer or floating point dtype.
|
25
25
|
By default, the layer will output floats.
|
26
26
|
|
27
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
28
|
+
(independently of which backend you're using).
|
29
|
+
|
27
30
|
Input shape:
|
28
31
|
3D (unbatched) or 4D (batched) tensor with shape:
|
29
32
|
`(..., height, width, channels)`, in `"channels_last"` format,
|
@@ -35,9 +38,6 @@ class RandomZoom(BaseImagePreprocessingLayer):
|
|
35
38
|
or `(..., channels, target_height, target_width)`,
|
36
39
|
in `"channels_first"` format.
|
37
40
|
|
38
|
-
**Note:** This layer is safe to use inside a `tf.data` pipeline
|
39
|
-
(independently of which backend you're using).
|
40
|
-
|
41
41
|
Args:
|
42
42
|
height_factor: a float represented as fraction of value, or a tuple of
|
43
43
|
size 2 representing lower and upper bound for zooming vertically.
|
@@ -21,6 +21,9 @@ class Resizing(BaseImagePreprocessingLayer):
|
|
21
21
|
format. Input pixel values can be of any range
|
22
22
|
(e.g. `[0., 1.)` or `[0, 255]`).
|
23
23
|
|
24
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
25
|
+
(independently of which backend you're using).
|
26
|
+
|
24
27
|
Input shape:
|
25
28
|
3D (unbatched) or 4D (batched) tensor with shape:
|
26
29
|
`(..., height, width, channels)`, in `"channels_last"` format,
|
@@ -32,9 +35,6 @@ class Resizing(BaseImagePreprocessingLayer):
|
|
32
35
|
or `(..., channels, target_height, target_width)`,
|
33
36
|
in `"channels_first"` format.
|
34
37
|
|
35
|
-
**Note:** This layer is safe to use inside a `tf.data` pipeline
|
36
|
-
(independently of which backend you're using).
|
37
|
-
|
38
38
|
Args:
|
39
39
|
height: Integer, the height of the output shape.
|
40
40
|
width: Integer, the width of the output shape.
|
@@ -15,6 +15,9 @@ class Solarization(BaseImagePreprocessingLayer):
|
|
15
15
|
to all values. When created with specified `threshold` the layer only
|
16
16
|
augments pixels that are above the `threshold` value.
|
17
17
|
|
18
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
19
|
+
(independently of which backend you're using).
|
20
|
+
|
18
21
|
Args:
|
19
22
|
addition_factor: (Optional) A tuple of two floats or a single float,
|
20
23
|
between 0 and 1.
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from keras.src.api_export import keras_export
|
2
|
-
from keras.src.layers.preprocessing.
|
2
|
+
from keras.src.layers.preprocessing.data_layer import DataLayer
|
3
3
|
|
4
4
|
# mel spectrum constants.
|
5
5
|
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
|
@@ -7,7 +7,7 @@ _MEL_HIGH_FREQUENCY_Q = 1127.0
|
|
7
7
|
|
8
8
|
|
9
9
|
@keras_export("keras.layers.MelSpectrogram")
|
10
|
-
class MelSpectrogram(
|
10
|
+
class MelSpectrogram(DataLayer):
|
11
11
|
"""A preprocessing layer to convert raw audio signals to Mel spectrograms.
|
12
12
|
|
13
13
|
This layer takes `float32`/`float64` single or batched audio signal as
|
@@ -24,10 +24,37 @@ class MelSpectrogram(TFDataLayer):
|
|
24
24
|
speech and music processing tasks like speech recognition, speaker
|
25
25
|
identification, and music genre classification.
|
26
26
|
|
27
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
28
|
+
(independently of which backend you're using).
|
29
|
+
|
27
30
|
References:
|
28
31
|
- [Spectrogram](https://en.wikipedia.org/wiki/Spectrogram),
|
29
32
|
- [Mel scale](https://en.wikipedia.org/wiki/Mel_scale).
|
30
33
|
|
34
|
+
Args:
|
35
|
+
fft_length: Integer, size of the FFT window.
|
36
|
+
sequence_stride: Integer, number of samples between successive STFT
|
37
|
+
columns.
|
38
|
+
sequence_length: Integer, size of the window used for applying
|
39
|
+
`window` to each audio frame. If `None`, defaults to `fft_length`.
|
40
|
+
window: String, name of the window function to use. Available values
|
41
|
+
are `"hann"` and `"hamming"`. If `window` is a tensor, it will be
|
42
|
+
used directly as the window and its length must be
|
43
|
+
`sequence_length`. If `window` is `None`, no windowing is
|
44
|
+
used. Defaults to `"hann"`.
|
45
|
+
sampling_rate: Integer, sample rate of the input signal.
|
46
|
+
num_mel_bins: Integer, number of mel bins to generate.
|
47
|
+
min_freq: Float, minimum frequency of the mel bins.
|
48
|
+
max_freq: Float, maximum frequency of the mel bins.
|
49
|
+
If `None`, defaults to `sampling_rate / 2`.
|
50
|
+
power_to_db: If True, convert the power spectrogram to decibels.
|
51
|
+
top_db: Float, minimum negative cut-off `max(10 * log10(S)) - top_db`.
|
52
|
+
mag_exp: Float, exponent for the magnitude spectrogram.
|
53
|
+
1 for magnitude, 2 for power, etc. Default is 2.
|
54
|
+
ref_power: Float, the power is scaled relative to it
|
55
|
+
`10 * log10(S / ref_power)`.
|
56
|
+
min_power: Float, minimum value for power and `ref_power`.
|
57
|
+
|
31
58
|
Examples:
|
32
59
|
|
33
60
|
**Unbatched audio signal**
|
@@ -55,29 +82,6 @@ class MelSpectrogram(TFDataLayer):
|
|
55
82
|
2D (unbatched) or 3D (batched) tensor with
|
56
83
|
shape:`(..., num_mel_bins, time)`.
|
57
84
|
|
58
|
-
Args:
|
59
|
-
fft_length: Integer, size of the FFT window.
|
60
|
-
sequence_stride: Integer, number of samples between successive STFT
|
61
|
-
columns.
|
62
|
-
sequence_length: Integer, size of the window used for applying
|
63
|
-
`window` to each audio frame. If `None`, defaults to `fft_length`.
|
64
|
-
window: String, name of the window function to use. Available values
|
65
|
-
are `"hann"` and `"hamming"`. If `window` is a tensor, it will be
|
66
|
-
used directly as the window and its length must be
|
67
|
-
`sequence_length`. If `window` is `None`, no windowing is
|
68
|
-
used. Defaults to `"hann"`.
|
69
|
-
sampling_rate: Integer, sample rate of the input signal.
|
70
|
-
num_mel_bins: Integer, number of mel bins to generate.
|
71
|
-
min_freq: Float, minimum frequency of the mel bins.
|
72
|
-
max_freq: Float, maximum frequency of the mel bins.
|
73
|
-
If `None`, defaults to `sampling_rate / 2`.
|
74
|
-
power_to_db: If True, convert the power spectrogram to decibels.
|
75
|
-
top_db: Float, minimum negative cut-off `max(10 * log10(S)) - top_db`.
|
76
|
-
mag_exp: Float, exponent for the magnitude spectrogram.
|
77
|
-
1 for magnitude, 2 for power, etc. Default is 2.
|
78
|
-
ref_power: Float, the power is scaled relative to it
|
79
|
-
`10 * log10(S / ref_power)`.
|
80
|
-
min_power: Float, minimum value for power and `ref_power`.
|
81
85
|
"""
|
82
86
|
|
83
87
|
def __init__(
|
@@ -5,12 +5,12 @@ import numpy as np
|
|
5
5
|
from keras.src import backend
|
6
6
|
from keras.src import ops
|
7
7
|
from keras.src.api_export import keras_export
|
8
|
-
from keras.src.layers.preprocessing.
|
8
|
+
from keras.src.layers.preprocessing.data_layer import DataLayer
|
9
9
|
from keras.src.utils.module_utils import tensorflow as tf
|
10
10
|
|
11
11
|
|
12
12
|
@keras_export("keras.layers.Normalization")
|
13
|
-
class Normalization(
|
13
|
+
class Normalization(DataLayer):
|
14
14
|
"""A preprocessing layer that normalizes continuous features.
|
15
15
|
|
16
16
|
This layer will shift and scale inputs into a distribution centered around
|
@@ -23,6 +23,9 @@ class Normalization(TFDataLayer):
|
|
23
23
|
variance of the data and store them as the layer's weights. `adapt()` should
|
24
24
|
be called before `fit()`, `evaluate()`, or `predict()`.
|
25
25
|
|
26
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
27
|
+
(independently of which backend you're using).
|
28
|
+
|
26
29
|
Args:
|
27
30
|
axis: Integer, tuple of integers, or None. The axis or axes that should
|
28
31
|
have a separate mean and variance for each index in the shape.
|
@@ -1,11 +1,11 @@
|
|
1
1
|
from keras.src import backend
|
2
2
|
from keras.src.api_export import keras_export
|
3
|
-
from keras.src.layers.preprocessing.
|
3
|
+
from keras.src.layers.preprocessing.data_layer import DataLayer
|
4
4
|
from keras.src.saving import serialization_lib
|
5
5
|
|
6
6
|
|
7
7
|
@keras_export("keras.layers.Rescaling")
|
8
|
-
class Rescaling(
|
8
|
+
class Rescaling(DataLayer):
|
9
9
|
"""A preprocessing layer which rescales input values to a new range.
|
10
10
|
|
11
11
|
This layer rescales every value of an input (often an image) by multiplying
|
@@ -23,7 +23,7 @@ class Rescaling(TFDataLayer):
|
|
23
23
|
of integer or floating point dtype, and by default the layer will output
|
24
24
|
floats.
|
25
25
|
|
26
|
-
**Note:** This layer is safe to use inside a `tf.data` pipeline
|
26
|
+
**Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
|
27
27
|
(independently of which backend you're using).
|
28
28
|
|
29
29
|
Args:
|
@@ -109,16 +109,16 @@ class Bidirectional(Layer):
|
|
109
109
|
# Recreate the forward layer from the original layer config, so that it
|
110
110
|
# will not carry over any state from the layer.
|
111
111
|
config = serialization_lib.serialize_keras_object(layer)
|
112
|
-
config["config"]["name"] =
|
113
|
-
layer.name,
|
112
|
+
config["config"]["name"] = (
|
113
|
+
f"forward_{utils.removeprefix(layer.name, 'forward_')}"
|
114
114
|
)
|
115
115
|
self.forward_layer = serialization_lib.deserialize_keras_object(config)
|
116
116
|
|
117
117
|
if backward_layer is None:
|
118
118
|
config = serialization_lib.serialize_keras_object(layer)
|
119
119
|
config["config"]["go_backwards"] = True
|
120
|
-
config["config"]["name"] =
|
121
|
-
layer.name,
|
120
|
+
config["config"]["name"] = (
|
121
|
+
f"backward_{utils.removeprefix(layer.name, 'backward_')}"
|
122
122
|
)
|
123
123
|
self.backward_layer = serialization_lib.deserialize_keras_object(
|
124
124
|
config
|
keras/src/legacy/backend.py
CHANGED
@@ -68,11 +68,7 @@ def batch_dot(x, y, axes=None):
|
|
68
68
|
raise ValueError(
|
69
69
|
"Cannot do batch_dot on inputs "
|
70
70
|
"with rank < 2. "
|
71
|
-
"Received inputs with tf.shapes "
|
72
|
-
+ str(x_shape)
|
73
|
-
+ " and "
|
74
|
-
+ str(y_shape)
|
75
|
-
+ "."
|
71
|
+
f"Received inputs with tf.shapes {x_shape} and {y_shape}."
|
76
72
|
)
|
77
73
|
|
78
74
|
x_batch_size = x_shape[0]
|
@@ -84,10 +80,7 @@ def batch_dot(x, y, axes=None):
|
|
84
80
|
"Cannot do batch_dot on inputs "
|
85
81
|
"with different batch sizes. "
|
86
82
|
"Received inputs with tf.shapes "
|
87
|
-
|
88
|
-
+ " and "
|
89
|
-
+ str(y_shape)
|
90
|
-
+ "."
|
83
|
+
f"{x_shape} and {y_shape}."
|
91
84
|
)
|
92
85
|
if isinstance(axes, int):
|
93
86
|
axes = [axes, axes]
|
@@ -101,9 +94,8 @@ def batch_dot(x, y, axes=None):
|
|
101
94
|
if py_any(isinstance(a, (list, tuple)) for a in axes):
|
102
95
|
raise ValueError(
|
103
96
|
"Multiple target dimensions are not supported. "
|
104
|
-
|
105
|
-
|
106
|
-
+ str(axes)
|
97
|
+
"Expected: None, int, (int, int), "
|
98
|
+
f"Provided: {axes}"
|
107
99
|
)
|
108
100
|
|
109
101
|
# if tuple, convert to list.
|
@@ -130,12 +122,8 @@ def batch_dot(x, y, axes=None):
|
|
130
122
|
if d1 is not None and d2 is not None and d1 != d2:
|
131
123
|
raise ValueError(
|
132
124
|
"Cannot do batch_dot on inputs with tf.shapes "
|
133
|
-
|
134
|
-
|
135
|
-
+ str(y_shape)
|
136
|
-
+ " with axes="
|
137
|
-
+ str(axes)
|
138
|
-
+ ". x.shape[%d] != y.shape[%d] (%d != %d)."
|
125
|
+
f"{x_shape} and {y_shape} with axes={axes}. "
|
126
|
+
"x.shape[%d] != y.shape[%d] (%d != %d)."
|
139
127
|
% (axes[0], axes[1], d1, d2)
|
140
128
|
)
|
141
129
|
|
@@ -1129,7 +1117,7 @@ def pool2d(
|
|
1129
1117
|
x, pool_size, strides, padding=padding, data_format=tf_data_format
|
1130
1118
|
)
|
1131
1119
|
else:
|
1132
|
-
raise ValueError("Invalid pooling mode:
|
1120
|
+
raise ValueError(f"Invalid pooling mode: {str(pool_mode)}")
|
1133
1121
|
|
1134
1122
|
if data_format == "channels_first" and tf_data_format == "NHWC":
|
1135
1123
|
x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
|
@@ -1169,7 +1157,7 @@ def pool3d(
|
|
1169
1157
|
x, pool_size, strides, padding=padding, data_format=tf_data_format
|
1170
1158
|
)
|
1171
1159
|
else:
|
1172
|
-
raise ValueError("Invalid pooling mode:
|
1160
|
+
raise ValueError(f"Invalid pooling mode: {str(pool_mode)}")
|
1173
1161
|
|
1174
1162
|
if data_format == "channels_first" and tf_data_format == "NDHWC":
|
1175
1163
|
x = tf.transpose(x, (0, 4, 1, 2, 3))
|
@@ -2150,9 +2138,7 @@ def switch(condition, then_expression, else_expression):
|
|
2150
2138
|
"Rank of `condition` should be less than or"
|
2151
2139
|
" equal to rank of `then_expression` and "
|
2152
2140
|
"`else_expression`. ndim(condition)="
|
2153
|
-
|
2154
|
-
+ ", ndim(then_expression)="
|
2155
|
-
+ str(expr_ndim)
|
2141
|
+
f"{cond_ndim}, ndim(then_expression)={expr_ndim}"
|
2156
2142
|
)
|
2157
2143
|
if cond_ndim > 1:
|
2158
2144
|
ndim_diff = expr_ndim - cond_ndim
|
@@ -617,17 +617,12 @@ class NumpyArrayIterator(Iterator):
|
|
617
617
|
channels_axis = 3 if data_format == "channels_last" else 1
|
618
618
|
if self.x.shape[channels_axis] not in {1, 3, 4}:
|
619
619
|
warnings.warn(
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
+ ". However, it was passed an array with shape "
|
627
|
-
+ str(self.x.shape)
|
628
|
-
+ " ("
|
629
|
-
+ str(self.x.shape[channels_axis])
|
630
|
-
+ " channels)."
|
620
|
+
f"NumpyArrayIterator is set to use the data format convention"
|
621
|
+
f' "{data_format}" (channels on axis {channels_axis})'
|
622
|
+
", i.e. expected either 1, 3, or 4 channels "
|
623
|
+
f"on axis {channels_axis}. "
|
624
|
+
f"However, it was passed an array with shape {self.x.shape}"
|
625
|
+
f" ({self.x.shape[channels_axis]} channels)."
|
631
626
|
)
|
632
627
|
if y is not None:
|
633
628
|
self.y = np.asarray(y)
|
@@ -1494,17 +1489,11 @@ class ImageDataGenerator:
|
|
1494
1489
|
if x.shape[self.channel_axis] not in {1, 3, 4}:
|
1495
1490
|
warnings.warn(
|
1496
1491
|
"Expected input to be images (as Numpy array) "
|
1497
|
-
'following the data format convention "'
|
1498
|
-
|
1499
|
-
|
1500
|
-
|
1501
|
-
|
1502
|
-
+ str(self.channel_axis)
|
1503
|
-
+ ". However, it was passed an array with shape "
|
1504
|
-
+ str(x.shape)
|
1505
|
-
+ " ("
|
1506
|
-
+ str(x.shape[self.channel_axis])
|
1507
|
-
+ " channels)."
|
1492
|
+
f'following the data format convention "{self.data_format}'
|
1493
|
+
f'" (channels on axis {self.channel_axis})'
|
1494
|
+
", i.e. expected either 1, 3 or 4 channels on axis "
|
1495
|
+
f"{self.channel_axis}. However, it was passed an array with"
|
1496
|
+
f" shape {x.shape} ({x.shape[self.channel_axis]} channels)."
|
1508
1497
|
)
|
1509
1498
|
|
1510
1499
|
if seed is not None:
|
@@ -102,7 +102,7 @@ class Tokenizer:
|
|
102
102
|
num_words = kwargs.pop("nb_words")
|
103
103
|
document_count = kwargs.pop("document_count", 0)
|
104
104
|
if kwargs:
|
105
|
-
raise TypeError("Unrecognized keyword arguments:
|
105
|
+
raise TypeError(f"Unrecognized keyword arguments: {str(kwargs)}")
|
106
106
|
|
107
107
|
self.word_counts = collections.OrderedDict()
|
108
108
|
self.word_docs = collections.defaultdict(int)
|
@@ -11,6 +11,7 @@ from keras.src.legacy.saving import json_utils
|
|
11
11
|
from keras.src.legacy.saving import saving_options
|
12
12
|
from keras.src.legacy.saving import saving_utils
|
13
13
|
from keras.src.saving import object_registration
|
14
|
+
from keras.src.saving import serialization_lib
|
14
15
|
from keras.src.utils import io_utils
|
15
16
|
|
16
17
|
try:
|
@@ -72,7 +73,9 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
|
|
72
73
|
f.close()
|
73
74
|
|
74
75
|
|
75
|
-
def load_model_from_hdf5(
|
76
|
+
def load_model_from_hdf5(
|
77
|
+
filepath, custom_objects=None, compile=True, safe_mode=True
|
78
|
+
):
|
76
79
|
"""Loads a model saved via `save_model_to_hdf5`.
|
77
80
|
|
78
81
|
Args:
|
@@ -128,7 +131,9 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True):
|
|
128
131
|
model_config = model_config.decode("utf-8")
|
129
132
|
model_config = json_utils.decode(model_config)
|
130
133
|
|
131
|
-
|
134
|
+
legacy_scope = saving_options.keras_option_scope(use_legacy_config=True)
|
135
|
+
safe_mode_scope = serialization_lib.SafeModeScope(safe_mode)
|
136
|
+
with legacy_scope, safe_mode_scope:
|
132
137
|
model = saving_utils.model_from_config(
|
133
138
|
model_config, custom_objects=custom_objects
|
134
139
|
)
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import json
|
2
1
|
import threading
|
3
2
|
|
4
3
|
from absl import logging
|
@@ -81,10 +80,6 @@ def model_from_config(config, custom_objects=None):
|
|
81
80
|
function_dict["config"]["closure"] = function_config[2]
|
82
81
|
config["config"]["function"] = function_dict
|
83
82
|
|
84
|
-
# TODO(nkovela): Swap find and replace args during Keras 3.0 release
|
85
|
-
# Replace keras refs with keras
|
86
|
-
config = _find_replace_nested_dict(config, "keras.", "keras.")
|
87
|
-
|
88
83
|
return serialization.deserialize_keras_object(
|
89
84
|
config,
|
90
85
|
module_objects=MODULE_OBJECTS.ALL_OBJECTS,
|
@@ -231,13 +226,6 @@ def _deserialize_metric(metric_config):
|
|
231
226
|
return metrics_module.deserialize(metric_config)
|
232
227
|
|
233
228
|
|
234
|
-
def _find_replace_nested_dict(config, find, replace):
|
235
|
-
dict_str = json.dumps(config)
|
236
|
-
dict_str = dict_str.replace(find, replace)
|
237
|
-
config = json.loads(dict_str)
|
238
|
-
return config
|
239
|
-
|
240
|
-
|
241
229
|
def _resolve_compile_arguments_compat(obj, obj_config, module):
|
242
230
|
"""Resolves backwards compatibility issues with training config arguments.
|
243
231
|
|
@@ -2,7 +2,6 @@
|
|
2
2
|
|
3
3
|
import contextlib
|
4
4
|
import inspect
|
5
|
-
import json
|
6
5
|
import threading
|
7
6
|
import weakref
|
8
7
|
|
@@ -485,12 +484,6 @@ def deserialize_keras_object(
|
|
485
484
|
arg_spec = inspect.getfullargspec(cls.from_config)
|
486
485
|
custom_objects = custom_objects or {}
|
487
486
|
|
488
|
-
# TODO(nkovela): Swap find and replace args during Keras 3.0 release
|
489
|
-
# Replace keras refs with keras
|
490
|
-
cls_config = _find_replace_nested_dict(
|
491
|
-
cls_config, "keras.", "keras."
|
492
|
-
)
|
493
|
-
|
494
487
|
if "custom_objects" in arg_spec.args:
|
495
488
|
deserialized_obj = cls.from_config(
|
496
489
|
cls_config,
|
@@ -565,10 +558,3 @@ def validate_config(config):
|
|
565
558
|
def is_default(method):
|
566
559
|
"""Check if a method is decorated with the `default` wrapper."""
|
567
560
|
return getattr(method, "_is_default", False)
|
568
|
-
|
569
|
-
|
570
|
-
def _find_replace_nested_dict(config, find, replace):
|
571
|
-
dict_str = json.dumps(config)
|
572
|
-
dict_str = dict_str.replace(find, replace)
|
573
|
-
config = json.loads(dict_str)
|
574
|
-
return config
|
keras/src/models/functional.py
CHANGED
@@ -773,7 +773,7 @@ def is_input_keras_tensor(x):
|
|
773
773
|
|
774
774
|
def clone_single_keras_tensor(x):
|
775
775
|
return backend.KerasTensor(
|
776
|
-
shape=x.shape, dtype=x.dtype, sparse=x.sparse, name=x.name
|
776
|
+
shape=x.shape, dtype=x.dtype, sparse=x.sparse, name=f"{x.name}_clone"
|
777
777
|
)
|
778
778
|
|
779
779
|
|
@@ -836,7 +836,7 @@ def clone_graph_nodes(inputs, outputs):
|
|
836
836
|
batch_shape=kt_input.shape,
|
837
837
|
dtype=kt_input.dtype,
|
838
838
|
sparse=kt_input.sparse,
|
839
|
-
name=kt_input.name
|
839
|
+
name=f"{kt_input.name}CLONE",
|
840
840
|
)
|
841
841
|
cloned_inputs.append(cloned_input)
|
842
842
|
kt_id_mapping[id(kt_input)] = cloned_input
|
keras/src/models/model.py
CHANGED
@@ -8,6 +8,7 @@ from keras.src import utils
|
|
8
8
|
from keras.src.api_export import keras_export
|
9
9
|
from keras.src.layers.layer import Layer
|
10
10
|
from keras.src.models.variable_mapping import map_saveable_variables
|
11
|
+
from keras.src.quantizers.gptq_config import GPTQConfig
|
11
12
|
from keras.src.saving import saving_api
|
12
13
|
from keras.src.trainers import trainer as base_trainer
|
13
14
|
from keras.src.utils import summary_utils
|
@@ -420,7 +421,7 @@ class Model(Trainer, base_trainer.Trainer, Layer):
|
|
420
421
|
**kwargs,
|
421
422
|
)
|
422
423
|
|
423
|
-
def quantize(self, mode, **kwargs):
|
424
|
+
def quantize(self, mode, config=None, **kwargs):
|
424
425
|
"""Quantize the weights of the model.
|
425
426
|
|
426
427
|
Note that the model must be built first before calling this method.
|
@@ -433,6 +434,23 @@ class Model(Trainer, base_trainer.Trainer, Layer):
|
|
433
434
|
"""
|
434
435
|
from keras.src.dtype_policies import QUANTIZATION_MODES
|
435
436
|
|
437
|
+
if mode == "gptq":
|
438
|
+
if not isinstance(config, GPTQConfig):
|
439
|
+
raise ValueError(
|
440
|
+
"The `config` argument must be of type "
|
441
|
+
"`keras.quantizers.GPTQConfig`."
|
442
|
+
)
|
443
|
+
# The config object's own quantize method drives the process
|
444
|
+
config.quantize(self)
|
445
|
+
return
|
446
|
+
|
447
|
+
# For all other modes, verify that a config object was not passed.
|
448
|
+
if config is not None:
|
449
|
+
raise ValueError(
|
450
|
+
f"The `config` argument is only supported for 'gptq' mode, "
|
451
|
+
f"but received mode='{mode}'."
|
452
|
+
)
|
453
|
+
|
436
454
|
type_check = kwargs.pop("type_check", True)
|
437
455
|
if kwargs:
|
438
456
|
raise ValueError(
|
@@ -854,9 +872,9 @@ class Model(Trainer, base_trainer.Trainer, Layer):
|
|
854
872
|
def _flatten(current_dict, prefix=""):
|
855
873
|
for key, value in current_dict.items():
|
856
874
|
if isinstance(value, dict):
|
857
|
-
_flatten(value, prefix
|
875
|
+
_flatten(value, f"{prefix}{key}/")
|
858
876
|
else:
|
859
|
-
flat_dict[prefix
|
877
|
+
flat_dict[f"{prefix}{key}"] = value
|
860
878
|
|
861
879
|
_flatten(nested_dict)
|
862
880
|
return flat_dict
|