tf-keras-nightly 2.19.0.dev2024121210__py3-none-any.whl → 2.21.0.dev2025123010__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tf_keras/__init__.py +1 -1
- tf_keras/protobuf/projector_config_pb2.py +23 -12
- tf_keras/protobuf/saved_metadata_pb2.py +21 -10
- tf_keras/protobuf/versions_pb2.py +19 -8
- tf_keras/src/__init__.py +1 -1
- tf_keras/src/backend.py +1 -1
- tf_keras/src/datasets/boston_housing.py +14 -5
- tf_keras/src/datasets/cifar10.py +9 -1
- tf_keras/src/datasets/cifar100.py +7 -1
- tf_keras/src/datasets/fashion_mnist.py +16 -4
- tf_keras/src/datasets/imdb.py +8 -0
- tf_keras/src/datasets/mnist.py +9 -3
- tf_keras/src/datasets/reuters.py +8 -0
- tf_keras/src/engine/base_layer.py +235 -97
- tf_keras/src/engine/base_layer_utils.py +17 -5
- tf_keras/src/engine/base_layer_v1.py +12 -3
- tf_keras/src/engine/data_adapter.py +35 -19
- tf_keras/src/engine/functional.py +36 -15
- tf_keras/src/engine/input_layer.py +9 -0
- tf_keras/src/engine/input_spec.py +11 -1
- tf_keras/src/engine/sequential.py +29 -12
- tf_keras/src/layers/activation/softmax.py +26 -11
- tf_keras/src/layers/attention/multi_head_attention.py +8 -1
- tf_keras/src/layers/core/tf_op_layer.py +4 -0
- tf_keras/src/layers/normalization/spectral_normalization.py +29 -22
- tf_keras/src/layers/rnn/cell_wrappers.py +13 -1
- tf_keras/src/metrics/confusion_metrics.py +51 -4
- tf_keras/src/models/sharpness_aware_minimization.py +17 -7
- tf_keras/src/preprocessing/sequence.py +2 -2
- tf_keras/src/saving/legacy/saved_model/save_impl.py +28 -12
- tf_keras/src/saving/legacy/saving_utils.py +14 -2
- tf_keras/src/saving/saving_api.py +18 -5
- tf_keras/src/saving/saving_lib.py +1 -1
- tf_keras/src/utils/layer_utils.py +45 -3
- tf_keras/src/utils/metrics_utils.py +4 -1
- tf_keras/src/utils/tf_utils.py +2 -2
- {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/METADATA +14 -3
- {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/RECORD +40 -62
- {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/WHEEL +1 -1
- tf_keras/src/layers/preprocessing/benchmarks/bucketized_column_dense_benchmark.py +0 -85
- tf_keras/src/layers/preprocessing/benchmarks/category_encoding_benchmark.py +0 -84
- tf_keras/src/layers/preprocessing/benchmarks/category_hash_dense_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/category_hash_varlen_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_dense_benchmark.py +0 -110
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_varlen_benchmark.py +0 -103
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_dense_benchmark.py +0 -87
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_dense_benchmark.py +0 -96
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_varlen_benchmark.py +0 -96
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_varlen_benchmark.py +0 -87
- tf_keras/src/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py +0 -109
- tf_keras/src/layers/preprocessing/benchmarks/embedding_dense_benchmark.py +0 -86
- tf_keras/src/layers/preprocessing/benchmarks/embedding_varlen_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/hashed_crossing_benchmark.py +0 -90
- tf_keras/src/layers/preprocessing/benchmarks/hashing_benchmark.py +0 -105
- tf_keras/src/layers/preprocessing/benchmarks/image_preproc_benchmark.py +0 -159
- tf_keras/src/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py +0 -135
- tf_keras/src/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py +0 -144
- tf_keras/src/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py +0 -124
- tf_keras/src/layers/preprocessing/benchmarks/weighted_embedding_varlen_benchmark.py +0 -99
- tf_keras/src/saving/legacy/saved_model/create_test_saved_model.py +0 -37
- tf_keras/src/tests/keras_doctest.py +0 -159
- {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/top_level.txt +0 -0
|
@@ -231,7 +231,7 @@ class TensorLikeDataAdapter(DataAdapter):
|
|
|
231
231
|
return True
|
|
232
232
|
return False
|
|
233
233
|
|
|
234
|
-
return all(_is_tensor(v) for v in flat_inputs)
|
|
234
|
+
return all(_is_tensor(v) for v in flat_inputs if v is not None)
|
|
235
235
|
|
|
236
236
|
def __init__(
|
|
237
237
|
self,
|
|
@@ -259,7 +259,7 @@ class TensorLikeDataAdapter(DataAdapter):
|
|
|
259
259
|
inputs = pack_x_y_sample_weight(x, y, sample_weights)
|
|
260
260
|
|
|
261
261
|
num_samples = set(
|
|
262
|
-
int(i.shape[0]) for i in tf.nest.flatten(inputs)
|
|
262
|
+
int(i.shape[0]) for i in tf.nest.flatten(inputs) if i is not None
|
|
263
263
|
).pop()
|
|
264
264
|
_check_data_cardinality(inputs)
|
|
265
265
|
|
|
@@ -386,7 +386,7 @@ class TensorLikeDataAdapter(DataAdapter):
|
|
|
386
386
|
|
|
387
387
|
def grab_batch(i, data):
|
|
388
388
|
return tf.nest.map_structure(
|
|
389
|
-
lambda d: tf.gather(d, i, axis=0), data
|
|
389
|
+
lambda d: tf.gather(d, i, axis=0) if d is not None else d, data
|
|
390
390
|
)
|
|
391
391
|
|
|
392
392
|
dataset = dataset.map(grab_batch, num_parallel_calls=tf.data.AUTOTUNE)
|
|
@@ -459,7 +459,7 @@ class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
|
|
|
459
459
|
if not TensorLikeDataAdapter.can_handle(
|
|
460
460
|
x, y
|
|
461
461
|
) and not CompositeTensorDataAdapter.can_handle(x, y):
|
|
462
|
-
return all(_is_array_like(v) for v in flat_inputs)
|
|
462
|
+
return all(_is_array_like(v) for v in flat_inputs if v is not None)
|
|
463
463
|
else:
|
|
464
464
|
return False
|
|
465
465
|
|
|
@@ -496,7 +496,7 @@ class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
|
|
|
496
496
|
shape[0] = None
|
|
497
497
|
return tuple(shape)
|
|
498
498
|
|
|
499
|
-
flat_dtypes = [inp.dtype for inp in flat_inputs]
|
|
499
|
+
flat_dtypes = [inp.dtype for inp in flat_inputs if inp is not None]
|
|
500
500
|
contiguous = True
|
|
501
501
|
if self._shuffle and self._shuffle != "batch":
|
|
502
502
|
contiguous = False
|
|
@@ -509,15 +509,26 @@ class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
|
|
|
509
509
|
# to a Tensor may force it into memory..
|
|
510
510
|
def py_method(ind):
|
|
511
511
|
def slice_array(data):
|
|
512
|
+
if data is None:
|
|
513
|
+
return None
|
|
512
514
|
return training_utils.slice_arrays(
|
|
513
515
|
data, ind.numpy(), contiguous=contiguous
|
|
514
516
|
)
|
|
515
517
|
|
|
516
|
-
return [
|
|
518
|
+
return [
|
|
519
|
+
slice_array(inp) for inp in flat_inputs if inp is not None
|
|
520
|
+
]
|
|
517
521
|
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
522
|
+
results = tf.py_function(py_method, [indices], flat_dtypes)
|
|
523
|
+
results_it = iter(results)
|
|
524
|
+
flat_out = []
|
|
525
|
+
for original_inp in flat_inputs:
|
|
526
|
+
if original_inp is None:
|
|
527
|
+
flat_out.append(None)
|
|
528
|
+
else:
|
|
529
|
+
v = next(results_it)
|
|
530
|
+
v.set_shape(dynamic_shape_like(original_inp))
|
|
531
|
+
flat_out.append(v)
|
|
521
532
|
return tf.nest.pack_sequence_as(inputs, flat_out)
|
|
522
533
|
|
|
523
534
|
dataset = indices_dataset.map(
|
|
@@ -608,8 +619,10 @@ class CompositeTensorDataAdapter(DataAdapter):
|
|
|
608
619
|
return True
|
|
609
620
|
return _is_composite(v)
|
|
610
621
|
|
|
611
|
-
return any(
|
|
612
|
-
|
|
622
|
+
return any(
|
|
623
|
+
_is_composite(v) for v in flat_inputs if v is not None
|
|
624
|
+
) and all(
|
|
625
|
+
_is_tensor_or_composite(v) for v in flat_inputs if v is not None
|
|
613
626
|
)
|
|
614
627
|
|
|
615
628
|
def __init__(
|
|
@@ -794,8 +807,7 @@ class DatasetAdapter(DataAdapter):
|
|
|
794
807
|
# each epoch.
|
|
795
808
|
return (
|
|
796
809
|
self._user_steps is None
|
|
797
|
-
or
|
|
798
|
-
== self._user_steps
|
|
810
|
+
or self._dataset.cardinality().numpy() == self._user_steps
|
|
799
811
|
)
|
|
800
812
|
|
|
801
813
|
def _validate_args(self, y, sample_weights, steps, pss_evaluation_shards):
|
|
@@ -819,8 +831,8 @@ class DatasetAdapter(DataAdapter):
|
|
|
819
831
|
"specify the number of steps to run."
|
|
820
832
|
)
|
|
821
833
|
else:
|
|
822
|
-
size =
|
|
823
|
-
if size == tf.data.
|
|
834
|
+
size = self._dataset.cardinality().numpy()
|
|
835
|
+
if size == tf.data.INFINITE_CARDINALITY:
|
|
824
836
|
if pss_evaluation_shards:
|
|
825
837
|
raise ValueError(
|
|
826
838
|
"When performing exact evaluation, the dataset "
|
|
@@ -1481,8 +1493,8 @@ class DataHandler:
|
|
|
1481
1493
|
if not isinstance(dataset, tf.data.Dataset):
|
|
1482
1494
|
return None
|
|
1483
1495
|
|
|
1484
|
-
size =
|
|
1485
|
-
if size == tf.data.
|
|
1496
|
+
size = dataset.cardinality()
|
|
1497
|
+
if size == tf.data.INFINITE_CARDINALITY and steps is None:
|
|
1486
1498
|
raise ValueError(
|
|
1487
1499
|
"When passing an infinitely repeating dataset, please specify "
|
|
1488
1500
|
"a `steps_per_epoch` value so that epoch level "
|
|
@@ -1945,14 +1957,18 @@ def single_batch_iterator(
|
|
|
1945
1957
|
|
|
1946
1958
|
|
|
1947
1959
|
def _check_data_cardinality(data):
|
|
1948
|
-
num_samples = set(
|
|
1960
|
+
num_samples = set(
|
|
1961
|
+
int(i.shape[0]) for i in tf.nest.flatten(data) if i is not None
|
|
1962
|
+
)
|
|
1949
1963
|
if len(num_samples) > 1:
|
|
1950
1964
|
msg = "Data cardinality is ambiguous:\n"
|
|
1951
1965
|
for label, single_data in zip(["x", "y", "sample_weight"], data):
|
|
1952
1966
|
msg += " {} sizes: {}\n".format(
|
|
1953
1967
|
label,
|
|
1954
1968
|
", ".join(
|
|
1955
|
-
str(i.shape[0])
|
|
1969
|
+
str(i.shape[0])
|
|
1970
|
+
for i in tf.nest.flatten(single_data)
|
|
1971
|
+
if i is not None
|
|
1956
1972
|
),
|
|
1957
1973
|
)
|
|
1958
1974
|
msg += "Make sure all arrays contain the same number of samples."
|
|
@@ -351,25 +351,45 @@ class Functional(training_lib.Model):
|
|
|
351
351
|
if isinstance(self._nested_inputs, dict):
|
|
352
352
|
# Case where `_nested_inputs` is a plain dict of Inputs.
|
|
353
353
|
names = sorted(self._nested_inputs.keys())
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
354
|
+
specs = []
|
|
355
|
+
for name in names:
|
|
356
|
+
layer = self._nested_inputs[name]._keras_history.layer
|
|
357
|
+
optional = (
|
|
358
|
+
layer.optional
|
|
359
|
+
if isinstance(layer, input_layer_module.InputLayer)
|
|
360
|
+
else False
|
|
359
361
|
)
|
|
360
|
-
|
|
361
|
-
|
|
362
|
+
specs.append(
|
|
363
|
+
input_spec.InputSpec(
|
|
364
|
+
shape=shape_with_no_batch_size(
|
|
365
|
+
self._nested_inputs[name]
|
|
366
|
+
),
|
|
367
|
+
allow_last_axis_squeeze=True,
|
|
368
|
+
name=name,
|
|
369
|
+
optional=optional,
|
|
370
|
+
)
|
|
371
|
+
)
|
|
372
|
+
return specs
|
|
362
373
|
else:
|
|
363
374
|
# Single input, or list / tuple of inputs.
|
|
364
375
|
# The data may be passed as a dict keyed by input name.
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
376
|
+
specs = []
|
|
377
|
+
for x in self.inputs:
|
|
378
|
+
layer = x._keras_history.layer
|
|
379
|
+
optional = (
|
|
380
|
+
layer.optional
|
|
381
|
+
if isinstance(layer, input_layer_module.InputLayer)
|
|
382
|
+
else False
|
|
370
383
|
)
|
|
371
|
-
|
|
372
|
-
|
|
384
|
+
specs.append(
|
|
385
|
+
input_spec.InputSpec(
|
|
386
|
+
shape=shape_with_no_batch_size(x),
|
|
387
|
+
allow_last_axis_squeeze=True,
|
|
388
|
+
name=x._keras_history.layer.name,
|
|
389
|
+
optional=optional,
|
|
390
|
+
)
|
|
391
|
+
)
|
|
392
|
+
return specs
|
|
373
393
|
|
|
374
394
|
@input_spec.setter
|
|
375
395
|
def input_spec(self, value):
|
|
@@ -644,7 +664,8 @@ class Functional(training_lib.Model):
|
|
|
644
664
|
else:
|
|
645
665
|
masks = self._flatten_to_reference_inputs(mask)
|
|
646
666
|
for input_t, mask in zip(inputs, masks):
|
|
647
|
-
input_t
|
|
667
|
+
if input_t is not None:
|
|
668
|
+
input_t._keras_mask = mask
|
|
648
669
|
|
|
649
670
|
# Dictionary mapping reference tensors to computed tensors.
|
|
650
671
|
tensor_dict = {}
|
|
@@ -98,6 +98,8 @@ class InputLayer(base_layer.Layer):
|
|
|
98
98
|
`tf.TypeSpec` represents the entire batch. When provided, all other
|
|
99
99
|
args except name must be `None`.
|
|
100
100
|
name: Optional name of the layer (string).
|
|
101
|
+
optional: Boolean, whether the input is optional or not. An optional
|
|
102
|
+
input can accept `None` values.
|
|
101
103
|
"""
|
|
102
104
|
|
|
103
105
|
@traceback_utils.filter_traceback
|
|
@@ -111,6 +113,7 @@ class InputLayer(base_layer.Layer):
|
|
|
111
113
|
name=None,
|
|
112
114
|
ragged=None,
|
|
113
115
|
type_spec=None,
|
|
116
|
+
optional=False,
|
|
114
117
|
**kwargs,
|
|
115
118
|
):
|
|
116
119
|
self._init_input_shape = input_shape
|
|
@@ -180,6 +183,7 @@ class InputLayer(base_layer.Layer):
|
|
|
180
183
|
self.ragged = True if ragged else False
|
|
181
184
|
self.batch_size = batch_size
|
|
182
185
|
self.supports_masking = True
|
|
186
|
+
self.optional = optional
|
|
183
187
|
|
|
184
188
|
if isinstance(input_shape, tf.TensorShape):
|
|
185
189
|
input_shape = tuple(input_shape.as_list())
|
|
@@ -284,6 +288,7 @@ class InputLayer(base_layer.Layer):
|
|
|
284
288
|
"sparse": self.sparse,
|
|
285
289
|
"ragged": self.ragged,
|
|
286
290
|
"name": self.name,
|
|
291
|
+
"optional": self.optional,
|
|
287
292
|
}
|
|
288
293
|
return config
|
|
289
294
|
|
|
@@ -303,6 +308,7 @@ def Input(
|
|
|
303
308
|
tensor=None,
|
|
304
309
|
ragged=None,
|
|
305
310
|
type_spec=None,
|
|
311
|
+
optional=False,
|
|
306
312
|
**kwargs,
|
|
307
313
|
):
|
|
308
314
|
"""`Input()` is used to instantiate a TF-Keras tensor.
|
|
@@ -341,6 +347,8 @@ def Input(
|
|
|
341
347
|
[this guide](https://www.tensorflow.org/guide/ragged_tensor).
|
|
342
348
|
type_spec: A `tf.TypeSpec` object to create the input placeholder from.
|
|
343
349
|
When provided, all other args except name must be None.
|
|
350
|
+
optional: Boolean, whether the input is optional or not. An optional
|
|
351
|
+
input can accept `None` values.
|
|
344
352
|
**kwargs: deprecated arguments support. Supports `batch_shape` and
|
|
345
353
|
`batch_input_shape`.
|
|
346
354
|
|
|
@@ -415,6 +423,7 @@ def Input(
|
|
|
415
423
|
"ragged": ragged,
|
|
416
424
|
"input_tensor": tensor,
|
|
417
425
|
"type_spec": type_spec,
|
|
426
|
+
"optional": optional,
|
|
418
427
|
}
|
|
419
428
|
|
|
420
429
|
batch_input_shape = kwargs.pop(
|
|
@@ -56,6 +56,8 @@ class InputSpec:
|
|
|
56
56
|
as long as the last axis of the spec is 1.
|
|
57
57
|
name: Expected key corresponding to this input when passing data as
|
|
58
58
|
a dictionary.
|
|
59
|
+
optional: Boolean, whether the input is optional or not. An optional input
|
|
60
|
+
can accept `None` values.
|
|
59
61
|
|
|
60
62
|
Example:
|
|
61
63
|
|
|
@@ -82,6 +84,7 @@ class InputSpec:
|
|
|
82
84
|
axes=None,
|
|
83
85
|
allow_last_axis_squeeze=False,
|
|
84
86
|
name=None,
|
|
87
|
+
optional=False,
|
|
85
88
|
):
|
|
86
89
|
self.dtype = tf.as_dtype(dtype).name if dtype is not None else None
|
|
87
90
|
shape = tf.TensorShape(shape)
|
|
@@ -99,6 +102,7 @@ class InputSpec:
|
|
|
99
102
|
self.min_ndim = min_ndim
|
|
100
103
|
self.name = name
|
|
101
104
|
self.allow_last_axis_squeeze = allow_last_axis_squeeze
|
|
105
|
+
self.optional = optional
|
|
102
106
|
try:
|
|
103
107
|
axes = axes or {}
|
|
104
108
|
self.axes = {int(k): axes[k] for k in axes}
|
|
@@ -204,7 +208,11 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
|
|
|
204
208
|
inputs = list_inputs
|
|
205
209
|
|
|
206
210
|
inputs = tf.nest.flatten(inputs)
|
|
207
|
-
for x in inputs:
|
|
211
|
+
for _, (x, spec) in enumerate(zip(inputs, input_spec)):
|
|
212
|
+
if spec is None:
|
|
213
|
+
continue
|
|
214
|
+
if x is None and spec.optional:
|
|
215
|
+
continue
|
|
208
216
|
# Having a shape/dtype is the only commonality of the various
|
|
209
217
|
# tensor-like objects that may be passed. The most common kind of
|
|
210
218
|
# invalid type we are guarding for is a Layer instance (Functional API),
|
|
@@ -224,6 +232,8 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
|
|
|
224
232
|
for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
|
|
225
233
|
if spec is None:
|
|
226
234
|
continue
|
|
235
|
+
if x is None and spec.optional:
|
|
236
|
+
continue
|
|
227
237
|
|
|
228
238
|
shape = tf.TensorShape(x.shape)
|
|
229
239
|
if shape.rank is None:
|
|
@@ -285,12 +285,16 @@ class Sequential(functional.Functional):
|
|
|
285
285
|
):
|
|
286
286
|
# Determine whether the input shape is novel, i.e. whether the model
|
|
287
287
|
# should be rebuilt.
|
|
288
|
-
input_shape =
|
|
288
|
+
input_shape = tf_utils.convert_shapes(input_shape)
|
|
289
289
|
if self._inferred_input_shape is None:
|
|
290
290
|
new_shape = input_shape
|
|
291
291
|
else:
|
|
292
|
-
new_shape =
|
|
293
|
-
|
|
292
|
+
new_shape = tf.nest.map_structure(
|
|
293
|
+
_relax_input_shape,
|
|
294
|
+
tf_utils.convert_shapes(
|
|
295
|
+
self._inferred_input_shape, to_tuples=False
|
|
296
|
+
),
|
|
297
|
+
tf_utils.convert_shapes(input_shape, to_tuples=False),
|
|
294
298
|
)
|
|
295
299
|
if (
|
|
296
300
|
new_shape is not None
|
|
@@ -299,10 +303,13 @@ class Sequential(functional.Functional):
|
|
|
299
303
|
# A novel shape has been received: we need to rebuild the model.
|
|
300
304
|
# In case we are inside a graph function, we step out of it.
|
|
301
305
|
with tf.init_scope():
|
|
302
|
-
inputs =
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
+
inputs = tf.nest.map_structure(
|
|
307
|
+
lambda s: input_layer.Input(
|
|
308
|
+
batch_shape=tf_utils.convert_shapes(s),
|
|
309
|
+
dtype=input_dtype,
|
|
310
|
+
name=self.layers[0].name + "_input",
|
|
311
|
+
),
|
|
312
|
+
tf_utils.convert_shapes(new_shape, to_tuples=False),
|
|
306
313
|
)
|
|
307
314
|
layer_input = inputs
|
|
308
315
|
created_nodes = set()
|
|
@@ -370,7 +377,7 @@ class Sequential(functional.Functional):
|
|
|
370
377
|
raise ValueError("You must provide an `input_shape` argument.")
|
|
371
378
|
self._build_graph_network_for_inferred_shape(input_shape)
|
|
372
379
|
if not self.built:
|
|
373
|
-
input_shape =
|
|
380
|
+
input_shape = tf_utils.convert_shapes(input_shape)
|
|
374
381
|
self._build_input_shape = input_shape
|
|
375
382
|
super().build(input_shape)
|
|
376
383
|
self.built = True
|
|
@@ -435,7 +442,8 @@ class Sequential(functional.Functional):
|
|
|
435
442
|
def get_config(self):
|
|
436
443
|
layer_configs = []
|
|
437
444
|
serialize_obj_fn = serialization_lib.serialize_keras_object
|
|
438
|
-
|
|
445
|
+
use_legacy_config = getattr(self, "use_legacy_config", False)
|
|
446
|
+
if use_legacy_config:
|
|
439
447
|
serialize_obj_fn = legacy_serialization.serialize_keras_object
|
|
440
448
|
for layer in super().layers:
|
|
441
449
|
# `super().layers` include the InputLayer if available (it is
|
|
@@ -446,7 +454,11 @@ class Sequential(functional.Functional):
|
|
|
446
454
|
config = training.Model.get_config(self)
|
|
447
455
|
config["name"] = self.name
|
|
448
456
|
config["layers"] = copy.deepcopy(layer_configs)
|
|
449
|
-
if
|
|
457
|
+
if (
|
|
458
|
+
use_legacy_config
|
|
459
|
+
and not self._is_graph_network
|
|
460
|
+
and self._build_input_shape
|
|
461
|
+
):
|
|
450
462
|
config["build_input_shape"] = self._build_input_shape
|
|
451
463
|
return config
|
|
452
464
|
|
|
@@ -458,6 +470,7 @@ class Sequential(functional.Functional):
|
|
|
458
470
|
layer_configs = config["layers"]
|
|
459
471
|
else:
|
|
460
472
|
name = None
|
|
473
|
+
build_input_shape = None
|
|
461
474
|
layer_configs = config
|
|
462
475
|
model = cls(name=name)
|
|
463
476
|
for layer_config in layer_configs:
|
|
@@ -519,11 +532,15 @@ def _get_shape_tuple(t):
|
|
|
519
532
|
return None
|
|
520
533
|
|
|
521
534
|
|
|
522
|
-
def
|
|
535
|
+
def _relax_input_shape(shape_1, shape_2):
|
|
523
536
|
if shape_1 is None or shape_2 is None:
|
|
524
537
|
return None
|
|
525
|
-
if
|
|
538
|
+
if shape_1.rank is None or shape_2.rank is None:
|
|
539
|
+
return None
|
|
540
|
+
if shape_1.rank != shape_2.rank:
|
|
526
541
|
return None
|
|
542
|
+
shape_1 = shape_1.as_list()
|
|
543
|
+
shape_2 = shape_2.as_list()
|
|
527
544
|
return tuple(None if d1 != d2 else d1 for d1, d2 in zip(shape_1, shape_2))
|
|
528
545
|
|
|
529
546
|
|
|
@@ -70,6 +70,8 @@ class Softmax(Layer):
|
|
|
70
70
|
Args:
|
|
71
71
|
axis: Integer, or list of Integers, axis along which the softmax
|
|
72
72
|
normalization is applied.
|
|
73
|
+
robust_masking: Bool, if true will use a more robust implementation when
|
|
74
|
+
dealing with masks.
|
|
73
75
|
Call arguments:
|
|
74
76
|
inputs: The inputs, or logits to the softmax layer.
|
|
75
77
|
mask: A boolean mask of the same shape as `inputs`. The mask
|
|
@@ -80,23 +82,34 @@ class Softmax(Layer):
|
|
|
80
82
|
Softmaxed output with the same shape as `inputs`.
|
|
81
83
|
"""
|
|
82
84
|
|
|
83
|
-
def __init__(self, axis=-1, **kwargs):
|
|
85
|
+
def __init__(self, axis=-1, robust_masking=False, **kwargs):
|
|
84
86
|
super().__init__(**kwargs)
|
|
85
87
|
self.supports_masking = True
|
|
88
|
+
self.robust_masking = robust_masking
|
|
86
89
|
self.axis = axis
|
|
87
90
|
|
|
88
91
|
def call(self, inputs, mask=None):
|
|
89
92
|
if mask is not None:
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
93
|
+
if self.robust_masking:
|
|
94
|
+
# We keep the positions where the mask is True or > 0.5, and set
|
|
95
|
+
# the other (masked) positions to -1e.9.
|
|
96
|
+
if mask.dtype is not tf.bool:
|
|
97
|
+
mask = tf.greater(mask, tf.constant(0.5, dtype=mask.dtype))
|
|
98
|
+
inputs = tf.where(
|
|
99
|
+
mask, inputs, _large_compatible_negative(inputs.dtype)
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
# Since mask is 1.0 for positions we want to keep and 0.0 for
|
|
103
|
+
# masked positions, this operation will create a tensor which is
|
|
104
|
+
# 0.0 for positions we want to attend and -1e.9 for masked
|
|
105
|
+
# positions.
|
|
106
|
+
adder = (1.0 - tf.cast(mask, inputs.dtype)) * (
|
|
107
|
+
_large_compatible_negative(inputs.dtype)
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Since we are adding it to the raw scores before the softmax,
|
|
111
|
+
# this is effectively the same as removing these entirely.
|
|
112
|
+
inputs += adder
|
|
100
113
|
if isinstance(self.axis, (tuple, list)):
|
|
101
114
|
if len(self.axis) > 1:
|
|
102
115
|
return tf.exp(
|
|
@@ -109,6 +122,8 @@ class Softmax(Layer):
|
|
|
109
122
|
|
|
110
123
|
def get_config(self):
|
|
111
124
|
config = {"axis": self.axis}
|
|
125
|
+
if self.robust_masking:
|
|
126
|
+
config["robust_masking"] = True
|
|
112
127
|
base_config = super().get_config()
|
|
113
128
|
return dict(list(base_config.items()) + list(config.items()))
|
|
114
129
|
|
|
@@ -198,6 +198,8 @@ class MultiHeadAttention(Layer):
|
|
|
198
198
|
activity_regularizer: Regularizer for dense layer activity.
|
|
199
199
|
kernel_constraint: Constraint for dense layer kernels.
|
|
200
200
|
bias_constraint: Constraint for dense layer kernels.
|
|
201
|
+
softmax_robust_masking: If true will use a more numerically robust
|
|
202
|
+
masking impl.
|
|
201
203
|
|
|
202
204
|
Call arguments:
|
|
203
205
|
query: Query `Tensor` of shape `(B, T, dim)`.
|
|
@@ -247,6 +249,7 @@ class MultiHeadAttention(Layer):
|
|
|
247
249
|
activity_regularizer=None,
|
|
248
250
|
kernel_constraint=None,
|
|
249
251
|
bias_constraint=None,
|
|
252
|
+
softmax_robust_masking=False,
|
|
250
253
|
**kwargs,
|
|
251
254
|
):
|
|
252
255
|
super().__init__(**kwargs)
|
|
@@ -264,6 +267,7 @@ class MultiHeadAttention(Layer):
|
|
|
264
267
|
self._activity_regularizer = regularizers.get(activity_regularizer)
|
|
265
268
|
self._kernel_constraint = constraints.get(kernel_constraint)
|
|
266
269
|
self._bias_constraint = constraints.get(bias_constraint)
|
|
270
|
+
self._softmax_robust_masking = softmax_robust_masking
|
|
267
271
|
if attention_axes is not None and not isinstance(
|
|
268
272
|
attention_axes, collections.abc.Sized
|
|
269
273
|
):
|
|
@@ -298,6 +302,7 @@ class MultiHeadAttention(Layer):
|
|
|
298
302
|
"query_shape": self._query_shape,
|
|
299
303
|
"key_shape": self._key_shape,
|
|
300
304
|
"value_shape": self._value_shape,
|
|
305
|
+
"softmax_robust_masking": self._softmax_robust_masking,
|
|
301
306
|
}
|
|
302
307
|
base_config = super().get_config()
|
|
303
308
|
return dict(list(base_config.items()) + list(config.items()))
|
|
@@ -476,7 +481,9 @@ class MultiHeadAttention(Layer):
|
|
|
476
481
|
)
|
|
477
482
|
)
|
|
478
483
|
self._softmax = activation.Softmax(
|
|
479
|
-
axis=norm_axes,
|
|
484
|
+
axis=norm_axes,
|
|
485
|
+
robust_masking=self._softmax_robust_masking,
|
|
486
|
+
dtype=self._dtype_policy,
|
|
480
487
|
)
|
|
481
488
|
self._dropout_layer = regularization.Dropout(
|
|
482
489
|
rate=self._dropout, dtype=self._dtype_policy
|
|
@@ -259,6 +259,10 @@ class TFOpLambda(Layer):
|
|
|
259
259
|
|
|
260
260
|
self._call_spec.expects_training_arg = False
|
|
261
261
|
self._call_spec.expects_mask_arg = False
|
|
262
|
+
# Clear the call-context arguments for the layer's call method.
|
|
263
|
+
# Otherwise, Keras ends up injecting context arguments into the op-call
|
|
264
|
+
# when the call method accepts kwargs.
|
|
265
|
+
self._call_spec._expected_context_args.clear()
|
|
262
266
|
|
|
263
267
|
def _call_wrapper(self, *args, **kwargs):
|
|
264
268
|
created_variables = []
|
|
@@ -95,7 +95,7 @@ class SpectralNormalization(Wrapper):
|
|
|
95
95
|
|
|
96
96
|
def call(self, inputs, training=False):
|
|
97
97
|
if training:
|
|
98
|
-
self.
|
|
98
|
+
self._update_weights()
|
|
99
99
|
|
|
100
100
|
output = self.layer(inputs)
|
|
101
101
|
return output
|
|
@@ -105,35 +105,42 @@ class SpectralNormalization(Wrapper):
|
|
|
105
105
|
self.layer.compute_output_shape(input_shape).as_list()
|
|
106
106
|
)
|
|
107
107
|
|
|
108
|
+
def _update_weights(self):
|
|
109
|
+
weights = self.kernel
|
|
110
|
+
vector_u = self.vector_u
|
|
111
|
+
|
|
112
|
+
kernel_weights, vector_u = tf.cond(
|
|
113
|
+
tf.reduce_all(tf.equal(weights, 0)),
|
|
114
|
+
lambda: (weights, vector_u),
|
|
115
|
+
lambda: self.normalize_weights(),
|
|
116
|
+
)
|
|
117
|
+
self.kernel.assign(kernel_weights)
|
|
118
|
+
self.vector_u.assign(vector_u)
|
|
119
|
+
|
|
108
120
|
def normalize_weights(self):
|
|
109
121
|
"""Generate spectral normalized weights.
|
|
110
122
|
|
|
111
123
|
This method will update the value of `self.kernel` with the
|
|
112
124
|
spectral normalized value, so that the layer is ready for `call()`.
|
|
113
125
|
"""
|
|
114
|
-
|
|
115
|
-
weights = tf.reshape(self.kernel, [-1, self.kernel_shape[-1]])
|
|
126
|
+
# Initialize vector_v to hint the compiler it always exist.
|
|
116
127
|
vector_u = self.vector_u
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
tf.matmul(vector_u, weights, transpose_b=True)
|
|
123
|
-
)
|
|
124
|
-
vector_u = tf.math.l2_normalize(tf.matmul(vector_v, weights))
|
|
125
|
-
vector_u = tf.stop_gradient(vector_u)
|
|
126
|
-
vector_v = tf.stop_gradient(vector_v)
|
|
127
|
-
sigma = tf.matmul(
|
|
128
|
-
tf.matmul(vector_v, weights), vector_u, transpose_b=True
|
|
129
|
-
)
|
|
130
|
-
self.vector_u.assign(tf.cast(vector_u, self.vector_u.dtype))
|
|
131
|
-
self.kernel.assign(
|
|
132
|
-
tf.cast(
|
|
133
|
-
tf.reshape(self.kernel / sigma, self.kernel_shape),
|
|
134
|
-
self.kernel.dtype,
|
|
135
|
-
)
|
|
128
|
+
vector_v = self.vector_u
|
|
129
|
+
weights = tf.reshape(self.kernel, [-1, self.kernel_shape[-1]])
|
|
130
|
+
for _ in range(self.power_iterations):
|
|
131
|
+
vector_v = tf.math.l2_normalize(
|
|
132
|
+
tf.matmul(vector_u, weights, transpose_b=True)
|
|
136
133
|
)
|
|
134
|
+
vector_u = tf.math.l2_normalize(tf.matmul(vector_v, weights))
|
|
135
|
+
vector_u = tf.stop_gradient(vector_u)
|
|
136
|
+
vector_v = tf.stop_gradient(vector_v)
|
|
137
|
+
sigma = tf.matmul(
|
|
138
|
+
tf.matmul(vector_v, weights),
|
|
139
|
+
vector_u,
|
|
140
|
+
transpose_b=True,
|
|
141
|
+
)
|
|
142
|
+
weights_normalized = tf.reshape(weights / sigma, self.kernel_shape)
|
|
143
|
+
return weights_normalized, vector_u
|
|
137
144
|
|
|
138
145
|
def get_config(self):
|
|
139
146
|
config = {"power_iterations": self.power_iterations}
|
|
@@ -52,9 +52,21 @@ class _RNNCellWrapper(AbstractRNNCell):
|
|
|
52
52
|
super().__init__(*args, **kwargs)
|
|
53
53
|
self.cell = cell
|
|
54
54
|
cell_call_spec = tf_inspect.getfullargspec(cell.call)
|
|
55
|
+
accepts_kwargs = cell_call_spec.varkw is not None
|
|
56
|
+
|
|
55
57
|
self._call_spec.expects_training_arg = (
|
|
56
58
|
"training" in cell_call_spec.args
|
|
57
|
-
) or
|
|
59
|
+
) or accepts_kwargs
|
|
60
|
+
|
|
61
|
+
# Filter _expects_context_arg. An argument is kept if:
|
|
62
|
+
# 1. It's an explicit argument in cell_call_spec.args OR
|
|
63
|
+
# 2. The cell accepts arbitrary keyword arguments (**kwargs),
|
|
64
|
+
# meaning it could potentially handle the context argument.
|
|
65
|
+
self._call_spec._expected_context_args = {
|
|
66
|
+
arg
|
|
67
|
+
for arg in self._call_spec._expected_context_args
|
|
68
|
+
if (arg in cell_call_spec.args) or accepts_kwargs
|
|
69
|
+
}
|
|
58
70
|
|
|
59
71
|
def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
|
|
60
72
|
"""Calls the wrapped cell and performs the wrapping logic.
|