tf-keras-nightly 2.21.0.dev2025111410__py3-none-any.whl → 2.21.0.dev2025111610__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tf_keras/__init__.py +1 -1
- tf_keras/src/layers/activation/softmax.py +26 -11
- tf_keras/src/layers/attention/multi_head_attention.py +8 -1
- {tf_keras_nightly-2.21.0.dev2025111410.dist-info → tf_keras_nightly-2.21.0.dev2025111610.dist-info}/METADATA +1 -1
- {tf_keras_nightly-2.21.0.dev2025111410.dist-info → tf_keras_nightly-2.21.0.dev2025111610.dist-info}/RECORD +7 -7
- {tf_keras_nightly-2.21.0.dev2025111410.dist-info → tf_keras_nightly-2.21.0.dev2025111610.dist-info}/WHEEL +0 -0
- {tf_keras_nightly-2.21.0.dev2025111410.dist-info → tf_keras_nightly-2.21.0.dev2025111610.dist-info}/top_level.txt +0 -0
tf_keras/__init__.py
CHANGED
|
@@ -70,6 +70,8 @@ class Softmax(Layer):
|
|
|
70
70
|
Args:
|
|
71
71
|
axis: Integer, or list of Integers, axis along which the softmax
|
|
72
72
|
normalization is applied.
|
|
73
|
+
robust_masking: Bool, if true will use a more robust implementation when
|
|
74
|
+
dealing with masks.
|
|
73
75
|
Call arguments:
|
|
74
76
|
inputs: The inputs, or logits to the softmax layer.
|
|
75
77
|
mask: A boolean mask of the same shape as `inputs`. The mask
|
|
@@ -80,23 +82,34 @@ class Softmax(Layer):
|
|
|
80
82
|
Softmaxed output with the same shape as `inputs`.
|
|
81
83
|
"""
|
|
82
84
|
|
|
83
|
-
def __init__(self, axis=-1, **kwargs):
|
|
85
|
+
def __init__(self, axis=-1, robust_masking=False, **kwargs):
|
|
84
86
|
super().__init__(**kwargs)
|
|
85
87
|
self.supports_masking = True
|
|
88
|
+
self.robust_masking = robust_masking
|
|
86
89
|
self.axis = axis
|
|
87
90
|
|
|
88
91
|
def call(self, inputs, mask=None):
|
|
89
92
|
if mask is not None:
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
93
|
+
if self.robust_masking:
|
|
94
|
+
# We keep the positions where the mask is True or > 0.5, and set
|
|
95
|
+
# the other (masked) positions to -1e.9.
|
|
96
|
+
if mask.dtype is not tf.bool:
|
|
97
|
+
mask = tf.greater(mask, tf.constant(0.5, dtype=mask.dtype))
|
|
98
|
+
inputs = tf.where(
|
|
99
|
+
mask, inputs, _large_compatible_negative(inputs.dtype)
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
# Since mask is 1.0 for positions we want to keep and 0.0 for
|
|
103
|
+
# masked positions, this operation will create a tensor which is
|
|
104
|
+
# 0.0 for positions we want to attend and -1e.9 for masked
|
|
105
|
+
# positions.
|
|
106
|
+
adder = (1.0 - tf.cast(mask, inputs.dtype)) * (
|
|
107
|
+
_large_compatible_negative(inputs.dtype)
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Since we are adding it to the raw scores before the softmax,
|
|
111
|
+
# this is effectively the same as removing these entirely.
|
|
112
|
+
inputs += adder
|
|
100
113
|
if isinstance(self.axis, (tuple, list)):
|
|
101
114
|
if len(self.axis) > 1:
|
|
102
115
|
return tf.exp(
|
|
@@ -109,6 +122,8 @@ class Softmax(Layer):
|
|
|
109
122
|
|
|
110
123
|
def get_config(self):
|
|
111
124
|
config = {"axis": self.axis}
|
|
125
|
+
if self.robust_masking:
|
|
126
|
+
config["robust_masking"] = True
|
|
112
127
|
base_config = super().get_config()
|
|
113
128
|
return dict(list(base_config.items()) + list(config.items()))
|
|
114
129
|
|
|
@@ -198,6 +198,8 @@ class MultiHeadAttention(Layer):
|
|
|
198
198
|
activity_regularizer: Regularizer for dense layer activity.
|
|
199
199
|
kernel_constraint: Constraint for dense layer kernels.
|
|
200
200
|
bias_constraint: Constraint for dense layer kernels.
|
|
201
|
+
softmax_robust_masking: If true will use a more numerically robust
|
|
202
|
+
masking impl.
|
|
201
203
|
|
|
202
204
|
Call arguments:
|
|
203
205
|
query: Query `Tensor` of shape `(B, T, dim)`.
|
|
@@ -247,6 +249,7 @@ class MultiHeadAttention(Layer):
|
|
|
247
249
|
activity_regularizer=None,
|
|
248
250
|
kernel_constraint=None,
|
|
249
251
|
bias_constraint=None,
|
|
252
|
+
softmax_robust_masking=False,
|
|
250
253
|
**kwargs,
|
|
251
254
|
):
|
|
252
255
|
super().__init__(**kwargs)
|
|
@@ -264,6 +267,7 @@ class MultiHeadAttention(Layer):
|
|
|
264
267
|
self._activity_regularizer = regularizers.get(activity_regularizer)
|
|
265
268
|
self._kernel_constraint = constraints.get(kernel_constraint)
|
|
266
269
|
self._bias_constraint = constraints.get(bias_constraint)
|
|
270
|
+
self._softmax_robust_masking = softmax_robust_masking
|
|
267
271
|
if attention_axes is not None and not isinstance(
|
|
268
272
|
attention_axes, collections.abc.Sized
|
|
269
273
|
):
|
|
@@ -298,6 +302,7 @@ class MultiHeadAttention(Layer):
|
|
|
298
302
|
"query_shape": self._query_shape,
|
|
299
303
|
"key_shape": self._key_shape,
|
|
300
304
|
"value_shape": self._value_shape,
|
|
305
|
+
"softmax_robust_masking": self._softmax_robust_masking,
|
|
301
306
|
}
|
|
302
307
|
base_config = super().get_config()
|
|
303
308
|
return dict(list(base_config.items()) + list(config.items()))
|
|
@@ -476,7 +481,9 @@ class MultiHeadAttention(Layer):
|
|
|
476
481
|
)
|
|
477
482
|
)
|
|
478
483
|
self._softmax = activation.Softmax(
|
|
479
|
-
axis=norm_axes,
|
|
484
|
+
axis=norm_axes,
|
|
485
|
+
robust_masking=self._softmax_robust_masking,
|
|
486
|
+
dtype=self._dtype_policy,
|
|
480
487
|
)
|
|
481
488
|
self._dropout_layer = regularization.Dropout(
|
|
482
489
|
rate=self._dropout, dtype=self._dtype_policy
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
tf_keras/__init__.py,sha256=
|
|
1
|
+
tf_keras/__init__.py,sha256=FYSNmTtnHJHq-r-PH9N0iF73Ckbmclt0R91HoEZaCG0,911
|
|
2
2
|
tf_keras/__internal__/__init__.py,sha256=OHQbeIC0QtRBI7dgXaJaVbH8F00x8dCI-DvEcIfyMsE,671
|
|
3
3
|
tf_keras/__internal__/backend/__init__.py,sha256=LnMs2A6685gDG79fxqmdulIYlVE_3WmXlBTBo9ZWYcw,162
|
|
4
4
|
tf_keras/__internal__/layers/__init__.py,sha256=F5SGMhOTPzm-PR44VrfinURHcVeQPIEdwnZlAkSTB3A,176
|
|
@@ -314,13 +314,13 @@ tf_keras/src/layers/activation/elu.py,sha256=n-WAE6NjC9mbqcV7Kxgpt8tTbvwCQIGsoCV
|
|
|
314
314
|
tf_keras/src/layers/activation/leaky_relu.py,sha256=cJmpwgg4KEu--iK9gFuJT7uEGpDArB8q-XNBmJfC7_U,2618
|
|
315
315
|
tf_keras/src/layers/activation/prelu.py,sha256=D2yhneQrYQP6aHSK8nvnMKa1hIeuPZO_XCB2Cu9Cl4Y,4440
|
|
316
316
|
tf_keras/src/layers/activation/relu.py,sha256=JklQuReRiR3huAGr3QRtuGL0URpdspDFzBNjZgv0HDw,4281
|
|
317
|
-
tf_keras/src/layers/activation/softmax.py,sha256=
|
|
317
|
+
tf_keras/src/layers/activation/softmax.py,sha256=0g8uN5N8QDW8lj6nGabR-EBk58njbiNdhDzglv9rxXU,4861
|
|
318
318
|
tf_keras/src/layers/activation/thresholded_relu.py,sha256=rQLn9cr-w6hVJET2mS7OIQ9diiUiqUrX4CysXKNYbmg,2503
|
|
319
319
|
tf_keras/src/layers/attention/__init__.py,sha256=6HjPSyLhs_bf4erT65KyhSCHQF7WeWZe9YTH7iW6Nek,945
|
|
320
320
|
tf_keras/src/layers/attention/additive_attention.py,sha256=jie0cAXJEjU4xXK_Ur1SrEL9RqDIIAPyaAkK8O71TEs,7485
|
|
321
321
|
tf_keras/src/layers/attention/attention.py,sha256=TCnoOWAfh6i275TvudxyjosczBmL_zz9ByEUi-xXkAU,8682
|
|
322
322
|
tf_keras/src/layers/attention/base_dense_attention.py,sha256=cEzBldjwQfuJfNZRimW5s-NqyENU2-lmqaNNxAGxhKw,10856
|
|
323
|
-
tf_keras/src/layers/attention/multi_head_attention.py,sha256=
|
|
323
|
+
tf_keras/src/layers/attention/multi_head_attention.py,sha256=FQX0YtXRy5kg8OlShA7cp2kfczzeWb9Oj3tbzkukLRw,30618
|
|
324
324
|
tf_keras/src/layers/convolutional/__init__.py,sha256=U-4tja5JhSUva2G9uMmsZyZty2N2N9jT6EJRu5HAo-Y,3355
|
|
325
325
|
tf_keras/src/layers/convolutional/base_conv.py,sha256=jvm4elEyIVSNfYZxh4inzQ1Q2CKS_f8VawvXMIJFSC4,17574
|
|
326
326
|
tf_keras/src/layers/convolutional/base_depthwise_conv.py,sha256=SVgR2Y8dpeX4eDEF1e0UY0Mxh4A47eGHhJCQ1peGwNQ,9661
|
|
@@ -584,7 +584,7 @@ tf_keras/src/utils/legacy/__init__.py,sha256=EfMmeHYDzwvxNaktPhQbkTdcPSIGCqMhBND
|
|
|
584
584
|
tf_keras/utils/__init__.py,sha256=b7_d-USe_EmLo02_P99Q1rUCzKBYayPCfiYFStP-0nw,2735
|
|
585
585
|
tf_keras/utils/experimental/__init__.py,sha256=DzGogE2AosjxOVILQBT8PDDcqbWTc0wWnZRobCdpcec,97
|
|
586
586
|
tf_keras/utils/legacy/__init__.py,sha256=7ujlDa5HeSRcth2NdqA0S1P2-VZF1kB3n68jye6Dj-8,189
|
|
587
|
-
tf_keras_nightly-2.21.0.
|
|
588
|
-
tf_keras_nightly-2.21.0.
|
|
589
|
-
tf_keras_nightly-2.21.0.
|
|
590
|
-
tf_keras_nightly-2.21.0.
|
|
587
|
+
tf_keras_nightly-2.21.0.dev2025111610.dist-info/METADATA,sha256=SF87YdU9NoBNrX-R3R0DLg9bst7FP5Fz0jCx_RtbGas,1857
|
|
588
|
+
tf_keras_nightly-2.21.0.dev2025111610.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
589
|
+
tf_keras_nightly-2.21.0.dev2025111610.dist-info/top_level.txt,sha256=LC8FK7zHDNKxB17C6lGKvrZ_fZZGJsRiBK23SfiDegY,9
|
|
590
|
+
tf_keras_nightly-2.21.0.dev2025111610.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|