tf-keras-nightly 2.17.0.dev2024031909__py3-none-any.whl → 2.19.0.dev2025011410__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tf_keras/__init__.py +1 -1
  2. tf_keras/src/__init__.py +1 -1
  3. tf_keras/src/backend.py +1 -1
  4. tf_keras/src/callbacks.py +24 -7
  5. tf_keras/src/datasets/boston_housing.py +14 -5
  6. tf_keras/src/datasets/cifar10.py +9 -1
  7. tf_keras/src/datasets/cifar100.py +7 -1
  8. tf_keras/src/datasets/fashion_mnist.py +16 -4
  9. tf_keras/src/datasets/imdb.py +8 -0
  10. tf_keras/src/datasets/mnist.py +9 -3
  11. tf_keras/src/datasets/reuters.py +8 -0
  12. tf_keras/src/engine/base_layer.py +10 -4
  13. tf_keras/src/engine/base_layer_v1.py +10 -4
  14. tf_keras/src/engine/node.py +8 -3
  15. tf_keras/src/layers/activation/prelu.py +1 -1
  16. tf_keras/src/layers/attention/base_dense_attention.py +2 -1
  17. tf_keras/src/layers/convolutional/base_conv.py +1 -1
  18. tf_keras/src/layers/convolutional/base_depthwise_conv.py +3 -1
  19. tf_keras/src/layers/convolutional/base_separable_conv.py +3 -1
  20. tf_keras/src/layers/convolutional/conv1d_transpose.py +3 -1
  21. tf_keras/src/layers/convolutional/conv2d_transpose.py +3 -1
  22. tf_keras/src/layers/convolutional/conv3d_transpose.py +3 -1
  23. tf_keras/src/layers/core/dense.py +1 -1
  24. tf_keras/src/layers/core/embedding.py +1 -1
  25. tf_keras/src/layers/locally_connected/locally_connected1d.py +1 -1
  26. tf_keras/src/layers/locally_connected/locally_connected2d.py +1 -1
  27. tf_keras/src/layers/normalization/batch_normalization.py +1 -1
  28. tf_keras/src/layers/normalization/layer_normalization.py +1 -1
  29. tf_keras/src/layers/normalization/unit_normalization.py +2 -1
  30. tf_keras/src/layers/rnn/abstract_rnn_cell.py +1 -1
  31. tf_keras/src/layers/rnn/base_conv_lstm.py +0 -1
  32. tf_keras/src/layers/rnn/base_conv_rnn.py +3 -1
  33. tf_keras/src/layers/rnn/base_rnn.py +1 -1
  34. tf_keras/src/layers/rnn/base_wrapper.py +1 -1
  35. tf_keras/src/layers/rnn/bidirectional.py +2 -1
  36. tf_keras/src/layers/rnn/cell_wrappers.py +3 -3
  37. tf_keras/src/layers/rnn/cudnn_gru.py +6 -3
  38. tf_keras/src/layers/rnn/cudnn_lstm.py +6 -3
  39. tf_keras/src/layers/rnn/gru.py +35 -47
  40. tf_keras/src/layers/rnn/legacy_cell_wrappers.py +3 -3
  41. tf_keras/src/layers/rnn/legacy_cells.py +20 -25
  42. tf_keras/src/layers/rnn/lstm.py +35 -50
  43. tf_keras/src/layers/rnn/simple_rnn.py +0 -1
  44. tf_keras/src/layers/rnn/stacked_rnn_cells.py +1 -1
  45. tf_keras/src/layers/rnn/time_distributed.py +0 -1
  46. tf_keras/src/mixed_precision/autocast_variable.py +12 -6
  47. tf_keras/src/mixed_precision/test_util.py +6 -5
  48. tf_keras/src/optimizers/legacy/optimizer_v2.py +9 -2
  49. tf_keras/src/optimizers/optimizer.py +18 -9
  50. tf_keras/src/premade_models/linear.py +2 -1
  51. tf_keras/src/saving/legacy/saved_model/json_utils.py +1 -1
  52. tf_keras/src/saving/saving_api.py +165 -127
  53. tf_keras/src/saving/saving_lib.py +1 -11
  54. tf_keras/src/saving/serialization_lib.py +1 -10
  55. tf_keras/src/utils/data_utils.py +1 -1
  56. tf_keras/src/utils/steps_per_execution_tuning.py +1 -1
  57. tf_keras/src/utils/tf_utils.py +2 -2
  58. tf_keras/src/utils/timeseries_dataset.py +13 -5
  59. {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/METADATA +14 -3
  60. {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/RECORD +62 -62
  61. {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/WHEEL +1 -1
  62. {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/top_level.txt +0 -0
tf_keras/__init__.py CHANGED
@@ -27,4 +27,4 @@ from tf_keras.src.engine.sequential import Sequential
27
27
  from tf_keras.src.engine.training import Model
28
28
 
29
29
 
30
- __version__ = "2.17.0.dev2024031909"
30
+ __version__ = "2.19.0.dev2025011410"
tf_keras/src/__init__.py CHANGED
@@ -35,7 +35,7 @@ from tf_keras.src.testing_infra import test_utils
35
35
  from tensorflow.python import tf2
36
36
  from tensorflow.python.util.tf_export import keras_export
37
37
 
38
- __version__ = "2.17.0"
38
+ __version__ = "2.19.0"
39
39
 
40
40
  keras_export("keras.__version__").export_constant(__name__, "__version__")
41
41
 
tf_keras/src/backend.py CHANGED
@@ -2029,7 +2029,7 @@ class RandomGenerator(tf.__internal__.tracking.AutoTrackable):
2029
2029
  if user_specified_seed is not None:
2030
2030
  return user_specified_seed
2031
2031
  elif getattr(_SEED_GENERATOR, "generator", None):
2032
- return _SEED_GENERATOR.generator.randint(1, 1e9)
2032
+ return _SEED_GENERATOR.generator.randint(1, int(1e9))
2033
2033
  else:
2034
2034
  return random.randint(1, int(1e9))
2035
2035
 
tf_keras/src/callbacks.py CHANGED
@@ -1423,20 +1423,20 @@ class ModelCheckpoint(Callback):
1423
1423
  if mode == "min":
1424
1424
  self.monitor_op = np.less
1425
1425
  if self.best is None:
1426
- self.best = np.Inf
1426
+ self.best = np.inf
1427
1427
  elif mode == "max":
1428
1428
  self.monitor_op = np.greater
1429
1429
  if self.best is None:
1430
- self.best = -np.Inf
1430
+ self.best = -np.inf
1431
1431
  else:
1432
1432
  if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
1433
1433
  self.monitor_op = np.greater
1434
1434
  if self.best is None:
1435
- self.best = -np.Inf
1435
+ self.best = -np.inf
1436
1436
  else:
1437
1437
  self.monitor_op = np.less
1438
1438
  if self.best is None:
1439
- self.best = np.Inf
1439
+ self.best = np.inf
1440
1440
 
1441
1441
  if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
1442
1442
  raise ValueError(
@@ -1903,6 +1903,23 @@ class BackupAndRestore(Callback):
1903
1903
  "only supports empty strategy, "
1904
1904
  "MirroredStrategy, MultiWorkerMirroredStrategy and TPUStrategy."
1905
1905
  )
1906
+
1907
+ # Re-initialize the optimizer.
1908
+ if self.model.built:
1909
+ if (
1910
+ self.model.optimizer is not None
1911
+ and callable(getattr(self.model.optimizer, "build", None))
1912
+ and not getattr(self.model.optimizer, "_built", False)
1913
+ ):
1914
+ self.model.optimizer.build(self.model.trainable_variables)
1915
+ else:
1916
+ logging.warning(
1917
+ "To use the BackupAndRestore callback, "
1918
+ "you model must be built before you call `fit()`. "
1919
+ f"Model {self.model} is unbuilt. You can build it "
1920
+ "beforehand by calling it on a batch of data."
1921
+ )
1922
+
1906
1923
  self.model._training_state = worker_training_state.WorkerTrainingState(
1907
1924
  self.model,
1908
1925
  self.backup_dir,
@@ -2095,7 +2112,7 @@ class EarlyStopping(Callback):
2095
2112
  # Allow instances to be re-used
2096
2113
  self.wait = 0
2097
2114
  self.stopped_epoch = 0
2098
- self.best = np.Inf if self.monitor_op == np.less else -np.Inf
2115
+ self.best = np.inf if self.monitor_op == np.less else -np.inf
2099
2116
  self.best_weights = None
2100
2117
  self.best_epoch = 0
2101
2118
 
@@ -3098,10 +3115,10 @@ class ReduceLROnPlateau(Callback):
3098
3115
  self.mode == "auto" and "acc" not in self.monitor
3099
3116
  ):
3100
3117
  self.monitor_op = lambda a, b: np.less(a, b - self.min_delta)
3101
- self.best = np.Inf
3118
+ self.best = np.inf
3102
3119
  else:
3103
3120
  self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta)
3104
- self.best = -np.Inf
3121
+ self.best = -np.inf
3105
3122
  self.cooldown_counter = 0
3106
3123
  self.wait = 0
3107
3124
 
@@ -14,6 +14,8 @@
14
14
  # ==============================================================================
15
15
  """Boston housing price regression dataset."""
16
16
 
17
+ import os
18
+
17
19
  import numpy as np
18
20
 
19
21
  from tf_keras.src.utils.data_utils import get_file
@@ -23,7 +25,9 @@ from tensorflow.python.util.tf_export import keras_export
23
25
 
24
26
 
25
27
  @keras_export("keras.datasets.boston_housing.load_data")
26
- def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
28
+ def load_data(
29
+ path="boston_housing.npz", test_split=0.2, seed=113, cache_dir=None
30
+ ):
27
31
  """Loads the Boston Housing dataset.
28
32
 
29
33
  This is a dataset taken from the StatLib library which is maintained at
@@ -43,11 +47,12 @@ def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
43
47
  [StatLib website](http://lib.stat.cmu.edu/datasets/boston).
44
48
 
45
49
  Args:
46
- path: path where to cache the dataset locally
47
- (relative to `~/.keras/datasets`).
50
+ path: path where to cache the dataset locally (relative to
51
+ `~/.keras/datasets`).
48
52
  test_split: fraction of the data to reserve as test set.
49
- seed: Random seed for shuffling the data
50
- before computing the test split.
53
+ seed: Random seed for shuffling the data before computing the test split.
54
+ cache_dir: directory where to cache the dataset locally. When None,
55
+ defaults to `~/.keras/datasets`.
51
56
 
52
57
  Returns:
53
58
  Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
@@ -64,12 +69,16 @@ def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
64
69
  origin_folder = (
65
70
  "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
66
71
  )
72
+ if cache_dir:
73
+ cache_dir = os.path.expanduser(cache_dir)
74
+ os.makedirs(cache_dir, exist_ok=True)
67
75
  path = get_file(
68
76
  path,
69
77
  origin=origin_folder + "boston_housing.npz",
70
78
  file_hash=( # noqa: E501
71
79
  "f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5"
72
80
  ),
81
+ cache_dir=cache_dir,
73
82
  )
74
83
  with np.load(path, allow_pickle=True) as f:
75
84
  x = f["x"]
@@ -27,7 +27,7 @@ from tensorflow.python.util.tf_export import keras_export
27
27
 
28
28
 
29
29
  @keras_export("keras.datasets.cifar10.load_data")
30
- def load_data():
30
+ def load_data(cache_dir=None):
31
31
  """Loads the CIFAR10 dataset.
32
32
 
33
33
  This is a dataset of 50,000 32x32 color training images and 10,000 test
@@ -49,6 +49,10 @@ def load_data():
49
49
  | 8 | ship |
50
50
  | 9 | truck |
51
51
 
52
+ Args:
53
+ cache_dir: directory where to cache the dataset locally. When None,
54
+ defaults to `~/.keras/datasets`.
55
+
52
56
  Returns:
53
57
  Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
54
58
 
@@ -78,6 +82,9 @@ def load_data():
78
82
  """
79
83
  dirname = "cifar-10-batches-py"
80
84
  origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
85
+ if cache_dir:
86
+ cache_dir = os.path.expanduser(cache_dir)
87
+ os.makedirs(cache_dir, exist_ok=True)
81
88
  path = get_file(
82
89
  dirname,
83
90
  origin=origin,
@@ -85,6 +92,7 @@ def load_data():
85
92
  file_hash=( # noqa: E501
86
93
  "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
87
94
  ),
95
+ cache_dir=cache_dir,
88
96
  )
89
97
 
90
98
  num_train_samples = 50000
@@ -27,7 +27,7 @@ from tensorflow.python.util.tf_export import keras_export
27
27
 
28
28
 
29
29
  @keras_export("keras.datasets.cifar100.load_data")
30
- def load_data(label_mode="fine"):
30
+ def load_data(label_mode="fine", cache_dir=None):
31
31
  """Loads the CIFAR100 dataset.
32
32
 
33
33
  This is a dataset of 50,000 32x32 color training images and
@@ -39,6 +39,8 @@ def load_data(label_mode="fine"):
39
39
  label_mode: one of "fine", "coarse". If it is "fine" the category labels
40
40
  are the fine-grained labels, if it is "coarse" the output labels are the
41
41
  coarse-grained superclasses.
42
+ cache_dir: directory where to cache the dataset locally. When None,
43
+ defaults to `~/.keras/datasets`.
42
44
 
43
45
  Returns:
44
46
  Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
@@ -75,6 +77,9 @@ def load_data(label_mode="fine"):
75
77
 
76
78
  dirname = "cifar-100-python"
77
79
  origin = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
80
+ if cache_dir:
81
+ cache_dir = os.path.expanduser(cache_dir)
82
+ os.makedirs(cache_dir, exist_ok=True)
78
83
  path = get_file(
79
84
  dirname,
80
85
  origin=origin,
@@ -82,6 +87,7 @@ def load_data(label_mode="fine"):
82
87
  file_hash=( # noqa: E501
83
88
  "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
84
89
  ),
90
+ cache_dir=cache_dir,
85
91
  )
86
92
 
87
93
  fpath = os.path.join(path, "train")
@@ -26,7 +26,7 @@ from tensorflow.python.util.tf_export import keras_export
26
26
 
27
27
 
28
28
  @keras_export("keras.datasets.fashion_mnist.load_data")
29
- def load_data():
29
+ def load_data(cache_dir=None):
30
30
  """Loads the Fashion-MNIST dataset.
31
31
 
32
32
  This is a dataset of 60,000 28x28 grayscale images of 10 fashion categories,
@@ -48,6 +48,10 @@ def load_data():
48
48
  | 8 | Bag |
49
49
  | 9 | Ankle boot |
50
50
 
51
+ Args:
52
+ cache_dir: directory where to cache the dataset locally. When None,
53
+ defaults to `~/.keras/datasets`.
54
+
51
55
  Returns:
52
56
  Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
53
57
 
@@ -77,7 +81,6 @@ def load_data():
77
81
  The copyright for Fashion-MNIST is held by Zalando SE.
78
82
  Fashion-MNIST is licensed under the [MIT license](
79
83
  https://github.com/zalandoresearch/fashion-mnist/blob/master/LICENSE).
80
-
81
84
  """
82
85
  dirname = os.path.join("datasets", "fashion-mnist")
83
86
  base = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
@@ -87,10 +90,19 @@ def load_data():
87
90
  "t10k-labels-idx1-ubyte.gz",
88
91
  "t10k-images-idx3-ubyte.gz",
89
92
  ]
90
-
93
+ if cache_dir:
94
+ cache_dir = os.path.expanduser(cache_dir)
95
+ os.makedirs(cache_dir, exist_ok=True)
91
96
  paths = []
92
97
  for fname in files:
93
- paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname))
98
+ paths.append(
99
+ get_file(
100
+ fname,
101
+ origin=base + fname,
102
+ cache_dir=cache_dir,
103
+ cache_subdir=dirname,
104
+ )
105
+ )
94
106
 
