tf-models-nightly 2.18.0.dev20240916__py2.py3-none-any.whl → 2.18.0.dev20240917__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -114,6 +114,7 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
114
114
  tgt_block_size=None,
115
115
  use_sigmoid_attn=False,
116
116
  sigmoid_attn_bias=None,
117
+ linformer_dim=None,
117
118
  **kwargs):
118
119
  """Initializes `TransformerEncoderBlock`.
119
120
 
@@ -191,6 +192,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
191
192
  `block_sparse_attention.MultiHeadAttention`
192
193
  sigmoid_attn_bias: This param is only used in
193
194
  `block_sparse_attention.MultiHeadAttention`
195
+ linformer_dim: Applies low-rank factorization on keys/values as in
196
+ https://arxiv.org/pdf/2006.04768.
194
197
  **kwargs: keyword arguments.
195
198
  """
196
199
  util.filter_kwargs(kwargs)
@@ -230,6 +233,7 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
230
233
  self._tgt_block_size = tgt_block_size
231
234
  self._use_sigmoid_attn = use_sigmoid_attn
232
235
  self._sigmoid_attn_bias = sigmoid_attn_bias
236
+ self._linformer_dim = linformer_dim
233
237
  if self._num_kv_heads is not None and self._src_block_size is not None:
234
238
  raise ValueError(
235
239
  "Block sparse attention does not support Multi-query attention."
@@ -366,16 +370,31 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
366
370
  name="output",
367
371
  kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
368
372
  bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
369
- **common_kwargs)
373
+ **common_kwargs,
374
+ )
370
375
  self._output_dropout = tf_keras.layers.Dropout(
371
- rate=self._output_dropout_rate)
376
+ rate=self._output_dropout_rate
377
+ )
372
378
  # Use float32 in layernorm for numeric stability.
373
379
  self._output_layer_norm = tf_keras.layers.LayerNormalization(
374
380
  name="output_layer_norm",
375
381
  axis=-1,
376
382
  epsilon=self._norm_epsilon,
377
- dtype=tf.float32)
378
-
383
+ dtype=tf.float32,
384
+ )
385
+ if self._linformer_dim is not None:
386
+ # Current implementation uses the same weights for keys and values.
387
+ # TODO(akandoor): Explore using different weights for keys and values.
388
+ self._lowrank_kv_projection = tf_keras.layers.EinsumDense(
389
+ "...bc,cd->...bd",
390
+ output_shape=(None, self._linformer_dim),
391
+ kernel_initializer=tf_utils.clone_initializer(
392
+ self._kernel_initializer
393
+ ),
394
+ bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
395
+ name="lowrank_kv_projection",
396
+ **common_kwargs,
397
+ )
379
398
  super().build(input_shape)
380
399
 
381
400
  def get_config(self):
@@ -480,6 +499,19 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
480
499
  if key_value is None:
481
500
  key_value = input_tensor
482
501
 
502
+ if self._linformer_dim is not None:
503
+ if attention_mask is not None:
504
+ # Applying mask before the low rank factorization so that padding is
505
+ # accounted for.
506
+ query_mask = tf.cast(attention_mask[:, :, 0], dtype=target_tensor.dtype)
507
+ target_tensor = target_tensor * tf.expand_dims(query_mask, axis=-1)
508
+ key_mask = tf.cast(attention_mask[:, 0, :], dtype=target_tensor.dtype)
509
+ key_value = key_value * tf.expand_dims(key_mask, axis=-1)
510
+ attention_mask = None
511
+ key_value = tf.transpose(key_value, [0, 2, 1])
512
+ key_value = self._lowrank_kv_projection(key_value)
513
+ key_value = tf.transpose(key_value, [0, 2, 1])
514
+
483
515
  if self._return_attention_scores:
484
516
  attention_output, attention_scores = self._attention_layer(
485
517
  query=target_tensor,
@@ -800,6 +800,42 @@ class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
800
800
  output_tensor[1].shape.as_list(), expected_attention_scores_shape
801
801
  )
802
802
 
