tf-models-nightly 2.20.0.dev20251029__py2.py3-none-any.whl → 2.20.0.dev20251031__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tf-models-nightly might be problematic. Click here for more details.

@@ -20,6 +20,7 @@ import tensorflow as tf, tf_keras
20
20
  from official.modeling import tf_utils
21
21
  from official.nlp.modeling.layers import block_sparse_attention
22
22
  from official.nlp.modeling.layers import multi_query_attention
23
+ from official.nlp.modeling.layers import talking_heads_attention
23
24
  from official.nlp.modeling.layers import util
24
25
 
25
26
 
@@ -118,6 +119,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
118
119
  linformer_dim=None,
119
120
  linformer_shared_kv_projection=True,
120
121
  lowrank_query_seq_proj_dim=None,
122
+ enable_talking_heads=False,
123
+ enable_gqa_optimization=False,
121
124
  **kwargs,
122
125
  ):
123
126
  """Initializes `TransformerEncoderBlock`.
@@ -202,6 +205,10 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
202
205
  keys and values.
203
206
  lowrank_query_seq_proj_dim: If set, applies a projection layer on query
204
207
  sequence to the given dimension. go/constformer-doc
208
+ enable_talking_heads: Enable talking heads as in
209
+ https://arxiv.org/pdf/2003.02436.
210
+ enable_gqa_optimization: Enable GQA optimization in multi-query attention.
211
+ This flag is valid only when num_kv_heads is set for GQA.
205
212
  **kwargs: keyword arguments.
206
213
  """
207
214
  util.filter_kwargs(kwargs)
@@ -244,6 +251,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
244
251
  self._linformer_dim = linformer_dim
245
252
  self._linformer_shared_kv_projection = linformer_shared_kv_projection
246
253
  self._lowrank_query_seq_proj_dim = lowrank_query_seq_proj_dim
254
+ self._enable_talking_heads = enable_talking_heads
255
+ self._enable_gqa_optimization = enable_gqa_optimization
247
256
  if (
248
257
  self._src_block_size is not None
249
258
  and self._num_kv_heads is not None
@@ -314,6 +323,11 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
314
323
  bias_constraint=self._bias_constraint,
315
324
  )
316
325
  if self._src_block_size is not None:
326
+ if self._enable_talking_heads:
327
+ raise ValueError(
328
+ "Block sparse attention does not support talking heads. Please"
329
+ " set enable_talking_heads to False."
330
+ )
317
331
  attention_layer_kwargs.update(
318
332
  src_block_size=self._src_block_size,
319
333
  tgt_block_size=self._tgt_block_size,
@@ -326,9 +340,22 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
326
340
  elif self._num_kv_heads is not None:
327
341
  attention_layer_kwargs.update(
328
342
  num_kv_heads=self._num_kv_heads,
343
+ enable_gqa_optimization=self._enable_gqa_optimization,
329
344
  name="multi_query_attention",
330
345
  )
331
- attention_fn = multi_query_attention.MultiHeadAttention
346
+ if self._enable_talking_heads:
347
+ attention_fn = (
348
+ multi_query_attention.TalkingHeadsMultiQueryAttention
349
+ )
350
+ else:
351
+ attention_fn = multi_query_attention.MultiHeadAttention
352
+ elif self._enable_talking_heads:
353
+ attention_layer_kwargs.update(
354
+ name="talking_heads_attention",
355
+ )
356
+ attention_fn = (
357
+ talking_heads_attention.TalkingHeadsAttention
358
+ )
332
359
  else:
333
360
  attention_fn = tf_keras.layers.MultiHeadAttention
334
361
  self._attention_layer = attention_fn(
@@ -743,35 +743,46 @@ class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
743
743
  num_attention_heads=num_attention_heads,
744
744
  inner_dim=2048,
745
745
  inner_activation='relu',
746
- return_attention_scores=return_attention_scores)
746
+ return_attention_scores=return_attention_scores,
747
+ )
747
748
  # Create a 3-dimensional input (the first dimension is implicit).
748
749
  data_tensor = tf_keras.Input(shape=(sequence_length, width))
749
750
  output_tensor = test_layer(data_tensor)
750
751
 
751
752
  expected_layer_output_shape = [None, sequence_length, width]
752
753
  expected_attention_scores_shape = [
753
- None, num_attention_heads, sequence_length, sequence_length
754
+ None,
755
+ num_attention_heads,
756
+ sequence_length,
757
+ sequence_length,
754
758
  ]
755
759
 
756
760
  if return_attention_scores:
757
761
  self.assertIsInstance(output_tensor, tuple)
758
762
  self.assertLen(output_tensor, 2)
759
763
  # First is the standard output.
760
- self.assertEqual(output_tensor[0].shape.as_list(),
761
- expected_layer_output_shape)
764
+ self.assertEqual(
765
+ output_tensor[0].shape.as_list(), expected_layer_output_shape
766
+ )
762
767
  # Second is the attention scores.
763
- self.assertEqual(output_tensor[1].shape.as_list(),
764
- expected_attention_scores_shape)
768
+ self.assertEqual(
769
+ output_tensor[1].shape.as_list(), expected_attention_scores_shape
770
+ )
765
771
  else:
