tf-models-nightly 2.18.0.dev20240917__py2.py3-none-any.whl → 2.18.0.dev20240919__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -115,6 +115,7 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
115
115
  use_sigmoid_attn=False,
116
116
  sigmoid_attn_bias=None,
117
117
  linformer_dim=None,
118
+ linformer_shared_kv_projection=True,
118
119
  **kwargs):
119
120
  """Initializes `TransformerEncoderBlock`.
120
121
 
@@ -194,6 +195,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
194
195
  `block_sparse_attention.MultiHeadAttention`
195
196
  linformer_dim: Applies low-rank factorization on keys/values as in
196
197
  https://arxiv.org/pdf/2006.04768.
198
+ linformer_shared_kv_projection: If set, projection layer is shared for
199
+ keys and values.
197
200
  **kwargs: keyword arguments.
198
201
  """
199
202
  util.filter_kwargs(kwargs)
@@ -234,6 +237,7 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
234
237
  self._use_sigmoid_attn = use_sigmoid_attn
235
238
  self._sigmoid_attn_bias = sigmoid_attn_bias
236
239
  self._linformer_dim = linformer_dim
240
+ self._linformer_shared_kv_projection = linformer_shared_kv_projection
237
241
  if self._num_kv_heads is not None and self._src_block_size is not None:
238
242
  raise ValueError(
239
243
  "Block sparse attention does not support Multi-query attention."
@@ -383,11 +387,13 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
383
387
  dtype=tf.float32,
384
388
  )
385
389
  if self._linformer_dim is not None:
386
- # Current implementation uses the same weights for keys and values.
387
- # TODO(akandoor): Explore using different weights for keys and values.
390
+ if self._linformer_shared_kv_projection:
391
+ low_rank_dim = self._linformer_dim
392
+ else:
393
+ low_rank_dim = 2 * self._linformer_dim
388
394
  self._lowrank_kv_projection = tf_keras.layers.EinsumDense(
389
395
  "...bc,cd->...bd",
390
- output_shape=(None, self._linformer_dim),
396
+ output_shape=(None, low_rank_dim),
391
397
  kernel_initializer=tf_utils.clone_initializer(
392
398
  self._kernel_initializer
393
399
  ),
@@ -444,6 +450,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
444
450
  "tgt_block_size": self._tgt_block_size,
445
451
  "use_sigmoid_attn": self._use_sigmoid_attn,
446
452
  "sigmoid_attn_bias": self._sigmoid_attn_bias,
453
+ "linformer_dim": self._linformer_dim,
454
+ "linformer_shared_kv_projection": self._linformer_shared_kv_projection,
447
455
  }
448
456
  base_config = super().get_config()
449
457
  return dict(list(base_config.items()) + list(config.items()))
@@ -499,6 +507,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
499
507
  if key_value is None:
500
508
  key_value = input_tensor
501
509
 
510
+ key = key_value
511
+ value = key_value
502
512
  if self._linformer_dim is not None:
503
513
  if attention_mask is not None:
504
514
  # Applying mask before the low rank factorization so that padding is
@@ -510,17 +520,28 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
510
520
  attention_mask = None
511
521
  key_value = tf.transpose(key_value, [0, 2, 1])
512
522
  key_value = self._lowrank_kv_projection(key_value)
513
- key_value = tf.transpose(key_value, [0, 2, 1])
514
-
523
+ if self._linformer_shared_kv_projection:
524
+ key_value = tf.transpose(key_value, [0, 2, 1])
525
+ key = key_value
526
+ value = key_value
527
+ else:
528
+ key = tf.transpose(key_value[:, :, :self._linformer_dim], [0, 2, 1])
529
+ value = tf.transpose(key_value[:, :, self._linformer_dim:], [0, 2, 1])
515
530
  if self._return_attention_scores:
516
531
  attention_output, attention_scores = self._attention_layer(
517
532
  query=target_tensor,
518
- value=key_value,
533
+ key=key,
534
+ value=value,
519
535
  attention_mask=attention_mask,
520
- return_attention_scores=True)
536
+ return_attention_scores=True,
537
+ )
521
538
  else:
522
539
  attention_output = self._attention_layer(
523
- query=target_tensor, value=key_value, attention_mask=attention_mask)
540
+ query=target_tensor,
541
+ key=key,
542
+ value=value,
543
+ attention_mask=attention_mask,
544
+ )
524
545
  attention_output = self._attention_dropout(attention_output)
