tf-keras-nightly 2.20.0.dev2025051109__py3-none-any.whl → 2.21.0.dev2025123010__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. tf_keras/__init__.py +1 -1
  2. tf_keras/protobuf/projector_config_pb2.py +23 -12
  3. tf_keras/protobuf/saved_metadata_pb2.py +21 -10
  4. tf_keras/protobuf/versions_pb2.py +19 -8
  5. tf_keras/src/__init__.py +1 -1
  6. tf_keras/src/engine/base_layer.py +234 -96
  7. tf_keras/src/engine/base_layer_utils.py +17 -5
  8. tf_keras/src/engine/base_layer_v1.py +12 -3
  9. tf_keras/src/engine/data_adapter.py +30 -13
  10. tf_keras/src/engine/functional.py +36 -15
  11. tf_keras/src/engine/input_layer.py +9 -0
  12. tf_keras/src/engine/input_spec.py +11 -1
  13. tf_keras/src/layers/activation/softmax.py +26 -11
  14. tf_keras/src/layers/attention/multi_head_attention.py +8 -1
  15. tf_keras/src/layers/core/tf_op_layer.py +4 -0
  16. tf_keras/src/layers/rnn/cell_wrappers.py +13 -1
  17. tf_keras/src/metrics/confusion_metrics.py +51 -4
  18. tf_keras/src/models/sharpness_aware_minimization.py +17 -7
  19. tf_keras/src/saving/legacy/saved_model/save_impl.py +28 -12
  20. tf_keras/src/saving/legacy/saving_utils.py +14 -2
  21. tf_keras/src/saving/saving_lib.py +1 -1
  22. tf_keras/src/utils/layer_utils.py +45 -3
  23. tf_keras/src/utils/metrics_utils.py +4 -1
  24. {tf_keras_nightly-2.20.0.dev2025051109.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/METADATA +2 -2
  25. {tf_keras_nightly-2.20.0.dev2025051109.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/RECORD +27 -49
  26. {tf_keras_nightly-2.20.0.dev2025051109.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/WHEEL +1 -1
  27. tf_keras/src/layers/preprocessing/benchmarks/bucketized_column_dense_benchmark.py +0 -85
  28. tf_keras/src/layers/preprocessing/benchmarks/category_encoding_benchmark.py +0 -84
  29. tf_keras/src/layers/preprocessing/benchmarks/category_hash_dense_benchmark.py +0 -89
  30. tf_keras/src/layers/preprocessing/benchmarks/category_hash_varlen_benchmark.py +0 -89
  31. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_dense_benchmark.py +0 -110
  32. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_varlen_benchmark.py +0 -103
  33. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_dense_benchmark.py +0 -87
  34. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_dense_benchmark.py +0 -96
  35. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_varlen_benchmark.py +0 -96
  36. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_varlen_benchmark.py +0 -87
  37. tf_keras/src/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py +0 -109
  38. tf_keras/src/layers/preprocessing/benchmarks/embedding_dense_benchmark.py +0 -86
  39. tf_keras/src/layers/preprocessing/benchmarks/embedding_varlen_benchmark.py +0 -89
  40. tf_keras/src/layers/preprocessing/benchmarks/hashed_crossing_benchmark.py +0 -90
  41. tf_keras/src/layers/preprocessing/benchmarks/hashing_benchmark.py +0 -105
  42. tf_keras/src/layers/preprocessing/benchmarks/image_preproc_benchmark.py +0 -159
  43. tf_keras/src/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py +0 -135
  44. tf_keras/src/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py +0 -144
  45. tf_keras/src/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py +0 -124
  46. tf_keras/src/layers/preprocessing/benchmarks/weighted_embedding_varlen_benchmark.py +0 -99
  47. tf_keras/src/saving/legacy/saved_model/create_test_saved_model.py +0 -37
  48. tf_keras/src/tests/keras_doctest.py +0 -159
  49. {tf_keras_nightly-2.20.0.dev2025051109.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/top_level.txt +0 -0
@@ -132,6 +132,7 @@ class Layer(base_layer.Layer):
132
132
  self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs
133
133
  ):
134
134
  self._instrument_layer_creation()
135
+ self._called = False
135
136
 
136
137
  # These properties should be set by the user via keyword arguments.
137
138
  # note that 'dtype', 'input_shape' and 'batch_input_shape'
@@ -165,6 +166,8 @@ class Layer(base_layer.Layer):
165
166
  self._input_spec = None
166
167
  self.supports_masking = False
167
168
 
169
+ self._call_context_args = {"training"}
170
+
168
171
  self._init_set_name(name)
169
172
  self._activity_regularizer = regularizers.get(
170
173
  kwargs.pop("activity_regularizer", None)
@@ -705,6 +708,7 @@ class Layer(base_layer.Layer):
705
708
  RuntimeError: if `super().__init__()` was not called in the
706
709
  constructor.
707
710
  """
711
+ self._called = True
708
712
  self._assert_built_as_v1()
709
713
 
710
714
  if not hasattr(self, "_thread_local"):
@@ -803,7 +807,12 @@ class Layer(base_layer.Layer):
803
807
  if build_graph and base_layer_utils.needs_keras_history(inputs):
804
808
  base_layer_utils.create_keras_history(inputs)
805
809
 
806
- with call_context.enter(self, inputs, build_graph, training_value):
810
+ with call_context.enter(
811
+ self,
812
+ inputs,
813
+ build_graph,
814
+ call_context_args={"training": training_value},
815
+ ):
807
816
  # Check input assumptions set after layer building, e.g. input
808
817
  # shape.
809
818
  if build_graph:
@@ -2177,8 +2186,8 @@ class Layer(base_layer.Layer):
2177
2186
  else:
2178
2187
  self._set_dtype_policy(policy.Policy(dtype))
2179
2188
  input_shapes = None
2180
- if all(hasattr(x, "shape") for x in input_list):
2181
- input_shapes = tf.nest.map_structure(lambda x: x.shape, inputs)
2189
+ if any(hasattr(x, "shape") for x in input_list):
2190
+ input_shapes = tf_utils.get_shapes(inputs)
2182
2191
  # Only call `build` if the user has manually overridden the build
2183
2192
  # method.
2184
2193
  if not hasattr(self.build, "_is_default"):
@@ -231,7 +231,7 @@ class TensorLikeDataAdapter(DataAdapter):
231
231
  return True
232
232
  return False
233
233
 
234
- return all(_is_tensor(v) for v in flat_inputs)
234
+ return all(_is_tensor(v) for v in flat_inputs if v is not None)
235
235
 
236
236
  def __init__(
237
237
  self,
@@ -259,7 +259,7 @@ class TensorLikeDataAdapter(DataAdapter):
259
259
  inputs = pack_x_y_sample_weight(x, y, sample_weights)
260
260
 
261
261
  num_samples = set(
262
- int(i.shape[0]) for i in tf.nest.flatten(inputs)
262
+ int(i.shape[0]) for i in tf.nest.flatten(inputs) if i is not None
263
263
  ).pop()
264
264
  _check_data_cardinality(inputs)
265
265
 
@@ -386,7 +386,7 @@ class TensorLikeDataAdapter(DataAdapter):
386
386
 
387
387
  def grab_batch(i, data):
388
388
  return tf.nest.map_structure(
389
- lambda d: tf.gather(d, i, axis=0), data
389
+ lambda d: tf.gather(d, i, axis=0) if d is not None else d, data
390
390
  )
391
391
 
392
392
  dataset = dataset.map(grab_batch, num_parallel_calls=tf.data.AUTOTUNE)
@@ -459,7 +459,7 @@ class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
459
459
  if not TensorLikeDataAdapter.can_handle(
460
460
  x, y
461
461
  ) and not CompositeTensorDataAdapter.can_handle(x, y):
462
- return all(_is_array_like(v) for v in flat_inputs)
462
+ return all(_is_array_like(v) for v in flat_inputs if v is not None)
463
463
  else:
464
464
  return False
465
465
 
@@ -496,7 +496,7 @@ class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
496
496
  shape[0] = None
497
497
  return tuple(shape)
498
498
 
499
- flat_dtypes = [inp.dtype for inp in flat_inputs]
499
+ flat_dtypes = [inp.dtype for inp in flat_inputs if inp is not None]
500
500
  contiguous = True
501
501
  if self._shuffle and self._shuffle != "batch":
502
502
  contiguous = False
@@ -509,15 +509,26 @@ class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
509
509
  # to a Tensor may force it into memory..
510
510
  def py_method(ind):
511
511
  def slice_array(data):
512
+ if data is None:
513
+ return None
512
514
  return training_utils.slice_arrays(
513
515
  data, ind.numpy(), contiguous=contiguous
514
516
  )
515
517
 
516
- return [slice_array(inp) for inp in flat_inputs]
518
+ return [
519
+ slice_array(inp) for inp in flat_inputs if inp is not None
520
+ ]
517
521
 
518
- flat_out = tf.py_function(py_method, [indices], flat_dtypes)
519
- for v, original_inp in zip(flat_out, flat_inputs):
520
- v.set_shape(dynamic_shape_like(original_inp))
522
+ results = tf.py_function(py_method, [indices], flat_dtypes)
523
+ results_it = iter(results)
524
+ flat_out = []
525
+ for original_inp in flat_inputs:
526
+ if original_inp is None:
527
+ flat_out.append(None)
528
+ else:
529
+ v = next(results_it)
530
+ v.set_shape(dynamic_shape_like(original_inp))
531
+ flat_out.append(v)
521
532
  return tf.nest.pack_sequence_as(inputs, flat_out)
522
533
 
523
534
  dataset = indices_dataset.map(
@@ -608,8 +619,10 @@ class CompositeTensorDataAdapter(DataAdapter):
608
619
  return True
609
620
  return _is_composite(v)
610
621
 
611
- return any(_is_composite(v) for v in flat_inputs) and all(
612
- _is_tensor_or_composite(v) for v in flat_inputs
622
+ return any(
623
+ _is_composite(v) for v in flat_inputs if v is not None
624
+ ) and all(
625
+ _is_tensor_or_composite(v) for v in flat_inputs if v is not None
613
626
  )
614
627
 
615
628
  def __init__(
@@ -1944,14 +1957,18 @@ def single_batch_iterator(
1944
1957
 
1945
1958
 
1946
1959
  def _check_data_cardinality(data):
1947
- num_samples = set(int(i.shape[0]) for i in tf.nest.flatten(data))
1960
+ num_samples = set(
1961
+ int(i.shape[0]) for i in tf.nest.flatten(data) if i is not None
1962
+ )
1948
1963
  if len(num_samples) > 1:
1949
1964
  msg = "Data cardinality is ambiguous:\n"
1950
1965
  for label, single_data in zip(["x", "y", "sample_weight"], data):
1951
1966
  msg += " {} sizes: {}\n".format(
1952
1967
  label,
1953
1968
  ", ".join(
1954
- str(i.shape[0]) for i in tf.nest.flatten(single_data)
1969
+ str(i.shape[0])
1970
+ for i in tf.nest.flatten(single_data)
1971
+ if i is not None
1955
1972
  ),
1956
1973
  )
1957
1974
  msg += "Make sure all arrays contain the same number of samples."
@@ -351,25 +351,45 @@ class Functional(training_lib.Model):
351
351
  if isinstance(self._nested_inputs, dict):
352
352
  # Case where `_nested_inputs` is a plain dict of Inputs.
353
353
  names = sorted(self._nested_inputs.keys())
354
- return [
355
- input_spec.InputSpec(
356
- shape=shape_with_no_batch_size(self._nested_inputs[name]),
357
- allow_last_axis_squeeze=True,
358
- name=name,
354
+ specs = []
355
+ for name in names:
356
+ layer = self._nested_inputs[name]._keras_history.layer
357
+ optional = (
358
+ layer.optional
359
+ if isinstance(layer, input_layer_module.InputLayer)
360
+ else False
359
361
  )
360
- for name in names
361
- ]
362
+ specs.append(
363
+ input_spec.InputSpec(
364
+ shape=shape_with_no_batch_size(
365
+ self._nested_inputs[name]
366
+ ),
367
+ allow_last_axis_squeeze=True,
368
+ name=name,
369
+ optional=optional,
370
+ )
371
+ )
372
+ return specs
362
373
  else:
363
374
  # Single input, or list / tuple of inputs.
364
375
  # The data may be passed as a dict keyed by input name.
365
- return [
366
- input_spec.InputSpec(
367
- shape=shape_with_no_batch_size(x),
368
- allow_last_axis_squeeze=True,
369
- name=x._keras_history.layer.name,
376
+ specs = []
377
+ for x in self.inputs:
378
+ layer = x._keras_history.layer
379
+ optional = (
380
+ layer.optional
381
+ if isinstance(layer, input_layer_module.InputLayer)
382
+ else False
370
383
  )
371
- for x in self.inputs
372
- ]
384
+ specs.append(
385
+ input_spec.InputSpec(
386
+ shape=shape_with_no_batch_size(x),
387
+ allow_last_axis_squeeze=True,
388
+ name=x._keras_history.layer.name,
389
+ optional=optional,
390
+ )
391
+ )
392
+ return specs
373
393
 
374
394
  @input_spec.setter
375
395
  def input_spec(self, value):
@@ -644,7 +664,8 @@ class Functional(training_lib.Model):
644
664
  else:
645
665
  masks = self._flatten_to_reference_inputs(mask)
646
666
  for input_t, mask in zip(inputs, masks):
647
- input_t._keras_mask = mask
667
+ if input_t is not None:
668
+ input_t._keras_mask = mask
648
669
 
649
670
  # Dictionary mapping reference tensors to computed tensors.
650
671
  tensor_dict = {}
@@ -98,6 +98,8 @@ class InputLayer(base_layer.Layer):
98
98
  `tf.TypeSpec` represents the entire batch. When provided, all other
99
99
  args except name must be `None`.
100
100
  name: Optional name of the layer (string).
101
+ optional: Boolean, whether the input is optional or not. An optional
102
+ input can accept `None` values.
101
103
  """
102
104
 
103
105
  @traceback_utils.filter_traceback
@@ -111,6 +113,7 @@ class InputLayer(base_layer.Layer):
111
113
  name=None,
112
114
  ragged=None,
113
115
  type_spec=None,
116
+ optional=False,
114
117
  **kwargs,
115
118
  ):
116
119
  self._init_input_shape = input_shape
@@ -180,6 +183,7 @@ class InputLayer(base_layer.Layer):
180
183
  self.ragged = True if ragged else False
181
184
  self.batch_size = batch_size
182
185
  self.supports_masking = True
186
+ self.optional = optional
183
187
 
184
188
  if isinstance(input_shape, tf.TensorShape):
185
189
  input_shape = tuple(input_shape.as_list())
@@ -284,6 +288,7 @@ class InputLayer(base_layer.Layer):
284
288
  "sparse": self.sparse,
285
289
  "ragged": self.ragged,
286
290
  "name": self.name,
291
+ "optional": self.optional,
287
292
  }
288
293
  return config
289
294
 
@@ -303,6 +308,7 @@ def Input(
303
308
  tensor=None,
304
309
  ragged=None,
305
310
  type_spec=None,
311
+ optional=False,
306
312
  **kwargs,
307
313
  ):
308
314
  """`Input()` is used to instantiate a TF-Keras tensor.
@@ -341,6 +347,8 @@ def Input(
341
347
  [this guide](https://www.tensorflow.org/guide/ragged_tensor).
342
348
  type_spec: A `tf.TypeSpec` object to create the input placeholder from.
343
349
  When provided, all other args except name must be None.
350
+ optional: Boolean, whether the input is optional or not. An optional
351
+ input can accept `None` values.
344
352
  **kwargs: deprecated arguments support. Supports `batch_shape` and
345
353
  `batch_input_shape`.
346
354
 
@@ -415,6 +423,7 @@ def Input(
415
423
  "ragged": ragged,
416
424
  "input_tensor": tensor,
417
425
  "type_spec": type_spec,
426
+ "optional": optional,
418
427
  }
419
428
 
420
429
  batch_input_shape = kwargs.pop(
@@ -56,6 +56,8 @@ class InputSpec:
56
56
  as long as the last axis of the spec is 1.
57
57
  name: Expected key corresponding to this input when passing data as
58
58
  a dictionary.
59
+ optional: Boolean, whether the input is optional or not. An optional input
60
+ can accept `None` values.
59
61
 
60
62
  Example:
61
63
 
@@ -82,6 +84,7 @@ class InputSpec:
82
84
  axes=None,
83
85
  allow_last_axis_squeeze=False,
84
86
  name=None,
87
+ optional=False,
85
88
  ):
86
89
  self.dtype = tf.as_dtype(dtype).name if dtype is not None else None
87
90
  shape = tf.TensorShape(shape)
@@ -99,6 +102,7 @@ class InputSpec:
99
102
  self.min_ndim = min_ndim
100
103
  self.name = name
101
104
  self.allow_last_axis_squeeze = allow_last_axis_squeeze
105
+ self.optional = optional
102
106
  try:
103
107
  axes = axes or {}
104
108
  self.axes = {int(k): axes[k] for k in axes}
@@ -204,7 +208,11 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
204
208
  inputs = list_inputs
205
209
 
206
210
  inputs = tf.nest.flatten(inputs)
207
- for x in inputs:
211
+ for _, (x, spec) in enumerate(zip(inputs, input_spec)):
212
+ if spec is None:
213
+ continue
214
+ if x is None and spec.optional:
215
+ continue
208
216
  # Having a shape/dtype is the only commonality of the various
209
217
  # tensor-like objects that may be passed. The most common kind of
210
218
  # invalid type we are guarding for is a Layer instance (Functional API),
@@ -224,6 +232,8 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
224
232
  for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
225
233
  if spec is None:
226
234
  continue
235
+ if x is None and spec.optional:
236
+ continue
227
237
 
228
238
  shape = tf.TensorShape(x.shape)
229
239
  if shape.rank is None:
@@ -70,6 +70,8 @@ class Softmax(Layer):
70
70
  Args:
71
71
  axis: Integer, or list of Integers, axis along which the softmax
72
72
  normalization is applied.
73
+ robust_masking: Bool, if true will use a more robust implementation when
74
+ dealing with masks.
73
75
  Call arguments:
74
76
  inputs: The inputs, or logits to the softmax layer.
75
77
  mask: A boolean mask of the same shape as `inputs`. The mask
@@ -80,23 +82,34 @@ class Softmax(Layer):
80
82
  Softmaxed output with the same shape as `inputs`.
81
83
  """
82
84
 
83
- def __init__(self, axis=-1, **kwargs):
85
+ def __init__(self, axis=-1, robust_masking=False, **kwargs):
84
86
  super().__init__(**kwargs)
85
87
  self.supports_masking = True
88
+ self.robust_masking = robust_masking
86
89
  self.axis = axis
87
90
 
88
91
  def call(self, inputs, mask=None):
89
92
  if mask is not None:
90
- # Since mask is 1.0 for positions we want to keep and 0.0 for masked
91
- # positions, this operation will create a tensor which is 0.0 for
92
- # positions we want to attend and -1e.9 for masked positions.
93
- adder = (1.0 - tf.cast(mask, inputs.dtype)) * (
94
- _large_compatible_negative(inputs.dtype)
95
- )
96
-
97
- # Since we are adding it to the raw scores before the softmax, this
98
- # is effectively the same as removing these entirely.
99
- inputs += adder
93
+ if self.robust_masking:
94
+ # We keep the positions where the mask is True or > 0.5, and set
95
+ # the other (masked) positions to -1e.9.
96
+ if mask.dtype is not tf.bool:
97
+ mask = tf.greater(mask, tf.constant(0.5, dtype=mask.dtype))
98
+ inputs = tf.where(
99
+ mask, inputs, _large_compatible_negative(inputs.dtype)
100
+ )
101
+ else:
102
+ # Since mask is 1.0 for positions we want to keep and 0.0 for
103
+ # masked positions, this operation will create a tensor which is
104
+ # 0.0 for positions we want to attend and -1e.9 for masked
105
+ # positions.
106
+ adder = (1.0 - tf.cast(mask, inputs.dtype)) * (
107
+ _large_compatible_negative(inputs.dtype)
108
+ )
109
+
110
+ # Since we are adding it to the raw scores before the softmax,
111
+ # this is effectively the same as removing these entirely.
112
+ inputs += adder
100
113
  if isinstance(self.axis, (tuple, list)):
101
114
  if len(self.axis) > 1:
102
115
  return tf.exp(
@@ -109,6 +122,8 @@ class Softmax(Layer):
109
122
 
110
123
  def get_config(self):
111
124
  config = {"axis": self.axis}
125
+ if self.robust_masking:
126
+ config["robust_masking"] = True
112
127
  base_config = super().get_config()
113
128
  return dict(list(base_config.items()) + list(config.items()))
114
129
 
@@ -198,6 +198,8 @@ class MultiHeadAttention(Layer):
198
198
  activity_regularizer: Regularizer for dense layer activity.
199
199
  kernel_constraint: Constraint for dense layer kernels.
200
200
  bias_constraint: Constraint for dense layer kernels.
201
+ softmax_robust_masking: If true will use a more numerically robust
202
+ masking impl.
201
203
 
202
204
  Call arguments:
203
205
  query: Query `Tensor` of shape `(B, T, dim)`.
@@ -247,6 +249,7 @@ class MultiHeadAttention(Layer):
247
249
  activity_regularizer=None,
248
250
  kernel_constraint=None,
249
251
  bias_constraint=None,
252
+ softmax_robust_masking=False,
250
253
  **kwargs,
251
254
  ):
252
255
  super().__init__(**kwargs)
@@ -264,6 +267,7 @@ class MultiHeadAttention(Layer):
264
267
  self._activity_regularizer = regularizers.get(activity_regularizer)
265
268
  self._kernel_constraint = constraints.get(kernel_constraint)
266
269
  self._bias_constraint = constraints.get(bias_constraint)
270
+ self._softmax_robust_masking = softmax_robust_masking
267
271
  if attention_axes is not None and not isinstance(
268
272
  attention_axes, collections.abc.Sized
269
273
  ):
@@ -298,6 +302,7 @@ class MultiHeadAttention(Layer):
298
302
  "query_shape": self._query_shape,
299
303
  "key_shape": self._key_shape,
300
304
  "value_shape": self._value_shape,
305
+ "softmax_robust_masking": self._softmax_robust_masking,
301
306
  }
302
307
  base_config = super().get_config()
303
308
  return dict(list(base_config.items()) + list(config.items()))
@@ -476,7 +481,9 @@ class MultiHeadAttention(Layer):
476
481
  )
