tf-keras-nightly 2.20.0.dev2025051109__py3-none-any.whl → 2.21.0.dev2025123010__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. tf_keras/__init__.py +1 -1
  2. tf_keras/protobuf/projector_config_pb2.py +23 -12
  3. tf_keras/protobuf/saved_metadata_pb2.py +21 -10
  4. tf_keras/protobuf/versions_pb2.py +19 -8
  5. tf_keras/src/__init__.py +1 -1
  6. tf_keras/src/engine/base_layer.py +234 -96
  7. tf_keras/src/engine/base_layer_utils.py +17 -5
  8. tf_keras/src/engine/base_layer_v1.py +12 -3
  9. tf_keras/src/engine/data_adapter.py +30 -13
  10. tf_keras/src/engine/functional.py +36 -15
  11. tf_keras/src/engine/input_layer.py +9 -0
  12. tf_keras/src/engine/input_spec.py +11 -1
  13. tf_keras/src/layers/activation/softmax.py +26 -11
  14. tf_keras/src/layers/attention/multi_head_attention.py +8 -1
  15. tf_keras/src/layers/core/tf_op_layer.py +4 -0
  16. tf_keras/src/layers/rnn/cell_wrappers.py +13 -1
  17. tf_keras/src/metrics/confusion_metrics.py +51 -4
  18. tf_keras/src/models/sharpness_aware_minimization.py +17 -7
  19. tf_keras/src/saving/legacy/saved_model/save_impl.py +28 -12
  20. tf_keras/src/saving/legacy/saving_utils.py +14 -2
  21. tf_keras/src/saving/saving_lib.py +1 -1
  22. tf_keras/src/utils/layer_utils.py +45 -3
  23. tf_keras/src/utils/metrics_utils.py +4 -1
  24. {tf_keras_nightly-2.20.0.dev2025051109.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/METADATA +2 -2
  25. {tf_keras_nightly-2.20.0.dev2025051109.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/RECORD +27 -49
  26. {tf_keras_nightly-2.20.0.dev2025051109.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/WHEEL +1 -1
  27. tf_keras/src/layers/preprocessing/benchmarks/bucketized_column_dense_benchmark.py +0 -85
  28. tf_keras/src/layers/preprocessing/benchmarks/category_encoding_benchmark.py +0 -84
  29. tf_keras/src/layers/preprocessing/benchmarks/category_hash_dense_benchmark.py +0 -89
  30. tf_keras/src/layers/preprocessing/benchmarks/category_hash_varlen_benchmark.py +0 -89
  31. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_dense_benchmark.py +0 -110
  32. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_varlen_benchmark.py +0 -103
  33. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_dense_benchmark.py +0 -87
  34. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_dense_benchmark.py +0 -96
  35. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_varlen_benchmark.py +0 -96
  36. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_varlen_benchmark.py +0 -87
  37. tf_keras/src/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py +0 -109
  38. tf_keras/src/layers/preprocessing/benchmarks/embedding_dense_benchmark.py +0 -86
  39. tf_keras/src/layers/preprocessing/benchmarks/embedding_varlen_benchmark.py +0 -89
  40. tf_keras/src/layers/preprocessing/benchmarks/hashed_crossing_benchmark.py +0 -90
  41. tf_keras/src/layers/preprocessing/benchmarks/hashing_benchmark.py +0 -105
  42. tf_keras/src/layers/preprocessing/benchmarks/image_preproc_benchmark.py +0 -159
  43. tf_keras/src/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py +0 -135
  44. tf_keras/src/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py +0 -144
  45. tf_keras/src/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py +0 -124
  46. tf_keras/src/layers/preprocessing/benchmarks/weighted_embedding_varlen_benchmark.py +0 -99
  47. tf_keras/src/saving/legacy/saved_model/create_test_saved_model.py +0 -37
  48. tf_keras/src/tests/keras_doctest.py +0 -159
  49. {tf_keras_nightly-2.20.0.dev2025051109.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/top_level.txt +0 -0
tf_keras/__init__.py CHANGED
@@ -27,4 +27,4 @@ from tf_keras.src.engine.sequential import Sequential
27
27
  from tf_keras.src.engine.training import Model
28
28
 
29
29
 
30
- __version__ = "2.20.0.dev2025051109"
30
+ __version__ = "2.21.0.dev2025123010"
@@ -1,11 +1,22 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
3
4
  # source: tf_keras/protobuf/projector_config.proto
5
+ # Protobuf Python Version: 6.31.1
4
6
  """Generated protocol buffer code."""
5
- from google.protobuf.internal import builder as _builder
6
7
  from google.protobuf import descriptor as _descriptor
7
8
  from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
8
10
  from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 6,
15
+ 31,
16
+ 1,
17
+ '',
18
+ 'tf_keras/protobuf/projector_config.proto'
19
+ )
9
20
  # @@protoc_insertion_point(imports)
10
21
 
11
22
  _sym_db = _symbol_database.Default()
@@ -15,15 +26,15 @@ _sym_db = _symbol_database.Default()
15
26
 
16
27
  DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(tf_keras/protobuf/projector_config.proto\x12 third_party.py.tf_keras.protobuf\">\n\x0eSpriteMetadata\x12\x12\n\nimage_path\x18\x01 \x01(\t\x12\x18\n\x10single_image_dim\x18\x02 \x03(\r\"\xc0\x01\n\rEmbeddingInfo\x12\x13\n\x0btensor_name\x18\x01 \x01(\t\x12\x15\n\rmetadata_path\x18\x02 \x01(\t\x12\x16\n\x0e\x62ookmarks_path\x18\x03 \x01(\t\x12\x14\n\x0ctensor_shape\x18\x04 \x03(\r\x12@\n\x06sprite\x18\x05 \x01(\x0b\x32\x30.third_party.py.tf_keras.protobuf.SpriteMetadata\x12\x13\n\x0btensor_path\x18\x06 \x01(\t\"\x93\x01\n\x0fProjectorConfig\x12\x1d\n\x15model_checkpoint_path\x18\x01 \x01(\t\x12\x43\n\nembeddings\x18\x02 \x03(\x0b\x32/.third_party.py.tf_keras.protobuf.EmbeddingInfo\x12\x1c\n\x14model_checkpoint_dir\x18\x03 \x01(\tb\x06proto3')
17
28
 
18
- _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
19
- _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.projector_config_pb2', globals())
20
- if _descriptor._USE_C_DESCRIPTORS == False:
21
-
22
- DESCRIPTOR._options = None
23
- _SPRITEMETADATA._serialized_start=78
24
- _SPRITEMETADATA._serialized_end=140
25
- _EMBEDDINGINFO._serialized_start=143
26
- _EMBEDDINGINFO._serialized_end=335
27
- _PROJECTORCONFIG._serialized_start=338
28
- _PROJECTORCONFIG._serialized_end=485
29
+ _globals = globals()
30
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.projector_config_pb2', _globals)
32
+ if not _descriptor._USE_C_DESCRIPTORS:
33
+ DESCRIPTOR._loaded_options = None
34
+ _globals['_SPRITEMETADATA']._serialized_start=78
35
+ _globals['_SPRITEMETADATA']._serialized_end=140
36
+ _globals['_EMBEDDINGINFO']._serialized_start=143
37
+ _globals['_EMBEDDINGINFO']._serialized_end=335
38
+ _globals['_PROJECTORCONFIG']._serialized_start=338
39
+ _globals['_PROJECTORCONFIG']._serialized_end=485
29
40
  # @@protoc_insertion_point(module_scope)
@@ -1,11 +1,22 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
3
4
  # source: tf_keras/protobuf/saved_metadata.proto
5
+ # Protobuf Python Version: 6.31.1
4
6
  """Generated protocol buffer code."""
5
- from google.protobuf.internal import builder as _builder
6
7
  from google.protobuf import descriptor as _descriptor
7
8
  from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
8
10
  from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 6,
15
+ 31,
16
+ 1,
17
+ '',
18
+ 'tf_keras/protobuf/saved_metadata.proto'
19
+ )
9
20
  # @@protoc_insertion_point(imports)
10
21
 
11
22
  _sym_db = _symbol_database.Default()
@@ -16,13 +27,13 @@ from tf_keras.protobuf import versions_pb2 as tf__keras_dot_protobuf_dot_version
16
27
 
17
28
  DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&tf_keras/protobuf/saved_metadata.proto\x12 third_party.py.tf_keras.protobuf\x1a tf_keras/protobuf/versions.proto\"M\n\rSavedMetadata\x12<\n\x05nodes\x18\x01 \x03(\x0b\x32-.third_party.py.tf_keras.protobuf.SavedObject\"\x9c\x01\n\x0bSavedObject\x12\x0f\n\x07node_id\x18\x02 \x01(\x05\x12\x11\n\tnode_path\x18\x03 \x01(\t\x12\x12\n\nidentifier\x18\x04 \x01(\t\x12\x10\n\x08metadata\x18\x05 \x01(\t\x12=\n\x07version\x18\x06 \x01(\x0b\x32,.third_party.py.tf_keras.protobuf.VersionDefJ\x04\x08\x01\x10\x02\x62\x06proto3')
18
29
 
19
- _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
20
- _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.saved_metadata_pb2', globals())
21
- if _descriptor._USE_C_DESCRIPTORS == False:
22
-
23
- DESCRIPTOR._options = None
24
- _SAVEDMETADATA._serialized_start=110
25
- _SAVEDMETADATA._serialized_end=187
26
- _SAVEDOBJECT._serialized_start=190
27
- _SAVEDOBJECT._serialized_end=346
30
+ _globals = globals()
31
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
32
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.saved_metadata_pb2', _globals)
33
+ if not _descriptor._USE_C_DESCRIPTORS:
34
+ DESCRIPTOR._loaded_options = None
35
+ _globals['_SAVEDMETADATA']._serialized_start=110
36
+ _globals['_SAVEDMETADATA']._serialized_end=187
37
+ _globals['_SAVEDOBJECT']._serialized_start=190
38
+ _globals['_SAVEDOBJECT']._serialized_end=346
28
39
  # @@protoc_insertion_point(module_scope)
@@ -1,11 +1,22 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
3
4
  # source: tf_keras/protobuf/versions.proto
5
+ # Protobuf Python Version: 6.31.1
4
6
  """Generated protocol buffer code."""
5
- from google.protobuf.internal import builder as _builder
6
7
  from google.protobuf import descriptor as _descriptor
7
8
  from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
8
10
  from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 6,
15
+ 31,
16
+ 1,
17
+ '',
18
+ 'tf_keras/protobuf/versions.proto'
19
+ )
9
20
  # @@protoc_insertion_point(imports)
10
21
 
11
22
  _sym_db = _symbol_database.Default()
@@ -15,11 +26,11 @@ _sym_db = _symbol_database.Default()
15
26
 
16
27
  DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n tf_keras/protobuf/versions.proto\x12 third_party.py.tf_keras.protobuf\"K\n\nVersionDef\x12\x10\n\x08producer\x18\x01 \x01(\x05\x12\x14\n\x0cmin_consumer\x18\x02 \x01(\x05\x12\x15\n\rbad_consumers\x18\x03 \x03(\x05\x62\x06proto3')
17
28
 
18
- _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
19
- _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.versions_pb2', globals())
20
- if _descriptor._USE_C_DESCRIPTORS == False:
21
-
22
- DESCRIPTOR._options = None
23
- _VERSIONDEF._serialized_start=70
24
- _VERSIONDEF._serialized_end=145
29
+ _globals = globals()
30
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.versions_pb2', _globals)
32
+ if not _descriptor._USE_C_DESCRIPTORS:
33
+ DESCRIPTOR._loaded_options = None
34
+ _globals['_VERSIONDEF']._serialized_start=70
35
+ _globals['_VERSIONDEF']._serialized_end=145
25
36
  # @@protoc_insertion_point(module_scope)
tf_keras/src/__init__.py CHANGED
@@ -35,7 +35,7 @@ from tf_keras.src.testing_infra import test_utils
35
35
  from tensorflow.python import tf2
36
36
  from tensorflow.python.util.tf_export import keras_export
37
37
 
38
- __version__ = "2.20.0"
38
+ __version__ = "2.21.0"
39
39
 
40
40
  keras_export("keras.__version__").export_constant(__name__, "__version__")
41
41
 
@@ -308,6 +308,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
308
308
  self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs
309
309
  ):
310
310
  self._instrument_layer_creation()
311
+ self._called = False
311
312
 
312
313
  # These properties should be set by the user via keyword arguments.
313
314
  # note that 'dtype', 'input_shape' and 'batch_input_shape'
@@ -326,6 +327,10 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
326
327
  # Validate optional keyword arguments.
327
328
  generic_utils.validate_kwargs(kwargs, allowed_kwargs)
328
329
 
330
+ # Track the built-in call-context arguments. These are arguments that
331
+ # are tracked and propagated across the call-stack by default.
332
+ self._call_context_args = {"training"}
333
+
329
334
  # Mutable properties
330
335
  # Indicates whether the layer's weights are updated during training
331
336
  # and whether the layer's updates are run during training.
@@ -411,6 +416,9 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
411
416
 
412
417
  self._init_call_fn_args()
413
418
 
419
+ # Track the built-in call-context arguments.
420
+ self._call_spec._update_call_context_arguments(self._call_context_args)
421
+
414
422
  # Whether the `call` method can be used to build a TF graph without
415
423
  # issues. This attribute has no effect if the model is created using
416
424
  # the Functional API. Instead, `model.dynamic` is determined based on
@@ -1042,6 +1050,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
1042
1050
  # - input_spec compatibility is only checked against `inputs`
1043
1051
  # - mixed precision casting (autocast) is only applied to `inputs`,
1044
1052
  # not to any other argument.
1053
+ self._called = True
1045
1054
  inputs, args, kwargs = self._call_spec.split_out_first_arg(args, kwargs)
1046
1055
  input_list = tf.nest.flatten(inputs)
1047
1056
 
@@ -1080,17 +1089,21 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
1080
1089
  if self._expects_mask_arg and mask_is_implicit:
1081
1090
  kwargs["mask"] = input_masks
1082
1091
 
1083
- # Training mode for `Layer.call` is set via (in order of priority):
1084
- # (1) The `training` argument passed to this `Layer.call`, if it is not
1085
- # None
1086
- # (2) The training mode of an outer `Layer.call`.
1087
- # (3) The default mode set by `tf.keras.backend.set_learning_phase` (if
1088
- # set)
1089
- # (4) Any non-None default value for `training` specified in the call
1092
+ # Call-context arguments for `Layer.call` is set via (in order of
1093
+ # priority):
1094
+ # (1) The argument passed to this `Layer.call`, if it is not None
1095
+ # (2) The argument value of an outer `Layer.call`.
1096
+ # (3) (only for "training") The default mode set by
1097
+ # `tf.keras.backend.set_learning_phase` (if set)
1098
+ # (4) Any non-None default value for the argument specified in the call
1090
1099
  # signature
1091
- # (5) False (treating the layer as if it's in inference)
1092
- args, kwargs, training_mode = self._set_training_mode(
1093
- args, kwargs, call_context
1100
+ # (5) False
1101
+ (
1102
+ args,
1103
+ kwargs,
1104
+ propagated,
1105
+ ) = self._get_propagated_call_context_arguments(
1106
+ args, kwargs, call_context, self._call_context_args
1094
1107
  )
1095
1108
 
1096
1109
  # Losses are cleared for all sublayers on the outermost `Layer.call`.
@@ -1104,7 +1117,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
1104
1117
  layer=self,
1105
1118
  inputs=inputs,
1106
1119
  build_graph=not eager,
1107
- training=training_mode,
1120
+ call_context_args=propagated,
1108
1121
  ):
1109
1122
  input_spec.assert_input_compatibility(
1110
1123
  self.input_spec, inputs, self.name
@@ -1152,6 +1165,55 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
1152
1165
 
1153
1166
  return outputs
1154
1167
 
1168
+ def _register_call_context_args(self, *argument_names):
1169
+ """Registers call-context args for this layer.
1170
+ If this layer declares a `call()` method that accepts
1171
+ one or more of the given args, those args will be
1172
+ automatically injected into the call signature of this
1173
+ layer. This layer will also propagate the args to any
1174
+ nested sublayers that are called from within this layer.
1175
+ If this layer doesn't declare a `call()` method that
1176
+ accepts one or more of the given args, these args will
1177
+ simply be propagated to any nested sublayers without
1178
+ being injected into the call signature of this layer.
1179
+ This is useful for propagating custom arguments
1180
+ from top-level layers/models to sublayers.
1181
+ Example:
1182
+ ```
1183
+ class Inner(layers.Layer):
1184
+ def __init__(self):
1185
+ super().__init__()
1186
+ # Register `foo_mode` as a call-context arg
1187
+ self._register_call_context_args("foo_mode")
1188
+ def call(self, x, foo_mode=False):
1189
+ # If foo_mode=True add 1, otherwise add 0
1190
+ add_val = ops.where(foo_mode, 1.0, 0.0)
1191
+ return x + add_val
1192
+ class Outer(layers.Layer):
1193
+ def __init__(self):
1194
+ super().__init__()
1195
+ self.inner = Inner()
1196
+ def call(self, x):
1197
+ # We don't explicitly pass foo_mode here—Base Layer.__call__
1198
+ # should inject it into `self.inner`
1199
+ return self.inner(x)
1200
+ sample_input = np.array([[1.0], [2.0]])
1201
+ # Sequential model
1202
+ seq = models.Sequential([Outer()])
1203
+ # Tell the Sequential model to propagate foo_mode down
1204
+ # the call-stack
1205
+ seq.register_call_context_args("foo_mode")
1206
+ # foo_mode=True -> input + 1
1207
+ out_true = seq(sample_input, foo_mode=True)
1208
+ """
1209
+ if self._called:
1210
+ raise RuntimeError(
1211
+ "Cannot add call-context args after the layer has been called."
1212
+ )
1213
+ self._call_context_args |= set(argument_names)
1214
+ self._call_spec._update_call_context_arguments(argument_names)
1215
+ self._call_spec._update_call_context_argument_defaults(argument_names)
1216
+
1155
1217
  def _get_unnested_name_scope(self):
1156
1218
  if _is_name_scope_on_model_declaration_enabled:
1157
1219
  with _name_scope_unnester(
@@ -2535,47 +2597,57 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
2535
2597
  kwargs["mask"] = input_masks
2536
2598
  mask_arg_passed_by_framework = True
2537
2599
 
2538
- # If `training` argument is None or not explicitly passed,
2539
- # propagate `training` value from this layer's calling layer.
2540
- training_value = None
2541
- training_arg_passed_by_framework = False
2542
- # Priority 1: `training` was explicitly passed a non-None value.
2543
- if self._call_spec.arg_was_passed("training", args, kwargs):
2544
- training_value = self._call_spec.get_arg_value(
2545
- "training", args, kwargs
2546
- )
2547
- if not self._expects_training_arg:
2548
- kwargs.pop("training")
2549
-
2550
- if training_value is None:
2551
- # Priority 2: `training` was passed to a parent layer.
2552
- if call_context.training is not None:
2553
- training_value = call_context.training
2554
- # Priority 3: `learning_phase()` has been set.
2555
- elif backend.global_learning_phase_is_set():
2556
- training_value = backend.learning_phase()
2557
- # Force the training_value to be bool type which matches to the
2558
- # contract for layer/model call args.
2559
- if tf.is_tensor(training_value):
2560
- training_value = tf.cast(training_value, tf.bool)
2600
+ propagated = dict()
2601
+ args_passed_by_framework = dict()
2602
+ for context_arg in self._call_context_args:
2603
+ # If `training` argument is None or not explicitly passed,
2604
+ # propagate `training` value from this layer's calling layer.
2605
+ value = None
2606
+ args_passed_by_framework[context_arg] = False
2607
+ # Priority 1: `training` was explicitly passed a non-None value.
2608
+ if self._call_spec.arg_was_passed(context_arg, args, kwargs):
2609
+ value = self._call_spec.get_arg_value(context_arg, args, kwargs)
2610
+ if not self._expects_context_arg(context_arg):
2611
+ kwargs.pop(context_arg)
2612
+
2613
+ if value is None:
2614
+ # Priority 2: `training` was passed to a parent layer.
2615
+ if call_context.get_call_context_arg(context_arg) is not None:
2616
+ value = call_context.get_call_context_arg(context_arg)
2617
+ # Priority 3: `learning_phase()` has been set.
2618
+ elif (
2619
+ context_arg == "training"
2620
+ and backend.global_learning_phase_is_set()
2621
+ ):
2622
+ value = backend.learning_phase()
2623
+ # Force the training_value to be bool type which matches to
2624
+ # the contract for layer/model call args.
2625
+ if tf.is_tensor(value):
2626
+ value = tf.cast(value, tf.bool)
2627
+ else:
2628
+ value = bool(value)
2629
+ # Priority 4: trace layer with the default training argument
2630
+ # specified in the `call` signature (or in inference mode if the
2631
+ # `call` signature specifies no non-None default).
2561
2632
  else:
2562
- training_value = bool(training_value)
2563
- # Priority 4: trace layer with the default training argument
2564
- # specified in the `call` signature (or in inference mode if the
2565
- # `call` signature specifies no non-None default).
2566
- else:
2567
- training_value = self._call_spec.default_training_arg
2568
- # In cases (2), (3), (4) the training argument is passed
2569
- # automatically by the framework, and will not be hard-coded into
2570
- # the model.
2571
- if self._expects_training_arg:
2572
- args, kwargs = self._call_spec.set_arg_value(
2573
- "training", training_value, args, kwargs
2574
- )
2575
- training_arg_passed_by_framework = True
2633
+ value = self._call_spec.get_context_arg_default(context_arg)
2634
+ # In cases (2), (3), (4) the training argument is passed
2635
+ # automatically by the framework, and will not be hard-coded
2636
+ # into the model.
2637
+ if self._expects_context_arg(context_arg):
2638
+ args, kwargs = self._call_spec.set_arg_value(
2639
+ context_arg, value, args, kwargs
2640
+ )
2641
+ args_passed_by_framework[context_arg] = True
2642
+
2643
+ if value is not None:
2644
+ propagated[context_arg] = value
2576
2645
 
2577
2646
  with call_context.enter(
2578
- layer=self, inputs=inputs, build_graph=True, training=training_value
2647
+ layer=self,
2648
+ inputs=inputs,
2649
+ build_graph=True,
2650
+ call_context_args=propagated,
2579
2651
  ):
2580
2652
  # Check input assumptions set after layer building, e.g. input
2581
2653
  # shape.
@@ -2601,10 +2673,13 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
2601
2673
  "Tensor or a list of Tensors, not None "
2602
2674
  "(layer: " + self.name + ")."
2603
2675
  )
2604
- if training_arg_passed_by_framework:
2605
- args, kwargs = self._call_spec.set_arg_value(
2606
- "training", None, args, kwargs, pop_kwarg_if_none=True
2607
- )
2676
+
2677
+ for context_arg, is_passed in args_passed_by_framework.items():
2678
+ if is_passed:
2679
+ args, kwargs = self._call_spec.set_arg_value(
2680
+ context_arg, None, args, kwargs, pop_kwarg_if_none=True
2681
+ )
2682
+
2608
2683
  if mask_arg_passed_by_framework:
2609
2684
  kwargs.pop("mask")
2610
2685
  # Node connectivity does not special-case the first argument.
@@ -2613,52 +2688,100 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
2613
2688
  )
2614
2689
  return outputs
2615
2690
 
2616
- def _set_training_mode(self, args, kwargs, call_context):
2617
- training_mode = None
2618
- if self._expects_training_arg:
2619
- # (1) `training` was passed to this `Layer.call`.
2620
- if self._call_spec.arg_was_passed("training", args, kwargs):
2621
- training_mode = self._call_spec.get_arg_value(
2622
- "training", args, kwargs
2623
- )
2624
- # If no `training` arg was passed, or `None` was explicitly passed,
2625
- # the framework will make a decision about the training mode is.
2626
- if training_mode is None:
2627
- call_ctx_training = call_context.training
2628
- # (2) `training` mode is inferred from an outer `Layer.call`.
2629
- if call_ctx_training is not None:
2630
- training_mode = call_ctx_training
2631
- # (3) User set `tf.keras.backend.set_learning_phase`.
2632
- elif backend.global_learning_phase_is_set():
2633
- training_mode = backend.learning_phase()
2634
- # Ensure value is a `bool` or `tf.bool`.
2635
- if isinstance(training_mode, bool):
2636
- pass
2637
- elif tf.is_tensor(training_mode):
2638
- training_mode = tf.cast(training_mode, tf.bool)
2691
+ def _get_propagated_call_context_arguments(
2692
+ self, args, kwargs, call_context, local_call_context_arguments
2693
+ ):
2694
+ """Resolves the values for propagated call context arguments for the
2695
+ current layer.
2696
+
2697
+ Args:
2698
+ args: The arguments passed to the current layer's `call` method.
2699
+ kwargs: The keyword arguments passed to the current layer's `call`
2700
+ method.
2701
+ call_context: The `CallContext` for the current call-stack.
2702
+ local_call_context_arguments: The call-context arguments registered
2703
+ with to the current layer's `Layer.call` method.
2704
+
2705
+ Returns:
2706
+ A tuple of the following:
2707
+ 1. Updated args
2708
+ 2. Updated kwargs
2709
+ 3. A dictionary of the resolved call-context arguments that should
2710
+ be propagated to the next layer in the call-stack.
2711
+ """
2712
+ propagated_context = dict()
2713
+ relevant_arguments = call_context.call_context_args.keys() | set(
2714
+ local_call_context_arguments
2715
+ )
2716
+
2717
+ for argument in relevant_arguments:
2718
+ authoritative_value = None
2719
+ was_explicitly_passed = self._call_spec.arg_was_passed(
2720
+ argument, args, kwargs
2721
+ )
2722
+ if self._expects_context_arg(argument):
2723
+ # (1) `arg_name` was passed to this `Layer.call`.
2724
+ if was_explicitly_passed:
2725
+ authoritative_value = self._call_spec.get_arg_value(
2726
+ argument, args, kwargs
2727
+ )
2728
+ # If no `arg_name` arg was passed, or `None` was explicitly
2729
+ # passed, the framework will make a decision about the training
2730
+ # mode is.
2731
+ if authoritative_value is None:
2732
+ value_from_context = call_context.get_call_context_arg(
2733
+ argument
2734
+ )
2735
+ # (2) `arg_name` mode is inferred from an outer
2736
+ # `Layer.call`.
2737
+ if value_from_context is not None:
2738
+ authoritative_value = value_from_context
2739
+ # (3) User set `tf.keras.backend.set_learning_phase`.
2740
+ elif (
2741
+ argument == "training"
2742
+ and backend.global_learning_phase_is_set()
2743
+ ):
2744
+ authoritative_value = backend.learning_phase()
2745
+ # Ensure value is a `bool` or `tf.bool`.
2746
+ if isinstance(authoritative_value, bool):
2747
+ pass
2748
+ elif tf.is_tensor(authoritative_value):
2749
+ authoritative_value = tf.cast(
2750
+ authoritative_value, tf.bool
2751
+ )
2752
+ else:
2753
+ authoritative_value = bool(authoritative_value)
2754
+ # (4) We default to using `call`'s default value for
2755
+ # `arg_name`, or treating the layer as if it is in inference
2756
+ # if no non-None default is specified in the `call`
2757
+ # signature.
2639
2758
  else:
2640
- training_mode = bool(training_mode)
2641
- # (4) We default to using `call`'s default value for `training`,
2642
- # or treating the layer as if it is in inference if no non-None
2643
- # default is specified in the `call` signature.
2644
- else:
2645
- training_mode = self._call_spec.default_training_arg
2759
+ authoritative_value = (
2760
+ self._call_spec.get_context_arg_default(argument)
2761
+ )
2646
2762
 
2647
- # For case (2), (3), (4) `training` arg is passed by framework.
2648
- args, kwargs = self._call_spec.set_arg_value(
2649
- "training", training_mode, args, kwargs
2650
- )
2651
- else:
2652
- if "training" in kwargs:
2653
- # `training` was passed to this `Layer` but is not needed for
2654
- # `Layer.call`. It will set the default mode for inner
2655
- # `Layer.call`s.
2656
- training_mode = kwargs.pop("training")
2763
+ # For case (2), (3), (4) `arg_name` arg is passed by
2764
+ # framework.
2765
+ args, kwargs = self._call_spec.set_arg_value(
2766
+ argument, authoritative_value, args, kwargs
2767
+ )
2657
2768
  else:
2658
- # Grab the current `training` mode from any outer `Layer.call`.
2659
- training_mode = call_context.training
2769
+ if argument in kwargs:
2770
+ # `arg_name` was passed to this `Layer` but is not needed
2771
+ # for `Layer.call`. It will set the default mode for inner
2772
+ # `Layer.call`s.
2773
+ authoritative_value = kwargs.pop(argument)
2774
+ else:
2775
+ # Grab the current `arg_name` mode from any outer
2776
+ # `Layer.call`.
2777
+ authoritative_value = call_context.get_call_context_arg(
2778
+ argument
2779
+ )
2780
+
2781
+ if authoritative_value is not None:
2782
+ propagated_context[argument] = authoritative_value
2660
2783
 
2661
- return args, kwargs, training_mode
2784
+ return args, kwargs, propagated_context
2662
2785
 
2663
2786
  def _autographed_call(self):
2664
2787
  # Wrapping `call` function in autograph to allow for dynamic control
@@ -3351,7 +3474,8 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
3351
3474
 
3352
3475
  def _init_call_fn_args(self, expects_training_arg=None):
3353
3476
  self._call_spec = layer_utils.CallFunctionSpec(
3354
- tf_inspect.getfullargspec(self.call)
3477
+ tf_inspect.getfullargspec(self.call),
3478
+ getattr(self, "_call_context_args", set()),
3355
3479
  )
3356
3480
  if expects_training_arg is not None:
3357
3481
  self._call_spec.expects_training_arg = expects_training_arg
@@ -3361,6 +3485,9 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
3361
3485
  """Whether the call function uses 'training' as a parameter."""
3362
3486
  return self._call_spec.expects_training_arg
3363
3487
 
3488
+ def _expects_context_arg(self, argument_name):
3489
+ return argument_name in self._call_spec.expected_context_args
3490
+
3364
3491
  @property
3365
3492
  def _expects_mask_arg(self):
3366
3493
  return self._call_spec.expects_mask_arg
@@ -3805,6 +3932,17 @@ class BaseRandomLayer(Layer):
3805
3932
  super().build(input_shape)
3806
3933
  self._random_generator._maybe_init()
3807
3934
 
3935
+ def get_config(self):
3936
+ base_config = super().get_config()
3937
+ if (
3938
+ self._random_generator._rng_type
3939
+ == backend.RandomGenerator.RNG_LEGACY_STATEFUL
3940
+ ):
3941
+ return base_config
3942
+
3943
+ config = {"rng_type": self._random_generator._rng_type}
3944
+ return dict(list(base_config.items()) + list(config.items()))
3945
+
3808
3946
  def _trackable_children(self, save_type="checkpoint", **kwargs):
3809
3947
  if save_type == "savedmodel":
3810
3948
  cache = kwargs["cache"]
@@ -480,7 +480,8 @@ class CallContext:
480
480
  layer: The `Layer` whose `call` is currently active.
481
481
  inputs: The inputs to the currently active `Layer`.
482
482
  build_graph: Whether currently inside a Graph or FuncGraph.
483
- training: Whether currently executing in training or inference mode.
483
+ call_context_args: The call-context arguments being propagated through the
484
+ the call-stack.
484
485
  saving: Whether currently saving to SavedModel.
485
486
  frozen: Whether currently executing inside a `Layer` with `trainable` set
486
487
  to `False`.
@@ -495,6 +496,7 @@ class CallContext:
495
496
  "layer": None,
496
497
  "inputs": None,
497
498
  "build_graph": False,
499
+ "call_context_args": dict(),
498
500
  "training": None,
499
501
  "saving": None,
500
502
  }