95
107
  with gzip.open(paths[0], "rb") as lbpath:
96
108
  y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
@@ -15,6 +15,7 @@
15
15
  """IMDB sentiment classification dataset."""
16
16
 
17
17
  import json
18
+ import os
18
19
 
19
20
  import numpy as np
20
21
 
@@ -36,6 +37,7 @@ def load_data(
36
37
  start_char=1,
37
38
  oov_char=2,
38
39
  index_from=3,
40
+ cache_dir=None,
39
41
  **kwargs,
40
42
  ):
41
43
  """Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
@@ -73,6 +75,8 @@ def load_data(
73
75
  Words that were cut out because of the `num_words` or
74
76
  `skip_top` limits will be replaced with this character.
75
77
  index_from: int. Index actual words with this index and higher.
78
+ cache_dir: directory where to cache the dataset locally. When None,
79
+ defaults to `~/.keras/datasets`.
76
80
  **kwargs: Used for backwards compatibility.
77
81
 
78
82
  Returns:
@@ -108,12 +112,16 @@ def load_data(
108
112
  origin_folder = (
109
113
  "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
110
114
  )
115
+ if cache_dir:
116
+ cache_dir = os.path.expanduser(cache_dir)
117
+ os.makedirs(cache_dir, exist_ok=True)
111
118
  path = get_file(
112
119
  path,
113
120
  origin=origin_folder + "imdb.npz",
114
121
  file_hash=( # noqa: E501
115
122
  "69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f"
116
123
  ),
124
+ cache_dir=cache_dir,
117
125
  )
118
126
  with np.load(path, allow_pickle=True) as f:
119
127
  x_train, labels_train = f["x_train"], f["y_train"]
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  """MNIST handwritten digits dataset."""
16
+ import os
16
17
 
17
18
  import numpy as np
18
19
 
@@ -23,7 +24,7 @@ from tensorflow.python.util.tf_export import keras_export
23
24
 
24
25
 
25
26
  @keras_export("keras.datasets.mnist.load_data")
26
- def load_data(path="mnist.npz"):
27
+ def load_data(path="mnist.npz", cache_dir=None):
27
28
  """Loads the MNIST dataset.
