tf-models-nightly 2.18.0.dev20240916__py2.py3-none-any.whl → 2.18.0.dev20240918__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -114,6 +114,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
114
114
  tgt_block_size=None,
115
115
  use_sigmoid_attn=False,
116
116
  sigmoid_attn_bias=None,
117
+ linformer_dim=None,
118
+ linformer_shared_kv_projection=True,
117
119
  **kwargs):
118
120
  """Initializes `TransformerEncoderBlock`.
119
121
 
@@ -191,6 +193,10 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
191
193
  `block_sparse_attention.MultiHeadAttention`
192
194
  sigmoid_attn_bias: This param is only used in
193
195
  `block_sparse_attention.MultiHeadAttention`
196
+ linformer_dim: Applies low-rank factorization on keys/values as in
197
+ https://arxiv.org/pdf/2006.04768.
198
+ linformer_shared_kv_projection: If set, projection layer is shared for
199
+ keys and values.
194
200
  **kwargs: keyword arguments.
195
201
  """
196
202
  util.filter_kwargs(kwargs)
@@ -230,6 +236,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
230
236
  self._tgt_block_size = tgt_block_size
231
237
  self._use_sigmoid_attn = use_sigmoid_attn
232
238
  self._sigmoid_attn_bias = sigmoid_attn_bias
239
+ self._linformer_dim = linformer_dim
240
+ self._linformer_shared_kv_projection = linformer_shared_kv_projection
233
241
  if self._num_kv_heads is not None and self._src_block_size is not None:
234
242
  raise ValueError(
235
243
  "Block sparse attention does not support Multi-query attention."
@@ -366,16 +374,33 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
366
374
  name="output",
367
375
  kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
368
376
  bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
369
- **common_kwargs)
377
+ **common_kwargs,
378
+ )
370
379
  self._output_dropout = tf_keras.layers.Dropout(
371
- rate=self._output_dropout_rate)
380
+ rate=self._output_dropout_rate
381
+ )
372
382
  # Use float32 in layernorm for numeric stability.
373
383
  self._output_layer_norm = tf_keras.layers.LayerNormalization(
374
384
  name="output_layer_norm",
375
385
  axis=-1,
376
386
  epsilon=self._norm_epsilon,
377
- dtype=tf.float32)
378
-
387
+ dtype=tf.float32,
388
+ )
389
+ if self._linformer_dim is not None:
390
+ if self._linformer_shared_kv_projection:
391
+ low_rank_dim = self._linformer_dim
392
+ else:
393
+ low_rank_dim = 2 * self._linformer_dim
394
+ self._lowrank_kv_projection = tf_keras.layers.EinsumDense(
395
+ "...bc,cd->...bd",
396
+ output_shape=(None, low_rank_dim),
397
+ kernel_initializer=tf_utils.clone_initializer(
398
+ self._kernel_initializer
399
+ ),
400
+ bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
401
+ name="lowrank_kv_projection",
402
+ **common_kwargs,
403
+ )
379
404
  super().build(input_shape)
380
405
 
381
406
  def get_config(self):
@@ -425,6 +450,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
425
450
  "tgt_block_size": self._tgt_block_size,
426
451
  "use_sigmoid_attn": self._use_sigmoid_attn,
427
452
  "sigmoid_attn_bias": self._sigmoid_attn_bias,
453
+ "linformer_dim": self._linformer_dim,
454
+ "linformer_shared_kv_projection": self._linformer_shared_kv_projection,
428
455
  }
429
456
  base_config = super().get_config()
430
457
  return dict(list(base_config.items()) + list(config.items()))
@@ -480,15 +507,41 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
480
507
  if key_value is None:
481
508
  key_value = input_tensor
482
509
 
510
+ key = key_value
511
+ value = key_value
512
+ if self._linformer_dim is not None:
513
+ if attention_mask is not None:
514
+ # Applying mask before the low rank factorization so that padding is
515
+ # accounted for.
516
+ query_mask = tf.cast(attention_mask[:, :, 0], dtype=target_tensor.dtype)
517
+ target_tensor = target_tensor * tf.expand_dims(query_mask, axis=-1)
518
+ key_mask = tf.cast(attention_mask[:, 0, :], dtype=target_tensor.dtype)
519
+ key_value = key_value * tf.expand_dims(key_mask, axis=-1)
520
+ attention_mask = None
521
+ key_value = tf.transpose(key_value, [0, 2, 1])
522
+ key_value = self._lowrank_kv_projection(key_value)
523
+ if self._linformer_shared_kv_projection:
524
+ key_value = tf.transpose(key_value, [0, 2, 1])
525
+ key = key_value
526
+ value = key_value
527
+ else:
528
+ key = tf.transpose(key_value[:, :, :self._linformer_dim], [0, 2, 1])
529
+ value = tf.transpose(key_value[:, :, self._linformer_dim:], [0, 2, 1])
483
530
  if self._return_attention_scores:
484
531
  attention_output, attention_scores = self._attention_layer(
485
532
  query=target_tensor,
486
- value=key_value,
533
+ key=key,
534
+ value=value,
487
535
  attention_mask=attention_mask,
488
- return_attention_scores=True)
536
+ return_attention_scores=True,
537
+ )
489
538
  else:
