tf-keras-nightly 2.20.0.dev2025051109__py3-none-any.whl → 2.21.0.dev2025123010__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tf_keras/__init__.py +1 -1
- tf_keras/protobuf/projector_config_pb2.py +23 -12
- tf_keras/protobuf/saved_metadata_pb2.py +21 -10
- tf_keras/protobuf/versions_pb2.py +19 -8
- tf_keras/src/__init__.py +1 -1
- tf_keras/src/engine/base_layer.py +234 -96
- tf_keras/src/engine/base_layer_utils.py +17 -5
- tf_keras/src/engine/base_layer_v1.py +12 -3
- tf_keras/src/engine/data_adapter.py +30 -13
- tf_keras/src/engine/functional.py +36 -15
- tf_keras/src/engine/input_layer.py +9 -0
- tf_keras/src/engine/input_spec.py +11 -1
- tf_keras/src/layers/activation/softmax.py +26 -11
- tf_keras/src/layers/attention/multi_head_attention.py +8 -1
- tf_keras/src/layers/core/tf_op_layer.py +4 -0
- tf_keras/src/layers/rnn/cell_wrappers.py +13 -1
- tf_keras/src/metrics/confusion_metrics.py +51 -4
- tf_keras/src/models/sharpness_aware_minimization.py +17 -7
- tf_keras/src/saving/legacy/saved_model/save_impl.py +28 -12
- tf_keras/src/saving/legacy/saving_utils.py +14 -2
- tf_keras/src/saving/saving_lib.py +1 -1
- tf_keras/src/utils/layer_utils.py +45 -3
- tf_keras/src/utils/metrics_utils.py +4 -1
- {tf_keras_nightly-2.20.0.dev2025051109.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/METADATA +2 -2
- {tf_keras_nightly-2.20.0.dev2025051109.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/RECORD +27 -49
- {tf_keras_nightly-2.20.0.dev2025051109.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/WHEEL +1 -1
- tf_keras/src/layers/preprocessing/benchmarks/bucketized_column_dense_benchmark.py +0 -85
- tf_keras/src/layers/preprocessing/benchmarks/category_encoding_benchmark.py +0 -84
- tf_keras/src/layers/preprocessing/benchmarks/category_hash_dense_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/category_hash_varlen_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_dense_benchmark.py +0 -110
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_varlen_benchmark.py +0 -103
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_dense_benchmark.py +0 -87
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_dense_benchmark.py +0 -96
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_varlen_benchmark.py +0 -96
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_varlen_benchmark.py +0 -87
- tf_keras/src/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py +0 -109
- tf_keras/src/layers/preprocessing/benchmarks/embedding_dense_benchmark.py +0 -86
- tf_keras/src/layers/preprocessing/benchmarks/embedding_varlen_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/hashed_crossing_benchmark.py +0 -90
- tf_keras/src/layers/preprocessing/benchmarks/hashing_benchmark.py +0 -105
- tf_keras/src/layers/preprocessing/benchmarks/image_preproc_benchmark.py +0 -159
- tf_keras/src/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py +0 -135
- tf_keras/src/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py +0 -144
- tf_keras/src/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py +0 -124
- tf_keras/src/layers/preprocessing/benchmarks/weighted_embedding_varlen_benchmark.py +0 -99
- tf_keras/src/saving/legacy/saved_model/create_test_saved_model.py +0 -37
- tf_keras/src/tests/keras_doctest.py +0 -159
- {tf_keras_nightly-2.20.0.dev2025051109.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/top_level.txt +0 -0
|
@@ -132,6 +132,7 @@ class Layer(base_layer.Layer):
|
|
|
132
132
|
self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs
|
|
133
133
|
):
|
|
134
134
|
self._instrument_layer_creation()
|
|
135
|
+
self._called = False
|
|
135
136
|
|
|
136
137
|
# These properties should be set by the user via keyword arguments.
|
|
137
138
|
# note that 'dtype', 'input_shape' and 'batch_input_shape'
|
|
@@ -165,6 +166,8 @@ class Layer(base_layer.Layer):
|
|
|
165
166
|
self._input_spec = None
|
|
166
167
|
self.supports_masking = False
|
|
167
168
|
|
|
169
|
+
self._call_context_args = {"training"}
|
|
170
|
+
|
|
168
171
|
self._init_set_name(name)
|
|
169
172
|
self._activity_regularizer = regularizers.get(
|
|
170
173
|
kwargs.pop("activity_regularizer", None)
|
|
@@ -705,6 +708,7 @@ class Layer(base_layer.Layer):
|
|
|
705
708
|
RuntimeError: if `super().__init__()` was not called in the
|
|
706
709
|
constructor.
|
|
707
710
|
"""
|
|
711
|
+
self._called = True
|
|
708
712
|
self._assert_built_as_v1()
|
|
709
713
|
|
|
710
714
|
if not hasattr(self, "_thread_local"):
|
|
@@ -803,7 +807,12 @@ class Layer(base_layer.Layer):
|
|
|
803
807
|
if build_graph and base_layer_utils.needs_keras_history(inputs):
|
|
804
808
|
base_layer_utils.create_keras_history(inputs)
|
|
805
809
|
|
|
806
|
-
with call_context.enter(
|
|
810
|
+
with call_context.enter(
|
|
811
|
+
self,
|
|
812
|
+
inputs,
|
|
813
|
+
build_graph,
|
|
814
|
+
call_context_args={"training": training_value},
|
|
815
|
+
):
|
|
807
816
|
# Check input assumptions set after layer building, e.g. input
|
|
808
817
|
# shape.
|
|
809
818
|
if build_graph:
|
|
@@ -2177,8 +2186,8 @@ class Layer(base_layer.Layer):
|
|
|
2177
2186
|
else:
|
|
2178
2187
|
self._set_dtype_policy(policy.Policy(dtype))
|
|
2179
2188
|
input_shapes = None
|
|
2180
|
-
if
|
|
2181
|
-
input_shapes =
|
|
2189
|
+
if any(hasattr(x, "shape") for x in input_list):
|
|
2190
|
+
input_shapes = tf_utils.get_shapes(inputs)
|
|
2182
2191
|
# Only call `build` if the user has manually overridden the build
|
|
2183
2192
|
# method.
|
|
2184
2193
|
if not hasattr(self.build, "_is_default"):
|
|
@@ -231,7 +231,7 @@ class TensorLikeDataAdapter(DataAdapter):
|
|
|
231
231
|
return True
|
|
232
232
|
return False
|
|
233
233
|
|
|
234
|
-
return all(_is_tensor(v) for v in flat_inputs)
|
|
234
|
+
return all(_is_tensor(v) for v in flat_inputs if v is not None)
|
|
235
235
|
|
|
236
236
|
def __init__(
|
|
237
237
|
self,
|
|
@@ -259,7 +259,7 @@ class TensorLikeDataAdapter(DataAdapter):
|
|
|
259
259
|
inputs = pack_x_y_sample_weight(x, y, sample_weights)
|
|
260
260
|
|
|
261
261
|
num_samples = set(
|
|
262
|
-
int(i.shape[0]) for i in tf.nest.flatten(inputs)
|
|
262
|
+
int(i.shape[0]) for i in tf.nest.flatten(inputs) if i is not None
|
|
263
263
|
).pop()
|
|
264
264
|
_check_data_cardinality(inputs)
|
|
265
265
|
|
|
@@ -386,7 +386,7 @@ class TensorLikeDataAdapter(DataAdapter):
|
|
|
386
386
|
|
|
387
387
|
def grab_batch(i, data):
|
|
388
388
|
return tf.nest.map_structure(
|
|
389
|
-
lambda d: tf.gather(d, i, axis=0), data
|
|
389
|
+
lambda d: tf.gather(d, i, axis=0) if d is not None else d, data
|
|
390
390
|
)
|
|
391
391
|
|
|
392
392
|
dataset = dataset.map(grab_batch, num_parallel_calls=tf.data.AUTOTUNE)
|
|
@@ -459,7 +459,7 @@ class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
|
|
|
459
459
|
if not TensorLikeDataAdapter.can_handle(
|
|
460
460
|
x, y
|
|
461
461
|
) and not CompositeTensorDataAdapter.can_handle(x, y):
|
|
462
|
-
return all(_is_array_like(v) for v in flat_inputs)
|
|
462
|
+
return all(_is_array_like(v) for v in flat_inputs if v is not None)
|
|
463
463
|
else:
|
|
464
464
|
return False
|
|
465
465
|
|
|
@@ -496,7 +496,7 @@ class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
|
|
|
496
496
|
shape[0] = None
|
|
497
497
|
return tuple(shape)
|
|
498
498
|
|
|
499
|
-
flat_dtypes = [inp.dtype for inp in flat_inputs]
|
|
499
|
+
flat_dtypes = [inp.dtype for inp in flat_inputs if inp is not None]
|
|
500
500
|
contiguous = True
|
|
501
501
|
if self._shuffle and self._shuffle != "batch":
|
|
502
502
|
contiguous = False
|
|
@@ -509,15 +509,26 @@ class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
|
|
|
509
509
|
# to a Tensor may force it into memory..
|
|
510
510
|
def py_method(ind):
|
|
511
511
|
def slice_array(data):
|
|
512
|
+
if data is None:
|
|
513
|
+
return None
|
|
512
514
|
return training_utils.slice_arrays(
|
|
513
515
|
data, ind.numpy(), contiguous=contiguous
|
|
514
516
|
)
|
|
515
517
|
|
|
516
|
-
return [
|
|
518
|
+
return [
|
|
519
|
+
slice_array(inp) for inp in flat_inputs if inp is not None
|
|
520
|
+
]
|
|
517
521
|
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
522
|
+
results = tf.py_function(py_method, [indices], flat_dtypes)
|
|
523
|
+
results_it = iter(results)
|
|
524
|
+
flat_out = []
|
|
525
|
+
for original_inp in flat_inputs:
|
|
526
|
+
if original_inp is None:
|
|
527
|
+
flat_out.append(None)
|
|
528
|
+
else:
|
|
529
|
+
v = next(results_it)
|
|
530
|
+
v.set_shape(dynamic_shape_like(original_inp))
|
|
531
|
+
flat_out.append(v)
|
|
521
532
|
return tf.nest.pack_sequence_as(inputs, flat_out)
|
|
522
533
|
|
|
523
534
|
dataset = indices_dataset.map(
|
|
@@ -608,8 +619,10 @@ class CompositeTensorDataAdapter(DataAdapter):
|
|
|
608
619
|
return True
|
|
609
620
|
return _is_composite(v)
|
|
610
621
|
|
|
611
|
-
return any(
|
|
612
|
-
|
|
622
|
+
return any(
|
|
623
|
+
_is_composite(v) for v in flat_inputs if v is not None
|
|
624
|
+
) and all(
|
|
625
|
+
_is_tensor_or_composite(v) for v in flat_inputs if v is not None
|
|
613
626
|
)
|
|
614
627
|
|
|
615
628
|
def __init__(
|
|
@@ -1944,14 +1957,18 @@ def single_batch_iterator(
|
|
|
1944
1957
|
|
|
1945
1958
|
|
|
1946
1959
|
def _check_data_cardinality(data):
|
|
1947
|
-
num_samples = set(
|
|
1960
|
+
num_samples = set(
|
|
1961
|
+
int(i.shape[0]) for i in tf.nest.flatten(data) if i is not None
|
|
1962
|
+
)
|
|
1948
1963
|
if len(num_samples) > 1:
|
|
1949
1964
|
msg = "Data cardinality is ambiguous:\n"
|
|
1950
1965
|
for label, single_data in zip(["x", "y", "sample_weight"], data):
|
|
1951
1966
|
msg += " {} sizes: {}\n".format(
|
|
1952
1967
|
label,
|
|
1953
1968
|
", ".join(
|
|
1954
|
-
str(i.shape[0])
|
|
1969
|
+
str(i.shape[0])
|
|
1970
|
+
for i in tf.nest.flatten(single_data)
|
|
1971
|
+
if i is not None
|
|
1955
1972
|
),
|
|
1956
1973
|
)
|
|
1957
1974
|
msg += "Make sure all arrays contain the same number of samples."
|
|
@@ -351,25 +351,45 @@ class Functional(training_lib.Model):
|
|
|
351
351
|
if isinstance(self._nested_inputs, dict):
|
|
352
352
|
# Case where `_nested_inputs` is a plain dict of Inputs.
|
|
353
353
|
names = sorted(self._nested_inputs.keys())
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
354
|
+
specs = []
|
|
355
|
+
for name in names:
|
|
356
|
+
layer = self._nested_inputs[name]._keras_history.layer
|
|
357
|
+
optional = (
|
|
358
|
+
layer.optional
|
|
359
|
+
if isinstance(layer, input_layer_module.InputLayer)
|
|
360
|
+
else False
|
|
359
361
|
)
|
|
360
|
-
|
|
361
|
-
|
|
362
|
+
specs.append(
|
|
363
|
+
input_spec.InputSpec(
|
|
364
|
+
shape=shape_with_no_batch_size(
|
|
365
|
+
self._nested_inputs[name]
|
|
366
|
+
),
|
|
367
|
+
allow_last_axis_squeeze=True,
|
|
368
|
+
name=name,
|
|
369
|
+
optional=optional,
|
|
370
|
+
)
|
|
371
|
+
)
|
|
372
|
+
return specs
|
|
362
373
|
else:
|
|
363
374
|
# Single input, or list / tuple of inputs.
|
|
364
375
|
# The data may be passed as a dict keyed by input name.
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
376
|
+
specs = []
|
|
377
|
+
for x in self.inputs:
|
|
378
|
+
layer = x._keras_history.layer
|
|
379
|
+
optional = (
|
|
380
|
+
layer.optional
|
|
381
|
+
if isinstance(layer, input_layer_module.InputLayer)
|
|
382
|
+
else False
|
|
370
383
|
)
|
|
371
|
-
|
|
372
|
-
|
|
384
|
+
specs.append(
|
|
385
|
+
input_spec.InputSpec(
|
|
386
|
+
shape=shape_with_no_batch_size(x),
|
|
387
|
+
allow_last_axis_squeeze=True,
|
|
388
|
+
name=x._keras_history.layer.name,
|
|
389
|
+
optional=optional,
|
|
390
|
+
)
|
|
391
|
+
)
|
|
392
|
+
return specs
|
|
373
393
|
|
|
374
394
|
@input_spec.setter
|
|
375
395
|
def input_spec(self, value):
|
|
@@ -644,7 +664,8 @@ class Functional(training_lib.Model):
|
|
|
644
664
|
else:
|
|
645
665
|
masks = self._flatten_to_reference_inputs(mask)
|
|
646
666
|
for input_t, mask in zip(inputs, masks):
|
|
647
|
-
input_t
|
|
667
|
+
if input_t is not None:
|
|
668
|
+
input_t._keras_mask = mask
|
|
648
669
|
|
|
649
670
|
# Dictionary mapping reference tensors to computed tensors.
|
|
650
671
|
tensor_dict = {}
|
|
@@ -98,6 +98,8 @@ class InputLayer(base_layer.Layer):
|
|
|
98
98
|
`tf.TypeSpec` represents the entire batch. When provided, all other
|
|
99
99
|
args except name must be `None`.
|
|
100
100
|
name: Optional name of the layer (string).
|
|
101
|
+
optional: Boolean, whether the input is optional or not. An optional
|
|
102
|
+
input can accept `None` values.
|
|
101
103
|
"""
|
|
102
104
|
|
|
103
105
|
@traceback_utils.filter_traceback
|
|
@@ -111,6 +113,7 @@ class InputLayer(base_layer.Layer):
|
|
|
111
113
|
name=None,
|
|
112
114
|
ragged=None,
|
|
113
115
|
type_spec=None,
|
|
116
|
+
optional=False,
|
|
114
117
|
**kwargs,
|
|
115
118
|
):
|
|
116
119
|
self._init_input_shape = input_shape
|
|
@@ -180,6 +183,7 @@ class InputLayer(base_layer.Layer):
|
|
|
180
183
|
self.ragged = True if ragged else False
|
|
181
184
|
self.batch_size = batch_size
|
|
182
185
|
self.supports_masking = True
|
|
186
|
+
self.optional = optional
|
|
183
187
|
|
|
184
188
|
if isinstance(input_shape, tf.TensorShape):
|
|
185
189
|
input_shape = tuple(input_shape.as_list())
|
|
@@ -284,6 +288,7 @@ class InputLayer(base_layer.Layer):
|
|
|
284
288
|
"sparse": self.sparse,
|
|
285
289
|
"ragged": self.ragged,
|
|
286
290
|
"name": self.name,
|
|
291
|
+
"optional": self.optional,
|
|
287
292
|
}
|
|
288
293
|
return config
|
|
289
294
|
|
|
@@ -303,6 +308,7 @@ def Input(
|
|
|
303
308
|
tensor=None,
|
|
304
309
|
ragged=None,
|
|
305
310
|
type_spec=None,
|
|
311
|
+
optional=False,
|
|
306
312
|
**kwargs,
|
|
307
313
|
):
|
|
308
314
|
"""`Input()` is used to instantiate a TF-Keras tensor.
|
|
@@ -341,6 +347,8 @@ def Input(
|
|
|
341
347
|
[this guide](https://www.tensorflow.org/guide/ragged_tensor).
|
|
342
348
|
type_spec: A `tf.TypeSpec` object to create the input placeholder from.
|
|
343
349
|
When provided, all other args except name must be None.
|
|
350
|
+
optional: Boolean, whether the input is optional or not. An optional
|
|
351
|
+
input can accept `None` values.
|
|
344
352
|
**kwargs: deprecated arguments support. Supports `batch_shape` and
|
|
345
353
|
`batch_input_shape`.
|
|
346
354
|
|
|
@@ -415,6 +423,7 @@ def Input(
|
|
|
415
423
|
"ragged": ragged,
|
|
416
424
|
"input_tensor": tensor,
|
|
417
425
|
"type_spec": type_spec,
|
|
426
|
+
"optional": optional,
|
|
418
427
|
}
|
|
419
428
|
|
|
420
429
|
batch_input_shape = kwargs.pop(
|
|
@@ -56,6 +56,8 @@ class InputSpec:
|
|
|
56
56
|
as long as the last axis of the spec is 1.
|
|
57
57
|
name: Expected key corresponding to this input when passing data as
|
|
58
58
|
a dictionary.
|
|
59
|
+
optional: Boolean, whether the input is optional or not. An optional input
|
|
60
|
+
can accept `None` values.
|
|
59
61
|
|
|
60
62
|
Example:
|
|
61
63
|
|
|
@@ -82,6 +84,7 @@ class InputSpec:
|
|
|
82
84
|
axes=None,
|
|
83
85
|
allow_last_axis_squeeze=False,
|
|
84
86
|
name=None,
|
|
87
|
+
optional=False,
|
|
85
88
|
):
|
|
86
89
|
self.dtype = tf.as_dtype(dtype).name if dtype is not None else None
|
|
87
90
|
shape = tf.TensorShape(shape)
|
|
@@ -99,6 +102,7 @@ class InputSpec:
|
|
|
99
102
|
self.min_ndim = min_ndim
|
|
100
103
|
self.name = name
|
|
101
104
|
self.allow_last_axis_squeeze = allow_last_axis_squeeze
|
|
105
|
+
self.optional = optional
|
|
102
106
|
try:
|
|
103
107
|
axes = axes or {}
|
|
104
108
|
self.axes = {int(k): axes[k] for k in axes}
|
|
@@ -204,7 +208,11 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
|
|
|
204
208
|
inputs = list_inputs
|
|
205
209
|
|
|
206
210
|
inputs = tf.nest.flatten(inputs)
|
|
207
|
-
for x in inputs:
|
|
211
|
+
for _, (x, spec) in enumerate(zip(inputs, input_spec)):
|
|
212
|
+
if spec is None:
|
|
213
|
+
continue
|
|
214
|
+
if x is None and spec.optional:
|
|
215
|
+
continue
|
|
208
216
|
# Having a shape/dtype is the only commonality of the various
|
|
209
217
|
# tensor-like objects that may be passed. The most common kind of
|
|
210
218
|
# invalid type we are guarding for is a Layer instance (Functional API),
|
|
@@ -224,6 +232,8 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
|
|
|
224
232
|
for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
|
|
225
233
|
if spec is None:
|
|
226
234
|
continue
|
|
235
|
+
if x is None and spec.optional:
|
|
236
|
+
continue
|
|
227
237
|
|
|
228
238
|
shape = tf.TensorShape(x.shape)
|
|
229
239
|
if shape.rank is None:
|
|
@@ -70,6 +70,8 @@ class Softmax(Layer):
|
|
|
70
70
|
Args:
|
|
71
71
|
axis: Integer, or list of Integers, axis along which the softmax
|
|
72
72
|
normalization is applied.
|
|
73
|
+
robust_masking: Bool, if true will use a more robust implementation when
|
|
74
|
+
dealing with masks.
|
|
73
75
|
Call arguments:
|
|
74
76
|
inputs: The inputs, or logits to the softmax layer.
|
|
75
77
|
mask: A boolean mask of the same shape as `inputs`. The mask
|
|
@@ -80,23 +82,34 @@ class Softmax(Layer):
|
|
|
80
82
|
Softmaxed output with the same shape as `inputs`.
|
|
81
83
|
"""
|
|
82
84
|
|
|
83
|
-
def __init__(self, axis=-1, **kwargs):
|
|
85
|
+
def __init__(self, axis=-1, robust_masking=False, **kwargs):
|
|
84
86
|
super().__init__(**kwargs)
|
|
85
87
|
self.supports_masking = True
|
|
88
|
+
self.robust_masking = robust_masking
|
|
86
89
|
self.axis = axis
|
|
87
90
|
|
|
88
91
|
def call(self, inputs, mask=None):
|
|
89
92
|
if mask is not None:
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
93
|
+
if self.robust_masking:
|
|
94
|
+
# We keep the positions where the mask is True or > 0.5, and set
|
|
95
|
+
# the other (masked) positions to -1e.9.
|
|
96
|
+
if mask.dtype is not tf.bool:
|
|
97
|
+
mask = tf.greater(mask, tf.constant(0.5, dtype=mask.dtype))
|
|
98
|
+
inputs = tf.where(
|
|
99
|
+
mask, inputs, _large_compatible_negative(inputs.dtype)
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
# Since mask is 1.0 for positions we want to keep and 0.0 for
|
|
103
|
+
# masked positions, this operation will create a tensor which is
|
|
104
|
+
# 0.0 for positions we want to attend and -1e.9 for masked
|
|
105
|
+
# positions.
|
|
106
|
+
adder = (1.0 - tf.cast(mask, inputs.dtype)) * (
|
|
107
|
+
_large_compatible_negative(inputs.dtype)
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Since we are adding it to the raw scores before the softmax,
|
|
111
|
+
# this is effectively the same as removing these entirely.
|
|
112
|
+
inputs += adder
|
|
100
113
|
if isinstance(self.axis, (tuple, list)):
|
|
101
114
|
if len(self.axis) > 1:
|
|
102
115
|
return tf.exp(
|
|
@@ -109,6 +122,8 @@ class Softmax(Layer):
|
|
|
109
122
|
|
|
110
123
|
def get_config(self):
|
|
111
124
|
config = {"axis": self.axis}
|
|
125
|
+
if self.robust_masking:
|
|
126
|
+
config["robust_masking"] = True
|
|
112
127
|
base_config = super().get_config()
|
|
113
128
|
return dict(list(base_config.items()) + list(config.items()))
|
|
114
129
|
|
|
@@ -198,6 +198,8 @@ class MultiHeadAttention(Layer):
|
|
|
198
198
|
activity_regularizer: Regularizer for dense layer activity.
|
|
199
199
|
kernel_constraint: Constraint for dense layer kernels.
|
|
200
200
|
bias_constraint: Constraint for dense layer kernels.
|
|
201
|
+
softmax_robust_masking: If true will use a more numerically robust
|
|
202
|
+
masking impl.
|
|
201
203
|
|
|
202
204
|
Call arguments:
|
|
203
205
|
query: Query `Tensor` of shape `(B, T, dim)`.
|
|
@@ -247,6 +249,7 @@ class MultiHeadAttention(Layer):
|
|
|
247
249
|
activity_regularizer=None,
|
|
248
250
|
kernel_constraint=None,
|
|
249
251
|
bias_constraint=None,
|
|
252
|
+
softmax_robust_masking=False,
|
|
250
253
|
**kwargs,
|
|
251
254
|
):
|
|
252
255
|
super().__init__(**kwargs)
|
|
@@ -264,6 +267,7 @@ class MultiHeadAttention(Layer):
|
|
|
264
267
|
self._activity_regularizer = regularizers.get(activity_regularizer)
|
|
265
268
|
self._kernel_constraint = constraints.get(kernel_constraint)
|
|
266
269
|
self._bias_constraint = constraints.get(bias_constraint)
|
|
270
|
+
self._softmax_robust_masking = softmax_robust_masking
|
|
267
271
|
if attention_axes is not None and not isinstance(
|
|
268
272
|
attention_axes, collections.abc.Sized
|
|
269
273
|
):
|
|
@@ -298,6 +302,7 @@ class MultiHeadAttention(Layer):
|
|
|
298
302
|
"query_shape": self._query_shape,
|
|
299
303
|
"key_shape": self._key_shape,
|
|
300
304
|
"value_shape": self._value_shape,
|
|
305
|
+
"softmax_robust_masking": self._softmax_robust_masking,
|
|
301
306
|
}
|
|
302
307
|
base_config = super().get_config()
|
|
303
308
|
return dict(list(base_config.items()) + list(config.items()))
|
|
@@ -476,7 +481,9 @@ class MultiHeadAttention(Layer):
|
|
|
476
481
|
)
|
|
477
482
|
)
|
|
478
483
|
self._softmax = activation.Softmax(
|
|
479
|
-
axis=norm_axes,
|
|
484
|
+
axis=norm_axes,
|
|
485
|
+
robust_masking=self._softmax_robust_masking,
|
|
486
|
+
dtype=self._dtype_policy,
|
|
480
487
|
)
|
|
481
488
|
self._dropout_layer = regularization.Dropout(
|
|
482
489
|
rate=self._dropout, dtype=self._dtype_policy
|
|
@@ -259,6 +259,10 @@ class TFOpLambda(Layer):
|
|
|
259
259
|
|
|
260
260
|
self._call_spec.expects_training_arg = False
|
|
261
261
|
self._call_spec.expects_mask_arg = False
|
|
262
|
+
# Clear the call-context arguments for the layer's call method.
|
|
263
|
+
# Otherwise, Keras ends up injecting context arguments into the op-call
|
|
264
|
+
# when the call method accepts kwargs.
|
|
265
|
+
self._call_spec._expected_context_args.clear()
|
|
262
266
|
|
|
263
267
|
def _call_wrapper(self, *args, **kwargs):
|
|
264
268
|
created_variables = []
|
|
@@ -52,9 +52,21 @@ class _RNNCellWrapper(AbstractRNNCell):
|
|
|
52
52
|
super().__init__(*args, **kwargs)
|
|
53
53
|
self.cell = cell
|
|
54
54
|
cell_call_spec = tf_inspect.getfullargspec(cell.call)
|
|
55
|
+
accepts_kwargs = cell_call_spec.varkw is not None
|
|
56
|
+
|
|
55
57
|
self._call_spec.expects_training_arg = (
|
|
56
58
|
"training" in cell_call_spec.args
|
|
57
|
-
) or
|
|
59
|
+
) or accepts_kwargs
|
|
60
|
+
|
|
61
|
+
# Filter _expects_context_arg. An argument is kept if:
|
|
62
|
+
# 1. It's an explicit argument in cell_call_spec.args OR
|
|
63
|
+
# 2. The cell accepts arbitrary keyword arguments (**kwargs),
|
|
64
|
+
# meaning it could potentially handle the context argument.
|
|
65
|
+
self._call_spec._expected_context_args = {
|
|
66
|
+
arg
|
|
67
|
+
for arg in self._call_spec._expected_context_args
|
|
68
|
+
if (arg in cell_call_spec.args) or accepts_kwargs
|
|
69
|
+
}
|
|
58
70
|
|
|
59
71
|
def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
|
|
60
72
|
"""Calls the wrapped cell and performs the wrapping logic.
|
|
@@ -1471,9 +1471,10 @@ class AUC(base_metric.Metric):
|
|
|
1471
1471
|
# label_weights should be of length equal to the number of
|
|
1472
1472
|
# labels.
|
|
1473
1473
|
shapes.append((self.label_weights, ("L",)))
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1474
|
+
|
|
1475
|
+
tf.debugging.assert_shapes(
|
|
1476
|
+
shapes, message="Number of labels is not consistent."
|
|
1477
|
+
)
|
|
1477
1478
|
|
|
1478
1479
|
# Only forward label_weights to update_confusion_matrix_variables when
|
|
1479
1480
|
# multi_label is False. Otherwise the averaging of individual label AUCs
|
|
@@ -1611,13 +1612,59 @@ class AUC(base_metric.Metric):
|
|
|
1611
1612
|
)
|
|
1612
1613
|
x = fp_rate
|
|
1613
1614
|
y = recall
|
|
1614
|
-
|
|
1615
|
+
elif self.curve == metrics_utils.AUCCurve.PR:
|
|
1615
1616
|
precision = tf.math.divide_no_nan(
|
|
1616
1617
|
self.true_positives,
|
|
1617
1618
|
tf.math.add(self.true_positives, self.false_positives),
|
|
1618
1619
|
)
|
|
1619
1620
|
x = recall
|
|
1620
1621
|
y = precision
|
|
1622
|
+
else: # curve == 'PR_GAIN'.
|
|
1623
|
+
# Due to the hyperbolic transform, this formula is less robust than
|
|
1624
|
+
# ROC or PR values. In particular
|
|
1625
|
+
# 1) Both measures diverge when there are no negative examples;
|
|
1626
|
+
# 2) Both measures diverge when there are no true positives;
|
|
1627
|
+
# 3) Recall gain becomes negative when the recall is lower than the
|
|
1628
|
+
# label average (i.e. when more negative examples are classified
|
|
1629
|
+
# positive than real positives).
|
|
1630
|
+
#
|
|
1631
|
+
# We ignore case 1 as it is easily communicated. For case 2 we set
|
|
1632
|
+
# recall_gain to 0 and precision_gain to 1. For case 3 we set the
|
|
1633
|
+
# recall_gain to 0. These fixes will result in an overastimation of
|
|
1634
|
+
# the AUC for estimateors that are anti-correlated with the label
|
|
1635
|
+
# (at some thresholds).
|
|
1636
|
+
#
|
|
1637
|
+
# The scaling factor $\frac{P}{N}$ that is used to form both
|
|
1638
|
+
# gain values.
|
|
1639
|
+
scaling_factor = tf.math.divide_no_nan(
|
|
1640
|
+
tf.math.add(self.true_positives, self.false_negatives),
|
|
1641
|
+
tf.math.add(self.false_positives, self.true_negatives),
|
|
1642
|
+
)
|
|
1643
|
+
|
|
1644
|
+
recall_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(
|
|
1645
|
+
self.false_negatives, self.true_positives
|
|
1646
|
+
)
|
|
1647
|
+
precision_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(
|
|
1648
|
+
self.false_positives, self.true_positives
|
|
1649
|
+
)
|
|
1650
|
+
# Handle case 2.
|
|
1651
|
+
recall_gain = tf.where(
|
|
1652
|
+
tf.equal(self.true_positives, 0.0),
|
|
1653
|
+
tf.zeros_like(recall_gain),
|
|
1654
|
+
recall_gain,
|
|
1655
|
+
)
|
|
1656
|
+
precision_gain = tf.where(
|
|
1657
|
+
tf.equal(self.true_positives, 0.0),
|
|
1658
|
+
tf.ones_like(precision_gain),
|
|
1659
|
+
precision_gain,
|
|
1660
|
+
)
|
|
1661
|
+
# Handle case 3.
|
|
1662
|
+
recall_gain = tf.math.maximum(
|
|
1663
|
+
recall_gain, tf.zeros_like(recall_gain)
|
|
1664
|
+
)
|
|
1665
|
+
|
|
1666
|
+
x = recall_gain
|
|
1667
|
+
y = precision_gain
|
|
1621
1668
|
|
|
1622
1669
|
# Find the rectangle heights based on `summation_method`.
|
|
1623
1670
|
if (
|
|
@@ -72,17 +72,27 @@ class SharpnessAwareMinimization(Model):
|
|
|
72
72
|
if self.num_batch_splits is not None:
|
|
73
73
|
x_split = tf.split(x, self.num_batch_splits)
|
|
74
74
|
y_split = tf.split(y, self.num_batch_splits)
|
|
75
|
+
# Split the sample weight if it is provided.
|
|
76
|
+
if sample_weight is not None:
|
|
77
|
+
sample_weight_split = tf.split(
|
|
78
|
+
sample_weight, self.num_batch_splits
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
sample_weight_split = [None] * self.num_batch_splits
|
|
75
82
|
else:
|
|
76
83
|
x_split = [x]
|
|
77
84
|
y_split = [y]
|
|
85
|
+
sample_weight_split = [sample_weight]
|
|
78
86
|
|
|
79
87
|
gradients_all_batches = []
|
|
80
88
|
pred_all_batches = []
|
|
81
|
-
for x_batch, y_batch in zip(
|
|
89
|
+
for x_batch, y_batch, sample_weight_batch in zip(
|
|
90
|
+
x_split, y_split, sample_weight_split
|
|
91
|
+
):
|
|
82
92
|
epsilon_w_cache = []
|
|
83
93
|
with tf.GradientTape() as tape:
|
|
84
|
-
pred = self
|
|
85
|
-
loss = self.compiled_loss(y_batch, pred)
|
|
94
|
+
pred = self(x_batch)
|
|
95
|
+
loss = self.compiled_loss(y_batch, pred, sample_weight_batch)
|
|
86
96
|
pred_all_batches.append(pred)
|
|
87
97
|
trainable_variables = self.model.trainable_variables
|
|
88
98
|
gradients = tape.gradient(loss, trainable_variables)
|
|
@@ -98,8 +108,8 @@ class SharpnessAwareMinimization(Model):
|
|
|
98
108
|
epsilon_w_cache.append(epsilon_w)
|
|
99
109
|
|
|
100
110
|
with tf.GradientTape() as tape:
|
|
101
|
-
pred = self(x_batch)
|
|
102
|
-
loss = self.compiled_loss(y_batch, pred)
|
|
111
|
+
pred = self(x_batch, training=True)
|
|
112
|
+
loss = self.compiled_loss(y_batch, pred, sample_weight_batch)
|
|
103
113
|
gradients = tape.gradient(loss, trainable_variables)
|
|
104
114
|
if len(gradients_all_batches) == 0:
|
|
105
115
|
for gradient in gradients:
|
|
@@ -127,7 +137,7 @@ class SharpnessAwareMinimization(Model):
|
|
|
127
137
|
self.compiled_metrics.update_state(y, pred, sample_weight)
|
|
128
138
|
return {m.name: m.result() for m in self.metrics}
|
|
129
139
|
|
|
130
|
-
def call(self, inputs):
|
|
140
|
+
def call(self, inputs, **kwargs):
|
|
131
141
|
"""Forward pass of SAM.
|
|
132
142
|
|
|
133
143
|
SAM delegates the forward pass call to the wrapped model.
|
|
@@ -138,7 +148,7 @@ class SharpnessAwareMinimization(Model):
|
|
|
138
148
|
Returns:
|
|
139
149
|
A Tensor, the outputs of the wrapped model for given `inputs`.
|
|
140
150
|
"""
|
|
141
|
-
return self.model(inputs)
|
|
151
|
+
return self.model(inputs, **kwargs)
|
|
142
152
|
|
|
143
153
|
def get_config(self):
|
|
144
154
|
config = super().get_config()
|