tf-keras-nightly 2.19.0.dev2024121210__py3-none-any.whl → 2.21.0.dev2025123010__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tf_keras/__init__.py +1 -1
- tf_keras/protobuf/projector_config_pb2.py +23 -12
- tf_keras/protobuf/saved_metadata_pb2.py +21 -10
- tf_keras/protobuf/versions_pb2.py +19 -8
- tf_keras/src/__init__.py +1 -1
- tf_keras/src/backend.py +1 -1
- tf_keras/src/datasets/boston_housing.py +14 -5
- tf_keras/src/datasets/cifar10.py +9 -1
- tf_keras/src/datasets/cifar100.py +7 -1
- tf_keras/src/datasets/fashion_mnist.py +16 -4
- tf_keras/src/datasets/imdb.py +8 -0
- tf_keras/src/datasets/mnist.py +9 -3
- tf_keras/src/datasets/reuters.py +8 -0
- tf_keras/src/engine/base_layer.py +235 -97
- tf_keras/src/engine/base_layer_utils.py +17 -5
- tf_keras/src/engine/base_layer_v1.py +12 -3
- tf_keras/src/engine/data_adapter.py +35 -19
- tf_keras/src/engine/functional.py +36 -15
- tf_keras/src/engine/input_layer.py +9 -0
- tf_keras/src/engine/input_spec.py +11 -1
- tf_keras/src/engine/sequential.py +29 -12
- tf_keras/src/layers/activation/softmax.py +26 -11
- tf_keras/src/layers/attention/multi_head_attention.py +8 -1
- tf_keras/src/layers/core/tf_op_layer.py +4 -0
- tf_keras/src/layers/normalization/spectral_normalization.py +29 -22
- tf_keras/src/layers/rnn/cell_wrappers.py +13 -1
- tf_keras/src/metrics/confusion_metrics.py +51 -4
- tf_keras/src/models/sharpness_aware_minimization.py +17 -7
- tf_keras/src/preprocessing/sequence.py +2 -2
- tf_keras/src/saving/legacy/saved_model/save_impl.py +28 -12
- tf_keras/src/saving/legacy/saving_utils.py +14 -2
- tf_keras/src/saving/saving_api.py +18 -5
- tf_keras/src/saving/saving_lib.py +1 -1
- tf_keras/src/utils/layer_utils.py +45 -3
- tf_keras/src/utils/metrics_utils.py +4 -1
- tf_keras/src/utils/tf_utils.py +2 -2
- {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/METADATA +14 -3
- {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/RECORD +40 -62
- {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/WHEEL +1 -1
- tf_keras/src/layers/preprocessing/benchmarks/bucketized_column_dense_benchmark.py +0 -85
- tf_keras/src/layers/preprocessing/benchmarks/category_encoding_benchmark.py +0 -84
- tf_keras/src/layers/preprocessing/benchmarks/category_hash_dense_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/category_hash_varlen_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_dense_benchmark.py +0 -110
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_varlen_benchmark.py +0 -103
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_dense_benchmark.py +0 -87
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_dense_benchmark.py +0 -96
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_varlen_benchmark.py +0 -96
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_varlen_benchmark.py +0 -87
- tf_keras/src/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py +0 -109
- tf_keras/src/layers/preprocessing/benchmarks/embedding_dense_benchmark.py +0 -86
- tf_keras/src/layers/preprocessing/benchmarks/embedding_varlen_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/hashed_crossing_benchmark.py +0 -90
- tf_keras/src/layers/preprocessing/benchmarks/hashing_benchmark.py +0 -105
- tf_keras/src/layers/preprocessing/benchmarks/image_preproc_benchmark.py +0 -159
- tf_keras/src/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py +0 -135
- tf_keras/src/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py +0 -144
- tf_keras/src/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py +0 -124
- tf_keras/src/layers/preprocessing/benchmarks/weighted_embedding_varlen_benchmark.py +0 -99
- tf_keras/src/saving/legacy/saved_model/create_test_saved_model.py +0 -37
- tf_keras/src/tests/keras_doctest.py +0 -159
- {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/top_level.txt +0 -0
|
@@ -308,6 +308,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
308
308
|
self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs
|
|
309
309
|
):
|
|
310
310
|
self._instrument_layer_creation()
|
|
311
|
+
self._called = False
|
|
311
312
|
|
|
312
313
|
# These properties should be set by the user via keyword arguments.
|
|
313
314
|
# note that 'dtype', 'input_shape' and 'batch_input_shape'
|
|
@@ -326,6 +327,10 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
326
327
|
# Validate optional keyword arguments.
|
|
327
328
|
generic_utils.validate_kwargs(kwargs, allowed_kwargs)
|
|
328
329
|
|
|
330
|
+
# Track the built-in call-context arguments. These are arguments that
|
|
331
|
+
# are tracked and propagated across the call-stack by default.
|
|
332
|
+
self._call_context_args = {"training"}
|
|
333
|
+
|
|
329
334
|
# Mutable properties
|
|
330
335
|
# Indicates whether the layer's weights are updated during training
|
|
331
336
|
# and whether the layer's updates are run during training.
|
|
@@ -411,6 +416,9 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
411
416
|
|
|
412
417
|
self._init_call_fn_args()
|
|
413
418
|
|
|
419
|
+
# Track the built-in call-context arguments.
|
|
420
|
+
self._call_spec._update_call_context_arguments(self._call_context_args)
|
|
421
|
+
|
|
414
422
|
# Whether the `call` method can be used to build a TF graph without
|
|
415
423
|
# issues. This attribute has no effect if the model is created using
|
|
416
424
|
# the Functional API. Instead, `model.dynamic` is determined based on
|
|
@@ -1042,6 +1050,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
1042
1050
|
# - input_spec compatibility is only checked against `inputs`
|
|
1043
1051
|
# - mixed precision casting (autocast) is only applied to `inputs`,
|
|
1044
1052
|
# not to any other argument.
|
|
1053
|
+
self._called = True
|
|
1045
1054
|
inputs, args, kwargs = self._call_spec.split_out_first_arg(args, kwargs)
|
|
1046
1055
|
input_list = tf.nest.flatten(inputs)
|
|
1047
1056
|
|
|
@@ -1080,17 +1089,21 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
1080
1089
|
if self._expects_mask_arg and mask_is_implicit:
|
|
1081
1090
|
kwargs["mask"] = input_masks
|
|
1082
1091
|
|
|
1083
|
-
#
|
|
1084
|
-
#
|
|
1085
|
-
#
|
|
1086
|
-
# (2) The
|
|
1087
|
-
# (3) The default mode set by
|
|
1088
|
-
#
|
|
1089
|
-
# (4) Any non-None default value for
|
|
1092
|
+
# Call-context arguments for `Layer.call` is set via (in order of
|
|
1093
|
+
# priority):
|
|
1094
|
+
# (1) The argument passed to this `Layer.call`, if it is not None
|
|
1095
|
+
# (2) The argument value of an outer `Layer.call`.
|
|
1096
|
+
# (3) (only for "training") The default mode set by
|
|
1097
|
+
# `tf.keras.backend.set_learning_phase` (if set)
|
|
1098
|
+
# (4) Any non-None default value for the argument specified in the call
|
|
1090
1099
|
# signature
|
|
1091
|
-
# (5) False
|
|
1092
|
-
|
|
1093
|
-
args,
|
|
1100
|
+
# (5) False
|
|
1101
|
+
(
|
|
1102
|
+
args,
|
|
1103
|
+
kwargs,
|
|
1104
|
+
propagated,
|
|
1105
|
+
) = self._get_propagated_call_context_arguments(
|
|
1106
|
+
args, kwargs, call_context, self._call_context_args
|
|
1094
1107
|
)
|
|
1095
1108
|
|
|
1096
1109
|
# Losses are cleared for all sublayers on the outermost `Layer.call`.
|
|
@@ -1104,7 +1117,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
1104
1117
|
layer=self,
|
|
1105
1118
|
inputs=inputs,
|
|
1106
1119
|
build_graph=not eager,
|
|
1107
|
-
|
|
1120
|
+
call_context_args=propagated,
|
|
1108
1121
|
):
|
|
1109
1122
|
input_spec.assert_input_compatibility(
|
|
1110
1123
|
self.input_spec, inputs, self.name
|
|
@@ -1152,6 +1165,55 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
1152
1165
|
|
|
1153
1166
|
return outputs
|
|
1154
1167
|
|
|
1168
|
+
def _register_call_context_args(self, *argument_names):
|
|
1169
|
+
"""Registers call-context args for this layer.
|
|
1170
|
+
If this layer declares a `call()` method that accepts
|
|
1171
|
+
one or more of the given args, those args will be
|
|
1172
|
+
automatically injected into the call signature of this
|
|
1173
|
+
layer. This layer will also propagate the args to any
|
|
1174
|
+
nested sublayers that are called from within this layer.
|
|
1175
|
+
If this layer doesn't declare a `call()` method that
|
|
1176
|
+
accepts one or more of the given args, these args will
|
|
1177
|
+
simply be propagated to any nested sublayers without
|
|
1178
|
+
being injected into the call signature of this layer.
|
|
1179
|
+
This is useful for propagating custom arguments
|
|
1180
|
+
from top-level layers/models to sublayers.
|
|
1181
|
+
Example:
|
|
1182
|
+
```
|
|
1183
|
+
class Inner(layers.Layer):
|
|
1184
|
+
def __init__(self):
|
|
1185
|
+
super().__init__()
|
|
1186
|
+
# Register `foo_mode` as a call-context arg
|
|
1187
|
+
self._register_call_context_args("foo_mode")
|
|
1188
|
+
def call(self, x, foo_mode=False):
|
|
1189
|
+
# If foo_mode=True add 1, otherwise add 0
|
|
1190
|
+
add_val = ops.where(foo_mode, 1.0, 0.0)
|
|
1191
|
+
return x + add_val
|
|
1192
|
+
class Outer(layers.Layer):
|
|
1193
|
+
def __init__(self):
|
|
1194
|
+
super().__init__()
|
|
1195
|
+
self.inner = Inner()
|
|
1196
|
+
def call(self, x):
|
|
1197
|
+
# We don't explicitly pass foo_mode here—Base Layer.__call__
|
|
1198
|
+
# should inject it into `self.inner`
|
|
1199
|
+
return self.inner(x)
|
|
1200
|
+
sample_input = np.array([[1.0], [2.0]])
|
|
1201
|
+
# Sequential model
|
|
1202
|
+
seq = models.Sequential([Outer()])
|
|
1203
|
+
# Tell the Sequential model to propagate foo_mode down
|
|
1204
|
+
# the call-stack
|
|
1205
|
+
seq.register_call_context_args("foo_mode")
|
|
1206
|
+
# foo_mode=True -> input + 1
|
|
1207
|
+
out_true = seq(sample_input, foo_mode=True)
|
|
1208
|
+
"""
|
|
1209
|
+
if self._called:
|
|
1210
|
+
raise RuntimeError(
|
|
1211
|
+
"Cannot add call-context args after the layer has been called."
|
|
1212
|
+
)
|
|
1213
|
+
self._call_context_args |= set(argument_names)
|
|
1214
|
+
self._call_spec._update_call_context_arguments(argument_names)
|
|
1215
|
+
self._call_spec._update_call_context_argument_defaults(argument_names)
|
|
1216
|
+
|
|
1155
1217
|
def _get_unnested_name_scope(self):
|
|
1156
1218
|
if _is_name_scope_on_model_declaration_enabled:
|
|
1157
1219
|
with _name_scope_unnester(
|
|
@@ -2321,7 +2383,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
2321
2383
|
"""
|
|
2322
2384
|
input_shape = config["input_shape"]
|
|
2323
2385
|
if input_shape is not None:
|
|
2324
|
-
self.build(input_shape)
|
|
2386
|
+
self.build(tf_utils.convert_shapes(input_shape, to_tuples=False))
|
|
2325
2387
|
|
|
2326
2388
|
############################################################################
|
|
2327
2389
|
# Methods & attributes below are all private and only used by the framework.
|
|
@@ -2535,47 +2597,57 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
2535
2597
|
kwargs["mask"] = input_masks
|
|
2536
2598
|
mask_arg_passed_by_framework = True
|
|
2537
2599
|
|
|
2538
|
-
|
|
2539
|
-
|
|
2540
|
-
|
|
2541
|
-
|
|
2542
|
-
|
|
2543
|
-
|
|
2544
|
-
|
|
2545
|
-
|
|
2546
|
-
)
|
|
2547
|
-
|
|
2548
|
-
|
|
2549
|
-
|
|
2550
|
-
|
|
2551
|
-
|
|
2552
|
-
|
|
2553
|
-
|
|
2554
|
-
|
|
2555
|
-
|
|
2556
|
-
|
|
2557
|
-
|
|
2558
|
-
|
|
2559
|
-
|
|
2560
|
-
|
|
2600
|
+
propagated = dict()
|
|
2601
|
+
args_passed_by_framework = dict()
|
|
2602
|
+
for context_arg in self._call_context_args:
|
|
2603
|
+
# If `training` argument is None or not explicitly passed,
|
|
2604
|
+
# propagate `training` value from this layer's calling layer.
|
|
2605
|
+
value = None
|
|
2606
|
+
args_passed_by_framework[context_arg] = False
|
|
2607
|
+
# Priority 1: `training` was explicitly passed a non-None value.
|
|
2608
|
+
if self._call_spec.arg_was_passed(context_arg, args, kwargs):
|
|
2609
|
+
value = self._call_spec.get_arg_value(context_arg, args, kwargs)
|
|
2610
|
+
if not self._expects_context_arg(context_arg):
|
|
2611
|
+
kwargs.pop(context_arg)
|
|
2612
|
+
|
|
2613
|
+
if value is None:
|
|
2614
|
+
# Priority 2: `training` was passed to a parent layer.
|
|
2615
|
+
if call_context.get_call_context_arg(context_arg) is not None:
|
|
2616
|
+
value = call_context.get_call_context_arg(context_arg)
|
|
2617
|
+
# Priority 3: `learning_phase()` has been set.
|
|
2618
|
+
elif (
|
|
2619
|
+
context_arg == "training"
|
|
2620
|
+
and backend.global_learning_phase_is_set()
|
|
2621
|
+
):
|
|
2622
|
+
value = backend.learning_phase()
|
|
2623
|
+
# Force the training_value to be bool type which matches to
|
|
2624
|
+
# the contract for layer/model call args.
|
|
2625
|
+
if tf.is_tensor(value):
|
|
2626
|
+
value = tf.cast(value, tf.bool)
|
|
2627
|
+
else:
|
|
2628
|
+
value = bool(value)
|
|
2629
|
+
# Priority 4: trace layer with the default training argument
|
|
2630
|
+
# specified in the `call` signature (or in inference mode if the
|
|
2631
|
+
# `call` signature specifies no non-None default).
|
|
2561
2632
|
else:
|
|
2562
|
-
|
|
2563
|
-
|
|
2564
|
-
|
|
2565
|
-
|
|
2566
|
-
|
|
2567
|
-
|
|
2568
|
-
|
|
2569
|
-
|
|
2570
|
-
|
|
2571
|
-
|
|
2572
|
-
|
|
2573
|
-
|
|
2574
|
-
)
|
|
2575
|
-
training_arg_passed_by_framework = True
|
|
2633
|
+
value = self._call_spec.get_context_arg_default(context_arg)
|
|
2634
|
+
# In cases (2), (3), (4) the training argument is passed
|
|
2635
|
+
# automatically by the framework, and will not be hard-coded
|
|
2636
|
+
# into the model.
|
|
2637
|
+
if self._expects_context_arg(context_arg):
|
|
2638
|
+
args, kwargs = self._call_spec.set_arg_value(
|
|
2639
|
+
context_arg, value, args, kwargs
|
|
2640
|
+
)
|
|
2641
|
+
args_passed_by_framework[context_arg] = True
|
|
2642
|
+
|
|
2643
|
+
if value is not None:
|
|
2644
|
+
propagated[context_arg] = value
|
|
2576
2645
|
|
|
2577
2646
|
with call_context.enter(
|
|
2578
|
-
layer=self,
|
|
2647
|
+
layer=self,
|
|
2648
|
+
inputs=inputs,
|
|
2649
|
+
build_graph=True,
|
|
2650
|
+
call_context_args=propagated,
|
|
2579
2651
|
):
|
|
2580
2652
|
# Check input assumptions set after layer building, e.g. input
|
|
2581
2653
|
# shape.
|
|
@@ -2601,10 +2673,13 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
2601
2673
|
"Tensor or a list of Tensors, not None "
|
|
2602
2674
|
"(layer: " + self.name + ")."
|
|
2603
2675
|
)
|
|
2604
|
-
|
|
2605
|
-
|
|
2606
|
-
|
|
2607
|
-
|
|
2676
|
+
|
|
2677
|
+
for context_arg, is_passed in args_passed_by_framework.items():
|
|
2678
|
+
if is_passed:
|
|
2679
|
+
args, kwargs = self._call_spec.set_arg_value(
|
|
2680
|
+
context_arg, None, args, kwargs, pop_kwarg_if_none=True
|
|
2681
|
+
)
|
|
2682
|
+
|
|
2608
2683
|
if mask_arg_passed_by_framework:
|
|
2609
2684
|
kwargs.pop("mask")
|
|
2610
2685
|
# Node connectivity does not special-case the first argument.
|
|
@@ -2613,52 +2688,100 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
2613
2688
|
)
|
|
2614
2689
|
return outputs
|
|
2615
2690
|
|
|
2616
|
-
def
|
|
2617
|
-
|
|
2618
|
-
|
|
2619
|
-
|
|
2620
|
-
|
|
2621
|
-
|
|
2622
|
-
|
|
2623
|
-
|
|
2624
|
-
|
|
2625
|
-
|
|
2626
|
-
|
|
2627
|
-
|
|
2628
|
-
|
|
2629
|
-
|
|
2630
|
-
|
|
2631
|
-
|
|
2632
|
-
|
|
2633
|
-
|
|
2634
|
-
|
|
2635
|
-
|
|
2636
|
-
|
|
2637
|
-
|
|
2638
|
-
|
|
2691
|
+
def _get_propagated_call_context_arguments(
|
|
2692
|
+
self, args, kwargs, call_context, local_call_context_arguments
|
|
2693
|
+
):
|
|
2694
|
+
"""Resolves the values for propagated call context arguments for the
|
|
2695
|
+
current layer.
|
|
2696
|
+
|
|
2697
|
+
Args:
|
|
2698
|
+
args: The arguments passed to the current layer's `call` method.
|
|
2699
|
+
kwargs: The keyword arguments passed to the current layer's `call`
|
|
2700
|
+
method.
|
|
2701
|
+
call_context: The `CallContext` for the current call-stack.
|
|
2702
|
+
local_call_context_arguments: The call-context arguments registered
|
|
2703
|
+
with to the current layer's `Layer.call` method.
|
|
2704
|
+
|
|
2705
|
+
Returns:
|
|
2706
|
+
A tuple of the following:
|
|
2707
|
+
1. Updated args
|
|
2708
|
+
2. Updated kwargs
|
|
2709
|
+
3. A dictionary of the resolved call-context arguments that should
|
|
2710
|
+
be propagated to the next layer in the call-stack.
|
|
2711
|
+
"""
|
|
2712
|
+
propagated_context = dict()
|
|
2713
|
+
relevant_arguments = call_context.call_context_args.keys() | set(
|
|
2714
|
+
local_call_context_arguments
|
|
2715
|
+
)
|
|
2716
|
+
|
|
2717
|
+
for argument in relevant_arguments:
|
|
2718
|
+
authoritative_value = None
|
|
2719
|
+
was_explicitly_passed = self._call_spec.arg_was_passed(
|
|
2720
|
+
argument, args, kwargs
|
|
2721
|
+
)
|
|
2722
|
+
if self._expects_context_arg(argument):
|
|
2723
|
+
# (1) `arg_name` was passed to this `Layer.call`.
|
|
2724
|
+
if was_explicitly_passed:
|
|
2725
|
+
authoritative_value = self._call_spec.get_arg_value(
|
|
2726
|
+
argument, args, kwargs
|
|
2727
|
+
)
|
|
2728
|
+
# If no `arg_name` arg was passed, or `None` was explicitly
|
|
2729
|
+
# passed, the framework will make a decision about the training
|
|
2730
|
+
# mode is.
|
|
2731
|
+
if authoritative_value is None:
|
|
2732
|
+
value_from_context = call_context.get_call_context_arg(
|
|
2733
|
+
argument
|
|
2734
|
+
)
|
|
2735
|
+
# (2) `arg_name` mode is inferred from an outer
|
|
2736
|
+
# `Layer.call`.
|
|
2737
|
+
if value_from_context is not None:
|
|
2738
|
+
authoritative_value = value_from_context
|
|
2739
|
+
# (3) User set `tf.keras.backend.set_learning_phase`.
|
|
2740
|
+
elif (
|
|
2741
|
+
argument == "training"
|
|
2742
|
+
and backend.global_learning_phase_is_set()
|
|
2743
|
+
):
|
|
2744
|
+
authoritative_value = backend.learning_phase()
|
|
2745
|
+
# Ensure value is a `bool` or `tf.bool`.
|
|
2746
|
+
if isinstance(authoritative_value, bool):
|
|
2747
|
+
pass
|
|
2748
|
+
elif tf.is_tensor(authoritative_value):
|
|
2749
|
+
authoritative_value = tf.cast(
|
|
2750
|
+
authoritative_value, tf.bool
|
|
2751
|
+
)
|
|
2752
|
+
else:
|
|
2753
|
+
authoritative_value = bool(authoritative_value)
|
|
2754
|
+
# (4) We default to using `call`'s default value for
|
|
2755
|
+
# `arg_name`, or treating the layer as if it is in inference
|
|
2756
|
+
# if no non-None default is specified in the `call`
|
|
2757
|
+
# signature.
|
|
2639
2758
|
else:
|
|
2640
|
-
|
|
2641
|
-
|
|
2642
|
-
|
|
2643
|
-
# default is specified in the `call` signature.
|
|
2644
|
-
else:
|
|
2645
|
-
training_mode = self._call_spec.default_training_arg
|
|
2759
|
+
authoritative_value = (
|
|
2760
|
+
self._call_spec.get_context_arg_default(argument)
|
|
2761
|
+
)
|
|
2646
2762
|
|
|
2647
|
-
|
|
2648
|
-
|
|
2649
|
-
|
|
2650
|
-
|
|
2651
|
-
|
|
2652
|
-
if "training" in kwargs:
|
|
2653
|
-
# `training` was passed to this `Layer` but is not needed for
|
|
2654
|
-
# `Layer.call`. It will set the default mode for inner
|
|
2655
|
-
# `Layer.call`s.
|
|
2656
|
-
training_mode = kwargs.pop("training")
|
|
2763
|
+
# For case (2), (3), (4) `arg_name` arg is passed by
|
|
2764
|
+
# framework.
|
|
2765
|
+
args, kwargs = self._call_spec.set_arg_value(
|
|
2766
|
+
argument, authoritative_value, args, kwargs
|
|
2767
|
+
)
|
|
2657
2768
|
else:
|
|
2658
|
-
|
|
2659
|
-
|
|
2769
|
+
if argument in kwargs:
|
|
2770
|
+
# `arg_name` was passed to this `Layer` but is not needed
|
|
2771
|
+
# for `Layer.call`. It will set the default mode for inner
|
|
2772
|
+
# `Layer.call`s.
|
|
2773
|
+
authoritative_value = kwargs.pop(argument)
|
|
2774
|
+
else:
|
|
2775
|
+
# Grab the current `arg_name` mode from any outer
|
|
2776
|
+
# `Layer.call`.
|
|
2777
|
+
authoritative_value = call_context.get_call_context_arg(
|
|
2778
|
+
argument
|
|
2779
|
+
)
|
|
2780
|
+
|
|
2781
|
+
if authoritative_value is not None:
|
|
2782
|
+
propagated_context[argument] = authoritative_value
|
|
2660
2783
|
|
|
2661
|
-
return args, kwargs,
|
|
2784
|
+
return args, kwargs, propagated_context
|
|
2662
2785
|
|
|
2663
2786
|
def _autographed_call(self):
|
|
2664
2787
|
# Wrapping `call` function in autograph to allow for dynamic control
|
|
@@ -3351,7 +3474,8 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
3351
3474
|
|
|
3352
3475
|
def _init_call_fn_args(self, expects_training_arg=None):
|
|
3353
3476
|
self._call_spec = layer_utils.CallFunctionSpec(
|
|
3354
|
-
tf_inspect.getfullargspec(self.call)
|
|
3477
|
+
tf_inspect.getfullargspec(self.call),
|
|
3478
|
+
getattr(self, "_call_context_args", set()),
|
|
3355
3479
|
)
|
|
3356
3480
|
if expects_training_arg is not None:
|
|
3357
3481
|
self._call_spec.expects_training_arg = expects_training_arg
|
|
@@ -3361,6 +3485,9 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
3361
3485
|
"""Whether the call function uses 'training' as a parameter."""
|
|
3362
3486
|
return self._call_spec.expects_training_arg
|
|
3363
3487
|
|
|
3488
|
+
def _expects_context_arg(self, argument_name):
|
|
3489
|
+
return argument_name in self._call_spec.expected_context_args
|
|
3490
|
+
|
|
3364
3491
|
@property
|
|
3365
3492
|
def _expects_mask_arg(self):
|
|
3366
3493
|
return self._call_spec.expects_mask_arg
|
|
@@ -3805,6 +3932,17 @@ class BaseRandomLayer(Layer):
|
|
|
3805
3932
|
super().build(input_shape)
|
|
3806
3933
|
self._random_generator._maybe_init()
|
|
3807
3934
|
|
|
3935
|
+
def get_config(self):
|
|
3936
|
+
base_config = super().get_config()
|
|
3937
|
+
if (
|
|
3938
|
+
self._random_generator._rng_type
|
|
3939
|
+
== backend.RandomGenerator.RNG_LEGACY_STATEFUL
|
|
3940
|
+
):
|
|
3941
|
+
return base_config
|
|
3942
|
+
|
|
3943
|
+
config = {"rng_type": self._random_generator._rng_type}
|
|
3944
|
+
return dict(list(base_config.items()) + list(config.items()))
|
|
3945
|
+
|
|
3808
3946
|
def _trackable_children(self, save_type="checkpoint", **kwargs):
|
|
3809
3947
|
if save_type == "savedmodel":
|
|
3810
3948
|
cache = kwargs["cache"]
|
|
@@ -480,7 +480,8 @@ class CallContext:
|
|
|
480
480
|
layer: The `Layer` whose `call` is currently active.
|
|
481
481
|
inputs: The inputs to the currently active `Layer`.
|
|
482
482
|
build_graph: Whether currently inside a Graph or FuncGraph.
|
|
483
|
-
|
|
483
|
+
call_context_args: The call-context arguments being propagated through the
|
|
484
|
+
the call-stack.
|
|
484
485
|
saving: Whether currently saving to SavedModel.
|
|
485
486
|
frozen: Whether currently executing inside a `Layer` with `trainable` set
|
|
486
487
|
to `False`.
|
|
@@ -495,6 +496,7 @@ class CallContext:
|
|
|
495
496
|
"layer": None,
|
|
496
497
|
"inputs": None,
|
|
497
498
|
"build_graph": False,
|
|
499
|
+
"call_context_args": dict(),
|
|
498
500
|
"training": None,
|
|
499
501
|
"saving": None,
|
|
500
502
|
}
|
|
@@ -502,14 +504,17 @@ class CallContext:
|
|
|
502
504
|
# refactor.
|
|
503
505
|
self._in_keras_graph = False
|
|
504
506
|
|
|
505
|
-
def enter(
|
|
507
|
+
def enter(
|
|
508
|
+
self, layer, inputs, build_graph, call_context_args=dict(), saving=None
|
|
509
|
+
):
|
|
506
510
|
"""Push a Layer and its inputs and state onto the current call context.
|
|
507
511
|
|
|
508
512
|
Args:
|
|
509
513
|
layer: The `Layer` whose `call` is currently active.
|
|
510
514
|
inputs: The inputs to the currently active `Layer`.
|
|
511
515
|
build_graph: Whether currently inside a Graph or FuncGraph.
|
|
512
|
-
|
|
516
|
+
call_context_args: The call-context arguments being propagated through
|
|
517
|
+
the call-stack.
|
|
513
518
|
saving: Whether currently saving to SavedModel.
|
|
514
519
|
|
|
515
520
|
Returns:
|
|
@@ -519,7 +524,7 @@ class CallContext:
|
|
|
519
524
|
"layer": layer,
|
|
520
525
|
"inputs": inputs,
|
|
521
526
|
"build_graph": build_graph,
|
|
522
|
-
"
|
|
527
|
+
"call_context_args": call_context_args,
|
|
523
528
|
"saving": saving,
|
|
524
529
|
}
|
|
525
530
|
return CallContextManager(self, state)
|
|
@@ -538,7 +543,14 @@ class CallContext:
|
|
|
538
543
|
|
|
539
544
|
@property
|
|
540
545
|
def training(self):
|
|
541
|
-
return self.
|
|
546
|
+
return self.call_context_args.get("training", None)
|
|
547
|
+
|
|
548
|
+
@property
|
|
549
|
+
def call_context_args(self):
|
|
550
|
+
return self._state["call_context_args"]
|
|
551
|
+
|
|
552
|
+
def get_call_context_arg(self, arg_name):
|
|
553
|
+
return self.call_context_args.get(arg_name, None)
|
|
542
554
|
|
|
543
555
|
@property
|
|
544
556
|
def saving(self):
|
|
@@ -132,6 +132,7 @@ class Layer(base_layer.Layer):
|
|
|
132
132
|
self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs
|
|
133
133
|
):
|
|
134
134
|
self._instrument_layer_creation()
|
|
135
|
+
self._called = False
|
|
135
136
|
|
|
136
137
|
# These properties should be set by the user via keyword arguments.
|
|
137
138
|
# note that 'dtype', 'input_shape' and 'batch_input_shape'
|
|
@@ -165,6 +166,8 @@ class Layer(base_layer.Layer):
|
|
|
165
166
|
self._input_spec = None
|
|
166
167
|
self.supports_masking = False
|
|
167
168
|
|
|
169
|
+
self._call_context_args = {"training"}
|
|
170
|
+
|
|
168
171
|
self._init_set_name(name)
|
|
169
172
|
self._activity_regularizer = regularizers.get(
|
|
170
173
|
kwargs.pop("activity_regularizer", None)
|
|
@@ -705,6 +708,7 @@ class Layer(base_layer.Layer):
|
|
|
705
708
|
RuntimeError: if `super().__init__()` was not called in the
|
|
706
709
|
constructor.
|
|
707
710
|
"""
|
|
711
|
+
self._called = True
|
|
708
712
|
self._assert_built_as_v1()
|
|
709
713
|
|
|
710
714
|
if not hasattr(self, "_thread_local"):
|
|
@@ -803,7 +807,12 @@ class Layer(base_layer.Layer):
|
|
|
803
807
|
if build_graph and base_layer_utils.needs_keras_history(inputs):
|
|
804
808
|
base_layer_utils.create_keras_history(inputs)
|
|
805
809
|
|
|
806
|
-
with call_context.enter(
|
|
810
|
+
with call_context.enter(
|
|
811
|
+
self,
|
|
812
|
+
inputs,
|
|
813
|
+
build_graph,
|
|
814
|
+
call_context_args={"training": training_value},
|
|
815
|
+
):
|
|
807
816
|
# Check input assumptions set after layer building, e.g. input
|
|
808
817
|
# shape.
|
|
809
818
|
if build_graph:
|
|
@@ -2177,8 +2186,8 @@ class Layer(base_layer.Layer):
|
|
|
2177
2186
|
else:
|
|
2178
2187
|
self._set_dtype_policy(policy.Policy(dtype))
|
|
2179
2188
|
input_shapes = None
|
|
2180
|
-
if
|
|
2181
|
-
input_shapes =
|
|
2189
|
+
if any(hasattr(x, "shape") for x in input_list):
|
|
2190
|
+
input_shapes = tf_utils.get_shapes(inputs)
|
|
2182
2191
|
# Only call `build` if the user has manually overridden the build
|
|
2183
2192
|
# method.
|
|
2184
2193
|
if not hasattr(self.build, "_is_default"):
|