tf-models-nightly 2.17.0.dev20240606__py2.py3-none-any.whl → 2.17.0.dev20240608__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/nlp/modeling/layers/block_sparse_attention.py +216 -0
- official/nlp/modeling/layers/block_sparse_attention_test.py +303 -0
- {tf_models_nightly-2.17.0.dev20240606.dist-info → tf_models_nightly-2.17.0.dev20240608.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.17.0.dev20240606.dist-info → tf_models_nightly-2.17.0.dev20240608.dist-info}/RECORD +8 -6
- {tf_models_nightly-2.17.0.dev20240606.dist-info → tf_models_nightly-2.17.0.dev20240608.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.17.0.dev20240606.dist-info → tf_models_nightly-2.17.0.dev20240608.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.17.0.dev20240606.dist-info → tf_models_nightly-2.17.0.dev20240608.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.17.0.dev20240606.dist-info → tf_models_nightly-2.17.0.dev20240608.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,216 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Block sparse attention converts query/key/value into blocks and performs diagonal block sparse attention."""
|
16
|
+
import collections
|
17
|
+
|
18
|
+
import tensorflow as tf, tf_keras
|
19
|
+
|
20
|
+
|
21
|
+
class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
|
22
|
+
"""Multi-head block sparse attention layer."""
|
23
|
+
|
24
|
+
def __init__(self, src_block_size=None, tgt_block_size=None, **kwargs):
|
25
|
+
"""Initializes the block sparse attention layer.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
src_block_size: The block size of the query. An integer that divides the
|
29
|
+
sequence length into blocks.
|
30
|
+
tgt_block_size: The block size of the key/value. An integer that divides
|
31
|
+
the sequence length into blocks. The number of blocks in the source and
|
32
|
+
target must be the same.
|
33
|
+
**kwargs: Args passed to the base class.
|
34
|
+
"""
|
35
|
+
super().__init__(**kwargs)
|
36
|
+
if src_block_size is None or src_block_size <= 0:
|
37
|
+
raise ValueError("src_block_size must be specified.")
|
38
|
+
self._src_block_size = src_block_size
|
39
|
+
self._tgt_block_size = tgt_block_size or self._src_block_size
|
40
|
+
|
41
|
+
def _build_from_signature(self, query, value, key=None):
|
42
|
+
# pytype: disable=attribute-error
|
43
|
+
super()._build_from_signature(query, value, key)
|
44
|
+
# pytype: enable=attribute-error
|
45
|
+
# The following capital letters are used to denote the tensor dimension
|
46
|
+
# parameters:
|
47
|
+
# B = batch size
|
48
|
+
# S = length of the key/value (target)
|
49
|
+
# D = model dimension.
|
50
|
+
# T = length of the query (source)
|
51
|
+
# t = block size of the source.
|
52
|
+
# s = block size of the target.
|
53
|
+
# L = number of blocks in the source/target.
|
54
|
+
# N = number of attention heads
|
55
|
+
# H = dimensions of each attention head.
|
56
|
+
with tf.init_scope():
|
57
|
+
proj_einsum_eqn = "BTD,DNH->BNTH"
|
58
|
+
bias_axes = "NH"
|
59
|
+
qk_output_shape = [
|
60
|
+
self._num_heads,
|
61
|
+
None,
|
62
|
+
self._key_dim,
|
63
|
+
]
|
64
|
+
v_output_shape = [
|
65
|
+
self._num_heads,
|
66
|
+
None,
|
67
|
+
self._value_dim,
|
68
|
+
]
|
69
|
+
self._query_dense = tf_keras.layers.EinsumDense(
|
70
|
+
proj_einsum_eqn,
|
71
|
+
output_shape=qk_output_shape,
|
72
|
+
bias_axes=bias_axes if self._use_bias else None,
|
73
|
+
name="query",
|
74
|
+
**self._get_common_kwargs_for_sublayer(),
|
75
|
+
)
|
76
|
+
self._key_dense = tf_keras.layers.EinsumDense(
|
77
|
+
proj_einsum_eqn,
|
78
|
+
output_shape=qk_output_shape,
|
79
|
+
bias_axes=bias_axes if self._use_bias else None,
|
80
|
+
name="key",
|
81
|
+
**self._get_common_kwargs_for_sublayer(),
|
82
|
+
)
|
83
|
+
self._value_dense = tf_keras.layers.EinsumDense(
|
84
|
+
proj_einsum_eqn,
|
85
|
+
output_shape=v_output_shape,
|
86
|
+
bias_axes=bias_axes if self._use_bias else None,
|
87
|
+
name="value",
|
88
|
+
**self._get_common_kwargs_for_sublayer(),
|
89
|
+
)
|
90
|
+
self._dot_product_equation = "BNLsH,BNLtH->BNLts"
|
91
|
+
self._combine_equation = "BNLts,BNLsH->BNLtH"
|
92
|
+
if self._output_shape:
|
93
|
+
if not isinstance(self._output_shape, collections.abc.Sized):
|
94
|
+
output_shape = [self._output_shape]
|
95
|
+
else:
|
96
|
+
output_shape = self._output_shape
|
97
|
+
else:
|
98
|
+
output_shape = [self._query_shape[-1]]
|
99
|
+
output_shape = [None] + output_shape
|
100
|
+
self._output_dense = tf_keras.layers.EinsumDense(
|
101
|
+
"BNTH,DNH->BTD",
|
102
|
+
output_shape=output_shape,
|
103
|
+
bias_axes="D" if self._use_bias else None,
|
104
|
+
name="attention_output",
|
105
|
+
**self._get_common_kwargs_for_sublayer(),
|
106
|
+
)
|
107
|
+
|
108
|
+
def _block_diagonal_mask(self, attention_mask, dtype=None):
|
109
|
+
"""Converts the attention mask to block diagonal."""
|
110
|
+
# Uses the same key mask for the entire query sequence since softmax
|
111
|
+
# is applied only on the key axis.
|
112
|
+
attention_mask = tf.cast(attention_mask[:, 0, :], dtype=dtype)
|
113
|
+
tgt_num_blocks = self._key_shape[-2] // self._tgt_block_size
|
114
|
+
attention_mask = tf.reshape(
|
115
|
+
attention_mask,
|
116
|
+
[
|
117
|
+
-1,
|
118
|
+
tgt_num_blocks,
|
119
|
+
self._tgt_block_size,
|
120
|
+
],
|
121
|
+
)
|
122
|
+
return tf.einsum("BLQ,BLK->BLQK", attention_mask, attention_mask)
|
123
|
+
|
124
|
+
def _masked_softmax(self, attention_scores, attention_mask=None):
|
125
|
+
# Normalize the attention scores to probabilities.
|
126
|
+
# `attention_scores` = [B, N, L, T, S]
|
127
|
+
if attention_mask is not None:
|
128
|
+
# `attention_mask` = [B, 1, L, T, S]
|
129
|
+
attention_mask = tf.expand_dims(attention_mask, axis=1)
|
130
|
+
return self._softmax(attention_scores, attention_mask)
|
131
|
+
|
132
|
+
def _compute_attention(
|
133
|
+
self, query, key, value, attention_mask=None, training=None
|
134
|
+
):
|
135
|
+
# src_num_blocks and tgt_num_blocks are the number of blocks in the source
|
136
|
+
# and target. Care should be taken to ensure that the number of blocks in
|
137
|
+
# the source and target are the same.
|
138
|
+
if self._query_shape[-2] % self._src_block_size != 0:
|
139
|
+
raise ValueError(
|
140
|
+
"query_shape[-2] must be divisible by src_block_size."
|
141
|
+
)
|
142
|
+
if self._key_shape[-2] % self._tgt_block_size != 0:
|
143
|
+
raise ValueError(
|
144
|
+
"key_shape[-2] must be divisible by tgt_block_size."
|
145
|
+
)
|
146
|
+
src_num_blocks = self._query_shape[-2] // self._src_block_size
|
147
|
+
tgt_num_blocks = self._key_shape[-2] // self._tgt_block_size
|
148
|
+
|
149
|
+
if src_num_blocks != tgt_num_blocks:
|
150
|
+
raise ValueError(
|
151
|
+
"src_num_blocks must be equal to tgt_num_blocks."
|
152
|
+
)
|
153
|
+
# Convert the query/key/value into blocks to perform block diagonal
|
154
|
+
# attention.
|
155
|
+
query_blocks = tf.reshape(query, [
|
156
|
+
-1,
|
157
|
+
self._num_heads,
|
158
|
+
src_num_blocks,
|
159
|
+
self._src_block_size,
|
160
|
+
self._key_dim,
|
161
|
+
])
|
162
|
+
key_blocks = tf.reshape(key, [
|
163
|
+
-1,
|
164
|
+
self._num_heads,
|
165
|
+
tgt_num_blocks,
|
166
|
+
self._tgt_block_size,
|
167
|
+
self._key_dim,
|
168
|
+
])
|
169
|
+
value_blocks = tf.reshape(value, [
|
170
|
+
-1,
|
171
|
+
self._num_heads,
|
172
|
+
tgt_num_blocks,
|
173
|
+
self._tgt_block_size,
|
174
|
+
self._value_dim,
|
175
|
+
])
|
176
|
+
if attention_mask is not None:
|
177
|
+
attention_mask = self._block_diagonal_mask(attention_mask, key.dtype)
|
178
|
+
# pytype: disable=attribute-error
|
179
|
+
attention_output, attention_scores = super()._compute_attention(
|
180
|
+
query_blocks,
|
181
|
+
key_blocks,
|
182
|
+
value_blocks,
|
183
|
+
attention_mask=attention_mask,
|
184
|
+
training=training,
|
185
|
+
)
|
186
|
+
# pytype: enable=attribute-error
|
187
|
+
# Reshape the attention output to the original shape.
|
188
|
+
attention_output = tf.reshape(attention_output, [
|
189
|
+
-1,
|
190
|
+
self._num_heads,
|
191
|
+
self._query_shape[1],
|
192
|
+
self._value_dim,
|
193
|
+
])
|
194
|
+
return attention_output, attention_scores
|
195
|
+
|
196
|
+
def call(
|
197
|
+
self,
|
198
|
+
query,
|
199
|
+
value,
|
200
|
+
key=None,
|
201
|
+
attention_mask=None,
|
202
|
+
return_attention_scores=False,
|
203
|
+
training=None,
|
204
|
+
use_causal_mask=False,
|
205
|
+
):
|
206
|
+
if use_causal_mask:
|
207
|
+
raise ValueError("use_causal_mask is not supported.")
|
208
|
+
return super().call(
|
209
|
+
query,
|
210
|
+
value,
|
211
|
+
key=key,
|
212
|
+
attention_mask=attention_mask,
|
213
|
+
return_attention_scores=return_attention_scores,
|
214
|
+
training=training,
|
215
|
+
use_causal_mask=use_causal_mask,
|
216
|
+
)
|
@@ -0,0 +1,303 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Tests for block sparse attention layer."""
|
16
|
+
|
17
|
+
from absl.testing import parameterized
|
18
|
+
import numpy as np
|
19
|
+
import tensorflow as tf, tf_keras
|
20
|
+
|
21
|
+
from official.nlp.modeling.layers import block_sparse_attention
|
22
|
+
|
23
|
+
|
24
|
+
class BlockSparseAttentionTest(tf.test.TestCase, parameterized.TestCase):
|
25
|
+
|
26
|
+
@parameterized.named_parameters(
|
27
|
+
("key_value_same_proj", None, None, [40, 80]),
|
28
|
+
("key_value_different_proj", 32, 60, [40, 60]),
|
29
|
+
)
|
30
|
+
def test_non_masked_attention(self, value_dim, output_shape, output_dims):
|
31
|
+
"""Test that the attention layer can be created without a mask tensor."""
|
32
|
+
test_layer = block_sparse_attention.MultiHeadAttention(
|
33
|
+
num_heads=12,
|
34
|
+
key_dim=64,
|
35
|
+
value_dim=value_dim,
|
36
|
+
output_shape=output_shape,
|
37
|
+
src_block_size=10,
|
38
|
+
tgt_block_size=5,
|
39
|
+
)
|
40
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
41
|
+
query = tf_keras.Input(shape=(40, 80))
|
42
|
+
value = tf_keras.Input(shape=(20, 80))
|
43
|
+
output = test_layer(query=query, value=value)
|
44
|
+
self.assertEqual(output.shape.as_list(), [None] + output_dims)
|
45
|
+
|
46
|
+
def test_non_masked_self_attention(self):
|
47
|
+
"""Test with one input (self-attenntion) and no mask tensor."""
|
48
|
+
test_layer = block_sparse_attention.MultiHeadAttention(
|
49
|
+
num_heads=12, key_dim=64, src_block_size=10, tgt_block_size=10
|
50
|
+
)
|
51
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
52
|
+
query = tf_keras.Input(shape=(40, 80))
|
53
|
+
output = test_layer(query, query)
|
54
|
+
self.assertEqual(output.shape.as_list(), [None, 40, 80])
|
55
|
+
|
56
|
+
@parameterized.named_parameters(("with_bias", True), ("no_bias", False))
|
57
|
+
def test_masked_attention(self, use_bias):
|
58
|
+
"""Test with a mask tensor."""
|
59
|
+
test_layer = block_sparse_attention.MultiHeadAttention(
|
60
|
+
num_heads=4, key_dim=2, use_bias=use_bias, src_block_size=2,
|
61
|
+
tgt_block_size=1,
|
62
|
+
)
|
63
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
64
|
+
batch_size = 3
|
65
|
+
query = tf_keras.Input(shape=(4, 8))
|
66
|
+
value = tf_keras.Input(shape=(2, 8))
|
67
|
+
mask_tensor = tf_keras.Input(shape=(4, 2))
|
68
|
+
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
|
69
|
+
|
70
|
+
# Create a model containing the test layer.
|
71
|
+
model = tf_keras.Model([query, value, mask_tensor], output)
|
72
|
+
|
73
|
+
# Generate data for the input (non-mask) tensors.
|
74
|
+
from_data = 10 * np.random.random_sample((batch_size, 4, 8))
|
75
|
+
to_data = 10 * np.random.random_sample((batch_size, 2, 8))
|
76
|
+
|
77
|
+
# Invoke the data with a random set of mask data. This should mask at
|
78
|
+
# least one element.
|
79
|
+
mask_data = np.random.randint(2, size=(batch_size, 4, 2))
|
80
|
+
masked_output_data = model.predict([from_data, to_data, mask_data])
|
81
|
+
|
82
|
+
# Invoke the same data, but with a null mask (where no elements are
|
83
|
+
# masked).
|
84
|
+
null_mask_data = np.ones((batch_size, 4, 2))
|
85
|
+
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
|
86
|
+
|
87
|
+
# Because one data is masked and one is not, the outputs should not be
|
88
|
+
# the same.
|
89
|
+
self.assertNotAllClose(masked_output_data, unmasked_output_data)
|
90
|
+
|
91
|
+
# Tests the layer with three inputs: Q, K, V.
|
92
|
+
key = tf_keras.Input(shape=(2, 8))
|
93
|
+
output = test_layer(
|
94
|
+
query, value=value, key=key, attention_mask=mask_tensor
|
95
|
+
)
|
96
|
+
model = tf_keras.Model([query, value, key, mask_tensor], output)
|
97
|
+
|
98
|
+
masked_output_data = model.predict(
|
99
|
+
[from_data, to_data, to_data, mask_data]
|
100
|
+
)
|
101
|
+
unmasked_output_data = model.predict(
|
102
|
+
[from_data, to_data, to_data, null_mask_data]
|
103
|
+
)
|
104
|
+
# Because one data is masked and one is not, the outputs should not be
|
105
|
+
# the same.
|
106
|
+
self.assertNotAllClose(masked_output_data, unmasked_output_data)
|
107
|
+
|
108
|
+
if use_bias:
|
109
|
+
self.assertLen(test_layer._query_dense.trainable_variables, 2)
|
110
|
+
self.assertLen(test_layer._output_dense.trainable_variables, 2)
|
111
|
+
else:
|
112
|
+
self.assertLen(test_layer._query_dense.trainable_variables, 1)
|
113
|
+
self.assertLen(test_layer._output_dense.trainable_variables, 1)
|
114
|
+
|
115
|
+
def test_masked_attention_with_scores(self):
|
116
|
+
"""Test with a mask tensor."""
|
117
|
+
test_layer = block_sparse_attention.MultiHeadAttention(
|
118
|
+
num_heads=4, key_dim=2, src_block_size=2, tgt_block_size=1,
|
119
|
+
)
|
120
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
121
|
+
batch_size = 3
|
122
|
+
query = tf_keras.Input(shape=(4, 8))
|
123
|
+
value = tf_keras.Input(shape=(2, 8))
|
124
|
+
mask_tensor = tf_keras.Input(shape=(4, 2))
|
125
|
+
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
|
126
|
+
|
127
|
+
# Create a model containing the test layer.
|
128
|
+
model = tf_keras.Model([query, value, mask_tensor], output)
|
129
|
+
|
130
|
+
# Generate data for the input (non-mask) tensors.
|
131
|
+
from_data = 10 * np.random.random_sample((batch_size, 4, 8))
|
132
|
+
to_data = 10 * np.random.random_sample((batch_size, 2, 8))
|
133
|
+
|
134
|
+
# Invoke the data with a random set of mask data. This should mask at
|
135
|
+
# least one element.
|
136
|
+
mask_data = np.random.randint(2, size=(batch_size, 4, 2))
|
137
|
+
masked_output_data = model.predict([from_data, to_data, mask_data])
|
138
|
+
|
139
|
+
# Invoke the same data, but with a null mask (where no elements are
|
140
|
+
# masked).
|
141
|
+
null_mask_data = np.ones((batch_size, 4, 2))
|
142
|
+
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
|
143
|
+
|
144
|
+
# Because one data is masked and one is not, the outputs should not be
|
145
|
+
# the same.
|
146
|
+
self.assertNotAllClose(masked_output_data, unmasked_output_data)
|
147
|
+
|
148
|
+
# Create a model containing attention scores.
|
149
|
+
output, scores = test_layer(
|
150
|
+
query=query,
|
151
|
+
value=value,
|
152
|
+
attention_mask=mask_tensor,
|
153
|
+
return_attention_scores=True,
|
154
|
+
)
|
155
|
+
model = tf_keras.Model([query, value, mask_tensor], [output, scores])
|
156
|
+
masked_output_data_score, masked_score = model.predict(
|
157
|
+
[from_data, to_data, mask_data]
|
158
|
+
)
|
159
|
+
unmasked_output_data_score, unmasked_score = model.predict(
|
160
|
+
[from_data, to_data, null_mask_data]
|
161
|
+
)
|
162
|
+
self.assertNotAllClose(masked_output_data_score, unmasked_output_data_score)
|
163
|
+
self.assertAllClose(masked_output_data, masked_output_data_score)
|
164
|
+
self.assertAllClose(unmasked_output_data, unmasked_output_data_score)
|
165
|
+
self.assertNotAllClose(masked_score, unmasked_score)
|
166
|
+
|
167
|
+
def test_initializer(self):
|
168
|
+
"""Test with a specified initializer."""
|
169
|
+
test_layer = block_sparse_attention.MultiHeadAttention(
|
170
|
+
num_heads=12,
|
171
|
+
key_dim=64,
|
172
|
+
src_block_size=10,
|
173
|
+
kernel_initializer=tf_keras.initializers.TruncatedNormal(stddev=0.02),
|
174
|
+
)
|
175
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
176
|
+
query = tf_keras.Input(shape=(40, 80))
|
177
|
+
output = test_layer(query, query)
|
178
|
+
self.assertEqual(output.shape.as_list(), [None, 40, 80])
|
179
|
+
|
180
|
+
# Make sure the sub layers have different kernel init value, and not
|
181
|
+
# reusing the initializers.
|
182
|
+
self.assertNotAllClose(
|
183
|
+
tf_keras.backend.eval(test_layer._query_dense.kernel),
|
184
|
+
tf_keras.backend.eval(test_layer._key_dense.kernel),
|
185
|
+
)
|
186
|
+
self.assertNotAllClose(
|
187
|
+
tf_keras.backend.eval(test_layer._query_dense.kernel),
|
188
|
+
tf_keras.backend.eval(test_layer._value_dense.kernel),
|
189
|
+
)
|
190
|
+
self.assertNotAllClose(
|
191
|
+
tf_keras.backend.eval(test_layer._query_dense.kernel),
|
192
|
+
tf_keras.backend.eval(test_layer._output_dense.kernel),
|
193
|
+
)
|
194
|
+
|
195
|
+
@parameterized.named_parameters(
|
196
|
+
("bfloat16", tf.bfloat16),
|
197
|
+
("float16", tf.float16),
|
198
|
+
("float32", tf.float32),
|
199
|
+
("float64", tf.float64),
|
200
|
+
)
|
201
|
+
def test_sublayer_dtypes(self, dtype):
|
202
|
+
test_layer = block_sparse_attention.MultiHeadAttention(
|
203
|
+
num_heads=12, key_dim=64, src_block_size=10, dtype=dtype
|
204
|
+
)
|
205
|
+
|
206
|
+
query = tf_keras.Input(shape=(40, 80), dtype=dtype)
|
207
|
+
# Build the layer
|
208
|
+
test_layer(query=query, value=query)
|
209
|
+
|
210
|
+
self.assertEqual(test_layer._query_dense.dtype, dtype)
|
211
|
+
self.assertEqual(test_layer._key_dense.dtype, dtype)
|
212
|
+
self.assertEqual(test_layer._value_dense.dtype, dtype)
|
213
|
+
self.assertEqual(test_layer._output_dense.dtype, dtype)
|
214
|
+
|
215
|
+
def test_dropout(self):
|
216
|
+
test_layer = block_sparse_attention.MultiHeadAttention(
|
217
|
+
num_heads=2, key_dim=2, dropout=0.5, src_block_size=2, tgt_block_size=1,
|
218
|
+
)
|
219
|
+
|
220
|
+
# Generate data for the input (non-mask) tensors.
|
221
|
+
from_data = tf_keras.backend.ones(shape=(32, 4, 8))
|
222
|
+
to_data = tf_keras.backend.ones(shape=(32, 2, 8))
|
223
|
+
train_out = test_layer(from_data, to_data, None, None, None, True)
|
224
|
+
test_out = test_layer(from_data, to_data, None, None, None, False)
|
225
|
+
|
226
|
+
# Output should be close when not in training mode,
|
227
|
+
# and should not be close when enabling dropout in training mode.
|
228
|
+
self.assertNotAllClose(
|
229
|
+
tf_keras.backend.eval(train_out), tf_keras.backend.eval(test_out)
|
230
|
+
)
|
231
|
+
|
232
|
+
def test_query_mask_progagation(self):
|
233
|
+
"""Test automatic propagation of the query's mask."""
|
234
|
+
test_layer = block_sparse_attention.MultiHeadAttention(
|
235
|
+
num_heads=2,
|
236
|
+
key_dim=2,
|
237
|
+
src_block_size=2,
|
238
|
+
tgt_block_size=1,
|
239
|
+
)
|
240
|
+
self.assertTrue(test_layer.supports_masking)
|
241
|
+
query = tf.constant(
|
242
|
+
[[1, 2, 3, 0, 0, 0], [3, 3, 1, 1, 2, 0], [1, 1, 0, 0, 0, 0]]
|
243
|
+
)
|
244
|
+
masked_query = tf_keras.layers.Embedding(4, 8, mask_zero=True)(query)
|
245
|
+
value = tf.random.normal((3, 3, 8))
|
246
|
+
output = test_layer(query=masked_query, value=value)
|
247
|
+
self.assertTrue(hasattr(output, "_keras_mask"))
|
248
|
+
self.assertAllEqual(masked_query._keras_mask, output._keras_mask)
|
249
|
+
|
250
|
+
def test_value_mask(self):
|
251
|
+
"""Test that the value mask is taken into account."""
|
252
|
+
test_layer = block_sparse_attention.MultiHeadAttention(
|
253
|
+
num_heads=2,
|
254
|
+
key_dim=2,
|
255
|
+
src_block_size=2,
|
256
|
+
tgt_block_size=1,
|
257
|
+
)
|
258
|
+
query = tf.constant(
|
259
|
+
[[1, 2, 3, 0, 0, 0], [3, 3, 1, 1, 2, 0], [1, 1, 0, 0, 0, 0]]
|
260
|
+
)
|
261
|
+
masked_query = tf_keras.layers.Embedding(4, 8, mask_zero=True)(query)
|
262
|
+
value = tf.constant([[5, 4, 0], [3, 0, 0], [2, 1, 1]])
|
263
|
+
masked_value = tf_keras.layers.Embedding(6, 8, mask_zero=True)(value)
|
264
|
+
output = test_layer(
|
265
|
+
query=masked_query,
|
266
|
+
value=masked_value,
|
267
|
+
)
|
268
|
+
mask = tf.constant(
|
269
|
+
[[[True, True, False]] * 3 + [[False, False, False]] * 2]
|
270
|
+
+ [[[True, False, False]] * 5]
|
271
|
+
+ [[[True, True, True]] + [[False, False, False]] * 4]
|
272
|
+
)
|
273
|
+
del masked_query._keras_mask
|
274
|
+
del masked_value._keras_mask
|
275
|
+
output_with_manual_mask = test_layer(
|
276
|
+
query=masked_query, value=masked_value, attention_mask=mask
|
277
|
+
)
|
278
|
+
self.assertAllClose(output, output_with_manual_mask)
|
279
|
+
|
280
|
+
def test_masks_are_cast_to_bool(self):
|
281
|
+
"""Test that the implicit and explicit masks are cast to bool."""
|
282
|
+
test_layer = block_sparse_attention.MultiHeadAttention(
|
283
|
+
num_heads=2, key_dim=2, src_block_size=2, tgt_block_size=1,
|
284
|
+
)
|
285
|
+
query = np.array(
|
286
|
+
[[1, 2, 3, 0, 0, 0], [3, 3, 1, 1, 2, 0], [1, 1, 0, 0, 0, 0]]
|
287
|
+
)
|
288
|
+
masked_query = tf_keras.layers.Embedding(4, 8, mask_zero=True)(query)
|
289
|
+
masked_query._keras_mask = tf.cast(masked_query._keras_mask, tf.float32)
|
290
|
+
value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]])
|
291
|
+
masked_value = tf_keras.layers.Embedding(6, 8, mask_zero=True)(value)
|
292
|
+
masked_value._keras_mask = tf.cast(masked_value._keras_mask, tf.float32)
|
293
|
+
float_mask = tf.constant([[[1.0]]])
|
294
|
+
# if all works well, the following should not raise any exception:
|
295
|
+
_ = test_layer(
|
296
|
+
query=masked_query,
|
297
|
+
value=masked_value,
|
298
|
+
attention_mask=float_mask,
|
299
|
+
)
|
300
|
+
|
301
|
+
|
302
|
+
if __name__ == "__main__":
|
303
|
+
tf.test.main()
|
@@ -305,6 +305,8 @@ official/nlp/modeling/layers/bigbird_attention.py,sha256=dzutgRoQt2DFsYMpMILv_QF
|
|
305
305
|
official/nlp/modeling/layers/bigbird_attention_test.py,sha256=cBYwK5k1rnykZ0gif-n7VaByLIoElA-N0_svCRKASoU,2206
|
306
306
|
official/nlp/modeling/layers/block_diag_feedforward.py,sha256=FDEt-J_QjOxwar3eT5yjMs4hR41Ppke1zj7iswsZR4M,7243
|
307
307
|
official/nlp/modeling/layers/block_diag_feedforward_test.py,sha256=wcg8In6FIOCxcKqe5rucftjJ_kUWTi9Ei7eEmlVCYpE,4181
|
308
|
+
official/nlp/modeling/layers/block_sparse_attention.py,sha256=Vjy0JULOb9u6-EzD460kXCotsibqyD29imlmrb7aVSY,7580
|
309
|
+
official/nlp/modeling/layers/block_sparse_attention_test.py,sha256=YF2_-I27INUFtu-WP7s7C1kpYmsobNIGOWM1iUvSD5Y,12041
|
308
310
|
official/nlp/modeling/layers/cls_head.py,sha256=0X_gdjnAt6TZVrH_xkDcQCpwLuVz5Pb7d04wEVN_Kn8,16208
|
309
311
|
official/nlp/modeling/layers/cls_head_test.py,sha256=01oMmiuyp1lDEXBYa9r3krn6BtH-QuSedGOca9LViEc,8888
|
310
312
|
official/nlp/modeling/layers/factorized_embedding.py,sha256=4oFRYJbpoaSxqv8hTWY2JPGPllp-zhniz99IyRtlzV8,2902
|
@@ -1210,9 +1212,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
|
|
1210
1212
|
tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
|
1211
1213
|
tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
|
1212
1214
|
tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
|
1213
|
-
tf_models_nightly-2.17.0.
|
1214
|
-
tf_models_nightly-2.17.0.
|
1215
|
-
tf_models_nightly-2.17.0.
|
1216
|
-
tf_models_nightly-2.17.0.
|
1217
|
-
tf_models_nightly-2.17.0.
|
1218
|
-
tf_models_nightly-2.17.0.
|
1215
|
+
tf_models_nightly-2.17.0.dev20240608.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
|
1216
|
+
tf_models_nightly-2.17.0.dev20240608.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
|
1217
|
+
tf_models_nightly-2.17.0.dev20240608.dist-info/METADATA,sha256=7_DZ26Maeag9jKW02vKvFZuYmrffrDvIUMLFR0DVLFI,1432
|
1218
|
+
tf_models_nightly-2.17.0.dev20240608.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
|
1219
|
+
tf_models_nightly-2.17.0.dev20240608.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
|
1220
|
+
tf_models_nightly-2.17.0.dev20240608.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|