477
482
  )
478
483
  self._softmax = activation.Softmax(
479
- axis=norm_axes, dtype=self._dtype_policy
484
+ axis=norm_axes,
485
+ robust_masking=self._softmax_robust_masking,
486
+ dtype=self._dtype_policy,
480
487
  )
481
488
  self._dropout_layer = regularization.Dropout(
482
489
  rate=self._dropout, dtype=self._dtype_policy
@@ -259,6 +259,10 @@ class TFOpLambda(Layer):
259
259
 
260
260
  self._call_spec.expects_training_arg = False
261
261
  self._call_spec.expects_mask_arg = False
262
+ # Clear the call-context arguments for the layer's call method.
263
+ # Otherwise, Keras ends up injecting context arguments into the op-call
264
+ # when the call method accepts kwargs.
265
+ self._call_spec._expected_context_args.clear()
262
266
 
263
267
  def _call_wrapper(self, *args, **kwargs):
264
268
  created_variables = []
@@ -52,9 +52,21 @@ class _RNNCellWrapper(AbstractRNNCell):
52
52
  super().__init__(*args, **kwargs)
53
53
  self.cell = cell
54
54
  cell_call_spec = tf_inspect.getfullargspec(cell.call)
55
+ accepts_kwargs = cell_call_spec.varkw is not None
56
+
55
57
  self._call_spec.expects_training_arg = (
56
58
  "training" in cell_call_spec.args
57
- ) or (cell_call_spec.varkw is not None)
59
+ ) or accepts_kwargs
60
+
61
+ # Filter _expects_context_arg. An argument is kept if:
62
+ # 1. It's an explicit argument in cell_call_spec.args OR
63
+ # 2. The cell accepts arbitrary keyword arguments (**kwargs),
64
+ # meaning it could potentially handle the context argument.
65
+ self._call_spec._expected_context_args = {
66
+ arg
67
+ for arg in self._call_spec._expected_context_args
68
+ if (arg in cell_call_spec.args) or accepts_kwargs
69
+ }
58
70
 