28
29
 
29
30
  This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
@@ -32,8 +33,9 @@ def load_data(path="mnist.npz"):
32
33
  [MNIST homepage](http://yann.lecun.com/exdb/mnist/).
33
34
 
34
35
  Args:
35
- path: path where to cache the dataset locally
36
- (relative to `~/.keras/datasets`).
36
+ path: path where to cache the dataset locally relative to cache_dir.
37
+ cache_dir: dir location where to cache the dataset locally. When None,
38
+ defaults to `~/.keras/datasets`.
37
39
 
38
40
  Returns:
39
41
  Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
@@ -72,12 +74,16 @@ def load_data(path="mnist.npz"):
72
74
  origin_folder = (
73
75
  "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
74
76
  )
77
+ if cache_dir:
78
+ cache_dir = os.path.expanduser(cache_dir)
79
+ os.makedirs(cache_dir, exist_ok=True)
75
80
  path = get_file(
76
81
  path,
77
82
  origin=origin_folder + "mnist.npz",
78
83
  file_hash=( # noqa: E501
79
84
  "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"
80
85
  ),
86
+ cache_dir=cache_dir,
81
87
  )
82
88
  with np.load(path, allow_pickle=True) as f:
83
89
  x_train, y_train = f["x_train"], f["y_train"]
@@ -15,6 +15,7 @@
15
15
  """Reuters topic classification dataset."""
16
16
 
17
17
  import json
18
+ import os
18
19
 
19
20
  import numpy as np
20
21
 
@@ -37,6 +38,7 @@ def load_data(
37
38
  start_char=1,
38
39
  oov_char=2,
39
40
  index_from=3,
41
+ cache_dir=None,
40
42
  **kwargs,
41
43
  ):
42
44
  """Loads the Reuters newswire classification dataset.
@@ -83,6 +85,8 @@ def load_data(
83
85
  Words that were cut out because of the `num_words` or
84
86
  `skip_top` limits will be replaced with this character.
85
87
  index_from: int. Index actual words with this index and higher.
88
+ cache_dir: directory where to cache the dataset locally. When None,
89
+ defaults to `~/.keras/datasets`.
86
90
  **kwargs: Used for backwards compatibility.
87
91
 
88
92
  Returns:
@@ -114,12 +118,16 @@ def load_data(
114
118
  origin_folder = (
115
119
  "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
116
120
  )
121
+ if cache_dir:
122
+ cache_dir = os.path.expanduser(cache_dir)
123
+ os.makedirs(cache_dir, exist_ok=True)
117
124
  path = get_file(
118
125
  path,
119
126
  origin=origin_folder + "reuters.npz",
120
127
  file_hash=( # noqa: E501
121
128
  "d6586e694ee56d7a4e65172e12b3e987c03096cb01eab99753921ef915959916"
122
129
  ),
130
+ cache_dir=cache_dir,
123
131
  )
124
132
  with np.load(path, allow_pickle=True) as f:
125
133
  xs, labels = f["x"], f["y"]
@@ -578,7 +578,8 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
578
578
  Accepted values are constants defined in the class
579
579
  `tf.VariableAggregation`.
580
580
  **kwargs: Additional keyword arguments. Accepted values are `getter`,
581
- `collections`, `experimental_autocast` and `caching_device`.
581
+ `collections`, `autocast`, `experimental_autocast` and
582
+ `caching_device`.
582
583
 
583
584
  Returns:
584
585
  The variable created.
@@ -594,6 +595,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
594
595
  # Validate optional keyword arguments.
595
596
  for kwarg in kwargs:
596
597
  if kwarg not in [
598
+ "autocast",
597
599
  "collections",
598
600
  "experimental_autocast",
599
601
  "caching_device",
@@ -603,9 +605,13 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
603
605
  ]:
604
606
  raise TypeError("Unknown keyword argument:", kwarg)
605
607
  collections_arg = kwargs.pop("collections", None)
606
- # 'experimental_autocast' can be set to False by the caller to indicate
607
- # an AutoCastVariable should never be created.
608
- autocast = kwargs.pop("experimental_autocast", True)
608
+ # 'autocast' or 'experimental_autocast' can be set to False by the
609
+ # caller to indicate an AutoCastVariable should never be created.
610
+ autocast = kwargs.pop("autocast", None)
611
+ if autocast is None:
612
+ autocast = kwargs.pop("experimental_autocast", None)
613
+ if autocast is None:
614
+ autocast = True
609
615
  # See the docstring for tf.Variable about the details for
610
616
  # caching_device.
611
617
  caching_device = kwargs.pop("caching_device", None)
@@ -352,7 +352,8 @@ class Layer(base_layer.Layer):
352
352
  Accepted values are constants defined in the class
353
353
  `tf.VariableAggregation`.
354
354
  **kwargs: Additional keyword arguments. Accepted values are `getter`,
355
- `collections`, `experimental_autocast` and `caching_device`.
355
+ `collections`, `autocast`, `experimental_autocast` and
356
+ `caching_device`.
356
357
 
357
358
  Returns:
358
359
  The created variable. Usually either a `Variable` or
@@ -371,6 +372,7 @@ class Layer(base_layer.Layer):
371
372
  # Validate optional keyword arguments.
372
373
  for kwarg in kwargs:
373
374
  if kwarg not in [
375
+ "autocast",
374
376
  "getter",
375
377
  "collections",
376
378
  "experimental_autocast",
@@ -380,9 +382,13 @@ class Layer(base_layer.Layer):
380
382
  has_custom_getter = "getter" in kwargs
381
383
  getter = kwargs.pop("getter", base_layer_utils.make_variable)
382
384
  collections_arg = kwargs.pop("collections", None)
383
- # 'experimental_autocast' can be set to False by the caller to indicate
384
- # an AutoCastVariable should never be created.
385
- autocast = kwargs.pop("experimental_autocast", True)
385
+ # 'autocast' or 'experimental_autocast' can be set to False by the
386
+ # caller to indicate an AutoCastVariable should never be created.
387
+ autocast = kwargs.pop("autocast", None)
388
+ if autocast is None:
389
+ autocast = kwargs.pop("experimental_autocast", None)
390
+ if autocast is None:
391
+ autocast = True
386
392
  # See the docstring for tf.Variable about the details for
387
393
  # caching_device.
388
394
  caching_device = kwargs.pop("caching_device", None)
@@ -84,9 +84,10 @@ class Node:
84
84
  self.call_args = call_args
85
85
  self.call_kwargs = call_kwargs
86
86
 
87
- # Cached for performance.
87
+ # Cached for performance. Put kwargs in order of the call method instead
88
+ # of using the sorted key order from `tf.nest.flatten`.
88
89
  self._flat_arguments = tf.nest.flatten(
89
- (self.call_args, self.call_kwargs)
90
+ (self.call_args, self.call_kwargs.values())
90
91
  )
91
92
  # Used to avoid expensive `nest` operations in the most common case.
92
93
  self._single_positional_tensor_passed = (
@@ -176,9 +177,13 @@ class Node:
176
177
  for kt_id, kt_index in self._keras_inputs_ids_and_indices:
177
178
  flat_arguments[kt_index] = tensor_dict[kt_id].pop()
178
179
 
180
+ # Pack the same way as `self._flat_arguments`, i.e. `kwargs` as a
181
+ # list in the original order.
179
182
  args, kwargs = tf.nest.pack_sequence_as(
180
- (self.call_args, self.call_kwargs), flat_arguments
183
+ (self.call_args, self.call_kwargs.values()), flat_arguments
181
184
  )
185
+ # Add the keys to `kwargs` to go from a list to a dict.
186
+ kwargs = {k: v for k, v in zip(self.call_kwargs.keys(), kwargs)}
182
187
  return args, kwargs
183
188
 
184
189
  def serialize(self, make_node_key, node_conversion_map):
@@ -102,7 +102,7 @@ class PReLU(Layer):
102
102
  if i not in self.shared_axes:
103
103
  axes[i] = input_shape[i]
104
104
  self.input_spec = InputSpec(ndim=len(input_shape), axes=axes)
105
- self.built = True
105
+ super().build(input_shape)
106
106
 
107
107
  def call(self, inputs):
108
108
  pos = backend.relu(inputs)
@@ -86,7 +86,8 @@ class BaseDenseAttention(base_layer.BaseRandomLayer):
86
86
  # be purely stateless, with no reference to any variable.
87
87
  if self.dropout > 0:
88
88
  super().build(input_shape)
89
- self.built = True
89
+ else:
90
+ base_layer.Layer.build(self, input_shape)
90
91
 
91
92
  def _calculate_scores(self, query, key):
92
93
  """Calculates attention scores.
@@ -248,7 +248,7 @@ class Conv(Layer):
248
248
  self.input_spec = InputSpec(
249
249
  min_ndim=self.rank + 2, axes={channel_axis: input_channel}
250
250
  )
251
- self.built = True
251
+ super().build(input_shape)
252
252
 
253
253
  def convolution_op(self, inputs, kernel):
254
254
  if self.padding == "causal":
@@ -20,6 +20,7 @@ import tensorflow.compat.v2 as tf
20
20
  from tf_keras.src import constraints
21
21
  from tf_keras.src import initializers
22
22
  from tf_keras.src import regularizers
23
+ from tf_keras.src.engine.base_layer import Layer
23
24
  from tf_keras.src.engine.input_spec import InputSpec
24
25
  from tf_keras.src.layers.convolutional.base_conv import Conv
25
26
 
@@ -202,7 +203,8 @@ class DepthwiseConv(Conv):
202
203
  self.input_spec = InputSpec(
203
204
  min_ndim=self.rank + 2, axes={channel_axis: input_dim}
204
205
  )
205
- self.built = True
206
+ # Call Layer.build() to skip Conv.build() which we override here.
207
+ Layer.build(self, input_shape)
206
208
 
207
209
  def call(self, inputs):
208
210
  raise NotImplementedError
@@ -21,6 +21,7 @@ from tf_keras.src import activations
21
21
  from tf_keras.src import constraints
22
22
  from tf_keras.src import initializers
23
23
  from tf_keras.src import regularizers
24
+ from tf_keras.src.engine.base_layer import Layer
24
25
  from tf_keras.src.engine.input_spec import InputSpec
25
26
  from tf_keras.src.layers.convolutional.base_conv import Conv
26
27
 
@@ -203,7 +204,8 @@ class SeparableConv(Conv):
203
204
  )