490
539
  attention_output = self._attention_layer(
491
- query=target_tensor, value=key_value, attention_mask=attention_mask)
540
+ query=target_tensor,
541
+ key=key,
542
+ value=value,
543
+ attention_mask=attention_mask,
544
+ )
492
545
  attention_output = self._attention_dropout(attention_output)
493
546
 
494
547
  if self._norm_first:
@@ -800,6 +800,47 @@ class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
800
800
  output_tensor[1].shape.as_list(), expected_attention_scores_shape
801
801
  )
802
802
 
803
+ @parameterized.named_parameters(
804
+ ('unshared_kv_projection', False),
805
+ ('shared_kv_projection', True),
806
+ )
807
+ def test_low_rank_attention(self, shared_kv_projection):
808
+ num_attention_heads = 8
809
+ sequence_length = 21
810
+ linformer_dim = 7
811
+ width = 80
812
+
813
+ test_layer = TransformerEncoderBlock(
814
+ num_attention_heads=num_attention_heads,
815
+ inner_dim=2048,
816
+ inner_activation='relu',
817
+ return_attention_scores=True,
818
+ linformer_dim=linformer_dim,
819
+ linformer_shared_kv_projection=shared_kv_projection,
820
+ )
821
+ # Create a 3-dimensional input (the first dimension is implicit).
822
+ data_tensor = tf_keras.Input(shape=(sequence_length, width))
823
+ output_tensor = test_layer(data_tensor)
824
+
825
+ expected_layer_output_shape = [None, sequence_length, width]
826
+ expected_attention_scores_shape = [
827
+ None,
828
+ num_attention_heads,
829
+ sequence_length,
830
+ linformer_dim,
831
+ ]
832
+
833
+ self.assertIsInstance(output_tensor, tuple)
834
+ self.assertLen(output_tensor, 2)
835
+ # First is the standard output.
836
+ self.assertEqual(
837
+ output_tensor[0].shape.as_list(), expected_layer_output_shape
838
+ )
839
+ # Second is the attention scores.
840
+ self.assertEqual(
841
+ output_tensor[1].shape.as_list(), expected_attention_scores_shape
842
+ )
843
+
803
844
 
804
845
  if __name__ == '__main__':
805
846
  tf.test.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.18.0.dev20240916
3
+ Version: 2.18.0.dev20240918
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -363,8 +363,8 @@ official/nlp/modeling/layers/tn_expand_condense_test.py,sha256=J52mXzoiuaXfR61kh
363
363
  official/nlp/modeling/layers/tn_transformer_expand_condense.py,sha256=gbGJOrgxJd1SyMGB6ME04FSxuZfHqsi94Xxt23l7368,11032
364
364
  official/nlp/modeling/layers/tn_transformer_test.py,sha256=Fh-EDRoAkhO7ccD3w3FsJHC51MnZySv8jBlHYnvKZMc,8893
365
365
  official/nlp/modeling/layers/transformer.py,sha256=yofIEOjZpcvDmHbcjBmkZrl5iSe6pLtMsetNbXmxDnY,20087
366
- official/nlp/modeling/layers/transformer_encoder_block.py,sha256=n7_HgFjCye7ZNxzQ67CtgboDKPIE-28796Y2aW8Zk_U,22566
367
- official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=5B_h8iNweUiRJR2IH1zxFelsfhVPEJJ4dEzL_pHPjI0,30968
366
+ official/nlp/modeling/layers/transformer_encoder_block.py,sha256=kiCQ4yGejmwRsJBKpmrwA1As4rFUekNYf9xGS052kyU,24766
367
+ official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=cIunagl03W1tPkkt1BDVpGEpd-7ZwCqc3sPdzQOmpuc,32269
368
368
  official/nlp/modeling/layers/transformer_scaffold.py,sha256=m8TF4geBkm8-VJQiTpzMI6FSJZry6oa2vPO3FXCCClE,15704
369
369
  official/nlp/modeling/layers/transformer_scaffold_test.py,sha256=pqUGldhmAKROrd4eoCWmHNtKOdCO6PH_-EigcYnvIpE,19920
370
370
  official/nlp/modeling/layers/transformer_test.py,sha256=kC_9NcLbJnBbuTaE_7BW60EF8xG_QUoICj0t0gS7O4Q,5522
@@ -1222,9 +1222,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
1222
1222
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1223
1223
  tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
1224
1224
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1225
- tf_models_nightly-2.18.0.dev20240916.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1226
- tf_models_nightly-2.18.0.dev20240916.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1227
- tf_models_nightly-2.18.0.dev20240916.dist-info/METADATA,sha256=3n7Gfhr0DnLjSH7idnMTxIGrR_G8evj10Yp7riWCjTo,1432
1228
- tf_models_nightly-2.18.0.dev20240916.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1229
- tf_models_nightly-2.18.0.dev20240916.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1230
- tf_models_nightly-2.18.0.dev20240916.dist-info/RECORD,,
1225
+ tf_models_nightly-2.18.0.dev20240918.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1226
+ tf_models_nightly-2.18.0.dev20240918.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1227
+ tf_models_nightly-2.18.0.dev20240918.dist-info/METADATA,sha256=krnAS7Dd_7oQqXP0my3C8NRwFzAv5zuclh6kFBkRxqw,1432
1228
+ tf_models_nightly-2.18.0.dev20240918.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1229
+ tf_models_nightly-2.18.0.dev20240918.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1230
+ tf_models_nightly-2.18.0.dev20240918.dist-info/RECORD,,