59
71
  def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
60
72
  """Calls the wrapped cell and performs the wrapping logic.
@@ -1471,9 +1471,10 @@ class AUC(base_metric.Metric):
1471
1471
  # label_weights should be of length equal to the number of
1472
1472
  # labels.
1473
1473
  shapes.append((self.label_weights, ("L",)))
1474
- tf.debugging.assert_shapes(
1475
- shapes, message="Number of labels is not consistent."
1476
- )
1474
+
1475
+ tf.debugging.assert_shapes(
1476
+ shapes, message="Number of labels is not consistent."
1477
+ )
1477
1478
 
1478
1479
  # Only forward label_weights to update_confusion_matrix_variables when
1479
1480
  # multi_label is False. Otherwise the averaging of individual label AUCs
@@ -1611,13 +1612,59 @@ class AUC(base_metric.Metric):
1611
1612
  )
1612
1613
  x = fp_rate
1613
1614
  y = recall
1614
- else: # curve == 'PR'.
1615
+ elif self.curve == metrics_utils.AUCCurve.PR:
1615
1616
  precision = tf.math.divide_no_nan(
1616
1617
  self.true_positives,
1617
1618
  tf.math.add(self.true_positives, self.false_positives),
1618
1619
  )
1619
1620
  x = recall
1620
1621
  y = precision
1622
+ else: # curve == 'PR_GAIN'.
1623
+ # Due to the hyperbolic transform, this formula is less robust than
1624
+ # ROC or PR values. In particular
1625
+ # 1) Both measures diverge when there are no negative examples;
1626
+ # 2) Both measures diverge when there are no true positives;
1627
+ # 3) Recall gain becomes negative when the recall is lower than the
1628
+ # label average (i.e. when more negative examples are classified
1629
+ # positive than real positives).
1630
+ #
1631
+ # We ignore case 1 as it is easily communicated. For case 2 we set
1632
+ # recall_gain to 0 and precision_gain to 1. For case 3 we set the
1633
+ # recall_gain to 0. These fixes will result in an overastimation of
1634
+ # the AUC for estimateors that are anti-correlated with the label
1635
+ # (at some thresholds).
1636
+ #
1637
+ # The scaling factor $\frac{P}{N}$ that is used to form both
1638
+ # gain values.
1639
+ scaling_factor = tf.math.divide_no_nan(
1640
+ tf.math.add(self.true_positives, self.false_negatives),
1641
+ tf.math.add(self.false_positives, self.true_negatives),
1642
+ )
1643
+
1644
+ recall_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(
1645
+ self.false_negatives, self.true_positives
1646
+ )
1647
+ precision_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(
1648
+ self.false_positives, self.true_positives
1649
+ )
1650
+ # Handle case 2.
1651
+ recall_gain = tf.where(
1652
+ tf.equal(self.true_positives, 0.0),
1653
+ tf.zeros_like(recall_gain),
1654
+ recall_gain,
1655
+ )
1656
+ precision_gain = tf.where(
1657
+ tf.equal(self.true_positives, 0.0),
1658
+ tf.ones_like(precision_gain),
1659
+ precision_gain,
1660
+ )
1661
+ # Handle case 3.
1662
+ recall_gain = tf.math.maximum(
1663
+ recall_gain, tf.zeros_like(recall_gain)
1664
+ )
1665
+
1666
+ x = recall_gain
1667
+ y = precision_gain
1621
1668
 
1622
1669
  # Find the rectangle heights based on `summation_method`.
1623
1670
  if (
@@ -72,17 +72,27 @@ class SharpnessAwareMinimization(Model):
72
72
  if self.num_batch_splits is not None:
73
73
  x_split = tf.split(x, self.num_batch_splits)
74
74
  y_split = tf.split(y, self.num_batch_splits)
75
+ # Split the sample weight if it is provided.
76
+ if sample_weight is not None:
77
+ sample_weight_split = tf.split(
78
+ sample_weight, self.num_batch_splits
79
+ )
80
+ else:
81
+ sample_weight_split = [None] * self.num_batch_splits
75
82
  else:
76
83
  x_split = [x]
77
84
  y_split = [y]
85
+ sample_weight_split = [sample_weight]
78
86
 
79
87
  gradients_all_batches = []
80
88
  pred_all_batches = []
81
- for x_batch, y_batch in zip(x_split, y_split):
89
+ for x_batch, y_batch, sample_weight_batch in zip(
90
+ x_split, y_split, sample_weight_split
91
+ ):
82
92
  epsilon_w_cache = []
83
93
  with tf.GradientTape() as tape:
84
- pred = self.model(x_batch)
85
- loss = self.compiled_loss(y_batch, pred)
94
+ pred = self(x_batch)
95
+ loss = self.compiled_loss(y_batch, pred, sample_weight_batch)
86
96
  pred_all_batches.append(pred)
87
97
  trainable_variables = self.model.trainable_variables
88
98
  gradients = tape.gradient(loss, trainable_variables)
@@ -98,8 +108,8 @@ class SharpnessAwareMinimization(Model):
98
108
  epsilon_w_cache.append(epsilon_w)
99
109
 
100
110
  with tf.GradientTape() as tape:
101
- pred = self(x_batch)
102
- loss = self.compiled_loss(y_batch, pred)
111
+ pred = self(x_batch, training=True)
112
+ loss = self.compiled_loss(y_batch, pred, sample_weight_batch)
103
113
  gradients = tape.gradient(loss, trainable_variables)
104
114
  if len(gradients_all_batches) == 0:
105
115
  for gradient in gradients:
@@ -127,7 +137,7 @@ class SharpnessAwareMinimization(Model):
127
137
  self.compiled_metrics.update_state(y, pred, sample_weight)
128
138
  return {m.name: m.result() for m in self.metrics}
129
139
 
130
- def call(self, inputs):
140
+ def call(self, inputs, **kwargs):
131
141
  """Forward pass of SAM.
132
142
 
133
143
  SAM delegates the forward pass call to the wrapped model.
@@ -138,7 +148,7 @@ class SharpnessAwareMinimization(Model):
138
148
  Returns:
139
149
  A Tensor, the outputs of the wrapped model for given `inputs`.
140
150
  """
141
- return self.model(inputs)
151
+ return self.model(inputs, **kwargs)
142
152
 
143
153
  def get_config(self):
144
154
  config = super().get_config()