@@ -502,14 +504,17 @@ class CallContext:
502
504
  # refactor.
503
505
  self._in_keras_graph = False
504
506
 
505
- def enter(self, layer, inputs, build_graph, training, saving=None):
507
+ def enter(
508
+ self, layer, inputs, build_graph, call_context_args=dict(), saving=None
509
+ ):
506
510
  """Push a Layer and its inputs and state onto the current call context.
507
511
 
508
512
  Args:
509
513
  layer: The `Layer` whose `call` is currently active.
510
514
  inputs: The inputs to the currently active `Layer`.
511
515
  build_graph: Whether currently inside a Graph or FuncGraph.
512
- training: Whether currently executing in training or inference mode.
516
+ call_context_args: The call-context arguments being propagated through
517
+ the call-stack.
513
518
  saving: Whether currently saving to SavedModel.
514
519
 
515
520
  Returns:
@@ -519,7 +524,7 @@ class CallContext:
519
524
  "layer": layer,
520
525
  "inputs": inputs,
521
526
  "build_graph": build_graph,
522
- "training": training,
527
+ "call_context_args": call_context_args,
523
528
  "saving": saving,
524
529
  }
525
530
  return CallContextManager(self, state)
@@ -538,7 +543,14 @@ class CallContext:
538
543
 
539
544
  @property
540
545
  def training(self):
541
- return self._state["training"]
546
+ return self.call_context_args.get("training", None)
547
+
548
+ @property
549
+ def call_context_args(self):
550
+ return self._state["call_context_args"]
551
+
552
+ def get_call_context_arg(self, arg_name):
553
+ return self.call_context_args.get(arg_name, None)
542
554
 
543
555
  @property
544
556
  def saving(self):