tf-models-nightly 2.20.0.dev20251028__py2.py3-none-any.whl → 2.20.0.dev20251029__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tf-models-nightly might be problematic. Click here for more details.
- official/nlp/modeling/layers/multi_query_attention.py +216 -3
- official/nlp/modeling/layers/multi_query_attention_test.py +200 -0
- {tf_models_nightly-2.20.0.dev20251028.dist-info → tf_models_nightly-2.20.0.dev20251029.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.20.0.dev20251028.dist-info → tf_models_nightly-2.20.0.dev20251029.dist-info}/RECORD +8 -8
- {tf_models_nightly-2.20.0.dev20251028.dist-info → tf_models_nightly-2.20.0.dev20251029.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.20.0.dev20251028.dist-info → tf_models_nightly-2.20.0.dev20251029.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.20.0.dev20251028.dist-info → tf_models_nightly-2.20.0.dev20251029.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.20.0.dev20251028.dist-info → tf_models_nightly-2.20.0.dev20251029.dist-info}/top_level.txt +0 -0
|
@@ -18,10 +18,13 @@ Based on https://arxiv.org/pdf/1911.02150.pdf and
|
|
|
18
18
|
https://arxiv.org/pdf/2305.13245.pdf.
|
|
19
19
|
"""
|
|
20
20
|
|
|
21
|
+
import math
|
|
21
22
|
import string
|
|
22
23
|
from typing import Optional, Sequence, Union
|
|
23
24
|
|
|
25
|
+
import gin
|
|
24
26
|
import tensorflow as tf, tf_keras
|
|
27
|
+
from official.modeling import tf_utils
|
|
25
28
|
|
|
26
29
|
_CHR_IDX = string.ascii_lowercase
|
|
27
30
|
|
|
@@ -78,7 +81,9 @@ def _get_output_shape(
|
|
|
78
81
|
class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
|
|
79
82
|
"""Multi-query attention layer."""
|
|
80
83
|
|
|
81
|
-
def __init__(
|
|
84
|
+
def __init__(
|
|
85
|
+
self, num_kv_heads=None, enable_gqa_optimization=False, **kwargs
|
|
86
|
+
):
|
|
82
87
|
# num_kv_heads defines the number of key/value heads. A value of 1 means
|
|
83
88
|
# that the key/value heads are shared across all query heads. Any other
|
|
84
89
|
# value must be less than num_heads and must divide num_heads exactly. If
|
|
@@ -86,6 +91,16 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
|
|
|
86
91
|
# num_kv_heads.
|
|
87
92
|
super().__init__(**kwargs)
|
|
88
93
|
self._num_kv_heads = num_kv_heads or self._num_heads
|
|
94
|
+
# TODO(akandoor): Remove this flag once the GQA optimization is rolled out.
|
|
95
|
+
# This flag is used to enable order of K,G in the einsum equations.
|
|
96
|
+
# This optimization is only used in GQA, and is disabled by default.
|
|
97
|
+
# If enabled, the einsum equations are:
|
|
98
|
+
# 1. Dot product: "...SKH,...TKGH->...KGTS"
|
|
99
|
+
# 2. Combine: "...KGTS,...SKH->...TKGH"
|
|
100
|
+
# If disabled, the einsum equations are:
|
|
101
|
+
# 1. Dot product: "...SKH,...TKnH->...nKTS"
|
|
102
|
+
# 2. Combine: "...nKTS,...SKH->...TnKH"
|
|
103
|
+
self._enable_gqa_optimization = enable_gqa_optimization
|
|
89
104
|
assert (
|
|
90
105
|
self._num_kv_heads < self._num_heads
|
|
91
106
|
), "num_kv_heads must be less than num_heads."
|
|
@@ -143,8 +158,12 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
|
|
|
143
158
|
output_dims = 2
|
|
144
159
|
key_last_dims = [self._num_kv_heads, self._key_dim]
|
|
145
160
|
value_last_dims = [self._num_kv_heads, self._value_dim]
|
|
146
|
-
self.
|
|
147
|
-
|
|
161
|
+
if self._enable_gqa_optimization:
|
|
162
|
+
self._dot_product_equation = "...SKH,...TKGH->...KGTS"
|
|
163
|
+
self._combine_equation = "...KGTS,...SKH->...TKGH"
|
|
164
|
+
else:
|
|
165
|
+
self._dot_product_equation = "...SKH,...TKnH->...nKTS"
|
|
166
|
+
self._combine_equation = "...nKTS,...SKH->...TnKH"
|
|
148
167
|
|
|
149
168
|
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
|
150
169
|
free_dims=self._key_shape.rank - 1,
|
|
@@ -170,6 +189,9 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
|
|
|
170
189
|
name="value",
|
|
171
190
|
**self._get_common_kwargs_for_sublayer(),
|
|
172
191
|
)
|
|
192
|
+
self._qkv_rank = (
|
|
193
|
+
output_rank if self._num_kv_heads > 1 else output_rank + 1
|
|
194
|
+
)
|
|
173
195
|
|
|
174
196
|
def _compute_attention(
|
|
175
197
|
self, query, key, value, attention_mask=None, training=None
|
|
@@ -211,3 +233,194 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
|
|
|
211
233
|
],
|
|
212
234
|
)
|
|
213
235
|
return attention_output, attention_scores
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@tf_keras.utils.register_keras_serializable(package="Text")
|
|
239
|
+
@gin.configurable
|
|
240
|
+
class TalkingHeadsMultiQueryAttention(MultiHeadAttention):
|
|
241
|
+
"""Implements Talking-Heads Attention combined with Multi-Query Attention.
|
|
242
|
+
|
|
243
|
+
See https://arxiv.org/pdf/2003.02436 for more details.
|
|
244
|
+
TODO(akandoor): Make num talking heads configurable. Currently, num talking
|
|
245
|
+
heads is fixed to num query heads.
|
|
246
|
+
|
|
247
|
+
This class inherits from MultiQueryAttention to get the MQA-specific
|
|
248
|
+
logic for __init__, get_config.
|
|
249
|
+
|
|
250
|
+
It then overrides _build_from_signature to add the talking-heads weights
|
|
251
|
+
and overrides _compute_attention to merge the MQA wrapper (for
|
|
252
|
+
reshaping) with the THA computation (for pre/post-softmax projections).
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
def _build_from_signature(
|
|
256
|
+
self,
|
|
257
|
+
query: Union[tf.Tensor, tf.TensorShape],
|
|
258
|
+
value: Union[tf.Tensor, tf.TensorShape],
|
|
259
|
+
key: Optional[Union[tf.Tensor, tf.TensorShape]] = None,
|
|
260
|
+
):
|
|
261
|
+
"""Builds layers and variables."""
|
|
262
|
+
# Call the parent (MultiQueryAttention) _build_from_signature.
|
|
263
|
+
super()._build_from_signature(query=query, value=value, key=key)
|
|
264
|
+
# Now, *after* all MQA setup is done, we add the THA setup logic.
|
|
265
|
+
qkv_rank = self._qkv_rank
|
|
266
|
+
# TalkingHeadsAttention logic to the MQA build logic.
|
|
267
|
+
num_batch_dims = qkv_rank - len(self._attention_axes) - 2
|
|
268
|
+
attn_scores_rank = num_batch_dims + 1 + len(self._attention_axes) * 2
|
|
269
|
+
scores_notation = _CHR_IDX[:attn_scores_rank]
|
|
270
|
+
projection_notation = scores_notation[num_batch_dims] + (
|
|
271
|
+
_CHR_IDX[attn_scores_rank])
|
|
272
|
+
projected_scores_notation = scores_notation[:num_batch_dims] + (
|
|
273
|
+
_CHR_IDX[attn_scores_rank] + scores_notation[num_batch_dims + 1:])
|
|
274
|
+
self._talking_heads_equation = "%s,%s->%s" % (
|
|
275
|
+
scores_notation, projection_notation, projected_scores_notation)
|
|
276
|
+
|
|
277
|
+
with tf.init_scope():
|
|
278
|
+
self._pre_softmax_weight = self.add_weight(
|
|
279
|
+
"pre_softmax_weight",
|
|
280
|
+
shape=(self._num_heads, self._num_heads),
|
|
281
|
+
initializer=tf_utils.clone_initializer(self._kernel_initializer),
|
|
282
|
+
regularizer=self._kernel_regularizer,
|
|
283
|
+
constraint=self._kernel_constraint,
|
|
284
|
+
dtype=self.dtype,
|
|
285
|
+
trainable=True)
|
|
286
|
+
self._post_softmax_weight = self.add_weight(
|
|
287
|
+
"post_softmax_weight",
|
|
288
|
+
shape=(self._num_heads, self._num_heads),
|
|
289
|
+
initializer=tf_utils.clone_initializer(self._kernel_initializer),
|
|
290
|
+
regularizer=self._kernel_regularizer,
|
|
291
|
+
constraint=self._kernel_constraint,
|
|
292
|
+
dtype=self.dtype,
|
|
293
|
+
trainable=True)
|
|
294
|
+
|
|
295
|
+
def _compute_attention(
|
|
296
|
+
self, query, key, value, attention_mask=None, training=None
|
|
297
|
+
):
|
|
298
|
+
"""Applies Dot-product attention, merging MQA wrapper and THA computation.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
query: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
|
|
302
|
+
key: Projected key `Tensor` of shape `[B, T, N, key_dim]`.
|
|
303
|
+
value: Projected value `Tensor` of shape `[B, T, N, value_dim]`.
|
|
304
|
+
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
|
|
305
|
+
attention to certain positions.
|
|
306
|
+
training: Python boolean indicating whether the layer should behave in
|
|
307
|
+
training mode (adding dropout) or in inference mode (doing nothing).
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
attention_output: Multi-headed outputs of attention computation.
|
|
311
|
+
attention_scores: Multi-headed attention weights.
|
|
312
|
+
"""
|
|
313
|
+
# This is the MQA "wrapper" logic for grouped queries
|
|
314
|
+
query_shape = tf.shape(query)
|
|
315
|
+
if self._num_kv_heads > 1:
|
|
316
|
+
query = tf.reshape(
|
|
317
|
+
query,
|
|
318
|
+
[
|
|
319
|
+
query_shape[0],
|
|
320
|
+
query_shape[1],
|
|
321
|
+
self._num_kv_heads,
|
|
322
|
+
self._num_heads // self._num_kv_heads,
|
|
323
|
+
query_shape[-1],
|
|
324
|
+
],
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# This is the THA "computation" logic
|
|
328
|
+
# Note: Applying scalar multiply at the smaller end of einsum improves
|
|
329
|
+
# XLA performance, but may introduce slight numeric differences in
|
|
330
|
+
# the Transformer attention head.
|
|
331
|
+
query = tf.multiply(
|
|
332
|
+
query, 1.0 / math.sqrt(float(self._key_dim))
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# Note: self._dot_product_equation was set by _build_from_signature
|
|
336
|
+
# (from MQA) to be MQA-compatible.
|
|
337
|
+
attention_scores = tf.einsum(self._dot_product_equation, key, query)
|
|
338
|
+
|
|
339
|
+
# --- Talking-Heads modification for MQA ---
|
|
340
|
+
# The THA _talking_heads_equation expects scores of shape [B, N, T, S].
|
|
341
|
+
# The MQA _dot_product_equation produces [B, K, G, T, S].
|
|
342
|
+
# We must reshape before and after applying TH logic.
|
|
343
|
+
scores_shape = tf.shape(attention_scores)
|
|
344
|
+
if self._num_kv_heads > 1:
|
|
345
|
+
# Reshape from [B, K, G, T, S] to [B, N, T, S]
|
|
346
|
+
attention_scores = tf.reshape(
|
|
347
|
+
attention_scores,
|
|
348
|
+
[
|
|
349
|
+
scores_shape[0], # Batch
|
|
350
|
+
self._num_heads, # N = K * G
|
|
351
|
+
scores_shape[-2], # T
|
|
352
|
+
scores_shape[-1] # S
|
|
353
|
+
]
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Apply linear projection before softmax
|
|
357
|
+
attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
|
|
358
|
+
self._pre_softmax_weight)
|
|
359
|
+
|
|
360
|
+
# Normalize the attention scores to probabilities.
|
|
361
|
+
attention_scores = self._masked_softmax(attention_scores, attention_mask)
|
|
362
|
+
|
|
363
|
+
# Apply linear projection after softmax
|
|
364
|
+
attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
|
|
365
|
+
self._post_softmax_weight)
|
|
366
|
+
|
|
367
|
+
# Reshape back to MQA-compatible shape [B, K, G, T, S]
|
|
368
|
+
# before the final combine_equation
|
|
369
|
+
if self._num_kv_heads > 1:
|
|
370
|
+
if self._enable_gqa_optimization:
|
|
371
|
+
attention_scores = tf.reshape(
|
|
372
|
+
attention_scores,
|
|
373
|
+
[
|
|
374
|
+
scores_shape[0], # B
|
|
375
|
+
self._num_kv_heads, # K
|
|
376
|
+
self._num_heads // self._num_kv_heads, # G
|
|
377
|
+
scores_shape[-2], # T
|
|
378
|
+
scores_shape[-1] # S
|
|
379
|
+
]
|
|
380
|
+
)
|
|
381
|
+
else:
|
|
382
|
+
attention_scores = tf.reshape(
|
|
383
|
+
attention_scores,
|
|
384
|
+
[
|
|
385
|
+
scores_shape[0], # B
|
|
386
|
+
self._num_heads // self._num_kv_heads, # G
|
|
387
|
+
self._num_kv_heads, # K
|
|
388
|
+
scores_shape[-2], # T
|
|
389
|
+
scores_shape[-1] # S
|
|
390
|
+
]
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# This is actually dropping out entire tokens to attend to.
|
|
394
|
+
attention_scores_dropout = self._dropout_layer(
|
|
395
|
+
attention_scores, training=training)
|
|
396
|
+
|
|
397
|
+
# Note: self._combine_equation was set by _build_from_signature
|
|
398
|
+
# (from MQA) to be MQA-compatible.
|
|
399
|
+
attention_output = tf.einsum(self._combine_equation,
|
|
400
|
+
attention_scores_dropout, value)
|
|
401
|
+
|
|
402
|
+
# This is the MQA "wrapper" logic for grouped queries
|
|
403
|
+
if self._num_kv_heads > 1:
|
|
404
|
+
attention_output = tf.reshape(
|
|
405
|
+
attention_output,
|
|
406
|
+
[
|
|
407
|
+
query_shape[0],
|
|
408
|
+
query_shape[1],
|
|
409
|
+
self._num_heads,
|
|
410
|
+
tf.shape(attention_output)[-1],
|
|
411
|
+
],
|
|
412
|
+
)
|
|
413
|
+
# We also need to reshape the final scores back to [B, N, T, S]
|
|
414
|
+
# for the return value.
|
|
415
|
+
attention_scores = tf.reshape(
|
|
416
|
+
attention_scores,
|
|
417
|
+
[
|
|
418
|
+
query_shape[0],
|
|
419
|
+
self._num_heads,
|
|
420
|
+
tf.shape(attention_scores)[-2],
|
|
421
|
+
tf.shape(attention_scores)[-1],
|
|
422
|
+
],
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
return attention_output, attention_scores
|
|
426
|
+
|
|
@@ -211,5 +211,205 @@ class MultiQueryAttentionTest(tf.test.TestCase, parameterized.TestCase):
|
|
|
211
211
|
self.assertNotAllClose(masked_score, unmasked_score)
|
|
212
212
|
|
|
213
213
|
|
|
214
|
+
class TalkingHeadsMultiQueryAttentionTest(
|
|
215
|
+
tf.test.TestCase, parameterized.TestCase
|
|
216
|
+
):
|
|
217
|
+
|
|
218
|
+
@parameterized.named_parameters(
|
|
219
|
+
("key_value_same_proj_mqa", 1, None, None, [40, 80]),
|
|
220
|
+
("key_value_different_proj_mqa", 1, 32, 60, [40, 60]),
|
|
221
|
+
("key_value_same_proj_gqa", 3, None, None, [40, 80]),
|
|
222
|
+
("key_value_different_proj_gqa", 3, 32, 60, [40, 60]),
|
|
223
|
+
)
|
|
224
|
+
def test_non_masked_attention(
|
|
225
|
+
self, num_kv_heads, value_dim, output_shape, output_dims
|
|
226
|
+
):
|
|
227
|
+
"""Test that the attention layer can be created without a mask tensor."""
|
|
228
|
+
test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
|
|
229
|
+
num_heads=12,
|
|
230
|
+
num_kv_heads=num_kv_heads,
|
|
231
|
+
enable_gqa_optimization=num_kv_heads > 1,
|
|
232
|
+
key_dim=64,
|
|
233
|
+
value_dim=value_dim,
|
|
234
|
+
output_shape=output_shape,
|
|
235
|
+
)
|
|
236
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
|
237
|
+
query = tf_keras.Input(shape=(40, 80))
|
|
238
|
+
value = tf_keras.Input(shape=(20, 80))
|
|
239
|
+
output = test_layer(query=query, value=value)
|
|
240
|
+
self.assertEqual(output.shape.as_list(), [None] + output_dims)
|
|
241
|
+
|
|
242
|
+
@parameterized.named_parameters(
|
|
243
|
+
("_mqa", 1),
|
|
244
|
+
("_gqa", 3),
|
|
245
|
+
)
|
|
246
|
+
def test_non_masked_self_attention(self, num_kv_heads):
|
|
247
|
+
"""Test with one input (self-attenntion) and no mask tensor."""
|
|
248
|
+
test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
|
|
249
|
+
num_heads=12,
|
|
250
|
+
num_kv_heads=num_kv_heads,
|
|
251
|
+
enable_gqa_optimization=num_kv_heads > 1,
|
|
252
|
+
key_dim=64,
|
|
253
|
+
)
|
|
254
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
|
255
|
+
query = tf_keras.Input(shape=(40, 80))
|
|
256
|
+
output = test_layer(query, query)
|
|
257
|
+
self.assertEqual(output.shape.as_list(), [None, 40, 80])
|
|
258
|
+
|
|
259
|
+
@parameterized.named_parameters(
|
|
260
|
+
("_mqa", 1),
|
|
261
|
+
("_gqa", 3),
|
|
262
|
+
)
|
|
263
|
+
def test_attention_scores(self, num_kv_heads):
|
|
264
|
+
"""Test attention outputs with coefficients."""
|
|
265
|
+
test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
|
|
266
|
+
num_heads=12, num_kv_heads=num_kv_heads, key_dim=64
|
|
267
|
+
)
|
|
268
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
|
269
|
+
query = tf_keras.Input(shape=(40, 80))
|
|
270
|
+
output, coef = test_layer(query, query, return_attention_scores=True)
|
|
271
|
+
self.assertEqual(output.shape.as_list(), [None, 40, 80])
|
|
272
|
+
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
|
|
273
|
+
|
|
274
|
+
@parameterized.named_parameters(
|
|
275
|
+
("_mqa", 1),
|
|
276
|
+
("_gqa", 3),
|
|
277
|
+
)
|
|
278
|
+
def test_attention_scores_with_values(self, num_kv_heads):
|
|
279
|
+
"""Test attention outputs with coefficients."""
|
|
280
|
+
test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
|
|
281
|
+
num_heads=12, num_kv_heads=num_kv_heads, key_dim=64
|
|
282
|
+
)
|
|
283
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
|
284
|
+
query = tf_keras.Input(shape=(40, 80))
|
|
285
|
+
value = tf_keras.Input(shape=(60, 80))
|
|
286
|
+
output, coef = test_layer(query, value, return_attention_scores=True)
|
|
287
|
+
self.assertEqual(output.shape.as_list(), [None, 40, 80])
|
|
288
|
+
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 60])
|
|
289
|
+
|
|
290
|
+
@parameterized.named_parameters(
|
|
291
|
+
("with_bias_mqa", 1, True),
|
|
292
|
+
("no_bias_mqa", 1, False),
|
|
293
|
+
("with_bias_gqa", 2, True),
|
|
294
|
+
("no_bias_gqa", 2, False),
|
|
295
|
+
)
|
|
296
|
+
def test_masked_attention(self, num_kv_heads, use_bias):
|
|
297
|
+
"""Test with a mask tensor."""
|
|
298
|
+
test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
|
|
299
|
+
num_heads=4,
|
|
300
|
+
num_kv_heads=num_kv_heads,
|
|
301
|
+
enable_gqa_optimization=num_kv_heads > 1,
|
|
302
|
+
key_dim=2,
|
|
303
|
+
use_bias=use_bias,
|
|
304
|
+
)
|
|
305
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
|
306
|
+
batch_size = 3
|
|
307
|
+
query = tf_keras.Input(shape=(4, 8))
|
|
308
|
+
value = tf_keras.Input(shape=(2, 8))
|
|
309
|
+
mask_tensor = tf_keras.Input(shape=(4, 2))
|
|
310
|
+
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
|
|
311
|
+
|
|
312
|
+
# Create a model containing the test layer.
|
|
313
|
+
model = tf_keras.Model([query, value, mask_tensor], output)
|
|
314
|
+
|
|
315
|
+
# Generate data for the input (non-mask) tensors.
|
|
316
|
+
from_data = 10 * np.random.random_sample((batch_size, 4, 8))
|
|
317
|
+
to_data = 10 * np.random.random_sample((batch_size, 2, 8))
|
|
318
|
+
|
|
319
|
+
# Invoke the data with a random set of mask data. This should mask at
|
|
320
|
+
# least one element.
|
|
321
|
+
mask_data = np.random.randint(2, size=(batch_size, 4, 2))
|
|
322
|
+
masked_output_data = model.predict([from_data, to_data, mask_data])
|
|
323
|
+
|
|
324
|
+
# Invoke the same data, but with a null mask (where no elements are
|
|
325
|
+
# masked).
|
|
326
|
+
null_mask_data = np.ones((batch_size, 4, 2))
|
|
327
|
+
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
|
|
328
|
+
|
|
329
|
+
# Because one data is masked and one is not, the outputs should not be
|
|
330
|
+
# the same.
|
|
331
|
+
self.assertNotAllClose(masked_output_data, unmasked_output_data)
|
|
332
|
+
|
|
333
|
+
# Tests the layer with three inputs: Q, K, V.
|
|
334
|
+
key = tf_keras.Input(shape=(2, 8))
|
|
335
|
+
output = test_layer(
|
|
336
|
+
query, value=value, key=key, attention_mask=mask_tensor
|
|
337
|
+
)
|
|
338
|
+
model = tf_keras.Model([query, value, key, mask_tensor], output)
|
|
339
|
+
|
|
340
|
+
masked_output_data = model.predict(
|
|
341
|
+
[from_data, to_data, to_data, mask_data]
|
|
342
|
+
)
|
|
343
|
+
unmasked_output_data = model.predict(
|
|
344
|
+
[from_data, to_data, to_data, null_mask_data]
|
|
345
|
+
)
|
|
346
|
+
# Because one data is masked and one is not, the outputs should not be
|
|
347
|
+
# the same.
|
|
348
|
+
self.assertNotAllClose(masked_output_data, unmasked_output_data)
|
|
349
|
+
|
|
350
|
+
if use_bias:
|
|
351
|
+
self.assertLen(test_layer._query_dense.trainable_variables, 2)
|
|
352
|
+
self.assertLen(test_layer._output_dense.trainable_variables, 2)
|
|
353
|
+
else:
|
|
354
|
+
self.assertLen(test_layer._query_dense.trainable_variables, 1)
|
|
355
|
+
self.assertLen(test_layer._output_dense.trainable_variables, 1)
|
|
356
|
+
|
|
357
|
+
@parameterized.named_parameters(
|
|
358
|
+
("_mqa", 1),
|
|
359
|
+
("_gqa", 2),
|
|
360
|
+
)
|
|
361
|
+
def test_masked_attention_with_scores(self, num_kv_heads):
|
|
362
|
+
"""Test with a mask tensor."""
|
|
363
|
+
test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
|
|
364
|
+
num_heads=4, num_kv_heads=num_kv_heads, key_dim=2
|
|
365
|
+
)
|
|
366
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
|
367
|
+
batch_size = 3
|
|
368
|
+
query = tf_keras.Input(shape=(4, 8))
|
|
369
|
+
value = tf_keras.Input(shape=(2, 8))
|
|
370
|
+
mask_tensor = tf_keras.Input(shape=(4, 2))
|
|
371
|
+
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
|
|
372
|
+
|
|
373
|
+
# Create a model containing the test layer.
|
|
374
|
+
model = tf_keras.Model([query, value, mask_tensor], output)
|
|
375
|
+
|
|
376
|
+
# Generate data for the input (non-mask) tensors.
|
|
377
|
+
from_data = 10 * np.random.random_sample((batch_size, 4, 8))
|
|
378
|
+
to_data = 10 * np.random.random_sample((batch_size, 2, 8))
|
|
379
|
+
|
|
380
|
+
# Invoke the data with a random set of mask data. This should mask at
|
|
381
|
+
# least one element.
|
|
382
|
+
mask_data = np.random.randint(2, size=(batch_size, 4, 2))
|
|
383
|
+
masked_output_data = model.predict([from_data, to_data, mask_data])
|
|
384
|
+
|
|
385
|
+
# Invoke the same data, but with a null mask (where no elements are
|
|
386
|
+
# masked).
|
|
387
|
+
null_mask_data = np.ones((batch_size, 4, 2))
|
|
388
|
+
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
|
|
389
|
+
|
|
390
|
+
# Because one data is masked and one is not, the outputs should not be
|
|
391
|
+
# the same.
|
|
392
|
+
self.assertNotAllClose(masked_output_data, unmasked_output_data)
|
|
393
|
+
|
|
394
|
+
# Create a model containing attention scores.
|
|
395
|
+
output, scores = test_layer(
|
|
396
|
+
query=query,
|
|
397
|
+
value=value,
|
|
398
|
+
attention_mask=mask_tensor,
|
|
399
|
+
return_attention_scores=True,
|
|
400
|
+
)
|
|
401
|
+
model = tf_keras.Model([query, value, mask_tensor], [output, scores])
|
|
402
|
+
masked_output_data_score, masked_score = model.predict(
|
|
403
|
+
[from_data, to_data, mask_data]
|
|
404
|
+
)
|
|
405
|
+
unmasked_output_data_score, unmasked_score = model.predict(
|
|
406
|
+
[from_data, to_data, null_mask_data]
|
|
407
|
+
)
|
|
408
|
+
self.assertNotAllClose(masked_output_data_score, unmasked_output_data_score)
|
|
409
|
+
self.assertAllClose(masked_output_data, masked_output_data_score)
|
|
410
|
+
self.assertAllClose(unmasked_output_data, unmasked_output_data_score)
|
|
411
|
+
self.assertNotAllClose(masked_score, unmasked_score)
|
|
412
|
+
|
|
413
|
+
|
|
214
414
|
if __name__ == "__main__":
|
|
215
415
|
tf.test.main()
|
|
@@ -331,8 +331,8 @@ official/nlp/modeling/layers/moe.py,sha256=vG5jNwX2Vt3GjWDZ_viOCfgI77antXDAnBtho
|
|
|
331
331
|
official/nlp/modeling/layers/moe_test.py,sha256=NmLln8M820Z4VX9BYFmBppiv8_NKoTWWzNGZ-g6eEBQ,9414
|
|
332
332
|
official/nlp/modeling/layers/multi_channel_attention.py,sha256=_9py6IXTJpBM3lhK0Xh4zsmrO7S8xbQJ4CR5PcKaPk4,7322
|
|
333
333
|
official/nlp/modeling/layers/multi_channel_attention_test.py,sha256=X-kLm-w_jbITZgGNSKAjEPOtjH2xJABI-mavBn2jbtA,1922
|
|
334
|
-
official/nlp/modeling/layers/multi_query_attention.py,sha256=
|
|
335
|
-
official/nlp/modeling/layers/multi_query_attention_test.py,sha256=
|
|
334
|
+
official/nlp/modeling/layers/multi_query_attention.py,sha256=NvFBSiGZMdQnikRMSkG3kKX5e7nP2qvUt0nYFTajMqc,15384
|
|
335
|
+
official/nlp/modeling/layers/multi_query_attention_test.py,sha256=2KLIynj9rXOcuvMxQ7s1JWo0B_C31RhBrsxGLX4ijg8,16455
|
|
336
336
|
official/nlp/modeling/layers/on_device_embedding.py,sha256=RN01ud3FeInON_bXVtzItfTLivolbu45R6ajPp1_BIo,4582
|
|
337
337
|
official/nlp/modeling/layers/on_device_embedding_test.py,sha256=i9esQYblo_uC-UnSkRs8Xgr6BldZ48iKoZhN4vrSTXI,8589
|
|
338
338
|
official/nlp/modeling/layers/pack_optimization.py,sha256=0CoWos2h4dfqO3UvI2xLpHnWzvJpDiBJVHR2MTamLM4,10760
|
|
@@ -1248,9 +1248,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=yiAneltAW3NHSj3fUSvHNBjfq0MGZ
|
|
|
1248
1248
|
tensorflow_models/nlp/__init__.py,sha256=8uQd4wI6Zc4IJMPjtQifMeWVbPFkTxqYh66wfivCOL4,807
|
|
1249
1249
|
tensorflow_models/uplift/__init__.py,sha256=NzaweFf4ZmhRb2l_fuV6bP-2N8oSO3xu6xJqVb1UmpY,999
|
|
1250
1250
|
tensorflow_models/vision/__init__.py,sha256=ks420Ooqzi0hU7HnQpM5rylLaE-YcJdJkBx_umVaXlE,833
|
|
1251
|
-
tf_models_nightly-2.20.0.
|
|
1252
|
-
tf_models_nightly-2.20.0.
|
|
1253
|
-
tf_models_nightly-2.20.0.
|
|
1254
|
-
tf_models_nightly-2.20.0.
|
|
1255
|
-
tf_models_nightly-2.20.0.
|
|
1256
|
-
tf_models_nightly-2.20.0.
|
|
1251
|
+
tf_models_nightly-2.20.0.dev20251029.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
|
|
1252
|
+
tf_models_nightly-2.20.0.dev20251029.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
|
|
1253
|
+
tf_models_nightly-2.20.0.dev20251029.dist-info/METADATA,sha256=hTl-klu4MFPyEUTW4FX3TyByeXJs4Hp7BJpSjTWFaCw,1432
|
|
1254
|
+
tf_models_nightly-2.20.0.dev20251029.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
|
|
1255
|
+
tf_models_nightly-2.20.0.dev20251029.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
|
|
1256
|
+
tf_models_nightly-2.20.0.dev20251029.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|