tf-models-nightly 2.20.0.dev20251028__py2.py3-none-any.whl → 2.20.0.dev20251030__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tf-models-nightly might be problematic. Click here for more details.

@@ -18,10 +18,13 @@ Based on https://arxiv.org/pdf/1911.02150.pdf and
18
18
  https://arxiv.org/pdf/2305.13245.pdf.
19
19
  """
20
20
 
21
+ import math
21
22
  import string
22
23
  from typing import Optional, Sequence, Union
23
24
 
25
+ import gin
24
26
  import tensorflow as tf, tf_keras
27
+ from official.modeling import tf_utils
25
28
 
26
29
  _CHR_IDX = string.ascii_lowercase
27
30
 
@@ -78,7 +81,9 @@ def _get_output_shape(
78
81
  class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
79
82
  """Multi-query attention layer."""
80
83
 
81
- def __init__(self, num_kv_heads=None, **kwargs):
84
+ def __init__(
85
+ self, num_kv_heads=None, enable_gqa_optimization=False, **kwargs
86
+ ):
82
87
  # num_kv_heads defines the number of key/value heads. A value of 1 means
83
88
  # that the key/value heads are shared across all query heads. Any other
84
89
  # value must be less than num_heads and must divide num_heads exactly. If
@@ -86,6 +91,16 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
86
91
  # num_kv_heads.
87
92
  super().__init__(**kwargs)
88
93
  self._num_kv_heads = num_kv_heads or self._num_heads
94
+ # TODO(akandoor): Remove this flag once the GQA optimization is rolled out.
95
+ # This flag is used to enable order of K,G in the einsum equations.
96
+ # This optimization is only used in GQA, and is disabled by default.
97
+ # If enabled, the einsum equations are:
98
+ # 1. Dot product: "...SKH,...TKGH->...KGTS"
99
+ # 2. Combine: "...KGTS,...SKH->...TKGH"
100
+ # If disabled, the einsum equations are:
101
+ # 1. Dot product: "...SKH,...TKnH->...nKTS"
102
+ # 2. Combine: "...nKTS,...SKH->...TnKH"
103
+ self._enable_gqa_optimization = enable_gqa_optimization
89
104
  assert (
90
105
  self._num_kv_heads < self._num_heads
91
106
  ), "num_kv_heads must be less than num_heads."
@@ -143,8 +158,12 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
143
158
  output_dims = 2
144
159
  key_last_dims = [self._num_kv_heads, self._key_dim]
145
160
  value_last_dims = [self._num_kv_heads, self._value_dim]
146
- self._dot_product_equation = "...SKH,...TKnH->...nKTS"
147
- self._combine_equation = "...nKTS,...SKH->...TnKH"
161
+ if self._enable_gqa_optimization:
162
+ self._dot_product_equation = "...SKH,...TKGH->...KGTS"
163
+ self._combine_equation = "...KGTS,...SKH->...TKGH"
164
+ else:
165
+ self._dot_product_equation = "...SKH,...TKnH->...nKTS"
166
+ self._combine_equation = "...nKTS,...SKH->...TnKH"
148
167
 
149
168
  einsum_equation, bias_axes, output_rank = _build_proj_equation(
150
169
  free_dims=self._key_shape.rank - 1,
@@ -170,6 +189,9 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
170
189
  name="value",
171
190
  **self._get_common_kwargs_for_sublayer(),
172
191
  )
192
+ self._qkv_rank = (
193
+ output_rank if self._num_kv_heads > 1 else output_rank + 1
194
+ )
173
195
 
174
196
  def _compute_attention(
175
197
  self, query, key, value, attention_mask=None, training=None
@@ -211,3 +233,194 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
211
233
  ],
212
234
  )
213
235
  return attention_output, attention_scores
236
+
237
+
238
+ @tf_keras.utils.register_keras_serializable(package="Text")
239
+ @gin.configurable
240
+ class TalkingHeadsMultiQueryAttention(MultiHeadAttention):
241
+ """Implements Talking-Heads Attention combined with Multi-Query Attention.
242
+
243
+ See https://arxiv.org/pdf/2003.02436 for more details.
244
+ TODO(akandoor): Make num talking heads configurable. Currently, num talking
245
+ heads is fixed to num query heads.
246
+
247
+ This class inherits from MultiQueryAttention to get the MQA-specific
248
+ logic for __init__, get_config.
249
+
250
+ It then overrides _build_from_signature to add the talking-heads weights
251
+ and overrides _compute_attention to merge the MQA wrapper (for
252
+ reshaping) with the THA computation (for pre/post-softmax projections).
253
+ """
254
+
255
+ def _build_from_signature(
256
+ self,
257
+ query: Union[tf.Tensor, tf.TensorShape],
258
+ value: Union[tf.Tensor, tf.TensorShape],
259
+ key: Optional[Union[tf.Tensor, tf.TensorShape]] = None,
260
+ ):
261
+ """Builds layers and variables."""
262
+ # Call the parent (MultiQueryAttention) _build_from_signature.
263
+ super()._build_from_signature(query=query, value=value, key=key)
264
+ # Now, *after* all MQA setup is done, we add the THA setup logic.
265
+ qkv_rank = self._qkv_rank
266
+ # TalkingHeadsAttention logic to the MQA build logic.
267
+ num_batch_dims = qkv_rank - len(self._attention_axes) - 2
268
+ attn_scores_rank = num_batch_dims + 1 + len(self._attention_axes) * 2
269
+ scores_notation = _CHR_IDX[:attn_scores_rank]
270
+ projection_notation = scores_notation[num_batch_dims] + (
271
+ _CHR_IDX[attn_scores_rank])
272
+ projected_scores_notation = scores_notation[:num_batch_dims] + (
273
+ _CHR_IDX[attn_scores_rank] + scores_notation[num_batch_dims + 1:])
274
+ self._talking_heads_equation = "%s,%s->%s" % (
275
+ scores_notation, projection_notation, projected_scores_notation)
276
+
277
+ with tf.init_scope():
278
+ self._pre_softmax_weight = self.add_weight(
279
+ "pre_softmax_weight",
280
+ shape=(self._num_heads, self._num_heads),
281
+ initializer=tf_utils.clone_initializer(self._kernel_initializer),
282
+ regularizer=self._kernel_regularizer,
283
+ constraint=self._kernel_constraint,
284
+ dtype=self.dtype,
285
+ trainable=True)
286
+ self._post_softmax_weight = self.add_weight(
287
+ "post_softmax_weight",
288
+ shape=(self._num_heads, self._num_heads),
289
+ initializer=tf_utils.clone_initializer(self._kernel_initializer),
290
+ regularizer=self._kernel_regularizer,
291
+ constraint=self._kernel_constraint,
292
+ dtype=self.dtype,
293
+ trainable=True)
294
+
295
+ def _compute_attention(
296
+ self, query, key, value, attention_mask=None, training=None
297
+ ):
298
+ """Applies Dot-product attention, merging MQA wrapper and THA computation.
299
+
300
+ Args:
301
+ query: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
302
+ key: Projected key `Tensor` of shape `[B, T, N, key_dim]`.
303
+ value: Projected value `Tensor` of shape `[B, T, N, value_dim]`.
304
+ attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
305
+ attention to certain positions.
306
+ training: Python boolean indicating whether the layer should behave in
307
+ training mode (adding dropout) or in inference mode (doing nothing).
308
+
309
+ Returns:
310
+ attention_output: Multi-headed outputs of attention computation.
311
+ attention_scores: Multi-headed attention weights.
312
+ """
313
+ # This is the MQA "wrapper" logic for grouped queries
314
+ query_shape = tf.shape(query)
315
+ if self._num_kv_heads > 1:
316
+ query = tf.reshape(
317
+ query,
318
+ [
319
+ query_shape[0],
320
+ query_shape[1],
321
+ self._num_kv_heads,
322
+ self._num_heads // self._num_kv_heads,
323
+ query_shape[-1],
324
+ ],
325
+ )
326
+
327
+ # This is the THA "computation" logic
328
+ # Note: Applying scalar multiply at the smaller end of einsum improves
329
+ # XLA performance, but may introduce slight numeric differences in
330
+ # the Transformer attention head.
331
+ query = tf.multiply(
332
+ query, 1.0 / math.sqrt(float(self._key_dim))
333
+ )
334
+
335
+ # Note: self._dot_product_equation was set by _build_from_signature
336
+ # (from MQA) to be MQA-compatible.
337
+ attention_scores = tf.einsum(self._dot_product_equation, key, query)
338
+
339
+ # --- Talking-Heads modification for MQA ---
340
+ # The THA _talking_heads_equation expects scores of shape [B, N, T, S].
341
+ # The MQA _dot_product_equation produces [B, K, G, T, S].
342
+ # We must reshape before and after applying TH logic.
343
+ scores_shape = tf.shape(attention_scores)
344
+ if self._num_kv_heads > 1:
345
+ # Reshape from [B, K, G, T, S] to [B, N, T, S]
346
+ attention_scores = tf.reshape(
347
+ attention_scores,
348
+ [
349
+ scores_shape[0], # Batch
350
+ self._num_heads, # N = K * G
351
+ scores_shape[-2], # T
352
+ scores_shape[-1] # S
353
+ ]
354
+ )
355
+
356
+ # Apply linear projection before softmax
357
+ attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
358
+ self._pre_softmax_weight)
359
+
360
+ # Normalize the attention scores to probabilities.
361
+ attention_scores = self._masked_softmax(attention_scores, attention_mask)
362
+
363
+ # Apply linear projection after softmax
364
+ attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
365
+ self._post_softmax_weight)
366
+
367
+ # Reshape back to MQA-compatible shape [B, K, G, T, S]
368
+ # before the final combine_equation
369
+ if self._num_kv_heads > 1:
370
+ if self._enable_gqa_optimization:
371
+ attention_scores = tf.reshape(
372
+ attention_scores,
373
+ [
374
+ scores_shape[0], # B
375
+ self._num_kv_heads, # K
376
+ self._num_heads // self._num_kv_heads, # G
377
+ scores_shape[-2], # T
378
+ scores_shape[-1] # S
379
+ ]
380
+ )
381
+ else:
382
+ attention_scores = tf.reshape(
383
+ attention_scores,
384
+ [
385
+ scores_shape[0], # B
386
+ self._num_heads // self._num_kv_heads, # G
387
+ self._num_kv_heads, # K
388
+ scores_shape[-2], # T
389
+ scores_shape[-1] # S
390
+ ]
391
+ )
392
+
393
+ # This is actually dropping out entire tokens to attend to.
394
+ attention_scores_dropout = self._dropout_layer(
395
+ attention_scores, training=training)
396
+
397
+ # Note: self._combine_equation was set by _build_from_signature
398
+ # (from MQA) to be MQA-compatible.
399
+ attention_output = tf.einsum(self._combine_equation,
400
+ attention_scores_dropout, value)
401
+
402
+ # This is the MQA "wrapper" logic for grouped queries
403
+ if self._num_kv_heads > 1:
404
+ attention_output = tf.reshape(
405
+ attention_output,
406
+ [
407
+ query_shape[0],
408
+ query_shape[1],
409
+ self._num_heads,
410
+ tf.shape(attention_output)[-1],
411
+ ],
412
+ )
413
+ # We also need to reshape the final scores back to [B, N, T, S]
414
+ # for the return value.
415
+ attention_scores = tf.reshape(
416
+ attention_scores,
417
+ [
418
+ query_shape[0],
419
+ self._num_heads,
420
+ tf.shape(attention_scores)[-2],
421
+ tf.shape(attention_scores)[-1],
422
+ ],
423
+ )
424
+
425
+ return attention_output, attention_scores
426
+
@@ -211,5 +211,205 @@ class MultiQueryAttentionTest(tf.test.TestCase, parameterized.TestCase):
211
211
  self.assertNotAllClose(masked_score, unmasked_score)
212
212
 
213
213
 
214
+ class TalkingHeadsMultiQueryAttentionTest(
215
+ tf.test.TestCase, parameterized.TestCase
216
+ ):
217
+
218
+ @parameterized.named_parameters(
219
+ ("key_value_same_proj_mqa", 1, None, None, [40, 80]),
220
+ ("key_value_different_proj_mqa", 1, 32, 60, [40, 60]),
221
+ ("key_value_same_proj_gqa", 3, None, None, [40, 80]),
222
+ ("key_value_different_proj_gqa", 3, 32, 60, [40, 60]),
223
+ )
224
+ def test_non_masked_attention(
225
+ self, num_kv_heads, value_dim, output_shape, output_dims
226
+ ):
227
+ """Test that the attention layer can be created without a mask tensor."""
228
+ test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
229
+ num_heads=12,
230
+ num_kv_heads=num_kv_heads,
231
+ enable_gqa_optimization=num_kv_heads > 1,
232
+ key_dim=64,
233
+ value_dim=value_dim,
234
+ output_shape=output_shape,
235
+ )
236
+ # Create a 3-dimensional input (the first dimension is implicit).
237
+ query = tf_keras.Input(shape=(40, 80))
238
+ value = tf_keras.Input(shape=(20, 80))
239
+ output = test_layer(query=query, value=value)
240
+ self.assertEqual(output.shape.as_list(), [None] + output_dims)
241
+
242
+ @parameterized.named_parameters(
243
+ ("_mqa", 1),
244
+ ("_gqa", 3),
245
+ )
246
+ def test_non_masked_self_attention(self, num_kv_heads):
247
+ """Test with one input (self-attenntion) and no mask tensor."""
248
+ test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
249
+ num_heads=12,
250
+ num_kv_heads=num_kv_heads,
251
+ enable_gqa_optimization=num_kv_heads > 1,
252
+ key_dim=64,
253
+ )
254
+ # Create a 3-dimensional input (the first dimension is implicit).
255
+ query = tf_keras.Input(shape=(40, 80))
256
+ output = test_layer(query, query)
257
+ self.assertEqual(output.shape.as_list(), [None, 40, 80])
258
+
259
+ @parameterized.named_parameters(
260
+ ("_mqa", 1),
261
+ ("_gqa", 3),
262
+ )
263
+ def test_attention_scores(self, num_kv_heads):
264
+ """Test attention outputs with coefficients."""
265
+ test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
266
+ num_heads=12, num_kv_heads=num_kv_heads, key_dim=64
267
+ )
268
+ # Create a 3-dimensional input (the first dimension is implicit).
269
+ query = tf_keras.Input(shape=(40, 80))
270
+ output, coef = test_layer(query, query, return_attention_scores=True)
271
+ self.assertEqual(output.shape.as_list(), [None, 40, 80])
272
+ self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
273
+
274
+ @parameterized.named_parameters(
275
+ ("_mqa", 1),
276
+ ("_gqa", 3),
277
+ )
278
+ def test_attention_scores_with_values(self, num_kv_heads):
279
+ """Test attention outputs with coefficients."""
280
+ test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
281
+ num_heads=12, num_kv_heads=num_kv_heads, key_dim=64
282
+ )
283
+ # Create a 3-dimensional input (the first dimension is implicit).
284
+ query = tf_keras.Input(shape=(40, 80))
285
+ value = tf_keras.Input(shape=(60, 80))
286
+ output, coef = test_layer(query, value, return_attention_scores=True)
287
+ self.assertEqual(output.shape.as_list(), [None, 40, 80])
288
+ self.assertEqual(coef.shape.as_list(), [None, 12, 40, 60])
289
+
290
+ @parameterized.named_parameters(
291
+ ("with_bias_mqa", 1, True),
292
+ ("no_bias_mqa", 1, False),
293
+ ("with_bias_gqa", 2, True),
294
+ ("no_bias_gqa", 2, False),
295
+ )
296
+ def test_masked_attention(self, num_kv_heads, use_bias):
297
+ """Test with a mask tensor."""
298
+ test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
299
+ num_heads=4,
300
+ num_kv_heads=num_kv_heads,
301
+ enable_gqa_optimization=num_kv_heads > 1,
302
+ key_dim=2,
303
+ use_bias=use_bias,
304
+ )
305
+ # Create a 3-dimensional input (the first dimension is implicit).
306
+ batch_size = 3
307
+ query = tf_keras.Input(shape=(4, 8))
308
+ value = tf_keras.Input(shape=(2, 8))
309
+ mask_tensor = tf_keras.Input(shape=(4, 2))
310
+ output = test_layer(query=query, value=value, attention_mask=mask_tensor)
311
+
312
+ # Create a model containing the test layer.
313
+ model = tf_keras.Model([query, value, mask_tensor], output)
314
+
315
+ # Generate data for the input (non-mask) tensors.
316
+ from_data = 10 * np.random.random_sample((batch_size, 4, 8))
317
+ to_data = 10 * np.random.random_sample((batch_size, 2, 8))
318
+
319
+ # Invoke the data with a random set of mask data. This should mask at
320
+ # least one element.
321
+ mask_data = np.random.randint(2, size=(batch_size, 4, 2))
322
+ masked_output_data = model.predict([from_data, to_data, mask_data])
323
+
324
+ # Invoke the same data, but with a null mask (where no elements are
325
+ # masked).
326
+ null_mask_data = np.ones((batch_size, 4, 2))
327
+ unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
328
+
329
+ # Because one data is masked and one is not, the outputs should not be
330
+ # the same.
331
+ self.assertNotAllClose(masked_output_data, unmasked_output_data)
332
+
333
+ # Tests the layer with three inputs: Q, K, V.
334
+ key = tf_keras.Input(shape=(2, 8))
335
+ output = test_layer(
336
+ query, value=value, key=key, attention_mask=mask_tensor
337
+ )
338
+ model = tf_keras.Model([query, value, key, mask_tensor], output)
339
+
340
+ masked_output_data = model.predict(
341
+ [from_data, to_data, to_data, mask_data]
342
+ )
343
+ unmasked_output_data = model.predict(
344
+ [from_data, to_data, to_data, null_mask_data]
345
+ )
346
+ # Because one data is masked and one is not, the outputs should not be
347
+ # the same.
348
+ self.assertNotAllClose(masked_output_data, unmasked_output_data)
349
+
350
+ if use_bias:
351
+ self.assertLen(test_layer._query_dense.trainable_variables, 2)
352
+ self.assertLen(test_layer._output_dense.trainable_variables, 2)
353
+ else:
354
+ self.assertLen(test_layer._query_dense.trainable_variables, 1)
355
+ self.assertLen(test_layer._output_dense.trainable_variables, 1)
356
+
357
+ @parameterized.named_parameters(
358
+ ("_mqa", 1),
359
+ ("_gqa", 2),
360
+ )
361
+ def test_masked_attention_with_scores(self, num_kv_heads):
362
+ """Test with a mask tensor."""
363
+ test_layer = multi_query_attention.TalkingHeadsMultiQueryAttention(
364
+ num_heads=4, num_kv_heads=num_kv_heads, key_dim=2
365
+ )
366
+ # Create a 3-dimensional input (the first dimension is implicit).
367
+ batch_size = 3
368
+ query = tf_keras.Input(shape=(4, 8))
369
+ value = tf_keras.Input(shape=(2, 8))
370
+ mask_tensor = tf_keras.Input(shape=(4, 2))
371
+ output = test_layer(query=query, value=value, attention_mask=mask_tensor)
372
+
373
+ # Create a model containing the test layer.
374
+ model = tf_keras.Model([query, value, mask_tensor], output)
375
+
376
+ # Generate data for the input (non-mask) tensors.
377
+ from_data = 10 * np.random.random_sample((batch_size, 4, 8))
378
+ to_data = 10 * np.random.random_sample((batch_size, 2, 8))
379
+
380
+ # Invoke the data with a random set of mask data. This should mask at
381
+ # least one element.
382
+ mask_data = np.random.randint(2, size=(batch_size, 4, 2))
383
+ masked_output_data = model.predict([from_data, to_data, mask_data])
384
+
385
+ # Invoke the same data, but with a null mask (where no elements are
386
+ # masked).
387
+ null_mask_data = np.ones((batch_size, 4, 2))
388
+ unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
389
+
390
+ # Because one data is masked and one is not, the outputs should not be
391
+ # the same.
392
+ self.assertNotAllClose(masked_output_data, unmasked_output_data)
393
+
394
+ # Create a model containing attention scores.
395
+ output, scores = test_layer(
396
+ query=query,
397
+ value=value,
398
+ attention_mask=mask_tensor,
399
+ return_attention_scores=True,
400
+ )
401
+ model = tf_keras.Model([query, value, mask_tensor], [output, scores])
402
+ masked_output_data_score, masked_score = model.predict(
403
+ [from_data, to_data, mask_data]
404
+ )
405
+ unmasked_output_data_score, unmasked_score = model.predict(
406
+ [from_data, to_data, null_mask_data]
407
+ )
408
+ self.assertNotAllClose(masked_output_data_score, unmasked_output_data_score)
409
+ self.assertAllClose(masked_output_data, masked_output_data_score)
410
+ self.assertAllClose(unmasked_output_data, unmasked_output_data_score)
411
+ self.assertNotAllClose(masked_score, unmasked_score)
412
+
413
+
214
414
  if __name__ == "__main__":
215
415
  tf.test.main()
@@ -20,6 +20,7 @@ import tensorflow as tf, tf_keras
20
20
  from official.modeling import tf_utils
21
21
  from official.nlp.modeling.layers import block_sparse_attention
22
22
  from official.nlp.modeling.layers import multi_query_attention
23
+ from official.nlp.modeling.layers import talking_heads_attention
23
24
  from official.nlp.modeling.layers import util
24
25
 
25
26
 
@@ -118,6 +119,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
118
119
  linformer_dim=None,
119
120
  linformer_shared_kv_projection=True,
120
121
  lowrank_query_seq_proj_dim=None,
122
+ enable_talking_heads=False,
123
+ enable_gqa_optimization=False,
121
124
  **kwargs,
122
125
  ):
123
126
  """Initializes `TransformerEncoderBlock`.