204
205
  else:
205
206
  self.bias = None
206
- self.built = True
207
+ # Call Layer.build() to skip Conv.build() which we override here.
208
+ Layer.build(self, input_shape)
207
209
 
208
210
  def call(self, inputs):
209
211
  raise NotImplementedError
@@ -22,6 +22,7 @@ from tf_keras.src import constraints
22
22
  from tf_keras.src import initializers
23
23
  from tf_keras.src import regularizers
24
24
  from tf_keras.src.dtensor import utils
25
+ from tf_keras.src.engine.base_layer import Layer
25
26
  from tf_keras.src.engine.input_spec import InputSpec
26
27
  from tf_keras.src.layers.convolutional.conv1d import Conv1D
27
28
  from tf_keras.src.utils import conv_utils
@@ -214,7 +215,8 @@ class Conv1DTranspose(Conv1D):
214
215
  )
215
216
  else:
216
217
  self.bias = None
217
- self.built = True
218
+ # Call Layer.build() to skip Conv.build() which we override here.
219
+ Layer.build(self, input_shape)
218
220
 
219
221
  def call(self, inputs):
220
222
  inputs_shape = tf.shape(inputs)
@@ -23,6 +23,7 @@ from tf_keras.src import constraints
23
23
  from tf_keras.src import initializers
