tf-keras-nightly 2.19.0.dev2024121210__py3-none-any.whl → 2.21.0.dev2025123010__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tf_keras/__init__.py +1 -1
  2. tf_keras/protobuf/projector_config_pb2.py +23 -12
  3. tf_keras/protobuf/saved_metadata_pb2.py +21 -10
  4. tf_keras/protobuf/versions_pb2.py +19 -8
  5. tf_keras/src/__init__.py +1 -1
  6. tf_keras/src/backend.py +1 -1
  7. tf_keras/src/datasets/boston_housing.py +14 -5
  8. tf_keras/src/datasets/cifar10.py +9 -1
  9. tf_keras/src/datasets/cifar100.py +7 -1
  10. tf_keras/src/datasets/fashion_mnist.py +16 -4
  11. tf_keras/src/datasets/imdb.py +8 -0
  12. tf_keras/src/datasets/mnist.py +9 -3
  13. tf_keras/src/datasets/reuters.py +8 -0
  14. tf_keras/src/engine/base_layer.py +235 -97
  15. tf_keras/src/engine/base_layer_utils.py +17 -5
  16. tf_keras/src/engine/base_layer_v1.py +12 -3
  17. tf_keras/src/engine/data_adapter.py +35 -19
  18. tf_keras/src/engine/functional.py +36 -15
  19. tf_keras/src/engine/input_layer.py +9 -0
  20. tf_keras/src/engine/input_spec.py +11 -1
  21. tf_keras/src/engine/sequential.py +29 -12
  22. tf_keras/src/layers/activation/softmax.py +26 -11
  23. tf_keras/src/layers/attention/multi_head_attention.py +8 -1
  24. tf_keras/src/layers/core/tf_op_layer.py +4 -0
  25. tf_keras/src/layers/normalization/spectral_normalization.py +29 -22
  26. tf_keras/src/layers/rnn/cell_wrappers.py +13 -1
  27. tf_keras/src/metrics/confusion_metrics.py +51 -4
  28. tf_keras/src/models/sharpness_aware_minimization.py +17 -7
  29. tf_keras/src/preprocessing/sequence.py +2 -2
  30. tf_keras/src/saving/legacy/saved_model/save_impl.py +28 -12
  31. tf_keras/src/saving/legacy/saving_utils.py +14 -2
  32. tf_keras/src/saving/saving_api.py +18 -5
  33. tf_keras/src/saving/saving_lib.py +1 -1
  34. tf_keras/src/utils/layer_utils.py +45 -3
  35. tf_keras/src/utils/metrics_utils.py +4 -1
  36. tf_keras/src/utils/tf_utils.py +2 -2
  37. {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/METADATA +14 -3
  38. {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/RECORD +40 -62
  39. {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/WHEEL +1 -1
  40. tf_keras/src/layers/preprocessing/benchmarks/bucketized_column_dense_benchmark.py +0 -85
  41. tf_keras/src/layers/preprocessing/benchmarks/category_encoding_benchmark.py +0 -84
  42. tf_keras/src/layers/preprocessing/benchmarks/category_hash_dense_benchmark.py +0 -89
  43. tf_keras/src/layers/preprocessing/benchmarks/category_hash_varlen_benchmark.py +0 -89
  44. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_dense_benchmark.py +0 -110
  45. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_varlen_benchmark.py +0 -103
  46. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_dense_benchmark.py +0 -87
  47. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_dense_benchmark.py +0 -96
  48. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_varlen_benchmark.py +0 -96
  49. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_varlen_benchmark.py +0 -87
  50. tf_keras/src/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py +0 -109
  51. tf_keras/src/layers/preprocessing/benchmarks/embedding_dense_benchmark.py +0 -86
  52. tf_keras/src/layers/preprocessing/benchmarks/embedding_varlen_benchmark.py +0 -89
  53. tf_keras/src/layers/preprocessing/benchmarks/hashed_crossing_benchmark.py +0 -90
  54. tf_keras/src/layers/preprocessing/benchmarks/hashing_benchmark.py +0 -105
  55. tf_keras/src/layers/preprocessing/benchmarks/image_preproc_benchmark.py +0 -159
  56. tf_keras/src/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py +0 -135
  57. tf_keras/src/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py +0 -144
  58. tf_keras/src/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py +0 -124
  59. tf_keras/src/layers/preprocessing/benchmarks/weighted_embedding_varlen_benchmark.py +0 -99
  60. tf_keras/src/saving/legacy/saved_model/create_test_saved_model.py +0 -37
  61. tf_keras/src/tests/keras_doctest.py +0 -159
  62. {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/top_level.txt +0 -0
@@ -231,7 +231,7 @@ class TensorLikeDataAdapter(DataAdapter):
231
231
  return True
232
232
  return False
233
233
 
234
- return all(_is_tensor(v) for v in flat_inputs)
234
+ return all(_is_tensor(v) for v in flat_inputs if v is not None)
235
235
 
236
236
  def __init__(
237
237
  self,
@@ -259,7 +259,7 @@ class TensorLikeDataAdapter(DataAdapter):
259
259
  inputs = pack_x_y_sample_weight(x, y, sample_weights)
260
260
 
261
261
  num_samples = set(
262
- int(i.shape[0]) for i in tf.nest.flatten(inputs)
262
+ int(i.shape[0]) for i in tf.nest.flatten(inputs) if i is not None
263
263
  ).pop()
264
264
  _check_data_cardinality(inputs)
265
265
 
@@ -386,7 +386,7 @@ class TensorLikeDataAdapter(DataAdapter):
386
386
 
387
387
  def grab_batch(i, data):
388
388
  return tf.nest.map_structure(
389
- lambda d: tf.gather(d, i, axis=0), data
389
+ lambda d: tf.gather(d, i, axis=0) if d is not None else d, data
390
390
  )
391
391
 
392
392
  dataset = dataset.map(grab_batch, num_parallel_calls=tf.data.AUTOTUNE)
@@ -459,7 +459,7 @@ class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
459
459
  if not TensorLikeDataAdapter.can_handle(
460
460
  x, y
461
461
  ) and not CompositeTensorDataAdapter.can_handle(x, y):
462
- return all(_is_array_like(v) for v in flat_inputs)
462
+ return all(_is_array_like(v) for v in flat_inputs if v is not None)
463
463
  else:
464
464
  return False
465
465
 
@@ -496,7 +496,7 @@ class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
496
496
  shape[0] = None
497
497
  return tuple(shape)
498
498
 
499
- flat_dtypes = [inp.dtype for inp in flat_inputs]
499
+ flat_dtypes = [inp.dtype for inp in flat_inputs if inp is not None]
500
500
  contiguous = True
501
501
  if self._shuffle and self._shuffle != "batch":
502
502
  contiguous = False
@@ -509,15 +509,26 @@ class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
509
509
  # to a Tensor may force it into memory..
510
510
  def py_method(ind):
511
511
  def slice_array(data):
512
+ if data is None:
513
+ return None
512
514
  return training_utils.slice_arrays(
513
515
  data, ind.numpy(), contiguous=contiguous
514
516
  )
515
517
 
516
- return [slice_array(inp) for inp in flat_inputs]
518
+ return [
519
+ slice_array(inp) for inp in flat_inputs if inp is not None
520
+ ]
517
521
 
518
- flat_out = tf.py_function(py_method, [indices], flat_dtypes)
519
- for v, original_inp in zip(flat_out, flat_inputs):
520
- v.set_shape(dynamic_shape_like(original_inp))
522
+ results = tf.py_function(py_method, [indices], flat_dtypes)
523
+ results_it = iter(results)
524
+ flat_out = []
525
+ for original_inp in flat_inputs:
526
+ if original_inp is None:
527
+ flat_out.append(None)
528
+ else:
529
+ v = next(results_it)
530
+ v.set_shape(dynamic_shape_like(original_inp))
531
+ flat_out.append(v)
521
532
  return tf.nest.pack_sequence_as(inputs, flat_out)
522
533
 
523
534
  dataset = indices_dataset.map(
@@ -608,8 +619,10 @@ class CompositeTensorDataAdapter(DataAdapter):
608
619
  return True
609
620
  return _is_composite(v)
610
621
 
611
- return any(_is_composite(v) for v in flat_inputs) and all(
612
- _is_tensor_or_composite(v) for v in flat_inputs
622
+ return any(
623
+ _is_composite(v) for v in flat_inputs if v is not None
624
+ ) and all(
625
+ _is_tensor_or_composite(v) for v in flat_inputs if v is not None
613
626
  )
614
627
 
615
628
  def __init__(
@@ -794,8 +807,7 @@ class DatasetAdapter(DataAdapter):
794
807
  # each epoch.
795
808
  return (
796
809
  self._user_steps is None
797
- or tf.data.experimental.cardinality(self._dataset).numpy()
798
- == self._user_steps
810
+ or self._dataset.cardinality().numpy() == self._user_steps
799
811
  )
800
812
 
801
813
  def _validate_args(self, y, sample_weights, steps, pss_evaluation_shards):
@@ -819,8 +831,8 @@ class DatasetAdapter(DataAdapter):
819
831
  "specify the number of steps to run."
820
832
  )
821
833
  else:
822
- size = tf.data.experimental.cardinality(self._dataset).numpy()
823
- if size == tf.data.experimental.INFINITE_CARDINALITY:
834
+ size = self._dataset.cardinality().numpy()
835
+ if size == tf.data.INFINITE_CARDINALITY:
824
836
  if pss_evaluation_shards:
825
837
  raise ValueError(
826
838
  "When performing exact evaluation, the dataset "
@@ -1481,8 +1493,8 @@ class DataHandler:
1481
1493
  if not isinstance(dataset, tf.data.Dataset):
1482
1494
  return None
1483
1495
 
1484
- size = tf.data.experimental.cardinality(dataset)
1485
- if size == tf.data.experimental.INFINITE_CARDINALITY and steps is None:
1496
+ size = dataset.cardinality()
1497
+ if size == tf.data.INFINITE_CARDINALITY and steps is None:
1486
1498
  raise ValueError(
1487
1499
  "When passing an infinitely repeating dataset, please specify "
1488
1500
  "a `steps_per_epoch` value so that epoch level "
@@ -1945,14 +1957,18 @@ def single_batch_iterator(
1945
1957
 
1946
1958
 
1947
1959
  def _check_data_cardinality(data):
1948
- num_samples = set(int(i.shape[0]) for i in tf.nest.flatten(data))
1960
+ num_samples = set(
1961
+ int(i.shape[0]) for i in tf.nest.flatten(data) if i is not None
1962
+ )
1949
1963
  if len(num_samples) > 1:
1950
1964
  msg = "Data cardinality is ambiguous:\n"
1951
1965
  for label, single_data in zip(["x", "y", "sample_weight"], data):
1952
1966
  msg += " {} sizes: {}\n".format(
1953
1967
  label,
1954
1968
  ", ".join(
1955
- str(i.shape[0]) for i in tf.nest.flatten(single_data)
1969
+ str(i.shape[0])
1970
+ for i in tf.nest.flatten(single_data)
1971
+ if i is not None
1956
1972
  ),
1957
1973
  )
1958
1974
  msg += "Make sure all arrays contain the same number of samples."
@@ -351,25 +351,45 @@ class Functional(training_lib.Model):
351
351
  if isinstance(self._nested_inputs, dict):
352
352
  # Case where `_nested_inputs` is a plain dict of Inputs.
353
353
  names = sorted(self._nested_inputs.keys())
354
- return [
355
- input_spec.InputSpec(
356
- shape=shape_with_no_batch_size(self._nested_inputs[name]),
357
- allow_last_axis_squeeze=True,
358
- name=name,
354
+ specs = []
355
+ for name in names:
356
+ layer = self._nested_inputs[name]._keras_history.layer
357
+ optional = (
358
+ layer.optional
359
+ if isinstance(layer, input_layer_module.InputLayer)
360
+ else False
359
361
  )
360
- for name in names
361
- ]
362
+ specs.append(
363
+ input_spec.InputSpec(
364
+ shape=shape_with_no_batch_size(
365
+ self._nested_inputs[name]
366
+ ),
367
+ allow_last_axis_squeeze=True,
368
+ name=name,
369
+ optional=optional,
370
+ )
371
+ )
372
+ return specs
362
373
  else:
363
374
  # Single input, or list / tuple of inputs.
364
375
  # The data may be passed as a dict keyed by input name.
365
- return [
366
- input_spec.InputSpec(
367
- shape=shape_with_no_batch_size(x),
368
- allow_last_axis_squeeze=True,
369
- name=x._keras_history.layer.name,
376
+ specs = []
377
+ for x in self.inputs:
378
+ layer = x._keras_history.layer
379
+ optional = (
380
+ layer.optional
381
+ if isinstance(layer, input_layer_module.InputLayer)
382
+ else False
370
383
  )
371
- for x in self.inputs
372
- ]
384
+ specs.append(
385
+ input_spec.InputSpec(
386
+ shape=shape_with_no_batch_size(x),
387
+ allow_last_axis_squeeze=True,
388
+ name=x._keras_history.layer.name,
389
+ optional=optional,
390
+ )
391
+ )
392
+ return specs
373
393
 
374
394
  @input_spec.setter
375
395
  def input_spec(self, value):
@@ -644,7 +664,8 @@ class Functional(training_lib.Model):
644
664
  else:
645
665
  masks = self._flatten_to_reference_inputs(mask)
646
666
  for input_t, mask in zip(inputs, masks):
647
- input_t._keras_mask = mask
667
+ if input_t is not None:
668
+ input_t._keras_mask = mask
648
669
 
649
670
  # Dictionary mapping reference tensors to computed tensors.
650
671
  tensor_dict = {}
@@ -98,6 +98,8 @@ class InputLayer(base_layer.Layer):
98
98
  `tf.TypeSpec` represents the entire batch. When provided, all other
99
99
  args except name must be `None`.
100
100
  name: Optional name of the layer (string).
101
+ optional: Boolean, whether the input is optional or not. An optional
102
+ input can accept `None` values.
101
103
  """
102
104
 
103
105
  @traceback_utils.filter_traceback
@@ -111,6 +113,7 @@ class InputLayer(base_layer.Layer):
111
113
  name=None,
112
114
  ragged=None,
113
115
  type_spec=None,
116
+ optional=False,
114
117
  **kwargs,
115
118
  ):
116
119
  self._init_input_shape = input_shape
@@ -180,6 +183,7 @@ class InputLayer(base_layer.Layer):
180
183
  self.ragged = True if ragged else False
181
184
  self.batch_size = batch_size
182
185
  self.supports_masking = True
186
+ self.optional = optional
183
187
 
184
188
  if isinstance(input_shape, tf.TensorShape):
185
189
  input_shape = tuple(input_shape.as_list())
@@ -284,6 +288,7 @@ class InputLayer(base_layer.Layer):
284
288
  "sparse": self.sparse,
285
289
  "ragged": self.ragged,
286
290
  "name": self.name,
291
+ "optional": self.optional,
287
292
  }
288
293
  return config
289
294
 
@@ -303,6 +308,7 @@ def Input(
303
308
  tensor=None,
304
309
  ragged=None,
305
310
  type_spec=None,
311
+ optional=False,
306
312
  **kwargs,
307
313
  ):
308
314
  """`Input()` is used to instantiate a TF-Keras tensor.
@@ -341,6 +347,8 @@ def Input(
341
347
  [this guide](https://www.tensorflow.org/guide/ragged_tensor).
342
348
  type_spec: A `tf.TypeSpec` object to create the input placeholder from.
343
349
  When provided, all other args except name must be None.
350
+ optional: Boolean, whether the input is optional or not. An optional
351
+ input can accept `None` values.
344
352
  **kwargs: deprecated arguments support. Supports `batch_shape` and
345
353
  `batch_input_shape`.
346
354
 
@@ -415,6 +423,7 @@ def Input(
415
423
  "ragged": ragged,
416
424
  "input_tensor": tensor,
417
425
  "type_spec": type_spec,
426
+ "optional": optional,
418
427
  }
419
428
 
420
429
  batch_input_shape = kwargs.pop(
@@ -56,6 +56,8 @@ class InputSpec:
56
56
  as long as the last axis of the spec is 1.
57
57
  name: Expected key corresponding to this input when passing data as
58
58
  a dictionary.
59
+ optional: Boolean, whether the input is optional or not. An optional input
60
+ can accept `None` values.
59
61
 
60
62
  Example:
61
63
 
@@ -82,6 +84,7 @@ class InputSpec:
82
84
  axes=None,
83
85
  allow_last_axis_squeeze=False,
84
86
  name=None,
87
+ optional=False,
85
88
  ):
86
89
  self.dtype = tf.as_dtype(dtype).name if dtype is not None else None
87
90
  shape = tf.TensorShape(shape)
@@ -99,6 +102,7 @@ class InputSpec:
99
102
  self.min_ndim = min_ndim
100
103
  self.name = name
101
104
  self.allow_last_axis_squeeze = allow_last_axis_squeeze
105
+ self.optional = optional
102
106
  try:
103
107
  axes = axes or {}
104
108
  self.axes = {int(k): axes[k] for k in axes}
@@ -204,7 +208,11 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
204
208
  inputs = list_inputs
205
209
 
206
210
  inputs = tf.nest.flatten(inputs)
207
- for x in inputs:
211
+ for _, (x, spec) in enumerate(zip(inputs, input_spec)):
212
+ if spec is None:
213
+ continue
214
+ if x is None and spec.optional:
215
+ continue
208
216
  # Having a shape/dtype is the only commonality of the various
209
217
  # tensor-like objects that may be passed. The most common kind of
210
218
  # invalid type we are guarding for is a Layer instance (Functional API),
@@ -224,6 +232,8 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
224
232
  for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
225
233
  if spec is None:
226
234
  continue
235
+ if x is None and spec.optional:
236
+ continue
227
237
 
228
238
  shape = tf.TensorShape(x.shape)
229
239
  if shape.rank is None:
@@ -285,12 +285,16 @@ class Sequential(functional.Functional):
285
285
  ):
286
286
  # Determine whether the input shape is novel, i.e. whether the model
287
287
  # should be rebuilt.
288
- input_shape = tuple(input_shape)
288
+ input_shape = tf_utils.convert_shapes(input_shape)
289
289
  if self._inferred_input_shape is None:
290
290
  new_shape = input_shape
291
291
  else:
292
- new_shape = relax_input_shape(
293
- self._inferred_input_shape, input_shape
292
+ new_shape = tf.nest.map_structure(
293
+ _relax_input_shape,
294
+ tf_utils.convert_shapes(
295
+ self._inferred_input_shape, to_tuples=False
296
+ ),
297
+ tf_utils.convert_shapes(input_shape, to_tuples=False),
294
298
  )
295
299
  if (
296
300
  new_shape is not None
@@ -299,10 +303,13 @@ class Sequential(functional.Functional):
299
303
  # A novel shape has been received: we need to rebuild the model.
300
304
  # In case we are inside a graph function, we step out of it.
301
305
  with tf.init_scope():
302
- inputs = input_layer.Input(
303
- batch_shape=new_shape,
304
- dtype=input_dtype,
305
- name=self.layers[0].name + "_input",
306
+ inputs = tf.nest.map_structure(
307
+ lambda s: input_layer.Input(
308
+ batch_shape=tf_utils.convert_shapes(s),
309
+ dtype=input_dtype,
310
+ name=self.layers[0].name + "_input",
311
+ ),
312
+ tf_utils.convert_shapes(new_shape, to_tuples=False),
306
313
  )
307
314
  layer_input = inputs
308
315
  created_nodes = set()
@@ -370,7 +377,7 @@ class Sequential(functional.Functional):
370
377
  raise ValueError("You must provide an `input_shape` argument.")
371
378
  self._build_graph_network_for_inferred_shape(input_shape)
372
379
  if not self.built:
373
- input_shape = tuple(input_shape)
380
+ input_shape = tf_utils.convert_shapes(input_shape)
374
381
  self._build_input_shape = input_shape
375
382
  super().build(input_shape)
376
383
  self.built = True
@@ -435,7 +442,8 @@ class Sequential(functional.Functional):
435
442
  def get_config(self):
436
443
  layer_configs = []
437
444
  serialize_obj_fn = serialization_lib.serialize_keras_object
438
- if getattr(self, "use_legacy_config", None):
445
+ use_legacy_config = getattr(self, "use_legacy_config", False)
446
+ if use_legacy_config:
439
447
  serialize_obj_fn = legacy_serialization.serialize_keras_object
440
448
  for layer in super().layers:
441
449
  # `super().layers` include the InputLayer if available (it is
@@ -446,7 +454,11 @@ class Sequential(functional.Functional):
446
454
  config = training.Model.get_config(self)
447
455
  config["name"] = self.name
448
456
  config["layers"] = copy.deepcopy(layer_configs)
449
- if not self._is_graph_network and self._build_input_shape is not None:
457
+ if (
458
+ use_legacy_config
459
+ and not self._is_graph_network
460
+ and self._build_input_shape
461
+ ):
450
462
  config["build_input_shape"] = self._build_input_shape
451
463
  return config
452
464
 
@@ -458,6 +470,7 @@ class Sequential(functional.Functional):
458
470
  layer_configs = config["layers"]
459
471
  else:
460
472
  name = None
473
+ build_input_shape = None
461
474
  layer_configs = config
462
475
  model = cls(name=name)
463
476
  for layer_config in layer_configs:
@@ -519,11 +532,15 @@ def _get_shape_tuple(t):
519
532
  return None
520
533
 
521
534
 
522
- def relax_input_shape(shape_1, shape_2):
535
+ def _relax_input_shape(shape_1, shape_2):
523
536
  if shape_1 is None or shape_2 is None:
524
537
  return None
525
- if len(shape_1) != len(shape_2):
538
+ if shape_1.rank is None or shape_2.rank is None:
539
+ return None
540
+ if shape_1.rank != shape_2.rank:
526
541
  return None
542
+ shape_1 = shape_1.as_list()
543
+ shape_2 = shape_2.as_list()
527
544
  return tuple(None if d1 != d2 else d1 for d1, d2 in zip(shape_1, shape_2))
528
545
 
529
546
 
@@ -70,6 +70,8 @@ class Softmax(Layer):
70
70
  Args:
71
71
  axis: Integer, or list of Integers, axis along which the softmax
72
72
  normalization is applied.
73
+ robust_masking: Bool, if true will use a more robust implementation when
74
+ dealing with masks.
73
75
  Call arguments:
74
76
  inputs: The inputs, or logits to the softmax layer.
75
77
  mask: A boolean mask of the same shape as `inputs`. The mask
@@ -80,23 +82,34 @@ class Softmax(Layer):
80
82
  Softmaxed output with the same shape as `inputs`.
81
83
  """
82
84
 
83
- def __init__(self, axis=-1, **kwargs):
85
+ def __init__(self, axis=-1, robust_masking=False, **kwargs):
84
86
  super().__init__(**kwargs)
85
87
  self.supports_masking = True
88
+ self.robust_masking = robust_masking
86
89
  self.axis = axis
87
90
 
88
91
  def call(self, inputs, mask=None):
89
92
  if mask is not None:
90
- # Since mask is 1.0 for positions we want to keep and 0.0 for masked
91
- # positions, this operation will create a tensor which is 0.0 for
92
- # positions we want to attend and -1e.9 for masked positions.
93
- adder = (1.0 - tf.cast(mask, inputs.dtype)) * (
94
- _large_compatible_negative(inputs.dtype)
95
- )
96
-
97
- # Since we are adding it to the raw scores before the softmax, this
98
- # is effectively the same as removing these entirely.
99
- inputs += adder
93
+ if self.robust_masking:
94
+ # We keep the positions where the mask is True or > 0.5, and set
95
+ # the other (masked) positions to -1e.9.
96
+ if mask.dtype is not tf.bool:
97
+ mask = tf.greater(mask, tf.constant(0.5, dtype=mask.dtype))
98
+ inputs = tf.where(
99
+ mask, inputs, _large_compatible_negative(inputs.dtype)
100
+ )
101
+ else:
102
+ # Since mask is 1.0 for positions we want to keep and 0.0 for
103
+ # masked positions, this operation will create a tensor which is
104
+ # 0.0 for positions we want to attend and -1e.9 for masked
105
+ # positions.
106
+ adder = (1.0 - tf.cast(mask, inputs.dtype)) * (
107
+ _large_compatible_negative(inputs.dtype)
108
+ )
109
+
110
+ # Since we are adding it to the raw scores before the softmax,
111
+ # this is effectively the same as removing these entirely.
112
+ inputs += adder
100
113
  if isinstance(self.axis, (tuple, list)):
101
114
  if len(self.axis) > 1:
102
115
  return tf.exp(
@@ -109,6 +122,8 @@ class Softmax(Layer):
109
122
 
110
123
  def get_config(self):
111
124
  config = {"axis": self.axis}
125
+ if self.robust_masking:
126
+ config["robust_masking"] = True
112
127
  base_config = super().get_config()
113
128
  return dict(list(base_config.items()) + list(config.items()))
114
129
 
@@ -198,6 +198,8 @@ class MultiHeadAttention(Layer):
198
198
  activity_regularizer: Regularizer for dense layer activity.
199
199
  kernel_constraint: Constraint for dense layer kernels.
200
200
  bias_constraint: Constraint for dense layer kernels.
201
+ softmax_robust_masking: If true will use a more numerically robust
202
+ masking impl.
201
203
 
202
204
  Call arguments:
203
205
  query: Query `Tensor` of shape `(B, T, dim)`.
@@ -247,6 +249,7 @@ class MultiHeadAttention(Layer):
247
249
  activity_regularizer=None,
248
250
  kernel_constraint=None,
249
251
  bias_constraint=None,
252
+ softmax_robust_masking=False,
250
253
  **kwargs,
251
254
  ):
252
255
  super().__init__(**kwargs)
@@ -264,6 +267,7 @@ class MultiHeadAttention(Layer):
264
267
  self._activity_regularizer = regularizers.get(activity_regularizer)
265
268
  self._kernel_constraint = constraints.get(kernel_constraint)
266
269
  self._bias_constraint = constraints.get(bias_constraint)
270
+ self._softmax_robust_masking = softmax_robust_masking
267
271
  if attention_axes is not None and not isinstance(
268
272
  attention_axes, collections.abc.Sized
269
273
  ):
@@ -298,6 +302,7 @@ class MultiHeadAttention(Layer):
298
302
  "query_shape": self._query_shape,
299
303
  "key_shape": self._key_shape,
300
304
  "value_shape": self._value_shape,
305
+ "softmax_robust_masking": self._softmax_robust_masking,
301
306
  }
302
307
  base_config = super().get_config()
303
308
  return dict(list(base_config.items()) + list(config.items()))
@@ -476,7 +481,9 @@ class MultiHeadAttention(Layer):
476
481
  )
477
482
  )
478
483
  self._softmax = activation.Softmax(
479
- axis=norm_axes, dtype=self._dtype_policy
484
+ axis=norm_axes,
485
+ robust_masking=self._softmax_robust_masking,
486
+ dtype=self._dtype_policy,
480
487
  )
481
488
  self._dropout_layer = regularization.Dropout(
482
489
  rate=self._dropout, dtype=self._dtype_policy
@@ -259,6 +259,10 @@ class TFOpLambda(Layer):
259
259
 
260
260
  self._call_spec.expects_training_arg = False
261
261
  self._call_spec.expects_mask_arg = False
262
+ # Clear the call-context arguments for the layer's call method.
263
+ # Otherwise, Keras ends up injecting context arguments into the op-call
264
+ # when the call method accepts kwargs.
265
+ self._call_spec._expected_context_args.clear()
262
266
 
263
267
  def _call_wrapper(self, *args, **kwargs):
264
268
  created_variables = []
@@ -95,7 +95,7 @@ class SpectralNormalization(Wrapper):
95
95
 
96
96
  def call(self, inputs, training=False):
97
97
  if training:
98
- self.normalize_weights()
98
+ self._update_weights()
99
99
 
100
100
  output = self.layer(inputs)
101
101
  return output
@@ -105,35 +105,42 @@ class SpectralNormalization(Wrapper):
105
105
  self.layer.compute_output_shape(input_shape).as_list()
106
106
  )
107
107
 
108
+ def _update_weights(self):
109
+ weights = self.kernel
110
+ vector_u = self.vector_u
111
+
112
+ kernel_weights, vector_u = tf.cond(
113
+ tf.reduce_all(tf.equal(weights, 0)),
114
+ lambda: (weights, vector_u),
115
+ lambda: self.normalize_weights(),
116
+ )
117
+ self.kernel.assign(kernel_weights)
118
+ self.vector_u.assign(vector_u)
119
+
108
120
  def normalize_weights(self):
109
121
  """Generate spectral normalized weights.
110
122
 
111
123
  This method will update the value of `self.kernel` with the
112
124
  spectral normalized value, so that the layer is ready for `call()`.
113
125
  """
114
-
115
- weights = tf.reshape(self.kernel, [-1, self.kernel_shape[-1]])
126
+ # Initialize vector_v to hint the compiler it always exist.
116
127
  vector_u = self.vector_u
117
-
118
- # check for zeroes weights
119
- if not tf.reduce_all(tf.equal(weights, 0.0)):
120
- for _ in range(self.power_iterations):
121
- vector_v = tf.math.l2_normalize(
122
- tf.matmul(vector_u, weights, transpose_b=True)
123
- )
124
- vector_u = tf.math.l2_normalize(tf.matmul(vector_v, weights))
125
- vector_u = tf.stop_gradient(vector_u)
126
- vector_v = tf.stop_gradient(vector_v)
127
- sigma = tf.matmul(
128
- tf.matmul(vector_v, weights), vector_u, transpose_b=True
129
- )
130
- self.vector_u.assign(tf.cast(vector_u, self.vector_u.dtype))
131
- self.kernel.assign(
132
- tf.cast(
133
- tf.reshape(self.kernel / sigma, self.kernel_shape),
134
- self.kernel.dtype,
135
- )
128
+ vector_v = self.vector_u
129
+ weights = tf.reshape(self.kernel, [-1, self.kernel_shape[-1]])
130
+ for _ in range(self.power_iterations):
131
+ vector_v = tf.math.l2_normalize(
132
+ tf.matmul(vector_u, weights, transpose_b=True)
136
133
  )
134
+ vector_u = tf.math.l2_normalize(tf.matmul(vector_v, weights))
135
+ vector_u = tf.stop_gradient(vector_u)
136
+ vector_v = tf.stop_gradient(vector_v)
137
+ sigma = tf.matmul(
138
+ tf.matmul(vector_v, weights),
139
+ vector_u,
140
+ transpose_b=True,
141
+ )
142
+ weights_normalized = tf.reshape(weights / sigma, self.kernel_shape)
143
+ return weights_normalized, vector_u
137
144
 
138
145
  def get_config(self):
139
146
  config = {"power_iterations": self.power_iterations}
@@ -52,9 +52,21 @@ class _RNNCellWrapper(AbstractRNNCell):
52
52
  super().__init__(*args, **kwargs)
53
53
  self.cell = cell
54
54
  cell_call_spec = tf_inspect.getfullargspec(cell.call)
55
+ accepts_kwargs = cell_call_spec.varkw is not None
56
+
55
57
  self._call_spec.expects_training_arg = (
56
58
  "training" in cell_call_spec.args
57
- ) or (cell_call_spec.varkw is not None)
59
+ ) or accepts_kwargs
60
+
61
+ # Filter _expects_context_arg. An argument is kept if:
62
+ # 1. It's an explicit argument in cell_call_spec.args OR
63
+ # 2. The cell accepts arbitrary keyword arguments (**kwargs),
64
+ # meaning it could potentially handle the context argument.
65
+ self._call_spec._expected_context_args = {
66
+ arg
67
+ for arg in self._call_spec._expected_context_args
68
+ if (arg in cell_call_spec.args) or accepts_kwargs
69
+ }
58
70
 
59
71
  def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
60
72
  """Calls the wrapped cell and performs the wrapping logic.