tf-keras-nightly 2.20.0.dev2025051109__py3-none-any.whl → 2.21.0.dev2025123010__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tf_keras/__init__.py +1 -1
- tf_keras/protobuf/projector_config_pb2.py +23 -12
- tf_keras/protobuf/saved_metadata_pb2.py +21 -10
- tf_keras/protobuf/versions_pb2.py +19 -8
- tf_keras/src/__init__.py +1 -1
- tf_keras/src/engine/base_layer.py +234 -96
- tf_keras/src/engine/base_layer_utils.py +17 -5
- tf_keras/src/engine/base_layer_v1.py +12 -3
- tf_keras/src/engine/data_adapter.py +30 -13
- tf_keras/src/engine/functional.py +36 -15
- tf_keras/src/engine/input_layer.py +9 -0
- tf_keras/src/engine/input_spec.py +11 -1
- tf_keras/src/layers/activation/softmax.py +26 -11
- tf_keras/src/layers/attention/multi_head_attention.py +8 -1
- tf_keras/src/layers/core/tf_op_layer.py +4 -0
- tf_keras/src/layers/rnn/cell_wrappers.py +13 -1
- tf_keras/src/metrics/confusion_metrics.py +51 -4
- tf_keras/src/models/sharpness_aware_minimization.py +17 -7
- tf_keras/src/saving/legacy/saved_model/save_impl.py +28 -12
- tf_keras/src/saving/legacy/saving_utils.py +14 -2
- tf_keras/src/saving/saving_lib.py +1 -1
- tf_keras/src/utils/layer_utils.py +45 -3
- tf_keras/src/utils/metrics_utils.py +4 -1
- {tf_keras_nightly-2.20.0.dev2025051109.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/METADATA +2 -2
- {tf_keras_nightly-2.20.0.dev2025051109.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/RECORD +27 -49
- {tf_keras_nightly-2.20.0.dev2025051109.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/WHEEL +1 -1
- tf_keras/src/layers/preprocessing/benchmarks/bucketized_column_dense_benchmark.py +0 -85
- tf_keras/src/layers/preprocessing/benchmarks/category_encoding_benchmark.py +0 -84
- tf_keras/src/layers/preprocessing/benchmarks/category_hash_dense_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/category_hash_varlen_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_dense_benchmark.py +0 -110
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_varlen_benchmark.py +0 -103
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_dense_benchmark.py +0 -87
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_dense_benchmark.py +0 -96
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_varlen_benchmark.py +0 -96
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_varlen_benchmark.py +0 -87
- tf_keras/src/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py +0 -109
- tf_keras/src/layers/preprocessing/benchmarks/embedding_dense_benchmark.py +0 -86
- tf_keras/src/layers/preprocessing/benchmarks/embedding_varlen_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/hashed_crossing_benchmark.py +0 -90
- tf_keras/src/layers/preprocessing/benchmarks/hashing_benchmark.py +0 -105
- tf_keras/src/layers/preprocessing/benchmarks/image_preproc_benchmark.py +0 -159
- tf_keras/src/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py +0 -135
- tf_keras/src/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py +0 -144
- tf_keras/src/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py +0 -124
- tf_keras/src/layers/preprocessing/benchmarks/weighted_embedding_varlen_benchmark.py +0 -99
- tf_keras/src/saving/legacy/saved_model/create_test_saved_model.py +0 -37
- tf_keras/src/tests/keras_doctest.py +0 -159
- {tf_keras_nightly-2.20.0.dev2025051109.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/top_level.txt +0 -0
tf_keras/__init__.py
CHANGED
|
@@ -1,11 +1,22 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
3
4
|
# source: tf_keras/protobuf/projector_config.proto
|
|
5
|
+
# Protobuf Python Version: 6.31.1
|
|
4
6
|
"""Generated protocol buffer code."""
|
|
5
|
-
from google.protobuf.internal import builder as _builder
|
|
6
7
|
from google.protobuf import descriptor as _descriptor
|
|
7
8
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
8
10
|
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
6,
|
|
15
|
+
31,
|
|
16
|
+
1,
|
|
17
|
+
'',
|
|
18
|
+
'tf_keras/protobuf/projector_config.proto'
|
|
19
|
+
)
|
|
9
20
|
# @@protoc_insertion_point(imports)
|
|
10
21
|
|
|
11
22
|
_sym_db = _symbol_database.Default()
|
|
@@ -15,15 +26,15 @@ _sym_db = _symbol_database.Default()
|
|
|
15
26
|
|
|
16
27
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(tf_keras/protobuf/projector_config.proto\x12 third_party.py.tf_keras.protobuf\">\n\x0eSpriteMetadata\x12\x12\n\nimage_path\x18\x01 \x01(\t\x12\x18\n\x10single_image_dim\x18\x02 \x03(\r\"\xc0\x01\n\rEmbeddingInfo\x12\x13\n\x0btensor_name\x18\x01 \x01(\t\x12\x15\n\rmetadata_path\x18\x02 \x01(\t\x12\x16\n\x0e\x62ookmarks_path\x18\x03 \x01(\t\x12\x14\n\x0ctensor_shape\x18\x04 \x03(\r\x12@\n\x06sprite\x18\x05 \x01(\x0b\x32\x30.third_party.py.tf_keras.protobuf.SpriteMetadata\x12\x13\n\x0btensor_path\x18\x06 \x01(\t\"\x93\x01\n\x0fProjectorConfig\x12\x1d\n\x15model_checkpoint_path\x18\x01 \x01(\t\x12\x43\n\nembeddings\x18\x02 \x03(\x0b\x32/.third_party.py.tf_keras.protobuf.EmbeddingInfo\x12\x1c\n\x14model_checkpoint_dir\x18\x03 \x01(\tb\x06proto3')
|
|
17
28
|
|
|
18
|
-
|
|
19
|
-
_builder.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
DESCRIPTOR.
|
|
23
|
-
_SPRITEMETADATA._serialized_start=78
|
|
24
|
-
_SPRITEMETADATA._serialized_end=140
|
|
25
|
-
_EMBEDDINGINFO._serialized_start=143
|
|
26
|
-
_EMBEDDINGINFO._serialized_end=335
|
|
27
|
-
_PROJECTORCONFIG._serialized_start=338
|
|
28
|
-
_PROJECTORCONFIG._serialized_end=485
|
|
29
|
+
_globals = globals()
|
|
30
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
31
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.projector_config_pb2', _globals)
|
|
32
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
33
|
+
DESCRIPTOR._loaded_options = None
|
|
34
|
+
_globals['_SPRITEMETADATA']._serialized_start=78
|
|
35
|
+
_globals['_SPRITEMETADATA']._serialized_end=140
|
|
36
|
+
_globals['_EMBEDDINGINFO']._serialized_start=143
|
|
37
|
+
_globals['_EMBEDDINGINFO']._serialized_end=335
|
|
38
|
+
_globals['_PROJECTORCONFIG']._serialized_start=338
|
|
39
|
+
_globals['_PROJECTORCONFIG']._serialized_end=485
|
|
29
40
|
# @@protoc_insertion_point(module_scope)
|
|
@@ -1,11 +1,22 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
3
4
|
# source: tf_keras/protobuf/saved_metadata.proto
|
|
5
|
+
# Protobuf Python Version: 6.31.1
|
|
4
6
|
"""Generated protocol buffer code."""
|
|
5
|
-
from google.protobuf.internal import builder as _builder
|
|
6
7
|
from google.protobuf import descriptor as _descriptor
|
|
7
8
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
8
10
|
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
6,
|
|
15
|
+
31,
|
|
16
|
+
1,
|
|
17
|
+
'',
|
|
18
|
+
'tf_keras/protobuf/saved_metadata.proto'
|
|
19
|
+
)
|
|
9
20
|
# @@protoc_insertion_point(imports)
|
|
10
21
|
|
|
11
22
|
_sym_db = _symbol_database.Default()
|
|
@@ -16,13 +27,13 @@ from tf_keras.protobuf import versions_pb2 as tf__keras_dot_protobuf_dot_version
|
|
|
16
27
|
|
|
17
28
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&tf_keras/protobuf/saved_metadata.proto\x12 third_party.py.tf_keras.protobuf\x1a tf_keras/protobuf/versions.proto\"M\n\rSavedMetadata\x12<\n\x05nodes\x18\x01 \x03(\x0b\x32-.third_party.py.tf_keras.protobuf.SavedObject\"\x9c\x01\n\x0bSavedObject\x12\x0f\n\x07node_id\x18\x02 \x01(\x05\x12\x11\n\tnode_path\x18\x03 \x01(\t\x12\x12\n\nidentifier\x18\x04 \x01(\t\x12\x10\n\x08metadata\x18\x05 \x01(\t\x12=\n\x07version\x18\x06 \x01(\x0b\x32,.third_party.py.tf_keras.protobuf.VersionDefJ\x04\x08\x01\x10\x02\x62\x06proto3')
|
|
18
29
|
|
|
19
|
-
|
|
20
|
-
_builder.
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
DESCRIPTOR.
|
|
24
|
-
_SAVEDMETADATA._serialized_start=110
|
|
25
|
-
_SAVEDMETADATA._serialized_end=187
|
|
26
|
-
_SAVEDOBJECT._serialized_start=190
|
|
27
|
-
_SAVEDOBJECT._serialized_end=346
|
|
30
|
+
_globals = globals()
|
|
31
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
32
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.saved_metadata_pb2', _globals)
|
|
33
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
34
|
+
DESCRIPTOR._loaded_options = None
|
|
35
|
+
_globals['_SAVEDMETADATA']._serialized_start=110
|
|
36
|
+
_globals['_SAVEDMETADATA']._serialized_end=187
|
|
37
|
+
_globals['_SAVEDOBJECT']._serialized_start=190
|
|
38
|
+
_globals['_SAVEDOBJECT']._serialized_end=346
|
|
28
39
|
# @@protoc_insertion_point(module_scope)
|
|
@@ -1,11 +1,22 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
3
4
|
# source: tf_keras/protobuf/versions.proto
|
|
5
|
+
# Protobuf Python Version: 6.31.1
|
|
4
6
|
"""Generated protocol buffer code."""
|
|
5
|
-
from google.protobuf.internal import builder as _builder
|
|
6
7
|
from google.protobuf import descriptor as _descriptor
|
|
7
8
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
8
10
|
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
6,
|
|
15
|
+
31,
|
|
16
|
+
1,
|
|
17
|
+
'',
|
|
18
|
+
'tf_keras/protobuf/versions.proto'
|
|
19
|
+
)
|
|
9
20
|
# @@protoc_insertion_point(imports)
|
|
10
21
|
|
|
11
22
|
_sym_db = _symbol_database.Default()
|
|
@@ -15,11 +26,11 @@ _sym_db = _symbol_database.Default()
|
|
|
15
26
|
|
|
16
27
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n tf_keras/protobuf/versions.proto\x12 third_party.py.tf_keras.protobuf\"K\n\nVersionDef\x12\x10\n\x08producer\x18\x01 \x01(\x05\x12\x14\n\x0cmin_consumer\x18\x02 \x01(\x05\x12\x15\n\rbad_consumers\x18\x03 \x03(\x05\x62\x06proto3')
|
|
17
28
|
|
|
18
|
-
|
|
19
|
-
_builder.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
DESCRIPTOR.
|
|
23
|
-
_VERSIONDEF._serialized_start=70
|
|
24
|
-
_VERSIONDEF._serialized_end=145
|
|
29
|
+
_globals = globals()
|
|
30
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
31
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.versions_pb2', _globals)
|
|
32
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
33
|
+
DESCRIPTOR._loaded_options = None
|
|
34
|
+
_globals['_VERSIONDEF']._serialized_start=70
|
|
35
|
+
_globals['_VERSIONDEF']._serialized_end=145
|
|
25
36
|
# @@protoc_insertion_point(module_scope)
|
tf_keras/src/__init__.py
CHANGED
|
@@ -35,7 +35,7 @@ from tf_keras.src.testing_infra import test_utils
|
|
|
35
35
|
from tensorflow.python import tf2
|
|
36
36
|
from tensorflow.python.util.tf_export import keras_export
|
|
37
37
|
|
|
38
|
-
__version__ = "2.
|
|
38
|
+
__version__ = "2.21.0"
|
|
39
39
|
|
|
40
40
|
keras_export("keras.__version__").export_constant(__name__, "__version__")
|
|
41
41
|
|
|
@@ -308,6 +308,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
308
308
|
self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs
|
|
309
309
|
):
|
|
310
310
|
self._instrument_layer_creation()
|
|
311
|
+
self._called = False
|
|
311
312
|
|
|
312
313
|
# These properties should be set by the user via keyword arguments.
|
|
313
314
|
# note that 'dtype', 'input_shape' and 'batch_input_shape'
|
|
@@ -326,6 +327,10 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
326
327
|
# Validate optional keyword arguments.
|
|
327
328
|
generic_utils.validate_kwargs(kwargs, allowed_kwargs)
|
|
328
329
|
|
|
330
|
+
# Track the built-in call-context arguments. These are arguments that
|
|
331
|
+
# are tracked and propagated across the call-stack by default.
|
|
332
|
+
self._call_context_args = {"training"}
|
|
333
|
+
|
|
329
334
|
# Mutable properties
|
|
330
335
|
# Indicates whether the layer's weights are updated during training
|
|
331
336
|
# and whether the layer's updates are run during training.
|
|
@@ -411,6 +416,9 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
411
416
|
|
|
412
417
|
self._init_call_fn_args()
|
|
413
418
|
|
|
419
|
+
# Track the built-in call-context arguments.
|
|
420
|
+
self._call_spec._update_call_context_arguments(self._call_context_args)
|
|
421
|
+
|
|
414
422
|
# Whether the `call` method can be used to build a TF graph without
|
|
415
423
|
# issues. This attribute has no effect if the model is created using
|
|
416
424
|
# the Functional API. Instead, `model.dynamic` is determined based on
|
|
@@ -1042,6 +1050,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
1042
1050
|
# - input_spec compatibility is only checked against `inputs`
|
|
1043
1051
|
# - mixed precision casting (autocast) is only applied to `inputs`,
|
|
1044
1052
|
# not to any other argument.
|
|
1053
|
+
self._called = True
|
|
1045
1054
|
inputs, args, kwargs = self._call_spec.split_out_first_arg(args, kwargs)
|
|
1046
1055
|
input_list = tf.nest.flatten(inputs)
|
|
1047
1056
|
|
|
@@ -1080,17 +1089,21 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
1080
1089
|
if self._expects_mask_arg and mask_is_implicit:
|
|
1081
1090
|
kwargs["mask"] = input_masks
|
|
1082
1091
|
|
|
1083
|
-
#
|
|
1084
|
-
#
|
|
1085
|
-
#
|
|
1086
|
-
# (2) The
|
|
1087
|
-
# (3) The default mode set by
|
|
1088
|
-
#
|
|
1089
|
-
# (4) Any non-None default value for
|
|
1092
|
+
# Call-context arguments for `Layer.call` is set via (in order of
|
|
1093
|
+
# priority):
|
|
1094
|
+
# (1) The argument passed to this `Layer.call`, if it is not None
|
|
1095
|
+
# (2) The argument value of an outer `Layer.call`.
|
|
1096
|
+
# (3) (only for "training") The default mode set by
|
|
1097
|
+
# `tf.keras.backend.set_learning_phase` (if set)
|
|
1098
|
+
# (4) Any non-None default value for the argument specified in the call
|
|
1090
1099
|
# signature
|
|
1091
|
-
# (5) False
|
|
1092
|
-
|
|
1093
|
-
args,
|
|
1100
|
+
# (5) False
|
|
1101
|
+
(
|
|
1102
|
+
args,
|
|
1103
|
+
kwargs,
|
|
1104
|
+
propagated,
|
|
1105
|
+
) = self._get_propagated_call_context_arguments(
|
|
1106
|
+
args, kwargs, call_context, self._call_context_args
|
|
1094
1107
|
)
|
|
1095
1108
|
|
|
1096
1109
|
# Losses are cleared for all sublayers on the outermost `Layer.call`.
|
|
@@ -1104,7 +1117,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
1104
1117
|
layer=self,
|
|
1105
1118
|
inputs=inputs,
|
|
1106
1119
|
build_graph=not eager,
|
|
1107
|
-
|
|
1120
|
+
call_context_args=propagated,
|
|
1108
1121
|
):
|
|
1109
1122
|
input_spec.assert_input_compatibility(
|
|
1110
1123
|
self.input_spec, inputs, self.name
|
|
@@ -1152,6 +1165,55 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
1152
1165
|
|
|
1153
1166
|
return outputs
|
|
1154
1167
|
|
|
1168
|
+
def _register_call_context_args(self, *argument_names):
|
|
1169
|
+
"""Registers call-context args for this layer.
|
|
1170
|
+
If this layer declares a `call()` method that accepts
|
|
1171
|
+
one or more of the given args, those args will be
|
|
1172
|
+
automatically injected into the call signature of this
|
|
1173
|
+
layer. This layer will also propagate the args to any
|
|
1174
|
+
nested sublayers that are called from within this layer.
|
|
1175
|
+
If this layer doesn't declare a `call()` method that
|
|
1176
|
+
accepts one or more of the given args, these args will
|
|
1177
|
+
simply be propagated to any nested sublayers without
|
|
1178
|
+
being injected into the call signature of this layer.
|
|
1179
|
+
This is useful for propagating custom arguments
|
|
1180
|
+
from top-level layers/models to sublayers.
|
|
1181
|
+
Example:
|
|
1182
|
+
```
|
|
1183
|
+
class Inner(layers.Layer):
|
|
1184
|
+
def __init__(self):
|
|
1185
|
+
super().__init__()
|
|
1186
|
+
# Register `foo_mode` as a call-context arg
|
|
1187
|
+
self._register_call_context_args("foo_mode")
|
|
1188
|
+
def call(self, x, foo_mode=False):
|
|
1189
|
+
# If foo_mode=True add 1, otherwise add 0
|
|
1190
|
+
add_val = ops.where(foo_mode, 1.0, 0.0)
|
|
1191
|
+
return x + add_val
|
|
1192
|
+
class Outer(layers.Layer):
|
|
1193
|
+
def __init__(self):
|
|
1194
|
+
super().__init__()
|
|
1195
|
+
self.inner = Inner()
|
|
1196
|
+
def call(self, x):
|
|
1197
|
+
# We don't explicitly pass foo_mode here—Base Layer.__call__
|
|
1198
|
+
# should inject it into `self.inner`
|
|
1199
|
+
return self.inner(x)
|
|
1200
|
+
sample_input = np.array([[1.0], [2.0]])
|
|
1201
|
+
# Sequential model
|
|
1202
|
+
seq = models.Sequential([Outer()])
|
|
1203
|
+
# Tell the Sequential model to propagate foo_mode down
|
|
1204
|
+
# the call-stack
|
|
1205
|
+
seq.register_call_context_args("foo_mode")
|
|
1206
|
+
# foo_mode=True -> input + 1
|
|
1207
|
+
out_true = seq(sample_input, foo_mode=True)
|
|
1208
|
+
"""
|
|
1209
|
+
if self._called:
|
|
1210
|
+
raise RuntimeError(
|
|
1211
|
+
"Cannot add call-context args after the layer has been called."
|
|
1212
|
+
)
|
|
1213
|
+
self._call_context_args |= set(argument_names)
|
|
1214
|
+
self._call_spec._update_call_context_arguments(argument_names)
|
|
1215
|
+
self._call_spec._update_call_context_argument_defaults(argument_names)
|
|
1216
|
+
|
|
1155
1217
|
def _get_unnested_name_scope(self):
|
|
1156
1218
|
if _is_name_scope_on_model_declaration_enabled:
|
|
1157
1219
|
with _name_scope_unnester(
|
|
@@ -2535,47 +2597,57 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
2535
2597
|
kwargs["mask"] = input_masks
|
|
2536
2598
|
mask_arg_passed_by_framework = True
|
|
2537
2599
|
|
|
2538
|
-
|
|
2539
|
-
|
|
2540
|
-
|
|
2541
|
-
|
|
2542
|
-
|
|
2543
|
-
|
|
2544
|
-
|
|
2545
|
-
|
|
2546
|
-
)
|
|
2547
|
-
|
|
2548
|
-
|
|
2549
|
-
|
|
2550
|
-
|
|
2551
|
-
|
|
2552
|
-
|
|
2553
|
-
|
|
2554
|
-
|
|
2555
|
-
|
|
2556
|
-
|
|
2557
|
-
|
|
2558
|
-
|
|
2559
|
-
|
|
2560
|
-
|
|
2600
|
+
propagated = dict()
|
|
2601
|
+
args_passed_by_framework = dict()
|
|
2602
|
+
for context_arg in self._call_context_args:
|
|
2603
|
+
# If `training` argument is None or not explicitly passed,
|
|
2604
|
+
# propagate `training` value from this layer's calling layer.
|
|
2605
|
+
value = None
|
|
2606
|
+
args_passed_by_framework[context_arg] = False
|
|
2607
|
+
# Priority 1: `training` was explicitly passed a non-None value.
|
|
2608
|
+
if self._call_spec.arg_was_passed(context_arg, args, kwargs):
|
|
2609
|
+
value = self._call_spec.get_arg_value(context_arg, args, kwargs)
|
|
2610
|
+
if not self._expects_context_arg(context_arg):
|
|
2611
|
+
kwargs.pop(context_arg)
|
|
2612
|
+
|
|
2613
|
+
if value is None:
|
|
2614
|
+
# Priority 2: `training` was passed to a parent layer.
|
|
2615
|
+
if call_context.get_call_context_arg(context_arg) is not None:
|
|
2616
|
+
value = call_context.get_call_context_arg(context_arg)
|
|
2617
|
+
# Priority 3: `learning_phase()` has been set.
|
|
2618
|
+
elif (
|
|
2619
|
+
context_arg == "training"
|
|
2620
|
+
and backend.global_learning_phase_is_set()
|
|
2621
|
+
):
|
|
2622
|
+
value = backend.learning_phase()
|
|
2623
|
+
# Force the training_value to be bool type which matches to
|
|
2624
|
+
# the contract for layer/model call args.
|
|
2625
|
+
if tf.is_tensor(value):
|
|
2626
|
+
value = tf.cast(value, tf.bool)
|
|
2627
|
+
else:
|
|
2628
|
+
value = bool(value)
|
|
2629
|
+
# Priority 4: trace layer with the default training argument
|
|
2630
|
+
# specified in the `call` signature (or in inference mode if the
|
|
2631
|
+
# `call` signature specifies no non-None default).
|
|
2561
2632
|
else:
|
|
2562
|
-
|
|
2563
|
-
|
|
2564
|
-
|
|
2565
|
-
|
|
2566
|
-
|
|
2567
|
-
|
|
2568
|
-
|
|
2569
|
-
|
|
2570
|
-
|
|
2571
|
-
|
|
2572
|
-
|
|
2573
|
-
|
|
2574
|
-
)
|
|
2575
|
-
training_arg_passed_by_framework = True
|
|
2633
|
+
value = self._call_spec.get_context_arg_default(context_arg)
|
|
2634
|
+
# In cases (2), (3), (4) the training argument is passed
|
|
2635
|
+
# automatically by the framework, and will not be hard-coded
|
|
2636
|
+
# into the model.
|
|
2637
|
+
if self._expects_context_arg(context_arg):
|
|
2638
|
+
args, kwargs = self._call_spec.set_arg_value(
|
|
2639
|
+
context_arg, value, args, kwargs
|
|
2640
|
+
)
|
|
2641
|
+
args_passed_by_framework[context_arg] = True
|
|
2642
|
+
|
|
2643
|
+
if value is not None:
|
|
2644
|
+
propagated[context_arg] = value
|
|
2576
2645
|
|
|
2577
2646
|
with call_context.enter(
|
|
2578
|
-
layer=self,
|
|
2647
|
+
layer=self,
|
|
2648
|
+
inputs=inputs,
|
|
2649
|
+
build_graph=True,
|
|
2650
|
+
call_context_args=propagated,
|
|
2579
2651
|
):
|
|
2580
2652
|
# Check input assumptions set after layer building, e.g. input
|
|
2581
2653
|
# shape.
|
|
@@ -2601,10 +2673,13 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
2601
2673
|
"Tensor or a list of Tensors, not None "
|
|
2602
2674
|
"(layer: " + self.name + ")."
|
|
2603
2675
|
)
|
|
2604
|
-
|
|
2605
|
-
|
|
2606
|
-
|
|
2607
|
-
|
|
2676
|
+
|
|
2677
|
+
for context_arg, is_passed in args_passed_by_framework.items():
|
|
2678
|
+
if is_passed:
|
|
2679
|
+
args, kwargs = self._call_spec.set_arg_value(
|
|
2680
|
+
context_arg, None, args, kwargs, pop_kwarg_if_none=True
|
|
2681
|
+
)
|
|
2682
|
+
|
|
2608
2683
|
if mask_arg_passed_by_framework:
|
|
2609
2684
|
kwargs.pop("mask")
|
|
2610
2685
|
# Node connectivity does not special-case the first argument.
|
|
@@ -2613,52 +2688,100 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
2613
2688
|
)
|
|
2614
2689
|
return outputs
|
|
2615
2690
|
|
|
2616
|
-
def
|
|
2617
|
-
|
|
2618
|
-
|
|
2619
|
-
|
|
2620
|
-
|
|
2621
|
-
|
|
2622
|
-
|
|
2623
|
-
|
|
2624
|
-
|
|
2625
|
-
|
|
2626
|
-
|
|
2627
|
-
|
|
2628
|
-
|
|
2629
|
-
|
|
2630
|
-
|
|
2631
|
-
|
|
2632
|
-
|
|
2633
|
-
|
|
2634
|
-
|
|
2635
|
-
|
|
2636
|
-
|
|
2637
|
-
|
|
2638
|
-
|
|
2691
|
+
def _get_propagated_call_context_arguments(
|
|
2692
|
+
self, args, kwargs, call_context, local_call_context_arguments
|
|
2693
|
+
):
|
|
2694
|
+
"""Resolves the values for propagated call context arguments for the
|
|
2695
|
+
current layer.
|
|
2696
|
+
|
|
2697
|
+
Args:
|
|
2698
|
+
args: The arguments passed to the current layer's `call` method.
|
|
2699
|
+
kwargs: The keyword arguments passed to the current layer's `call`
|
|
2700
|
+
method.
|
|
2701
|
+
call_context: The `CallContext` for the current call-stack.
|
|
2702
|
+
local_call_context_arguments: The call-context arguments registered
|
|
2703
|
+
with to the current layer's `Layer.call` method.
|
|
2704
|
+
|
|
2705
|
+
Returns:
|
|
2706
|
+
A tuple of the following:
|
|
2707
|
+
1. Updated args
|
|
2708
|
+
2. Updated kwargs
|
|
2709
|
+
3. A dictionary of the resolved call-context arguments that should
|
|
2710
|
+
be propagated to the next layer in the call-stack.
|
|
2711
|
+
"""
|
|
2712
|
+
propagated_context = dict()
|
|
2713
|
+
relevant_arguments = call_context.call_context_args.keys() | set(
|
|
2714
|
+
local_call_context_arguments
|
|
2715
|
+
)
|
|
2716
|
+
|
|
2717
|
+
for argument in relevant_arguments:
|
|
2718
|
+
authoritative_value = None
|
|
2719
|
+
was_explicitly_passed = self._call_spec.arg_was_passed(
|
|
2720
|
+
argument, args, kwargs
|
|
2721
|
+
)
|
|
2722
|
+
if self._expects_context_arg(argument):
|
|
2723
|
+
# (1) `arg_name` was passed to this `Layer.call`.
|
|
2724
|
+
if was_explicitly_passed:
|
|
2725
|
+
authoritative_value = self._call_spec.get_arg_value(
|
|
2726
|
+
argument, args, kwargs
|
|
2727
|
+
)
|
|
2728
|
+
# If no `arg_name` arg was passed, or `None` was explicitly
|
|
2729
|
+
# passed, the framework will make a decision about the training
|
|
2730
|
+
# mode is.
|
|
2731
|
+
if authoritative_value is None:
|
|
2732
|
+
value_from_context = call_context.get_call_context_arg(
|
|
2733
|
+
argument
|
|
2734
|
+
)
|
|
2735
|
+
# (2) `arg_name` mode is inferred from an outer
|
|
2736
|
+
# `Layer.call`.
|
|
2737
|
+
if value_from_context is not None:
|
|
2738
|
+
authoritative_value = value_from_context
|
|
2739
|
+
# (3) User set `tf.keras.backend.set_learning_phase`.
|
|
2740
|
+
elif (
|
|
2741
|
+
argument == "training"
|
|
2742
|
+
and backend.global_learning_phase_is_set()
|
|
2743
|
+
):
|
|
2744
|
+
authoritative_value = backend.learning_phase()
|
|
2745
|
+
# Ensure value is a `bool` or `tf.bool`.
|
|
2746
|
+
if isinstance(authoritative_value, bool):
|
|
2747
|
+
pass
|
|
2748
|
+
elif tf.is_tensor(authoritative_value):
|
|
2749
|
+
authoritative_value = tf.cast(
|
|
2750
|
+
authoritative_value, tf.bool
|
|
2751
|
+
)
|
|
2752
|
+
else:
|
|
2753
|
+
authoritative_value = bool(authoritative_value)
|
|
2754
|
+
# (4) We default to using `call`'s default value for
|
|
2755
|
+
# `arg_name`, or treating the layer as if it is in inference
|
|
2756
|
+
# if no non-None default is specified in the `call`
|
|
2757
|
+
# signature.
|
|
2639
2758
|
else:
|
|
2640
|
-
|
|
2641
|
-
|
|
2642
|
-
|
|
2643
|
-
# default is specified in the `call` signature.
|
|
2644
|
-
else:
|
|
2645
|
-
training_mode = self._call_spec.default_training_arg
|
|
2759
|
+
authoritative_value = (
|
|
2760
|
+
self._call_spec.get_context_arg_default(argument)
|
|
2761
|
+
)
|
|
2646
2762
|
|
|
2647
|
-
|
|
2648
|
-
|
|
2649
|
-
|
|
2650
|
-
|
|
2651
|
-
|
|
2652
|
-
if "training" in kwargs:
|
|
2653
|
-
# `training` was passed to this `Layer` but is not needed for
|
|
2654
|
-
# `Layer.call`. It will set the default mode for inner
|
|
2655
|
-
# `Layer.call`s.
|
|
2656
|
-
training_mode = kwargs.pop("training")
|
|
2763
|
+
# For case (2), (3), (4) `arg_name` arg is passed by
|
|
2764
|
+
# framework.
|
|
2765
|
+
args, kwargs = self._call_spec.set_arg_value(
|
|
2766
|
+
argument, authoritative_value, args, kwargs
|
|
2767
|
+
)
|
|
2657
2768
|
else:
|
|
2658
|
-
|
|
2659
|
-
|
|
2769
|
+
if argument in kwargs:
|
|
2770
|
+
# `arg_name` was passed to this `Layer` but is not needed
|
|
2771
|
+
# for `Layer.call`. It will set the default mode for inner
|
|
2772
|
+
# `Layer.call`s.
|
|
2773
|
+
authoritative_value = kwargs.pop(argument)
|
|
2774
|
+
else:
|
|
2775
|
+
# Grab the current `arg_name` mode from any outer
|
|
2776
|
+
# `Layer.call`.
|
|
2777
|
+
authoritative_value = call_context.get_call_context_arg(
|
|
2778
|
+
argument
|
|
2779
|
+
)
|
|
2780
|
+
|
|
2781
|
+
if authoritative_value is not None:
|
|
2782
|
+
propagated_context[argument] = authoritative_value
|
|
2660
2783
|
|
|
2661
|
-
return args, kwargs,
|
|
2784
|
+
return args, kwargs, propagated_context
|
|
2662
2785
|
|
|
2663
2786
|
def _autographed_call(self):
|
|
2664
2787
|
# Wrapping `call` function in autograph to allow for dynamic control
|
|
@@ -3351,7 +3474,8 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
3351
3474
|
|
|
3352
3475
|
def _init_call_fn_args(self, expects_training_arg=None):
|
|
3353
3476
|
self._call_spec = layer_utils.CallFunctionSpec(
|
|
3354
|
-
tf_inspect.getfullargspec(self.call)
|
|
3477
|
+
tf_inspect.getfullargspec(self.call),
|
|
3478
|
+
getattr(self, "_call_context_args", set()),
|
|
3355
3479
|
)
|
|
3356
3480
|
if expects_training_arg is not None:
|
|
3357
3481
|
self._call_spec.expects_training_arg = expects_training_arg
|
|
@@ -3361,6 +3485,9 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
|
|
|
3361
3485
|
"""Whether the call function uses 'training' as a parameter."""
|
|
3362
3486
|
return self._call_spec.expects_training_arg
|
|
3363
3487
|
|
|
3488
|
+
def _expects_context_arg(self, argument_name):
|
|
3489
|
+
return argument_name in self._call_spec.expected_context_args
|
|
3490
|
+
|
|
3364
3491
|
@property
|
|
3365
3492
|
def _expects_mask_arg(self):
|
|
3366
3493
|
return self._call_spec.expects_mask_arg
|
|
@@ -3805,6 +3932,17 @@ class BaseRandomLayer(Layer):
|
|
|
3805
3932
|
super().build(input_shape)
|
|
3806
3933
|
self._random_generator._maybe_init()
|
|
3807
3934
|
|
|
3935
|
+
def get_config(self):
|
|
3936
|
+
base_config = super().get_config()
|
|
3937
|
+
if (
|
|
3938
|
+
self._random_generator._rng_type
|
|
3939
|
+
== backend.RandomGenerator.RNG_LEGACY_STATEFUL
|
|
3940
|
+
):
|
|
3941
|
+
return base_config
|
|
3942
|
+
|
|
3943
|
+
config = {"rng_type": self._random_generator._rng_type}
|
|
3944
|
+
return dict(list(base_config.items()) + list(config.items()))
|
|
3945
|
+
|
|
3808
3946
|
def _trackable_children(self, save_type="checkpoint", **kwargs):
|
|
3809
3947
|
if save_type == "savedmodel":
|
|
3810
3948
|
cache = kwargs["cache"]
|
|
@@ -480,7 +480,8 @@ class CallContext:
|
|
|
480
480
|
layer: The `Layer` whose `call` is currently active.
|
|
481
481
|
inputs: The inputs to the currently active `Layer`.
|
|
482
482
|
build_graph: Whether currently inside a Graph or FuncGraph.
|
|
483
|
-
|
|
483
|
+
call_context_args: The call-context arguments being propagated through the
|
|
484
|
+
the call-stack.
|
|
484
485
|
saving: Whether currently saving to SavedModel.
|
|
485
486
|
frozen: Whether currently executing inside a `Layer` with `trainable` set
|
|
486
487
|
to `False`.
|
|
@@ -495,6 +496,7 @@ class CallContext:
|
|
|
495
496
|
"layer": None,
|
|
496
497
|
"inputs": None,
|
|
497
498
|
"build_graph": False,
|
|
499
|
+
"call_context_args": dict(),
|
|
498
500
|
"training": None,
|
|
499
501
|
"saving": None,
|
|
500
502
|
}
|
|
@@ -502,14 +504,17 @@ class CallContext:
|
|
|
502
504
|
# refactor.
|
|
503
505
|
self._in_keras_graph = False
|
|
504
506
|
|
|
505
|
-
def enter(
|
|
507
|
+
def enter(
|
|
508
|
+
self, layer, inputs, build_graph, call_context_args=dict(), saving=None
|
|
509
|
+
):
|
|
506
510
|
"""Push a Layer and its inputs and state onto the current call context.
|
|
507
511
|
|
|
508
512
|
Args:
|
|
509
513
|
layer: The `Layer` whose `call` is currently active.
|
|
510
514
|
inputs: The inputs to the currently active `Layer`.
|
|
511
515
|
build_graph: Whether currently inside a Graph or FuncGraph.
|
|
512
|
-
|
|
516
|
+
call_context_args: The call-context arguments being propagated through
|
|
517
|
+
the call-stack.
|
|
513
518
|
saving: Whether currently saving to SavedModel.
|
|
514
519
|
|
|
515
520
|
Returns:
|
|
@@ -519,7 +524,7 @@ class CallContext:
|
|
|
519
524
|
"layer": layer,
|
|
520
525
|
"inputs": inputs,
|
|
521
526
|
"build_graph": build_graph,
|
|
522
|
-
"
|
|
527
|
+
"call_context_args": call_context_args,
|
|
523
528
|
"saving": saving,
|
|
524
529
|
}
|
|
525
530
|
return CallContextManager(self, state)
|
|
@@ -538,7 +543,14 @@ class CallContext:
|
|
|
538
543
|
|
|
539
544
|
@property
|
|
540
545
|
def training(self):
|
|
541
|
-
return self.
|
|
546
|
+
return self.call_context_args.get("training", None)
|
|
547
|
+
|
|
548
|
+
@property
|
|
549
|
+
def call_context_args(self):
|
|
550
|
+
return self._state["call_context_args"]
|
|
551
|
+
|
|
552
|
+
def get_call_context_arg(self, arg_name):
|
|
553
|
+
return self.call_context_args.get(arg_name, None)
|
|
542
554
|
|
|
543
555
|
@property
|
|
544
556
|
def saving(self):
|