tf-models-nightly 2.18.0.dev20240917__py2.py3-none-any.whl → 2.18.0.dev20240919__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/nlp/modeling/layers/transformer_encoder_block.py +29 -8
- official/nlp/modeling/layers/transformer_encoder_block_test.py +6 -1
- {tf_models_nightly-2.18.0.dev20240917.dist-info → tf_models_nightly-2.18.0.dev20240919.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.18.0.dev20240917.dist-info → tf_models_nightly-2.18.0.dev20240919.dist-info}/RECORD +8 -8
- {tf_models_nightly-2.18.0.dev20240917.dist-info → tf_models_nightly-2.18.0.dev20240919.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.18.0.dev20240917.dist-info → tf_models_nightly-2.18.0.dev20240919.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.18.0.dev20240917.dist-info → tf_models_nightly-2.18.0.dev20240919.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.18.0.dev20240917.dist-info → tf_models_nightly-2.18.0.dev20240919.dist-info}/top_level.txt +0 -0
@@ -115,6 +115,7 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
115
115
|
use_sigmoid_attn=False,
|
116
116
|
sigmoid_attn_bias=None,
|
117
117
|
linformer_dim=None,
|
118
|
+
linformer_shared_kv_projection=True,
|
118
119
|
**kwargs):
|
119
120
|
"""Initializes `TransformerEncoderBlock`.
|
120
121
|
|
@@ -194,6 +195,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
194
195
|
`block_sparse_attention.MultiHeadAttention`
|
195
196
|
linformer_dim: Applies low-rank factorization on keys/values as in
|
196
197
|
https://arxiv.org/pdf/2006.04768.
|
198
|
+
linformer_shared_kv_projection: If set, projection layer is shared for
|
199
|
+
keys and values.
|
197
200
|
**kwargs: keyword arguments.
|
198
201
|
"""
|
199
202
|
util.filter_kwargs(kwargs)
|
@@ -234,6 +237,7 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
234
237
|
self._use_sigmoid_attn = use_sigmoid_attn
|
235
238
|
self._sigmoid_attn_bias = sigmoid_attn_bias
|
236
239
|
self._linformer_dim = linformer_dim
|
240
|
+
self._linformer_shared_kv_projection = linformer_shared_kv_projection
|
237
241
|
if self._num_kv_heads is not None and self._src_block_size is not None:
|
238
242
|
raise ValueError(
|
239
243
|
"Block sparse attention does not support Multi-query attention."
|
@@ -383,11 +387,13 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
383
387
|
dtype=tf.float32,
|
384
388
|
)
|
385
389
|
if self._linformer_dim is not None:
|
386
|
-
|
387
|
-
|
390
|
+
if self._linformer_shared_kv_projection:
|
391
|
+
low_rank_dim = self._linformer_dim
|
392
|
+
else:
|
393
|
+
low_rank_dim = 2 * self._linformer_dim
|
388
394
|
self._lowrank_kv_projection = tf_keras.layers.EinsumDense(
|
389
395
|
"...bc,cd->...bd",
|
390
|
-
output_shape=(None,
|
396
|
+
output_shape=(None, low_rank_dim),
|
391
397
|
kernel_initializer=tf_utils.clone_initializer(
|
392
398
|
self._kernel_initializer
|
393
399
|
),
|
@@ -444,6 +450,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
444
450
|
"tgt_block_size": self._tgt_block_size,
|
445
451
|
"use_sigmoid_attn": self._use_sigmoid_attn,
|
446
452
|
"sigmoid_attn_bias": self._sigmoid_attn_bias,
|
453
|
+
"linformer_dim": self._linformer_dim,
|
454
|
+
"linformer_shared_kv_projection": self._linformer_shared_kv_projection,
|
447
455
|
}
|
448
456
|
base_config = super().get_config()
|
449
457
|
return dict(list(base_config.items()) + list(config.items()))
|
@@ -499,6 +507,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
499
507
|
if key_value is None:
|
500
508
|
key_value = input_tensor
|
501
509
|
|
510
|
+
key = key_value
|
511
|
+
value = key_value
|
502
512
|
if self._linformer_dim is not None:
|
503
513
|
if attention_mask is not None:
|
504
514
|
# Applying mask before the low rank factorization so that padding is
|
@@ -510,17 +520,28 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
510
520
|
attention_mask = None
|
511
521
|
key_value = tf.transpose(key_value, [0, 2, 1])
|
512
522
|
key_value = self._lowrank_kv_projection(key_value)
|
513
|
-
|
514
|
-
|
523
|
+
if self._linformer_shared_kv_projection:
|
524
|
+
key_value = tf.transpose(key_value, [0, 2, 1])
|
525
|
+
key = key_value
|
526
|
+
value = key_value
|
527
|
+
else:
|
528
|
+
key = tf.transpose(key_value[:, :, :self._linformer_dim], [0, 2, 1])
|
529
|
+
value = tf.transpose(key_value[:, :, self._linformer_dim:], [0, 2, 1])
|
515
530
|
if self._return_attention_scores:
|
516
531
|
attention_output, attention_scores = self._attention_layer(
|
517
532
|
query=target_tensor,
|
518
|
-
|
533
|
+
key=key,
|
534
|
+
value=value,
|
519
535
|
attention_mask=attention_mask,
|
520
|
-
return_attention_scores=True
|
536
|
+
return_attention_scores=True,
|
537
|
+
)
|
521
538
|
else:
|
522
539
|
attention_output = self._attention_layer(
|
523
|
-
query=target_tensor,
|
540
|
+
query=target_tensor,
|
541
|
+
key=key,
|
542
|
+
value=value,
|
543
|
+
attention_mask=attention_mask,
|
544
|
+
)
|
524
545
|
attention_output = self._attention_dropout(attention_output)
|
525
546
|
|
526
547
|
if self._norm_first:
|
@@ -800,7 +800,11 @@ class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
|
|
800
800
|
output_tensor[1].shape.as_list(), expected_attention_scores_shape
|
801
801
|
)
|
802
802
|
|
803
|
-
|
803
|
+
@parameterized.named_parameters(
|
804
|
+
('unshared_kv_projection', False),
|
805
|
+
('shared_kv_projection', True),
|
806
|
+
)
|
807
|
+
def test_low_rank_attention(self, shared_kv_projection):
|
804
808
|
num_attention_heads = 8
|
805
809
|
sequence_length = 21
|
806
810
|
linformer_dim = 7
|
@@ -812,6 +816,7 @@ class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
|
|
812
816
|
inner_activation='relu',
|
813
817
|
return_attention_scores=True,
|
814
818
|
linformer_dim=linformer_dim,
|
819
|
+
linformer_shared_kv_projection=shared_kv_projection,
|
815
820
|
)
|
816
821
|
# Create a 3-dimensional input (the first dimension is implicit).
|
817
822
|
data_tensor = tf_keras.Input(shape=(sequence_length, width))
|
@@ -363,8 +363,8 @@ official/nlp/modeling/layers/tn_expand_condense_test.py,sha256=J52mXzoiuaXfR61kh
|
|
363
363
|
official/nlp/modeling/layers/tn_transformer_expand_condense.py,sha256=gbGJOrgxJd1SyMGB6ME04FSxuZfHqsi94Xxt23l7368,11032
|
364
364
|
official/nlp/modeling/layers/tn_transformer_test.py,sha256=Fh-EDRoAkhO7ccD3w3FsJHC51MnZySv8jBlHYnvKZMc,8893
|
365
365
|
official/nlp/modeling/layers/transformer.py,sha256=yofIEOjZpcvDmHbcjBmkZrl5iSe6pLtMsetNbXmxDnY,20087
|
366
|
-
official/nlp/modeling/layers/transformer_encoder_block.py,sha256=
|
367
|
-
official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=
|
366
|
+
official/nlp/modeling/layers/transformer_encoder_block.py,sha256=kiCQ4yGejmwRsJBKpmrwA1As4rFUekNYf9xGS052kyU,24766
|
367
|
+
official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=cIunagl03W1tPkkt1BDVpGEpd-7ZwCqc3sPdzQOmpuc,32269
|
368
368
|
official/nlp/modeling/layers/transformer_scaffold.py,sha256=m8TF4geBkm8-VJQiTpzMI6FSJZry6oa2vPO3FXCCClE,15704
|
369
369
|
official/nlp/modeling/layers/transformer_scaffold_test.py,sha256=pqUGldhmAKROrd4eoCWmHNtKOdCO6PH_-EigcYnvIpE,19920
|
370
370
|
official/nlp/modeling/layers/transformer_test.py,sha256=kC_9NcLbJnBbuTaE_7BW60EF8xG_QUoICj0t0gS7O4Q,5522
|
@@ -1222,9 +1222,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
|
|
1222
1222
|
tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
|
1223
1223
|
tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
|
1224
1224
|
tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
|
1225
|
-
tf_models_nightly-2.18.0.
|
1226
|
-
tf_models_nightly-2.18.0.
|
1227
|
-
tf_models_nightly-2.18.0.
|
1228
|
-
tf_models_nightly-2.18.0.
|
1229
|
-
tf_models_nightly-2.18.0.
|
1230
|
-
tf_models_nightly-2.18.0.
|
1225
|
+
tf_models_nightly-2.18.0.dev20240919.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
|
1226
|
+
tf_models_nightly-2.18.0.dev20240919.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
|
1227
|
+
tf_models_nightly-2.18.0.dev20240919.dist-info/METADATA,sha256=z_bABqB1Cm6qU18Rku98ShkjheDYP-JbtDnzitB16cw,1432
|
1228
|
+
tf_models_nightly-2.18.0.dev20240919.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
|
1229
|
+
tf_models_nightly-2.18.0.dev20240919.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
|
1230
|
+
tf_models_nightly-2.18.0.dev20240919.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|