tf-models-nightly 2.18.0.dev20240916__py2.py3-none-any.whl → 2.18.0.dev20240918__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- official/nlp/modeling/layers/transformer_encoder_block.py +60 -7
- official/nlp/modeling/layers/transformer_encoder_block_test.py +41 -0
- {tf_models_nightly-2.18.0.dev20240916.dist-info → tf_models_nightly-2.18.0.dev20240918.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.18.0.dev20240916.dist-info → tf_models_nightly-2.18.0.dev20240918.dist-info}/RECORD +8 -8
- {tf_models_nightly-2.18.0.dev20240916.dist-info → tf_models_nightly-2.18.0.dev20240918.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.18.0.dev20240916.dist-info → tf_models_nightly-2.18.0.dev20240918.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.18.0.dev20240916.dist-info → tf_models_nightly-2.18.0.dev20240918.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.18.0.dev20240916.dist-info → tf_models_nightly-2.18.0.dev20240918.dist-info}/top_level.txt +0 -0
@@ -114,6 +114,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
114
114
|
tgt_block_size=None,
|
115
115
|
use_sigmoid_attn=False,
|
116
116
|
sigmoid_attn_bias=None,
|
117
|
+
linformer_dim=None,
|
118
|
+
linformer_shared_kv_projection=True,
|
117
119
|
**kwargs):
|
118
120
|
"""Initializes `TransformerEncoderBlock`.
|
119
121
|
|
@@ -191,6 +193,10 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
191
193
|
`block_sparse_attention.MultiHeadAttention`
|
192
194
|
sigmoid_attn_bias: This param is only used in
|
193
195
|
`block_sparse_attention.MultiHeadAttention`
|
196
|
+
linformer_dim: Applies low-rank factorization on keys/values as in
|
197
|
+
https://arxiv.org/pdf/2006.04768.
|
198
|
+
linformer_shared_kv_projection: If set, projection layer is shared for
|
199
|
+
keys and values.
|
194
200
|
**kwargs: keyword arguments.
|
195
201
|
"""
|
196
202
|
util.filter_kwargs(kwargs)
|
@@ -230,6 +236,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
230
236
|
self._tgt_block_size = tgt_block_size
|
231
237
|
self._use_sigmoid_attn = use_sigmoid_attn
|
232
238
|
self._sigmoid_attn_bias = sigmoid_attn_bias
|
239
|
+
self._linformer_dim = linformer_dim
|
240
|
+
self._linformer_shared_kv_projection = linformer_shared_kv_projection
|
233
241
|
if self._num_kv_heads is not None and self._src_block_size is not None:
|
234
242
|
raise ValueError(
|
235
243
|
"Block sparse attention does not support Multi-query attention."
|
@@ -366,16 +374,33 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
366
374
|
name="output",
|
367
375
|
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
|
368
376
|
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
|
369
|
-
**common_kwargs
|
377
|
+
**common_kwargs,
|
378
|
+
)
|
370
379
|
self._output_dropout = tf_keras.layers.Dropout(
|
371
|
-
rate=self._output_dropout_rate
|
380
|
+
rate=self._output_dropout_rate
|
381
|
+
)
|
372
382
|
# Use float32 in layernorm for numeric stability.
|
373
383
|
self._output_layer_norm = tf_keras.layers.LayerNormalization(
|
374
384
|
name="output_layer_norm",
|
375
385
|
axis=-1,
|
376
386
|
epsilon=self._norm_epsilon,
|
377
|
-
dtype=tf.float32
|
378
|
-
|
387
|
+
dtype=tf.float32,
|
388
|
+
)
|
389
|
+
if self._linformer_dim is not None:
|
390
|
+
if self._linformer_shared_kv_projection:
|
391
|
+
low_rank_dim = self._linformer_dim
|
392
|
+
else:
|
393
|
+
low_rank_dim = 2 * self._linformer_dim
|
394
|
+
self._lowrank_kv_projection = tf_keras.layers.EinsumDense(
|
395
|
+
"...bc,cd->...bd",
|
396
|
+
output_shape=(None, low_rank_dim),
|
397
|
+
kernel_initializer=tf_utils.clone_initializer(
|
398
|
+
self._kernel_initializer
|
399
|
+
),
|
400
|
+
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
|
401
|
+
name="lowrank_kv_projection",
|
402
|
+
**common_kwargs,
|
403
|
+
)
|
379
404
|
super().build(input_shape)
|
380
405
|
|
381
406
|
def get_config(self):
|
@@ -425,6 +450,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
425
450
|
"tgt_block_size": self._tgt_block_size,
|
426
451
|
"use_sigmoid_attn": self._use_sigmoid_attn,
|
427
452
|
"sigmoid_attn_bias": self._sigmoid_attn_bias,
|
453
|
+
"linformer_dim": self._linformer_dim,
|
454
|
+
"linformer_shared_kv_projection": self._linformer_shared_kv_projection,
|
428
455
|
}
|
429
456
|
base_config = super().get_config()
|
430
457
|
return dict(list(base_config.items()) + list(config.items()))
|
@@ -480,15 +507,41 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
|
|
480
507
|
if key_value is None:
|
481
508
|
key_value = input_tensor
|
482
509
|
|
510
|
+
key = key_value
|
511
|
+
value = key_value
|
512
|
+
if self._linformer_dim is not None:
|
513
|
+
if attention_mask is not None:
|
514
|
+
# Applying mask before the low rank factorization so that padding is
|
515
|
+
# accounted for.
|
516
|
+
query_mask = tf.cast(attention_mask[:, :, 0], dtype=target_tensor.dtype)
|
517
|
+
target_tensor = target_tensor * tf.expand_dims(query_mask, axis=-1)
|
518
|
+
key_mask = tf.cast(attention_mask[:, 0, :], dtype=target_tensor.dtype)
|
519
|
+
key_value = key_value * tf.expand_dims(key_mask, axis=-1)
|
520
|
+
attention_mask = None
|
521
|
+
key_value = tf.transpose(key_value, [0, 2, 1])
|
522
|
+
key_value = self._lowrank_kv_projection(key_value)
|
523
|
+
if self._linformer_shared_kv_projection:
|
524
|
+
key_value = tf.transpose(key_value, [0, 2, 1])
|
525
|
+
key = key_value
|
526
|
+
value = key_value
|
527
|
+
else:
|
528
|
+
key = tf.transpose(key_value[:, :, :self._linformer_dim], [0, 2, 1])
|
529
|
+
value = tf.transpose(key_value[:, :, self._linformer_dim:], [0, 2, 1])
|
483
530
|
if self._return_attention_scores:
|
484
531
|
attention_output, attention_scores = self._attention_layer(
|
485
532
|
query=target_tensor,
|
486
|
-
|
533
|
+
key=key,
|
534
|
+
value=value,
|
487
535
|
attention_mask=attention_mask,
|
488
|
-
return_attention_scores=True
|
536
|
+
return_attention_scores=True,
|
537
|
+
)
|
489
538
|
else:
|
490
539
|
attention_output = self._attention_layer(
|
491
|
-
query=target_tensor,
|
540
|
+
query=target_tensor,
|
541
|
+
key=key,
|
542
|
+
value=value,
|
543
|
+
attention_mask=attention_mask,
|
544
|
+
)
|
492
545
|
attention_output = self._attention_dropout(attention_output)
|
493
546
|
|
494
547
|
if self._norm_first:
|
@@ -800,6 +800,47 @@ class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
|
|
800
800
|
output_tensor[1].shape.as_list(), expected_attention_scores_shape
|
801
801
|
)
|
802
802
|
|
803
|
+
@parameterized.named_parameters(
|
804
|
+
('unshared_kv_projection', False),
|
805
|
+
('shared_kv_projection', True),
|
806
|
+
)
|
807
|
+
def test_low_rank_attention(self, shared_kv_projection):
|
808
|
+
num_attention_heads = 8
|
809
|
+
sequence_length = 21
|
810
|
+
linformer_dim = 7
|
811
|
+
width = 80
|
812
|
+
|
813
|
+
test_layer = TransformerEncoderBlock(
|
814
|
+
num_attention_heads=num_attention_heads,
|
815
|
+
inner_dim=2048,
|
816
|
+
inner_activation='relu',
|
817
|
+
return_attention_scores=True,
|
818
|
+
linformer_dim=linformer_dim,
|
819
|
+
linformer_shared_kv_projection=shared_kv_projection,
|
820
|
+
)
|
821
|
+
# Create a 3-dimensional input (the first dimension is implicit).
|
822
|
+
data_tensor = tf_keras.Input(shape=(sequence_length, width))
|
823
|
+
output_tensor = test_layer(data_tensor)
|
824
|
+
|
825
|
+
expected_layer_output_shape = [None, sequence_length, width]
|
826
|
+
expected_attention_scores_shape = [
|
827
|
+
None,
|
828
|
+
num_attention_heads,
|
829
|
+
sequence_length,
|
830
|
+
linformer_dim,
|
831
|
+
]
|
832
|
+
|
833
|
+
self.assertIsInstance(output_tensor, tuple)
|
834
|
+
self.assertLen(output_tensor, 2)
|
835
|
+
# First is the standard output.
|
836
|
+
self.assertEqual(
|
837
|
+
output_tensor[0].shape.as_list(), expected_layer_output_shape
|
838
|
+
)
|
839
|
+
# Second is the attention scores.
|
840
|
+
self.assertEqual(
|
841
|
+
output_tensor[1].shape.as_list(), expected_attention_scores_shape
|
842
|
+
)
|
843
|
+
|
803
844
|
|
804
845
|
if __name__ == '__main__':
|
805
846
|
tf.test.main()
|
@@ -363,8 +363,8 @@ official/nlp/modeling/layers/tn_expand_condense_test.py,sha256=J52mXzoiuaXfR61kh
|
|
363
363
|
official/nlp/modeling/layers/tn_transformer_expand_condense.py,sha256=gbGJOrgxJd1SyMGB6ME04FSxuZfHqsi94Xxt23l7368,11032
|
364
364
|
official/nlp/modeling/layers/tn_transformer_test.py,sha256=Fh-EDRoAkhO7ccD3w3FsJHC51MnZySv8jBlHYnvKZMc,8893
|
365
365
|
official/nlp/modeling/layers/transformer.py,sha256=yofIEOjZpcvDmHbcjBmkZrl5iSe6pLtMsetNbXmxDnY,20087
|
366
|
-
official/nlp/modeling/layers/transformer_encoder_block.py,sha256=
|
367
|
-
official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=
|
366
|
+
official/nlp/modeling/layers/transformer_encoder_block.py,sha256=kiCQ4yGejmwRsJBKpmrwA1As4rFUekNYf9xGS052kyU,24766
|
367
|
+
official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=cIunagl03W1tPkkt1BDVpGEpd-7ZwCqc3sPdzQOmpuc,32269
|
368
368
|
official/nlp/modeling/layers/transformer_scaffold.py,sha256=m8TF4geBkm8-VJQiTpzMI6FSJZry6oa2vPO3FXCCClE,15704
|
369
369
|
official/nlp/modeling/layers/transformer_scaffold_test.py,sha256=pqUGldhmAKROrd4eoCWmHNtKOdCO6PH_-EigcYnvIpE,19920
|
370
370
|
official/nlp/modeling/layers/transformer_test.py,sha256=kC_9NcLbJnBbuTaE_7BW60EF8xG_QUoICj0t0gS7O4Q,5522
|
@@ -1222,9 +1222,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
|
|
1222
1222
|
tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
|
1223
1223
|
tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
|
1224
1224
|
tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
|
1225
|
-
tf_models_nightly-2.18.0.
|
1226
|
-
tf_models_nightly-2.18.0.
|
1227
|
-
tf_models_nightly-2.18.0.
|
1228
|
-
tf_models_nightly-2.18.0.
|
1229
|
-
tf_models_nightly-2.18.0.
|
1230
|
-
tf_models_nightly-2.18.0.
|
1225
|
+
tf_models_nightly-2.18.0.dev20240918.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
|
1226
|
+
tf_models_nightly-2.18.0.dev20240918.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
|
1227
|
+
tf_models_nightly-2.18.0.dev20240918.dist-info/METADATA,sha256=krnAS7Dd_7oQqXP0my3C8NRwFzAv5zuclh6kFBkRxqw,1432
|
1228
|
+
tf_models_nightly-2.18.0.dev20240918.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
|
1229
|
+
tf_models_nightly-2.18.0.dev20240918.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
|
1230
|
+
tf_models_nightly-2.18.0.dev20240918.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|