24
24
  from tf_keras.src import regularizers
25
25
  from tf_keras.src.dtensor import utils
26
+ from tf_keras.src.engine.base_layer import Layer
26
27
  from tf_keras.src.engine.input_spec import InputSpec
27
28
  from tf_keras.src.layers.convolutional.conv2d import Conv2D
28
29
  from tf_keras.src.utils import conv_utils
@@ -240,7 +241,8 @@ class Conv2DTranspose(Conv2D):
240
241
  )
241
242
  else:
242
243
  self.bias = None
243
- self.built = True
244
+ # Call Layer.build() to skip Conv.build() which we override here.
245
+ Layer.build(self, input_shape)
244
246
 
245
247
  def call(self, inputs):
246
248
  inputs_shape = tf.shape(inputs)
@@ -22,6 +22,7 @@ from tf_keras.src import constraints
22
22
  from tf_keras.src import initializers
23
23
  from tf_keras.src import regularizers
24
24
  from tf_keras.src.dtensor import utils
25
+ from tf_keras.src.engine.base_layer import Layer
25
26
  from tf_keras.src.engine.input_spec import InputSpec
26
27
  from tf_keras.src.layers.convolutional.conv3d import Conv3D
27
28
  from tf_keras.src.utils import conv_utils
@@ -247,7 +248,8 @@ class Conv3DTranspose(Conv3D):
247
248
  )