766
772
  # Only the standard layer output.
767
- self.assertEqual(output_tensor.shape.as_list(),
768
- expected_layer_output_shape)
773
+ self.assertEqual(
774
+ output_tensor.shape.as_list(), expected_layer_output_shape
775
+ )
769
776
 
770
777
  @parameterized.named_parameters(
771
778
  ('mqa', 1),
772
779
  ('gqa', 4),
780
+ ('talking_heads_mqa', 1, True),
781
+ ('talking_heads_gqa', 4, True),
773
782
  )
774
- def test_attention_with_kv_heads(self, num_kv_heads):
783
+ def test_attention_with_kv_heads(
784
+ self, num_kv_heads, enable_talking_heads=False
785
+ ):
775
786
  num_attention_heads = 8
776
787
  sequence_length = 21
777
788
  width = 80
@@ -782,6 +793,8 @@ class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
782
793
  inner_activation='relu',
783
794
  return_attention_scores=True,
784
795
  num_kv_heads=num_kv_heads,
796
+ enable_talking_heads=enable_talking_heads,
797
+ enable_gqa_optimization=True,
785
798
  )
786
799
  # Create a 3-dimensional input (the first dimension is implicit).
787
800
  data_tensor = tf_keras.Input(shape=(sequence_length, width))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.20.0.dev20251029
3
+ Version: 2.20.0.dev20251031
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -363,8 +363,8 @@ official/nlp/modeling/layers/tn_expand_condense_test.py,sha256=QWq1dJqQUPe5n69K3
363
363
  official/nlp/modeling/layers/tn_transformer_expand_condense.py,sha256=omzTkCBEk2TOkHEYDEBwve6WsOitX7IIJHzeKXdqDq0,11012
364
364
  official/nlp/modeling/layers/tn_transformer_test.py,sha256=pSCONEZRI4J9_6QLTJ3g_ynUYLrRXsJ1c2YMSiOV_60,8893
365
365
  official/nlp/modeling/layers/transformer.py,sha256=VjUO-gVj_PnavbT_vSrg5NDKMr0SRSiqSg5ktd42m5M,20087
366
- official/nlp/modeling/layers/transformer_encoder_block.py,sha256=5GJgtK1mdTxMDYVWfUoBAI_GvjDL0zO9AWtKCovSZiU,28789
367
- official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=7yBgv1UNmfOFre6txF_Rq93RLc1TJwnJ7-Dz4p55sy4,37602
366
+ official/nlp/modeling/layers/transformer_encoder_block.py,sha256=BiL8ErBs-m0UZ6ONVJV0ncfWX3LhMhPetIhfH2VvuD4,29910
367
+ official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=g7oMDPvwg6Fv75SBdm6BInXPI8r5GcItBRjLFGuObyg,37821
368
368
  official/nlp/modeling/layers/transformer_scaffold.py,sha256=qmzhCJvbbFVF9zDqnfO4Zs2JDXwKhK7iEBOhsU6-KpQ,15704
369
369
  official/nlp/modeling/layers/transformer_scaffold_test.py,sha256=dRJwesTBKm-mF5mDHrHfVpVNnxa-Wx-fj_4ZHDPTpE0,19920
370
370
  official/nlp/modeling/layers/transformer_test.py,sha256=-pk9cdz9UlMpCIkGRkCKsMmjdRGi0seySaaB_2dwmXw,5522
@@ -1248,9 +1248,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=yiAneltAW3NHSj3fUSvHNBjfq0MGZ
1248
1248
  tensorflow_models/nlp/__init__.py,sha256=8uQd4wI6Zc4IJMPjtQifMeWVbPFkTxqYh66wfivCOL4,807
1249
1249
  tensorflow_models/uplift/__init__.py,sha256=NzaweFf4ZmhRb2l_fuV6bP-2N8oSO3xu6xJqVb1UmpY,999
1250
1250
  tensorflow_models/vision/__init__.py,sha256=ks420Ooqzi0hU7HnQpM5rylLaE-YcJdJkBx_umVaXlE,833
1251
- tf_models_nightly-2.20.0.dev20251029.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1252
- tf_models_nightly-2.20.0.dev20251029.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1253
- tf_models_nightly-2.20.0.dev20251029.dist-info/METADATA,sha256=hTl-klu4MFPyEUTW4FX3TyByeXJs4Hp7BJpSjTWFaCw,1432
1254
- tf_models_nightly-2.20.0.dev20251029.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1255
- tf_models_nightly-2.20.0.dev20251029.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1256
- tf_models_nightly-2.20.0.dev20251029.dist-info/RECORD,,
1251
+ tf_models_nightly-2.20.0.dev20251031.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1252
+ tf_models_nightly-2.20.0.dev20251031.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1253
+ tf_models_nightly-2.20.0.dev20251031.dist-info/METADATA,sha256=Nkc6PAbFGnKBfKCfhkqghJ4risJv55iE9dKZblkfEq0,1432
1254
+ tf_models_nightly-2.20.0.dev20251031.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1255
+ tf_models_nightly-2.20.0.dev20251031.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1256
+ tf_models_nightly-2.20.0.dev20251031.dist-info/RECORD,,