525
546
 
526
547
  if self._norm_first:
@@ -800,7 +800,11 @@ class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
800
800
  output_tensor[1].shape.as_list(), expected_attention_scores_shape
801
801
  )
802
802
 
803
- def test_low_rank_attention(self):
803
+ @parameterized.named_parameters(
804
+ ('unshared_kv_projection', False),
805
+ ('shared_kv_projection', True),
806
+ )
807
+ def test_low_rank_attention(self, shared_kv_projection):
804
808
  num_attention_heads = 8
805
809
  sequence_length = 21
806
810
  linformer_dim = 7
@@ -812,6 +816,7 @@ class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
812
816
  inner_activation='relu',
813
817
  return_attention_scores=True,
814
818
  linformer_dim=linformer_dim,
819
+ linformer_shared_kv_projection=shared_kv_projection,
815
820
  )
816
821
  # Create a 3-dimensional input (the first dimension is implicit).
817
822
  data_tensor = tf_keras.Input(shape=(sequence_length, width))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.18.0.dev20240917
3
+ Version: 2.18.0.dev20240919
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -363,8 +363,8 @@ official/nlp/modeling/layers/tn_expand_condense_test.py,sha256=J52mXzoiuaXfR61kh
363
363
  official/nlp/modeling/layers/tn_transformer_expand_condense.py,sha256=gbGJOrgxJd1SyMGB6ME04FSxuZfHqsi94Xxt23l7368,11032
364
364
  official/nlp/modeling/layers/tn_transformer_test.py,sha256=Fh-EDRoAkhO7ccD3w3FsJHC51MnZySv8jBlHYnvKZMc,8893
365
365
  official/nlp/modeling/layers/transformer.py,sha256=yofIEOjZpcvDmHbcjBmkZrl5iSe6pLtMsetNbXmxDnY,20087
366
- official/nlp/modeling/layers/transformer_encoder_block.py,sha256=dxUCn9LckIJCpxJ8DRmiAU-4ycCmddXBTifLdihDmiU,24047
367
- official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=nbrfktOe0_WNhYYk0IlEJPf5d-9xtBoi2wDdO_FWF_k,32068
366
+ official/nlp/modeling/layers/transformer_encoder_block.py,sha256=kiCQ4yGejmwRsJBKpmrwA1As4rFUekNYf9xGS052kyU,24766
367
+ official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=cIunagl03W1tPkkt1BDVpGEpd-7ZwCqc3sPdzQOmpuc,32269
368
368
  official/nlp/modeling/layers/transformer_scaffold.py,sha256=m8TF4geBkm8-VJQiTpzMI6FSJZry6oa2vPO3FXCCClE,15704
369
369
  official/nlp/modeling/layers/transformer_scaffold_test.py,sha256=pqUGldhmAKROrd4eoCWmHNtKOdCO6PH_-EigcYnvIpE,19920
370
370
  official/nlp/modeling/layers/transformer_test.py,sha256=kC_9NcLbJnBbuTaE_7BW60EF8xG_QUoICj0t0gS7O4Q,5522
@@ -1222,9 +1222,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
1222
1222
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1223
1223
  tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
1224
1224
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1225
- tf_models_nightly-2.18.0.dev20240917.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1226
- tf_models_nightly-2.18.0.dev20240917.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1227
- tf_models_nightly-2.18.0.dev20240917.dist-info/METADATA,sha256=l3aNnMlgUyV26Zw7dWix41njTiJ4a6o7gA-8SiI6Qq4,1432
1228
- tf_models_nightly-2.18.0.dev20240917.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1229
- tf_models_nightly-2.18.0.dev20240917.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1230
- tf_models_nightly-2.18.0.dev20240917.dist-info/RECORD,,
1225
+ tf_models_nightly-2.18.0.dev20240919.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1226
+ tf_models_nightly-2.18.0.dev20240919.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1227
+ tf_models_nightly-2.18.0.dev20240919.dist-info/METADATA,sha256=z_bABqB1Cm6qU18Rku98ShkjheDYP-JbtDnzitB16cw,1432
1228
+ tf_models_nightly-2.18.0.dev20240919.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1229
+ tf_models_nightly-2.18.0.dev20240919.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1230
+ tf_models_nightly-2.18.0.dev20240919.dist-info/RECORD,,