tf-models-nightly 2.20.0.dev20251028__py2.py3-none-any.whl → 2.20.0.dev20251104__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/nlp/modeling/layers/multi_query_attention.py +216 -3
- official/nlp/modeling/layers/multi_query_attention_test.py +200 -0
- official/nlp/modeling/layers/transformer_encoder_block.py +28 -1
- official/nlp/modeling/layers/transformer_encoder_block_test.py +22 -9
- {tf_models_nightly-2.20.0.dev20251028.dist-info → tf_models_nightly-2.20.0.dev20251104.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.20.0.dev20251028.dist-info → tf_models_nightly-2.20.0.dev20251104.dist-info}/RECORD +10 -10
- {tf_models_nightly-2.20.0.dev20251028.dist-info → tf_models_nightly-2.20.0.dev20251104.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.20.0.dev20251028.dist-info → tf_models_nightly-2.20.0.dev20251104.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.20.0.dev20251028.dist-info → tf_models_nightly-2.20.0.dev20251104.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.20.0.dev20251028.dist-info → tf_models_nightly-2.20.0.dev20251104.dist-info}/top_level.txt +0 -0
|
@@ -18,10 +18,13 @@ Based on https://arxiv.org/pdf/1911.02150.pdf and
|
|
|
18
18
|
https://arxiv.org/pdf/2305.13245.pdf.
|
|
19
19
|
"""
|
|
20
20
|
|
|
21
|
+
import math
|
|
21
22
|
import string
|
|
22
23
|
from typing import Optional, Sequence, Union
|
|
23
24
|
|
|
25
|
+
import gin
|
|
24
26
|
import tensorflow as tf, tf_keras
|
|
27
|
+
from official.modeling import tf_utils
|
|
25
28
|
|
|
26
29
|
_CHR_IDX = string.ascii_lowercase
|
|
27
30
|
|
|
@@ -78,7 +81,9 @@ def _get_output_shape(
|
|
|
78
81
|
class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
|
|
79
82
|
"""Multi-query attention layer."""
|
|
80
83
|
|
|
81
|
-
def __init__(
|
|
84
|
+
def __init__(
|
|
85
|
+
self, num_kv_heads=None, enable_gqa_optimization=False, **kwargs
|
|
86
|
+
):
|
|
82
87
|
# num_kv_heads defines the number of key/value heads. A value of 1 means
|
|
83
88
|
# that the key/value heads are shared across all query heads. Any other
|
|
84
89
|
# value must be less than num_heads and must divide num_heads exactly. If
|
|
@@ -86,6 +91,16 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
|
|
|
86
91
|
# num_kv_heads.
|
|
87
92
|
super().__init__(**kwargs)
|
|
88
93
|
self._num_kv_heads = num_kv_heads or self._num_heads
|
|
94
|
+
# TODO(akandoor): Remove this flag once the GQA optimization is rolled out.
|
|
95
|
+
# This flag is used to enable order of K,G in the einsum equations.
|
|
96
|
+
# This optimization is only used in GQA, and is disabled by default.
|
|
97
|
+
# If enabled, the einsum equations are:
|
|
98
|
+
# 1. Dot product: "...SKH,...TKGH->...KGTS"
|
|
99
|
+
# 2. Combine: "...KGTS,...SKH->...TKGH"
|
|
100
|
+
# If disabled, the einsum equations are:
|
|
101
|
+
# 1. Dot product: "...SKH,...TKnH->...nKTS"
|
|
102
|
+
# 2. Combine: "...nKTS,...SKH->...TnKH"
|
|
103
|
+
self._enable_gqa_optimization = enable_gqa_optimization
|
|
89
104
|
assert (
|
|
90
105
|
self._num_kv_heads < self._num_heads
|
|
91
106
|
), "num_kv_heads must be less than num_heads."
|
|
@@ -143,8 +158,12 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
|
|
|
143
158
|
output_dims = 2
|
|
144
159
|
key_last_dims = [self._num_kv_heads, self._key_dim]
|
|
145
160
|
value_last_dims = [self._num_kv_heads, self._value_dim]
|
|
146
|
-
self.
|
|
147
|
-
|
|
161
|
+
if self._enable_gqa_optimization:
|
|
162
|
+
self._dot_product_equation = "...SKH,...TKGH->...KGTS"
|
|
163
|
+
self._combine_equation = "...KGTS,...SKH->...TKGH"
|
|
164
|
+
else:
|
|
165
|
+
self._dot_product_equation = "...SKH,...TKnH->...nKTS"
|
|
166
|
+
self._combine_equation = "...nKTS,...SKH->...TnKH"
|
|
148
167
|
|
|
149
168
|
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
|
150
169
|
free_dims=self._key_shape.rank - 1,
|
|
@@ -170,6 +189,9 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
|
|
|
170
189
|
name="value",
|
|
171
190
|
**self._get_common_kwargs_for_sublayer(),
|
|
172
191
|
)
|
|
192
|
+
self._qkv_rank = (
|
|
193
|
+
output_rank if self._num_kv_heads > 1 else output_rank + 1
|
|
194
|
+
)
|
|
173
195
|
|
|
174
196
|
def _compute_attention(
|
|
175
197
|
self, query, key, value, attention_mask=None, training=None
|
|
@@ -211,3 +233,194 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
|
|
|
211
233
|
],
|
|
212
234
|
)
|
|
213
235
|
return attention_output, attention_scores
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@tf_keras.utils.register_keras_serializable(package="Text")
|
|
239
|
+
@gin.configurable
|
|
240
|
+
class TalkingHeadsMultiQueryAttention(MultiHeadAttention):
|
|
241
|
+
"""Implements Talking-Heads Attention combined with Multi-Query Attention.
|
|
242
|
+
|
|
243
|
+
See https://arxiv.org/pdf/2003.02436 for more details.
|
|
244
|
+
TODO(akandoor): Make num talking heads configurable. Currently, num talking
|
|
245
|
+
heads is fixed to num query heads.
|
|
246
|
+
|
|
247
|
+
This class inherits from MultiQueryAttention to get the MQA-specific
|
|
248
|
+
logic for __init__, get_config.
|
|
249
|
+
|
|
250
|
+
It then overrides _build_from_signature to add the talking-heads weights
|
|
251
|
+
and overrides _compute_attention to merge the MQA wrapper (for
|
|
252
|
+
reshaping) with the THA computation (for pre/post-softmax projections).
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
def _build_from_signature(
|
|
256
|
+
self,
|
|
257
|
+
query: Union[tf.Tensor, tf.TensorShape],
|
|
258
|
+
value: Union[tf.Tensor, tf.TensorShape],
|
|
259
|
+
key: Optional[Union[tf.Tensor, tf.TensorShape]] = None,
|
|
260
|
+
):
|
|
261
|
+
"""Builds layers and variables."""
|
|
262
|
+
# Call the parent (MultiQueryAttention) _build_from_signature.
|
|
263
|
+
super()._build_from_signature(query=query, value=value, key=key)
|
|
264
|
+
# Now, *after* all MQA setup is done, we add the THA setup logic.
|
|
265
|
+
qkv_rank = self._qkv_rank
|
|
266
|
+
# TalkingHeadsAttention logic to the MQA build logic.
|
|
267
|
+
num_batch_dims = qkv_rank - len(self._attention_axes) - 2
|
|
268
|
+
attn_scores_rank = num_batch_dims + 1 + len(self._attention_axes) * 2
|
|
269
|
+
scores_notation = _CHR_IDX[:attn_scores_rank]
|
|
270
|
+
projection_notation = scores_notation[num_batch_dims] + (
|
|
271
|
+
_CHR_IDX[attn_scores_rank])
|
|
272
|
+
projected_scores_notation = scores_notation[:num_batch_dims] + (
|
|
273
|
+
_CHR_IDX[attn_scores_rank] + scores_notation[num_batch_dims + 1:])
|
|
274
|
+
self._talking_heads_equation = "%s,%s->%s" % (
|
|
275
|
+
scores_notation, projection_notation, projected_scores_notation)
|
|
276
|
+
|
|
277
|
+
with tf.init_scope():
|
|
278
|
+
self._pre_softmax_weight = self.add_weight(
|
|
279
|
+
"pre_softmax_weight",
|
|
280
|
+
shape=(self._num_heads, self._num_heads),
|
|
281
|
+
initializer=tf_utils.clone_initializer(self._kernel_initializer),
|
|
282
|
+
regularizer=self._kernel_regularizer,
|
|
283
|
+
constraint=self._kernel_constraint,
|
|
284
|
+
dtype=self.dtype,
|
|
285
|
+
trainable=True)
|
|
286
|
+
self._post_softmax_weight = self.add_weight(
|
|
287
|
+
"post_softmax_weight",
|
|
288
|
+
shape=(self._num_heads, self._num_heads),
|
|
289
|
+
initializer=tf_utils.clone_initializer(self._kernel_initializer),
|
|
290
|
+
regularizer=self._kernel_regularizer,
|
|
291
|
+
constraint=self._kernel_constraint,
|
|
292
|
+
dtype=self.dtype,
|
|
293
|
+
trainable=True)
|
|
294
|
+
|
|
295
|
+
def _compute_attention(
|
|
296
|
+
self, query, key, value, attention_mask=None, training=None
|
|
297
|
+
):
|
|
298
|
+
"""Applies Dot-product attention, merging MQA wrapper and THA computation.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
query: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
|
|
302
|
+
key: Projected key `Tensor` of shape `[B, T, N, key_dim]`.
|
|
303
|
+
value: Projected value `Tensor` of shape `[B, T, N, value_dim]`.
|
|
304
|
+
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
|
|
305
|
+
attention to certain positions.
|
|
306
|
+
training: Python boolean indicating whether the layer should behave in
|
|
307
|
+
training mode (adding dropout) or in inference mode (doing nothing).
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
attention_output: Multi-headed outputs of attention computation.
|
|
311
|
+
attention_scores: Multi-headed attention weights.
|
|
312
|
+
"""
|
|
313
|
+
# This is the MQA "wrapper" logic for grouped queries
|
|
314
|
+
query_shape = tf.shape(query)
|
|
315
|
+
if self._num_kv_heads > 1:
|
|
316
|
+
query = tf.reshape(
|
|
317
|
+
query,
|
|
318
|
+
[
|
|
319
|
+
query_shape[0],
|
|
320
|
+
query_shape[1],
|
|
321
|
+
self._num_kv_heads,
|
|
322
|
+
self._num_heads // self._num_kv_heads,
|
|
323
|
+
query_shape[-1],
|
|
324
|
+
],
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# This is the THA "computation" logic
|
|
328
|
+
# Note: Applying scalar multiply at the smaller end of einsum improves
|
|
329
|
+
# XLA performance, but may introduce slight numeric differences in
|
|
330
|
+
# the Transformer attention head.
|
|
331
|
+
query = tf.multiply(
|
|
332
|
+
query, 1.0 / math.sqrt(float(self._key_dim))
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# Note: self._dot_product_equation was set by _build_from_signature
|
|
336
|
+
# (from MQA) to be MQA-compatible.
|
|
337
|
+
attention_scores = tf.einsum(self._dot_product_equation, key, query)
|
|
338
|
+
|
|
339
|
+
# --- Talking-Heads modification for MQA ---
|
|
340
|
+
# The THA _talking_heads_equation expects scores of shape [B, N, T, S].
|
|
341
|
+
# The MQA _dot_product_equation produces [B, K, G, T, S].
|
|
342
|
+
# We must reshape before and after applying TH logic.
|
|
343
|
+
scores_shape = tf.shape(attention_scores)
|
|
344
|
+
if self._num_kv_heads > 1:
|
|
345
|
+
# Reshape from [B, K, G, T, S] to [B, N, T, S]
|
|
346
|
+
attention_scores = tf.reshape(
|
|
347
|
+
attention_scores,
|
|
348
|
+
[
|
|
349
|
+
scores_shape[0], # Batch
|
|
350
|
+
self._num_heads, # N = K * G
|
|
351
|
+
scores_shape[-2], # T
|
|
352
|
+
scores_shape[-1] # S
|
|
353
|
+
]
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Apply linear projection before softmax
|
|
357
|
+
attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
|
|
358
|
+
self._pre_softmax_weight)
|
|
359
|
+
|
|
360
|
+
# Normalize the attention scores to probabilities.
|
|
361
|
+
attention_scores = self._masked_softmax(attention_scores, attention_mask)
|
|
362
|
+
|
|
363
|
+
# Apply linear projection after softmax
|
|
364
|
+
attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
|
|
365
|
+
self._post_softmax_weight)
|
|
366
|
+
|
|
367
|
+
# Reshape back to MQA-compatible shape [B, K, G, T, S]
|
|
368
|
+
# before the final combine_equation
|
|
369
|
+
if self._num_kv_heads > 1:
|
|
370
|
+
if self._enable_gqa_optimization:
|
|
371
|
+
attention_scores = tf.reshape(
|
|
372
|
+
attention_scores,
|
|
373
|
+
[
|
|
374
|
+
scores_shape[0], # B
|
|
375
|
+
self._num_kv_heads, # K
|
|
376
|
+
self._num_heads // self._num_kv_heads, # G
|
|
377
|
+
scores_shape[-2], # T
|
|
378
|
+
scores_shape[-1] # S
|
|
379
|
+
]
|
|
380
|
+
)
|
|
381
|
+
else:
|
|
382
|
+
attention_scores = tf.reshape(
|
|
383
|
+
attention_scores,
|
|
384
|
+
[
|
|
385
|
+
scores_shape[0], # B
|
|
386
|
+
self._num_heads // self._num_kv_heads, # G
|
|
387
|
+
self._num_kv_heads, # K
|
|
388
|
+
scores_shape[-2], # T
|
|
389
|
+
scores_shape[-1] # S
|
|
390
|
+
]
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# This is actually dropping out entire tokens to attend to.
|
|
394
|
+
attention_scores_dropout = self._dropout_layer(
|
|
395
|
+
attention_scores, training=training)
|
|
396
|
+
|
|
397
|
+
# Note: self._combine_equation was set by _build_from_signature
|
|
398
|
+
# (from MQA) to be MQA-compatible.
|
|
399
|
+
attention_output = tf.einsum(self._combine_equation,
|
|
400
|
+
attention_scores_dropout, value)
|
|
401
|
+
|
|
402
|
+
# This is the MQA "wrapper" logic for grouped queries
|
|
403
|
+
if self._num_kv_heads > 1:
|
|
404
|
+
attention_output = tf.reshape(
|
|
405
|
+
attention_output,
|
|
406
|
+
[
|
|
407
|
+
query_shape[0],
|
|
408
|
+
query_shape[1],
|
|
409
|
+
self._num_heads,
|
|
410
|
+
tf.shape(attention_output)[-1],
|
|
411
|
+
],
|
|
412
|
+
)
|
|
413
|
+
# We also need to reshape the final scores back to [B, N, T, S]
|
|
414
|
+
# for the return value.
|
|
415
|
+
attention_scores = tf.reshape(
|
|
416
|
+
attention_scores,
|
|
417
|
+
[
|
|
418
|
+
query_shape[0],
|
|
419
|
+
self._num_heads,
|
|
420
|
+
tf.shape(attention_scores)[-2],
|
|
421
|
+
tf.shape(attention_scores)[-1],
|
|
422
|
+
],
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
return attention_output, attention_scores
|
|
426
|
+
|
|
@@ -211,5 +211,205 @@ class MultiQueryAttentionTest(tf.test.TestCase, parameterized.TestCase):
|
|
|
211
211
|
self.assertNotAllClose(masked_score, unmasked_score)
|
|
212
212
|
|
|
213
213
|
|
|
214
|
+
class TalkingHeadsMultiQueryAttentionTest(
|
|
215
|
+
tf.test.TestCase, parameterized.TestCase
|
|
216
|
+
):
|
|
217
|
+
|
|
218
|
+
@parameterized.named_parameters(
|
|
219
|
+
("key_value_same_proj_mqa", 1, None, None, [40, 80]),
|
|
220
|
+
("key_value_different_proj_mqa", 1, 32, 60, [40, 60]),
|
|
221
|
+
("key_value_same_proj_gqa", 3, None, None, [40, 80]),
|
|
222
|
+
("key_value_different_proj_gqa", 3, 32, 60, [40, 60]),
|
|
223
|
+
)
|
|
224
|
+
def test_non_masked_attention(
|
|
225
|
+
self, num_kv_heads, value_dim, output_shape, output_dims
|
|
226
|
+
):
|
|
227
|
+
"""Test that the attention layer can be created without a mask tensor."""
|
|
228
|
+
test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
|
|
229
|
+
num_heads=12,
|
|
230
|
+
num_kv_heads=num_kv_heads,
|
|
231
|
+
enable_gqa_optimization=num_kv_heads > 1,
|
|
232
|
+
key_dim=64,
|
|
233
|
+
value_dim=value_dim,
|
|
234
|
+
output_shape=output_shape,
|
|
235
|
+
)
|
|
236
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
|
237
|
+
query = tf_keras.Input(shape=(40, 80))
|
|
238
|
+
value = tf_keras.Input(shape=(20, 80))
|
|
239
|
+
output = test_layer(query=query, value=value)
|
|
240
|
+
self.assertEqual(output.shape.as_list(), [None] + output_dims)
|
|
241
|
+
|
|
242
|
+
@parameterized.named_parameters(
|
|
243
|
+
("_mqa", 1),
|
|
244
|
+
("_gqa", 3),
|
|
245
|
+
)
|
|
246
|
+
def test_non_masked_self_attention(self, num_kv_heads):
|
|
247
|
+
"""Test with one input (self-attenntion) and no mask tensor."""
|
|
248
|
+
test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
|
|
249
|
+
num_heads=12,
|
|
250
|
+
num_kv_heads=num_kv_heads,
|
|
251
|
+
enable_gqa_optimization=num_kv_heads > 1,
|
|
252
|
+
key_dim=64,
|
|
253
|
+
)
|
|
254
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
|
255
|
+
query = tf_keras.Input(shape=(40, 80))
|
|
256
|
+
output = test_layer(query, query)
|
|
257
|
+
self.assertEqual(output.shape.as_list(), [None, 40, 80])
|
|
258
|
+
|
|
259
|
+
@parameterized.named_parameters(
|
|
260
|
+
("_mqa", 1),
|
|
261
|
+
("_gqa", 3),
|
|
262
|
+
)
|
|
263
|
+
def test_attention_scores(self, num_kv_heads):
|
|
264
|
+
"""Test attention outputs with coefficients."""
|
|
265
|
+
test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
|
|
266
|
+
num_heads=12, num_kv_heads=num_kv_heads, key_dim=64
|
|
267
|
+
)
|
|
268
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
|
269
|
+
query = tf_keras.Input(shape=(40, 80))
|
|
270
|
+
output, coef = test_layer(query, query, return_attention_scores=True)
|
|
271
|
+
self.assertEqual(output.shape.as_list(), [None, 40, 80])
|
|
272
|
+
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
|
|
273
|
+
|
|
274
|
+
@parameterized.named_parameters(
|
|
275
|
+
("_mqa", 1),
|
|
276
|
+
("_gqa", 3),
|
|
277
|
+
)
|
|
278
|
+
def test_attention_scores_with_values(self, num_kv_heads):
|
|
279
|
+
"""Test attention outputs with coefficients."""
|
|
280
|
+
test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
|
|
281
|
+
num_heads=12, num_kv_heads=num_kv_heads, key_dim=64
|
|
282
|
+
)
|
|
283
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
|
284
|
+
query = tf_keras.Input(shape=(40, 80))
|
|
285
|
+
value = tf_keras.Input(shape=(60, 80))
|
|
286
|
+
output, coef = test_layer(query, value, return_attention_scores=True)
|
|
287
|
+
self.assertEqual(output.shape.as_list(), [None, 40, 80])
|
|
288
|
+
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 60])
|
|
289
|
+
|
|
290
|
+
@parameterized.named_parameters(
|
|
291
|
+
("with_bias_mqa", 1, True),
|
|
292
|
+
("no_bias_mqa", 1, False),
|
|
293
|
+
("with_bias_gqa", 2, True),
|
|
294
|
+
("no_bias_gqa", 2, False),
|
|
295
|
+
)
|
|
296
|
+
def test_masked_attention(self, num_kv_heads, use_bias):
|
|
297
|
+
"""Test with a mask tensor."""
|
|
298
|
+
test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
|
|
299
|
+
num_heads=4,
|
|
300
|
+
num_kv_heads=num_kv_heads,
|
|
301
|
+
enable_gqa_optimization=num_kv_heads > 1,
|
|
302
|
+
key_dim=2,
|
|
303
|
+
use_bias=use_bias,
|
|
304
|
+
)
|
|
305
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
|
306
|
+
batch_size = 3
|
|
307
|
+
query = tf_keras.Input(shape=(4, 8))
|
|
308
|
+
value = tf_keras.Input(shape=(2, 8))
|
|
309
|
+
mask_tensor = tf_keras.Input(shape=(4, 2))
|
|
310
|
+
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
|
|
311
|
+
|
|
312
|
+
# Create a model containing the test layer.
|
|
313
|
+
model = tf_keras.Model([query, value, mask_tensor], output)
|
|
314
|
+
|
|
315
|
+
# Generate data for the input (non-mask) tensors.
|
|
316
|
+
from_data = 10 * np.random.random_sample((batch_size, 4, 8))
|
|
317
|
+
to_data = 10 * np.random.random_sample((batch_size, 2, 8))
|
|
318
|
+
|
|
319
|
+
# Invoke the data with a random set of mask data. This should mask at
|
|
320
|
+
# least one element.
|
|
321
|
+
mask_data = np.random.randint(2, size=(batch_size, 4, 2))
|
|
322
|
+
masked_output_data = model.predict([from_data, to_data, mask_data])
|
|
323
|
+
|
|
324
|
+
# Invoke the same data, but with a null mask (where no elements are
|
|
325
|
+
# masked).
|
|
326
|
+
null_mask_data = np.ones((batch_size, 4, 2))
|
|
327
|
+
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
|
|
328
|
+
|
|
329
|
+
# Because one data is masked and one is not, the outputs should not be
|
|
330
|
+
# the same.
|
|
331
|
+
self.assertNotAllClose(masked_output_data, unmasked_output_data)
|
|
332
|
+
|
|
333
|
+
# Tests the layer with three inputs: Q, K, V.
|
|
334
|
+
key = tf_keras.Input(shape=(2, 8))
|
|
335
|
+
output = test_layer(
|
|
336
|
+
query, value=value, key=key, attention_mask=mask_tensor
|
|
337
|
+
)
|
|
338
|
+
model = tf_keras.Model([query, value, key, mask_tensor], output)
|
|
339
|
+
|
|
340
|
+
masked_output_data = model.predict(
|
|
341
|
+
[from_data, to_data, to_data, mask_data]
|
|
342
|
+
)
|
|
343
|
+
unmasked_output_data = model.predict(
|
|
344
|
+
[from_data, to_data, to_data, null_mask_data]
|
|
345
|
+
)
|
|
346
|
+
# Because one data is masked and one is not, the outputs should not be
|
|
347
|
+
# the same.
|
|
348
|
+
self.assertNotAllClose(masked_output_data, unmasked_output_data)
|
|
349
|
+
|
|
350
|
+
if use_bias:
|
|
351
|
+
self.assertLen(test_layer._query_dense.trainable_variables, 2)
|
|
352
|
+
self.assertLen(test_layer._output_dense.trainable_variables, 2)
|
|
353
|
+
else:
|
|
354
|
+
self.assertLen(test_layer._query_dense.trainable_variables, 1)
|
|
355
|
+
self.assertLen(test_layer._output_dense.trainable_variables, 1)
|
|
356
|
+
|
|
357
|
+
@parameterized.named_parameters(
|
|
358
|
+
("_mqa", 1),
|
|
359
|
+
("_gqa", 2),
|
|
360
|
+
)
|
|
361
|
+
def test_masked_attention_with_scores(self, num_kv_heads):
|
|
362
|
+
"""Test with a mask tensor."""
|
|
363
|
+
test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
|
|
364
|
+
num_heads=4, num_kv_heads=num_kv_heads, key_dim=2
|
|
365
|
+
)
|
|
366
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
|
367
|
+
batch_size = 3
|
|
368
|
+
query = tf_keras.Input(shape=(4, 8))
|
|
369
|
+
value = tf_keras.Input(shape=(2, 8))
|
|
370
|
+
mask_tensor = tf_keras.Input(shape=(4, 2))
|
|
371
|
+
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
|
|
372
|
+
|
|
373
|
+
# Create a model containing the test layer.
|
|
374
|
+
model = tf_keras.Model([query, value, mask_tensor], output)
|
|
375
|
+
|
|
376
|
+
# Generate data for the input (non-mask) tensors.
|
|
377
|
+
from_data = 10 * np.random.random_sample((batch_size, 4, 8))
|
|
378
|
+
to_data = 10 * np.random.random_sample((batch_size, 2, 8))
|
|
379
|
+
|
|
380
|
+
# Invoke the data with a random set of mask data. This should mask at
|
|
381
|
+
# least one element.
|
|
382
|
+
mask_data = np.random.randint(2, size=(batch_size, 4, 2))
|
|
383
|
+
masked_output_data = model.predict([from_data, to_data, mask_data])
|
|
384
|
+
|
|
385
|
+
# Invoke the same data, but with a null mask (where no elements are
|
|
386
|
+
# masked).
|
|
387
|
+
null_mask_data = np.ones((batch_size, 4, 2))
|
|
388
|
+
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
|
|
389
|
+
|
|
390
|
+
# Because one data is masked and one is not, the outputs should not be
|
|
391
|
+
# the same.
|
|
392
|
+
self.assertNotAllClose(masked_output_data, unmasked_output_data)
|
|
393
|
+
|
|
394
|
+
# Create a model containing attention scores.
|
|
395
|
+
output, scores = test_layer(
|
|
396
|
+
query=query,
|
|
397
|
+
value=value,
|
|
398
|
+
attention_mask=mask_tensor,
|
|
399
|
+
return_attention_scores=True,
|
|
400
|
+
)
|
|
401
|
+
model = tf_keras.Model([query, value, mask_tensor], [output, scores])
|
|
402
|
+
masked_output_data_score, masked_score = model.predict(
|
|
403
|
+
[from_data, to_data, mask_data]
|
|
404
|
+
)
|
|
405
|
+
unmasked_output_data_score, unmasked_score = model.predict(
|
|
406
|
+
[from_data, to_data, null_mask_data]
|
|
407
|
+
)
|
|
408
|
+
self.assertNotAllClose(masked_output_data_score, unmasked_output_data_score)
|
|
409
|
+
self.assertAllClose(masked_output_data, masked_output_data_score)
|
|
410
|
+
self.assertAllClose(unmasked_output_data, unmasked_output_data_score)
|
|
411
|
+
self.assertNotAllClose(masked_score, unmasked_score)
|
|
412
|
+
|
|
413
|
+
|
|
214
414
|
if __name__ == "__main__":
|
|
215
415
|
tf.test.main()
|
|
@@ -20,6 +20,7 @@ import tensorflow as tf, tf_keras
|
|
|
20
20
|
from official.modeling import tf_utils
|
|
21
21
|
from official.nlp.modeling.layers import block_sparse_attention
|
|
22
22
|
from official.nlp.modeling.layers import multi_query_attention
|
|
23
|
+
from official.nlp.modeling.layers import talking_heads_attention
|
|
23
24
|
from official.nlp.modeling.layers import util
|
|
24
25
|
|
|
25
26
|
|
|
@@ -118,6 +119,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
|
118
119
|
linformer_dim=None,
|
|
119
120
|
linformer_shared_kv_projection=True,
|
|
120
121
|
lowrank_query_seq_proj_dim=None,
|
|
122
|
+
enable_talking_heads=False,
|
|
123
|
+
enable_gqa_optimization=False,
|
|
121
124
|
**kwargs,
|
|
122
125
|
):
|
|
123
126
|
"""Initializes `TransformerEncoderBlock`.
|
|
@@ -202,6 +205,10 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
|
202
205
|
keys and values.
|
|
203
206
|
lowrank_query_seq_proj_dim: If set, applies a projection layer on query
|
|
204
207
|
sequence to the given dimension. go/constformer-doc
|
|
208
|
+
enable_talking_heads: Enable talking heads as in
|
|
209
|
+
https://arxiv.org/pdf/2003.02436.
|
|
210
|
+
enable_gqa_optimization: Enable GQA optimization in multi-query attention.
|
|
211
|
+
This flag is valid only when num_kv_heads is set for GQA.
|
|
205
212
|
**kwargs: keyword arguments.
|
|
206
213
|
"""
|
|
207
214
|
util.filter_kwargs(kwargs)
|
|
@@ -244,6 +251,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
|
244
251
|
self._linformer_dim = linformer_dim
|
|
245
252
|
self._linformer_shared_kv_projection = linformer_shared_kv_projection
|
|
246
253
|
self._lowrank_query_seq_proj_dim = lowrank_query_seq_proj_dim
|
|
254
|
+
self._enable_talking_heads = enable_talking_heads
|
|
255
|
+
self._enable_gqa_optimization = enable_gqa_optimization
|
|
247
256
|
if (
|
|
248
257
|
self._src_block_size is not None
|
|
249
258
|
and self._num_kv_heads is not None
|
|
@@ -314,6 +323,11 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
|
314
323
|
bias_constraint=self._bias_constraint,
|
|
315
324
|
)
|
|
316
325
|
if self._src_block_size is not None:
|
|
326
|
+
if self._enable_talking_heads:
|
|
327
|
+
raise ValueError(
|
|
328
|
+
"Block sparse attention does not support talking heads. Please"
|
|
329
|
+
" set enable_talking_heads to False."
|
|
330
|
+
)
|
|
317
331
|
attention_layer_kwargs.update(
|
|
318
332
|
src_block_size=self._src_block_size,
|
|
319
333
|
tgt_block_size=self._tgt_block_size,
|
|
@@ -326,9 +340,22 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
|
326
340
|
elif self._num_kv_heads is not None:
|
|
327
341
|
attention_layer_kwargs.update(
|
|
328
342
|
num_kv_heads=self._num_kv_heads,
|
|
343
|
+
enable_gqa_optimization=self._enable_gqa_optimization,
|
|
329
344
|
name="multi_query_attention",
|
|
330
345
|
)
|
|
331
|
-
|
|
346
|
+
if self._enable_talking_heads:
|
|
347
|
+
attention_fn = (
|
|
348
|
+
multi_query_attention.TalkingHeadsMultiQueryAttention
|
|
349
|
+
)
|
|
350
|
+
else:
|
|
351
|
+
attention_fn = multi_query_attention.MultiHeadAttention
|
|
352
|
+
elif self._enable_talking_heads:
|
|
353
|
+
attention_layer_kwargs.update(
|
|
354
|
+
name="talking_heads_attention",
|
|
355
|
+
)
|
|
356
|
+
attention_fn = (
|
|
357
|
+
talking_heads_attention.TalkingHeadsAttention
|
|
358
|
+
)
|
|
332
359
|
else:
|
|
333
360
|
attention_fn = tf_keras.layers.MultiHeadAttention
|
|
334
361
|
self._attention_layer = attention_fn(
|
|
@@ -743,35 +743,46 @@ class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
|
|
|
743
743
|
num_attention_heads=num_attention_heads,
|
|
744
744
|
inner_dim=2048,
|
|
745
745
|
inner_activation='relu',
|
|
746
|
-
return_attention_scores=return_attention_scores
|
|
746
|
+
return_attention_scores=return_attention_scores,
|
|
747
|
+
)
|
|
747
748
|
# Create a 3-dimensional input (the first dimension is implicit).
|
|
748
749
|
data_tensor = tf_keras.Input(shape=(sequence_length, width))
|
|
749
750
|
output_tensor = test_layer(data_tensor)
|
|
750
751
|
|
|
751
752
|
expected_layer_output_shape = [None, sequence_length, width]
|
|
752
753
|
expected_attention_scores_shape = [
|
|
753
|
-
None,
|
|
754
|
+
None,
|
|
755
|
+
num_attention_heads,
|
|
756
|
+
sequence_length,
|
|
757
|
+
sequence_length,
|
|
754
758
|
]
|
|
755
759
|
|
|
756
760
|
if return_attention_scores:
|
|
757
761
|
self.assertIsInstance(output_tensor, tuple)
|
|
758
762
|
self.assertLen(output_tensor, 2)
|
|
759
763
|
# First is the standard output.
|
|
760
|
-
self.assertEqual(
|
|
761
|
-
|
|
764
|
+
self.assertEqual(
|
|
765
|
+
output_tensor[0].shape.as_list(), expected_layer_output_shape
|
|
766
|
+
)
|
|
762
767
|
# Second is the attention scores.
|
|
763
|
-
self.assertEqual(
|
|
764
|
-
|
|
768
|
+
self.assertEqual(
|
|
769
|
+
output_tensor[1].shape.as_list(), expected_attention_scores_shape
|
|
770
|
+
)
|
|
765
771
|
else:
|
|
766
772
|
# Only the standard layer output.
|
|
767
|
-
self.assertEqual(
|
|
768
|
-
|
|
773
|
+
self.assertEqual(
|
|
774
|
+
output_tensor.shape.as_list(), expected_layer_output_shape
|
|
775
|
+
)
|
|
769
776
|
|
|
770
777
|
@parameterized.named_parameters(
|
|
771
778
|
('mqa', 1),
|
|
772
779
|
('gqa', 4),
|
|
780
|
+
('talking_heads_mqa', 1, True),
|
|
781
|
+
('talking_heads_gqa', 4, True),
|
|
773
782
|
)
|
|
774
|
-
def test_attention_with_kv_heads(
|
|
783
|
+
def test_attention_with_kv_heads(
|
|
784
|
+
self, num_kv_heads, enable_talking_heads=False
|
|
785
|
+
):
|
|
775
786
|
num_attention_heads = 8
|
|
776
787
|
sequence_length = 21
|
|
777
788
|
width = 80
|
|
@@ -782,6 +793,8 @@ class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
|
|
|
782
793
|
inner_activation='relu',
|
|
783
794
|
return_attention_scores=True,
|
|
784
795
|
num_kv_heads=num_kv_heads,
|
|
796
|
+
enable_talking_heads=enable_talking_heads,
|
|
797
|
+
enable_gqa_optimization=True,
|
|
785
798
|
)
|
|
786
799
|
# Create a 3-dimensional input (the first dimension is implicit).
|
|
787
800
|
data_tensor = tf_keras.Input(shape=(sequence_length, width))
|
|
@@ -331,8 +331,8 @@ official/nlp/modeling/layers/moe.py,sha256=vG5jNwX2Vt3GjWDZ_viOCfgI77antXDAnBtho
|
|
|
331
331
|
official/nlp/modeling/layers/moe_test.py,sha256=NmLln8M820Z4VX9BYFmBppiv8_NKoTWWzNGZ-g6eEBQ,9414
|
|
332
332
|
official/nlp/modeling/layers/multi_channel_attention.py,sha256=_9py6IXTJpBM3lhK0Xh4zsmrO7S8xbQJ4CR5PcKaPk4,7322
|
|
333
333
|
official/nlp/modeling/layers/multi_channel_attention_test.py,sha256=X-kLm-w_jbITZgGNSKAjEPOtjH2xJABI-mavBn2jbtA,1922
|
|
334
|
-
official/nlp/modeling/layers/multi_query_attention.py,sha256=
|
|
335
|
-
official/nlp/modeling/layers/multi_query_attention_test.py,sha256=
|
|
334
|
+
official/nlp/modeling/layers/multi_query_attention.py,sha256=NvFBSiGZMdQnikRMSkG3kKX5e7nP2qvUt0nYFTajMqc,15384
|
|
335
|
+
official/nlp/modeling/layers/multi_query_attention_test.py,sha256=2KLIynj9rXOcuvMxQ7s1JWo0B_C31RhBrsxGLX4ijg8,16455
|
|
336
336
|
official/nlp/modeling/layers/on_device_embedding.py,sha256=RN01ud3FeInON_bXVtzItfTLivolbu45R6ajPp1_BIo,4582
|
|
337
337
|
official/nlp/modeling/layers/on_device_embedding_test.py,sha256=i9esQYblo_uC-UnSkRs8Xgr6BldZ48iKoZhN4vrSTXI,8589
|
|
338
338
|
official/nlp/modeling/layers/pack_optimization.py,sha256=0CoWos2h4dfqO3UvI2xLpHnWzvJpDiBJVHR2MTamLM4,10760
|
|
@@ -363,8 +363,8 @@ official/nlp/modeling/layers/tn_expand_condense_test.py,sha256=QWq1dJqQUPe5n69K3
|
|
|
363
363
|
official/nlp/modeling/layers/tn_transformer_expand_condense.py,sha256=omzTkCBEk2TOkHEYDEBwve6WsOitX7IIJHzeKXdqDq0,11012
|
|
364
364
|
official/nlp/modeling/layers/tn_transformer_test.py,sha256=pSCONEZRI4J9_6QLTJ3g_ynUYLrRXsJ1c2YMSiOV_60,8893
|
|
365
365
|
official/nlp/modeling/layers/transformer.py,sha256=VjUO-gVj_PnavbT_vSrg5NDKMr0SRSiqSg5ktd42m5M,20087
|
|
366
|
-
official/nlp/modeling/layers/transformer_encoder_block.py,sha256=
|
|
367
|
-
official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=
|
|
366
|
+
official/nlp/modeling/layers/transformer_encoder_block.py,sha256=BiL8ErBs-m0UZ6ONVJV0ncfWX3LhMhPetIhfH2VvuD4,29910
|
|
367
|
+
official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=g7oMDPvwg6Fv75SBdm6BInXPI8r5GcItBRjLFGuObyg,37821
|
|
368
368
|
official/nlp/modeling/layers/transformer_scaffold.py,sha256=qmzhCJvbbFVF9zDqnfO4Zs2JDXwKhK7iEBOhsU6-KpQ,15704
|
|
369
369
|
official/nlp/modeling/layers/transformer_scaffold_test.py,sha256=dRJwesTBKm-mF5mDHrHfVpVNnxa-Wx-fj_4ZHDPTpE0,19920
|
|
370
370
|
official/nlp/modeling/layers/transformer_test.py,sha256=-pk9cdz9UlMpCIkGRkCKsMmjdRGi0seySaaB_2dwmXw,5522
|
|
@@ -1248,9 +1248,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=yiAneltAW3NHSj3fUSvHNBjfq0MGZ
|
|
|
1248
1248
|
tensorflow_models/nlp/__init__.py,sha256=8uQd4wI6Zc4IJMPjtQifMeWVbPFkTxqYh66wfivCOL4,807
|
|
1249
1249
|
tensorflow_models/uplift/__init__.py,sha256=NzaweFf4ZmhRb2l_fuV6bP-2N8oSO3xu6xJqVb1UmpY,999
|
|
1250
1250
|
tensorflow_models/vision/__init__.py,sha256=ks420Ooqzi0hU7HnQpM5rylLaE-YcJdJkBx_umVaXlE,833
|
|
1251
|
-
tf_models_nightly-2.20.0.
|
|
1252
|
-
tf_models_nightly-2.20.0.
|
|
1253
|
-
tf_models_nightly-2.20.0.
|
|
1254
|
-
tf_models_nightly-2.20.0.
|
|
1255
|
-
tf_models_nightly-2.20.0.
|
|
1256
|
-
tf_models_nightly-2.20.0.
|
|
1251
|
+
tf_models_nightly-2.20.0.dev20251104.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
|
|
1252
|
+
tf_models_nightly-2.20.0.dev20251104.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
|
|
1253
|
+
tf_models_nightly-2.20.0.dev20251104.dist-info/METADATA,sha256=5-g6lSvAXrAENi0P1IZeyQdFJ7JABiG_2O0sfvyZhLw,1432
|
|
1254
|
+
tf_models_nightly-2.20.0.dev20251104.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
|
|
1255
|
+
tf_models_nightly-2.20.0.dev20251104.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
|
|
1256
|
+
tf_models_nightly-2.20.0.dev20251104.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|