803
+ def test_low_rank_attention(self):
804
+ num_attention_heads = 8
805
+ sequence_length = 21
806
+ linformer_dim = 7
807
+ width = 80
808
+
809
+ test_layer = TransformerEncoderBlock(
810
+ num_attention_heads=num_attention_heads,
811
+ inner_dim=2048,
812
+ inner_activation='relu',
813
+ return_attention_scores=True,
814
+ linformer_dim=linformer_dim,
815
+ )
816
+ # Create a 3-dimensional input (the first dimension is implicit).
817
+ data_tensor = tf_keras.Input(shape=(sequence_length, width))
818
+ output_tensor = test_layer(data_tensor)
819
+
820
+ expected_layer_output_shape = [None, sequence_length, width]
821
+ expected_attention_scores_shape = [
822
+ None,
823
+ num_attention_heads,
824
+ sequence_length,
825
+ linformer_dim,
826
+ ]
827
+
828
+ self.assertIsInstance(output_tensor, tuple)
829
+ self.assertLen(output_tensor, 2)
830
+ # First is the standard output.
831
+ self.assertEqual(
832
+ output_tensor[0].shape.as_list(), expected_layer_output_shape
833
+ )
834
+ # Second is the attention scores.
835
+ self.assertEqual(
836
+ output_tensor[1].shape.as_list(), expected_attention_scores_shape
837
+ )
838
+
803
839
 
804
840
  if __name__ == '__main__':
805
841
  tf.test.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.18.0.dev20240916
3
+ Version: 2.18.0.dev20240917
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -363,8 +363,8 @@ official/nlp/modeling/layers/tn_expand_condense_test.py,sha256=J52mXzoiuaXfR61kh
363
363
  official/nlp/modeling/layers/tn_transformer_expand_condense.py,sha256=gbGJOrgxJd1SyMGB6ME04FSxuZfHqsi94Xxt23l7368,11032
364
364
  official/nlp/modeling/layers/tn_transformer_test.py,sha256=Fh-EDRoAkhO7ccD3w3FsJHC51MnZySv8jBlHYnvKZMc,8893
365
365
  official/nlp/modeling/layers/transformer.py,sha256=yofIEOjZpcvDmHbcjBmkZrl5iSe6pLtMsetNbXmxDnY,20087
366
- official/nlp/modeling/layers/transformer_encoder_block.py,sha256=n7_HgFjCye7ZNxzQ67CtgboDKPIE-28796Y2aW8Zk_U,22566
367
- official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=5B_h8iNweUiRJR2IH1zxFelsfhVPEJJ4dEzL_pHPjI0,30968
366
+ official/nlp/modeling/layers/transformer_encoder_block.py,sha256=dxUCn9LckIJCpxJ8DRmiAU-4ycCmddXBTifLdihDmiU,24047
367
+ official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=nbrfktOe0_WNhYYk0IlEJPf5d-9xtBoi2wDdO_FWF_k,32068
368
368
  official/nlp/modeling/layers/transformer_scaffold.py,sha256=m8TF4geBkm8-VJQiTpzMI6FSJZry6oa2vPO3FXCCClE,15704
369
369
  official/nlp/modeling/layers/transformer_scaffold_test.py,sha256=pqUGldhmAKROrd4eoCWmHNtKOdCO6PH_-EigcYnvIpE,19920
370
370
  official/nlp/modeling/layers/transformer_test.py,sha256=kC_9NcLbJnBbuTaE_7BW60EF8xG_QUoICj0t0gS7O4Q,5522
@@ -1222,9 +1222,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
1222
1222
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1223
1223
  tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
1224
1224
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1225
- tf_models_nightly-2.18.0.dev20240916.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1226
- tf_models_nightly-2.18.0.dev20240916.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1227
- tf_models_nightly-2.18.0.dev20240916.dist-info/METADATA,sha256=3n7Gfhr0DnLjSH7idnMTxIGrR_G8evj10Yp7riWCjTo,1432
1228
- tf_models_nightly-2.18.0.dev20240916.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1229
- tf_models_nightly-2.18.0.dev20240916.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1230
- tf_models_nightly-2.18.0.dev20240916.dist-info/RECORD,,
1225
+ tf_models_nightly-2.18.0.dev20240917.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1226
+ tf_models_nightly-2.18.0.dev20240917.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1227
+ tf_models_nightly-2.18.0.dev20240917.dist-info/METADATA,sha256=l3aNnMlgUyV26Zw7dWix41njTiJ4a6o7gA-8SiI6Qq4,1432
1228
+ tf_models_nightly-2.18.0.dev20240917.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1229
+ tf_models_nightly-2.18.0.dev20240917.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1230
+ tf_models_nightly-2.18.0.dev20240917.dist-info/RECORD,,