tf-keras-nightly 2.21.0.dev2025111310__py3-none-any.whl → 2.21.0.dev2025111510__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tf_keras/__init__.py CHANGED
@@ -27,4 +27,4 @@ from tf_keras.src.engine.sequential import Sequential
27
27
  from tf_keras.src.engine.training import Model
28
28
 
29
29
 
30
- __version__ = "2.21.0.dev2025111310"
30
+ __version__ = "2.21.0.dev2025111510"
@@ -70,6 +70,8 @@ class Softmax(Layer):
70
70
  Args:
71
71
  axis: Integer, or list of Integers, axis along which the softmax
72
72
  normalization is applied.
73
+ robust_masking: Bool, if true will use a more robust implementation when
74
+ dealing with masks.
73
75
  Call arguments:
74
76
  inputs: The inputs, or logits to the softmax layer.
75
77
  mask: A boolean mask of the same shape as `inputs`. The mask
@@ -80,23 +82,34 @@ class Softmax(Layer):
80
82
  Softmaxed output with the same shape as `inputs`.
81
83
  """
82
84
 
83
- def __init__(self, axis=-1, **kwargs):
85
+ def __init__(self, axis=-1, robust_masking=False, **kwargs):
84
86
  super().__init__(**kwargs)
85
87
  self.supports_masking = True
88
+ self.robust_masking = robust_masking
86
89
  self.axis = axis
87
90
 
88
91
  def call(self, inputs, mask=None):
89
92
  if mask is not None:
90
- # Since mask is 1.0 for positions we want to keep and 0.0 for masked
91
- # positions, this operation will create a tensor which is 0.0 for
92
- # positions we want to attend and -1e.9 for masked positions.
93
- adder = (1.0 - tf.cast(mask, inputs.dtype)) * (
94
- _large_compatible_negative(inputs.dtype)
95
- )
96
-
97
- # Since we are adding it to the raw scores before the softmax, this
98
- # is effectively the same as removing these entirely.
99
- inputs += adder
93
+ if self.robust_masking:
94
+ # We keep the positions where the mask is True or > 0.5, and set
95
+ # the other (masked) positions to -1e.9.
96
+ if mask.dtype is not tf.bool:
97
+ mask = tf.greater(mask, tf.constant(0.5, dtype=mask.dtype))
98
+ inputs = tf.where(
99
+ mask, inputs, _large_compatible_negative(inputs.dtype)
100
+ )
101
+ else:
102
+ # Since mask is 1.0 for positions we want to keep and 0.0 for
103
+ # masked positions, this operation will create a tensor which is
104
+ # 0.0 for positions we want to attend and -1e.9 for masked
105
+ # positions.
106
+ adder = (1.0 - tf.cast(mask, inputs.dtype)) * (
107
+ _large_compatible_negative(inputs.dtype)
108
+ )
109
+
110
+ # Since we are adding it to the raw scores before the softmax,
111
+ # this is effectively the same as removing these entirely.
112
+ inputs += adder
100
113
  if isinstance(self.axis, (tuple, list)):
101
114
  if len(self.axis) > 1:
102
115
  return tf.exp(
@@ -109,6 +122,8 @@ class Softmax(Layer):
109
122
 
110
123
  def get_config(self):
111
124
  config = {"axis": self.axis}
125
+ if self.robust_masking:
126
+ config["robust_masking"] = True
112
127
  base_config = super().get_config()
113
128
  return dict(list(base_config.items()) + list(config.items()))
114
129
 
@@ -198,6 +198,8 @@ class MultiHeadAttention(Layer):
198
198
  activity_regularizer: Regularizer for dense layer activity.
199
199
  kernel_constraint: Constraint for dense layer kernels.
200
200
  bias_constraint: Constraint for dense layer kernels.
201
+ softmax_robust_masking: If true will use a more numerically robust
202
+ masking impl.
201
203
 
202
204
  Call arguments:
203
205
  query: Query `Tensor` of shape `(B, T, dim)`.
@@ -247,6 +249,7 @@ class MultiHeadAttention(Layer):
247
249
  activity_regularizer=None,
248
250
  kernel_constraint=None,
249
251
  bias_constraint=None,
252
+ softmax_robust_masking=False,
250
253
  **kwargs,
251
254
  ):
252
255
  super().__init__(**kwargs)
@@ -264,6 +267,7 @@ class MultiHeadAttention(Layer):
264
267
  self._activity_regularizer = regularizers.get(activity_regularizer)
265
268
  self._kernel_constraint = constraints.get(kernel_constraint)
266
269
  self._bias_constraint = constraints.get(bias_constraint)
270
+ self._softmax_robust_masking = softmax_robust_masking
267
271
  if attention_axes is not None and not isinstance(
268
272
  attention_axes, collections.abc.Sized
269
273
  ):
@@ -298,6 +302,7 @@ class MultiHeadAttention(Layer):
298
302
  "query_shape": self._query_shape,
299
303
  "key_shape": self._key_shape,
300
304
  "value_shape": self._value_shape,
305
+ "softmax_robust_masking": self._softmax_robust_masking,
301
306
  }
302
307
  base_config = super().get_config()
303
308
  return dict(list(base_config.items()) + list(config.items()))
@@ -476,7 +481,9 @@ class MultiHeadAttention(Layer):
476
481
  )
477
482
  )
478
483
  self._softmax = activation.Softmax(
479
- axis=norm_axes, dtype=self._dtype_policy
484
+ axis=norm_axes,
485
+ robust_masking=self._softmax_robust_masking,
486
+ dtype=self._dtype_policy,
480
487
  )
481
488
  self._dropout_layer = regularization.Dropout(
482
489
  rate=self._dropout, dtype=self._dtype_policy
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tf_keras-nightly
3
- Version: 2.21.0.dev2025111310
3
+ Version: 2.21.0.dev2025111510
4
4
  Summary: Deep learning for humans.
5
5
  Home-page: https://keras.io/
6
6
  Download-URL: https://github.com/keras-team/tf-keras/tags
@@ -1,4 +1,4 @@
1
- tf_keras/__init__.py,sha256=eFirk4FsoBp1NWtnKRktl-K6FFqDZOEA6TaG5Su-dhk,911
1
+ tf_keras/__init__.py,sha256=rLVd1GJJDlrAKzFenVo0OokgsLbWW7-3OnWYZHftXgM,911
2
2
  tf_keras/__internal__/__init__.py,sha256=OHQbeIC0QtRBI7dgXaJaVbH8F00x8dCI-DvEcIfyMsE,671
3
3
  tf_keras/__internal__/backend/__init__.py,sha256=LnMs2A6685gDG79fxqmdulIYlVE_3WmXlBTBo9ZWYcw,162
4
4
  tf_keras/__internal__/layers/__init__.py,sha256=F5SGMhOTPzm-PR44VrfinURHcVeQPIEdwnZlAkSTB3A,176
@@ -314,13 +314,13 @@ tf_keras/src/layers/activation/elu.py,sha256=n-WAE6NjC9mbqcV7Kxgpt8tTbvwCQIGsoCV
314
314
  tf_keras/src/layers/activation/leaky_relu.py,sha256=cJmpwgg4KEu--iK9gFuJT7uEGpDArB8q-XNBmJfC7_U,2618
315
315
  tf_keras/src/layers/activation/prelu.py,sha256=D2yhneQrYQP6aHSK8nvnMKa1hIeuPZO_XCB2Cu9Cl4Y,4440
316
316
  tf_keras/src/layers/activation/relu.py,sha256=JklQuReRiR3huAGr3QRtuGL0URpdspDFzBNjZgv0HDw,4281
317
- tf_keras/src/layers/activation/softmax.py,sha256=G6MfTCogGTKwyP7b6ByxeIHFNQtUKgrZXB8MP9hNstQ,4105
317
+ tf_keras/src/layers/activation/softmax.py,sha256=0g8uN5N8QDW8lj6nGabR-EBk58njbiNdhDzglv9rxXU,4861
318
318
  tf_keras/src/layers/activation/thresholded_relu.py,sha256=rQLn9cr-w6hVJET2mS7OIQ9diiUiqUrX4CysXKNYbmg,2503
319
319
  tf_keras/src/layers/attention/__init__.py,sha256=6HjPSyLhs_bf4erT65KyhSCHQF7WeWZe9YTH7iW6Nek,945
320
320
  tf_keras/src/layers/attention/additive_attention.py,sha256=jie0cAXJEjU4xXK_Ur1SrEL9RqDIIAPyaAkK8O71TEs,7485
321
321
  tf_keras/src/layers/attention/attention.py,sha256=TCnoOWAfh6i275TvudxyjosczBmL_zz9ByEUi-xXkAU,8682
322
322
  tf_keras/src/layers/attention/base_dense_attention.py,sha256=cEzBldjwQfuJfNZRimW5s-NqyENU2-lmqaNNxAGxhKw,10856
323
- tf_keras/src/layers/attention/multi_head_attention.py,sha256=05RC-2BSmCcBFtVY2loQPeiMYp8XArmbvovPl8kpiEA,30279
323
+ tf_keras/src/layers/attention/multi_head_attention.py,sha256=FQX0YtXRy5kg8OlShA7cp2kfczzeWb9Oj3tbzkukLRw,30618
324
324
  tf_keras/src/layers/convolutional/__init__.py,sha256=U-4tja5JhSUva2G9uMmsZyZty2N2N9jT6EJRu5HAo-Y,3355
325
325
  tf_keras/src/layers/convolutional/base_conv.py,sha256=jvm4elEyIVSNfYZxh4inzQ1Q2CKS_f8VawvXMIJFSC4,17574
326
326
  tf_keras/src/layers/convolutional/base_depthwise_conv.py,sha256=SVgR2Y8dpeX4eDEF1e0UY0Mxh4A47eGHhJCQ1peGwNQ,9661
@@ -584,7 +584,7 @@ tf_keras/src/utils/legacy/__init__.py,sha256=EfMmeHYDzwvxNaktPhQbkTdcPSIGCqMhBND
584
584
  tf_keras/utils/__init__.py,sha256=b7_d-USe_EmLo02_P99Q1rUCzKBYayPCfiYFStP-0nw,2735
585
585
  tf_keras/utils/experimental/__init__.py,sha256=DzGogE2AosjxOVILQBT8PDDcqbWTc0wWnZRobCdpcec,97
586
586
  tf_keras/utils/legacy/__init__.py,sha256=7ujlDa5HeSRcth2NdqA0S1P2-VZF1kB3n68jye6Dj-8,189
587
- tf_keras_nightly-2.21.0.dev2025111310.dist-info/METADATA,sha256=FZioTXJ-tYTS3AzxSICXoXvuUFlTim7r_SMKT0oGOyU,1857
588
- tf_keras_nightly-2.21.0.dev2025111310.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
589
- tf_keras_nightly-2.21.0.dev2025111310.dist-info/top_level.txt,sha256=LC8FK7zHDNKxB17C6lGKvrZ_fZZGJsRiBK23SfiDegY,9
590
- tf_keras_nightly-2.21.0.dev2025111310.dist-info/RECORD,,
587
+ tf_keras_nightly-2.21.0.dev2025111510.dist-info/METADATA,sha256=oS3tHc8RYnnpb-O3e1kWcj4TyAEdALVAe-HwJV9RN1k,1857
588
+ tf_keras_nightly-2.21.0.dev2025111510.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
589
+ tf_keras_nightly-2.21.0.dev2025111510.dist-info/top_level.txt,sha256=LC8FK7zHDNKxB17C6lGKvrZ_fZZGJsRiBK23SfiDegY,9
590
+ tf_keras_nightly-2.21.0.dev2025111510.dist-info/RECORD,,