@@ -202,6 +205,10 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
202
205
  keys and values.
203
206
  lowrank_query_seq_proj_dim: If set, applies a projection layer on query
204
207
  sequence to the given dimension. go/constformer-doc
208
+ enable_talking_heads: Enable talking heads as in
209
+ https://arxiv.org/pdf/2003.02436.
210
+ enable_gqa_optimization: Enable GQA optimization in multi-query attention.
211
+ This flag is valid only when num_kv_heads is set for GQA.
205
212
  **kwargs: keyword arguments.
206
213
  """
207
214
  util.filter_kwargs(kwargs)
@@ -244,6 +251,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
244
251
  self._linformer_dim = linformer_dim
245
252
  self._linformer_shared_kv_projection = linformer_shared_kv_projection
246
253
  self._lowrank_query_seq_proj_dim = lowrank_query_seq_proj_dim
254
+ self._enable_talking_heads = enable_talking_heads
255
+ self._enable_gqa_optimization = enable_gqa_optimization
247
256
  if (
248
257
  self._src_block_size is not None
249
258
  and self._num_kv_heads is not None
@@ -314,6 +323,11 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
314
323
  bias_constraint=self._bias_constraint,
315
324
  )
316
325
  if self._src_block_size is not None:
326
+ if self._enable_talking_heads:
327
+ raise ValueError(
328
+ "Block sparse attention does not support talking heads. Please"
329
+ " set enable_talking_heads to False."
330
+ )
317
331
  attention_layer_kwargs.update(
318
332
  src_block_size=self._src_block_size,
319
333
  tgt_block_size=self._tgt_block_size,
@@ -326,9 +340,22 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
326
340
  elif self._num_kv_heads is not None:
327
341
  attention_layer_kwargs.update(
328
342
  num_kv_heads=self._num_kv_heads,
343
+ enable_gqa_optimization=self._enable_gqa_optimization,
329
344
  name="multi_query_attention",
330
345
  )
331
- attention_fn = multi_query_attention.MultiHeadAttention
346
+ if self._enable_talking_heads:
347
+ attention_fn = (
348
+ multi_query_attention.TalkingHeadsMultiQueryAttention
349
+ )
350
+ else:
351
+ attention_fn = multi_query_attention.MultiHeadAttention
352
+ elif self._enable_talking_heads:
353
+ attention_layer_kwargs.update(
354
+ name="talking_heads_attention",
355
+ )
356
+ attention_fn = (
357
+ talking_heads_attention.TalkingHeadsAttention
358
+ )
332
359
  else:
333
360
  attention_fn = tf_keras.layers.MultiHeadAttention
334
361
  self._attention_layer = attention_fn(
@@ -743,35 +743,46 @@ class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
743
743
  num_attention_heads=num_attention_heads,
744
744
  inner_dim=2048,
745
745
  inner_activation='relu',
746
- return_attention_scores=return_attention_scores)
746
+ return_attention_scores=return_attention_scores,
747
+ )
747
748
  # Create a 3-dimensional input (the first dimension is implicit).
748
749
  data_tensor = tf_keras.Input(shape=(sequence_length, width))
749
750
  output_tensor = test_layer(data_tensor)
750
751
 
751
752
  expected_layer_output_shape = [None, sequence_length, width]
752
753
  expected_attention_scores_shape = [
753
- None, num_attention_heads, sequence_length, sequence_length
754
+ None,
755
+ num_attention_heads,
756
+ sequence_length,
757
+ sequence_length,
754
758
  ]
755
759
 
756
760
  if return_attention_scores:
757
761
  self.assertIsInstance(output_tensor, tuple)
758
762
  self.assertLen(output_tensor, 2)
759
763
  # First is the standard output.
760
- self.assertEqual(output_tensor[0].shape.as_list(),
761
- expected_layer_output_shape)
764
+ self.assertEqual(
765
+ output_tensor[0].shape.as_list(), expected_layer_output_shape
766
+ )
762
767
  # Second is the attention scores.
763
- self.assertEqual(output_tensor[1].shape.as_list(),
764
- expected_attention_scores_shape)
768
+ self.assertEqual(
769
+ output_tensor[1].shape.as_list(), expected_attention_scores_shape
770
+ )
765
771
  else:
766
772
  # Only the standard layer output.
767
- self.assertEqual(output_tensor.shape.as_list(),
768
- expected_layer_output_shape)
773
+ self.assertEqual(
774
+ output_tensor.shape.as_list(), expected_layer_output_shape
775
+ )
769
776
 
770
777
  @parameterized.named_parameters(
771
778
  ('mqa', 1),
772
779
  ('gqa', 4),
780
+ ('talking_heads_mqa', 1, True),
781
+ ('talking_heads_gqa', 4, True),
773
782
  )
774
- def test_attention_with_kv_heads(self, num_kv_heads):
783
+ def test_attention_with_kv_heads(
784
+ self, num_kv_heads, enable_talking_heads=False
785
+ ):
775
786
  num_attention_heads = 8
776
787
  sequence_length = 21
777
788
  width = 80
@@ -782,6 +793,8 @@ class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
782
793
  inner_activation='relu',
783
794
  return_attention_scores=True,
784
795
  num_kv_heads=num_kv_heads,
796
+ enable_talking_heads=enable_talking_heads,
797
+ enable_gqa_optimization=True,
785
798
  )
786
799
  # Create a 3-dimensional input (the first dimension is implicit).
787
800
  data_tensor = tf_keras.Input(shape=(sequence_length, width))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.20.0.dev20251028
3
+ Version: 2.20.0.dev20251030
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -331,8 +331,8 @@ official/nlp/modeling/layers/moe.py,sha256=vG5jNwX2Vt3GjWDZ_viOCfgI77antXDAnBtho
331
331
  official/nlp/modeling/layers/moe_test.py,sha256=NmLln8M820Z4VX9BYFmBppiv8_NKoTWWzNGZ-g6eEBQ,9414
332
332
  official/nlp/modeling/layers/multi_channel_attention.py,sha256=_9py6IXTJpBM3lhK0Xh4zsmrO7S8xbQJ4CR5PcKaPk4,7322
333
333
  official/nlp/modeling/layers/multi_channel_attention_test.py,sha256=X-kLm-w_jbITZgGNSKAjEPOtjH2xJABI-mavBn2jbtA,1922
334
- official/nlp/modeling/layers/multi_query_attention.py,sha256=Tsv1tg3GMcze9cpCpNddamL4dHWFQr4uyqO4n4cAJJk,7111
335
- official/nlp/modeling/layers/multi_query_attention_test.py,sha256=HplLhNbhr24r3wBD5Dzz7Y_VxylQ5DEa1prmXYmSN3g,8512
334
+ official/nlp/modeling/layers/multi_query_attention.py,sha256=NvFBSiGZMdQnikRMSkG3kKX5e7nP2qvUt0nYFTajMqc,15384
335
+ official/nlp/modeling/layers/multi_query_attention_test.py,sha256=2KLIynj9rXOcuvMxQ7s1JWo0B_C31RhBrsxGLX4ijg8,16455
336
336
  official/nlp/modeling/layers/on_device_embedding.py,sha256=RN01ud3FeInON_bXVtzItfTLivolbu45R6ajPp1_BIo,4582
337
337
  official/nlp/modeling/layers/on_device_embedding_test.py,sha256=i9esQYblo_uC-UnSkRs8Xgr6BldZ48iKoZhN4vrSTXI,8589
338
338
  official/nlp/modeling/layers/pack_optimization.py,sha256=0CoWos2h4dfqO3UvI2xLpHnWzvJpDiBJVHR2MTamLM4,10760
@@ -363,8 +363,8 @@ official/nlp/modeling/layers/tn_expand_condense_test.py,sha256=QWq1dJqQUPe5n69K3
363
363
  official/nlp/modeling/layers/tn_transformer_expand_condense.py,sha256=omzTkCBEk2TOkHEYDEBwve6WsOitX7IIJHzeKXdqDq0,11012
364
364
  official/nlp/modeling/layers/tn_transformer_test.py,sha256=pSCONEZRI4J9_6QLTJ3g_ynUYLrRXsJ1c2YMSiOV_60,8893
365
365
  official/nlp/modeling/layers/transformer.py,sha256=VjUO-gVj_PnavbT_vSrg5NDKMr0SRSiqSg5ktd42m5M,20087
366
- official/nlp/modeling/layers/transformer_encoder_block.py,sha256=5GJgtK1mdTxMDYVWfUoBAI_GvjDL0zO9AWtKCovSZiU,28789
367
- official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=7yBgv1UNmfOFre6txF_Rq93RLc1TJwnJ7-Dz4p55sy4,37602
366
+ official/nlp/modeling/layers/transformer_encoder_block.py,sha256=BiL8ErBs-m0UZ6ONVJV0ncfWX3LhMhPetIhfH2VvuD4,29910
367
+ official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=g7oMDPvwg6Fv75SBdm6BInXPI8r5GcItBRjLFGuObyg,37821
368
368
  official/nlp/modeling/layers/transformer_scaffold.py,sha256=qmzhCJvbbFVF9zDqnfO4Zs2JDXwKhK7iEBOhsU6-KpQ,15704
369
369
  official/nlp/modeling/layers/transformer_scaffold_test.py,sha256=dRJwesTBKm-mF5mDHrHfVpVNnxa-Wx-fj_4ZHDPTpE0,19920
370
370
  official/nlp/modeling/layers/transformer_test.py,sha256=-pk9cdz9UlMpCIkGRkCKsMmjdRGi0seySaaB_2dwmXw,5522
@@ -1248,9 +1248,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=yiAneltAW3NHSj3fUSvHNBjfq0MGZ
1248
1248
  tensorflow_models/nlp/__init__.py,sha256=8uQd4wI6Zc4IJMPjtQifMeWVbPFkTxqYh66wfivCOL4,807
1249
1249
  tensorflow_models/uplift/__init__.py,sha256=NzaweFf4ZmhRb2l_fuV6bP-2N8oSO3xu6xJqVb1UmpY,999
1250
1250
  tensorflow_models/vision/__init__.py,sha256=ks420Ooqzi0hU7HnQpM5rylLaE-YcJdJkBx_umVaXlE,833
1251
- tf_models_nightly-2.20.0.dev20251028.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1252
- tf_models_nightly-2.20.0.dev20251028.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1253
- tf_models_nightly-2.20.0.dev20251028.dist-info/METADATA,sha256=uajPDCXLRHjjP9c1TuxbTh2QH1LBO84MLfydg3W5J50,1432
1254
- tf_models_nightly-2.20.0.dev20251028.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1255
- tf_models_nightly-2.20.0.dev20251028.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1256
- tf_models_nightly-2.20.0.dev20251028.dist-info/RECORD,,
1251
+ tf_models_nightly-2.20.0.dev20251030.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1252
+ tf_models_nightly-2.20.0.dev20251030.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1253
+ tf_models_nightly-2.20.0.dev20251030.dist-info/METADATA,sha256=I2M0rXHA_2EU7FPchNNQCp9kEk7kXD5A35__qN9eB0E,1432
1254
+ tf_models_nightly-2.20.0.dev20251030.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1255
+ tf_models_nightly-2.20.0.dev20251030.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1256
+ tf_models_nightly-2.20.0.dev20251030.dist-info/RECORD,,