248
249
  else:
249
250
  self.bias = None
250
- self.built = True
251
+ # Call Layer.build() to skip Conv.build() which we override here.
252
+ Layer.build(self, input_shape)
251
253
 
252
254
  def call(self, inputs):
253
255
  inputs_shape = tf.shape(inputs)
@@ -174,7 +174,7 @@ class Dense(Layer):
174
174
  )
175
175
  else:
176
176
  self.bias = None
177
- self.built = True
177
+ super().build(input_shape)
178
178
 
179
179
  def call(self, inputs):
180
180
  if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
@@ -185,7 +185,7 @@ class Embedding(Layer):
185
185
  constraint=self.embeddings_constraint,
186
186
  experimental_autocast=False,
187
187
  )
188
- self.built = True
188
+ super().build(input_shape)
189
189
 
190
190
  def compute_mask(self, inputs, mask=None):
191
191
  if not self.mask_zero:
@@ -284,7 +284,7 @@ class LocallyConnected1D(Layer):
284
284
  self.input_spec = InputSpec(ndim=3, axes={1: input_dim})
285
285
  else:
286
286
  self.input_spec = InputSpec(ndim=3, axes={-1: input_dim})
287
- self.built = True
287
+ super().build(input_shape)
288
288
 
289
289
  @tf_utils.shape_type_conversion
290
290
  def compute_output_shape(self, input_shape):