tf-models-nightly 2.18.0.dev20241021__py2.py3-none-any.whl → 2.18.0.dev20241022__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -48,6 +48,7 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
48
48
  tgt_block_size=None,
49
49
  use_sigmoid_attn=False,
50
50
  sigmoid_attn_bias=None,
51
+ num_kv_heads=None,
51
52
  **kwargs
52
53
  ):
53
54
  """Initializes the block sparse attention layer.
@@ -61,6 +62,8 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
61
62
  use_sigmoid_attn: If enabled, uses sigmoid instead of softmax to compute
62
63
  attn probs. https://arxiv.org/pdf/2409.04431
63
64
  sigmoid_attn_bias: Bias for sigmoid attn. Suggested value -ln(seq_len).
65
+ num_kv_heads: Number of key/value heads in the multi-head self attention.
66
+ Refer to multi_query_attention.py for more details.
64
67
  **kwargs: Args passed to the base class.
65
68
  """
66
69
  super().__init__(**kwargs)
@@ -68,6 +71,11 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
68
71
  raise ValueError("src_block_size must be specified.")
69
72
  self._src_block_size = src_block_size
70
73
  self._tgt_block_size = tgt_block_size or self._src_block_size
74
+ self._num_kv_heads = num_kv_heads
75
+ if num_kv_heads is not None and num_kv_heads != 1:
76
+ raise ValueError(
77
+ "num_kv_heads must be 1. Grouped-query attention is not supported."
78
+ )
71
79
  self._use_sigmoid_attn = use_sigmoid_attn
72
80
  self._sigmoid_attn_bias = sigmoid_attn_bias
73
81
  if self._use_sigmoid_attn:
@@ -117,22 +125,50 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
117
125
  name="query",
118
126
  **self._get_common_kwargs_for_sublayer(),
119
127
  )
120
- self._key_dense = tf_keras.layers.EinsumDense(
121
- proj_einsum_eqn,
122
- output_shape=qk_output_shape,
123
- bias_axes=bias_axes if self._use_bias else None,
124
- name="key",
125
- **self._get_common_kwargs_for_sublayer(),
126
- )
127
- self._value_dense = tf_keras.layers.EinsumDense(
128
- proj_einsum_eqn,
129
- output_shape=v_output_shape,
130
- bias_axes=bias_axes if self._use_bias else None,
131
- name="value",
132
- **self._get_common_kwargs_for_sublayer(),
133
- )
134
- self._dot_product_equation = "BNLsH,BNLtH->BNLts"
135
- self._combine_equation = "BNLts,BNLsH->BNLtH"
128
+ if self._num_kv_heads == 1:
129
+ self._key_dense = tf_keras.layers.EinsumDense(
130
+ "BTD,DH->BTH",
131
+ output_shape=[None, self._key_dim],
132
+ bias_axes="H" if self._use_bias else None,
133
+ name="key",
134
+ **self._get_common_kwargs_for_sublayer(),
135
+ )
136
+ self._value_dense = tf_keras.layers.EinsumDense(
137
+ "BTD,DH->BTH",
138
+ output_shape=[None, self._value_dim],
139
+ bias_axes="H" if self._use_bias else None,
140
+ name="value",
141
+ **self._get_common_kwargs_for_sublayer(),
142
+ )
143
+ else:
144
+ self._key_dense = tf_keras.layers.EinsumDense(
145
+ proj_einsum_eqn,
146
+ output_shape=qk_output_shape,
147
+ bias_axes=bias_axes if self._use_bias else None,
148
+ name="key",
149
+ **self._get_common_kwargs_for_sublayer(),
150
+ )
151
+ self._value_dense = tf_keras.layers.EinsumDense(
152
+ proj_einsum_eqn,
153
+ output_shape=v_output_shape,
154
+ bias_axes=bias_axes if self._use_bias else None,
155
+ name="value",
156
+ **self._get_common_kwargs_for_sublayer(),
157
+ )
158
+ if self._key_shape[-2] == self._tgt_block_size:
159
+ if self._num_kv_heads == 1:
160
+ self._dot_product_equation = "BsH,BNLtH->BNLts"
161
+ self._combine_equation = "BNLts,BsH->BNLtH"
162
+ else:
163
+ self._dot_product_equation = "BNsH,BNLtH->BNLts"
164
+ self._combine_equation = "BNLts,BNsH->BNLtH"
165
+ else:
166
+ if self._num_kv_heads == 1:
167
+ self._dot_product_equation = "BLsH,BNLtH->BNLts"
168
+ self._combine_equation = "BNLts,BLsH->BNLtH"
169
+ else:
170
+ self._dot_product_equation = "BNLsH,BNLtH->BNLts"
171
+ self._combine_equation = "BNLts,BNLsH->BNLtH"
136
172
  if self._output_shape:
137
173
  if not isinstance(self._output_shape, collections.abc.Sized):
138
174
  output_shape = [self._output_shape]
@@ -153,17 +189,25 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
153
189
  """Converts the attention mask to block diagonal."""
154
190
  # Uses the same key mask for the entire query sequence since softmax
155
191
  # is applied only on the key axis.
156
- attention_mask = tf.cast(attention_mask[:, 0, :], dtype=dtype)
157
192
  tgt_num_blocks = self._key_shape[-2] // self._tgt_block_size
158
- attention_mask = tf.reshape(
159
- attention_mask,
160
- [
161
- -1,
162
- tgt_num_blocks,
163
- self._tgt_block_size,
164
- ],
165
- )
166
- return tf.einsum("BLQ,BLK->BLQK", attention_mask, attention_mask)
193
+ if tgt_num_blocks == 1:
194
+ src_num_blocks = self._query_shape[-2] // self._src_block_size
195
+ result = tf.reshape(
196
+ attention_mask,
197
+ [-1, src_num_blocks, self._src_block_size, self._tgt_block_size],
198
+ )
199
+ else:
200
+ attention_mask = tf.cast(attention_mask[:, 0, :], dtype=dtype)
201
+ attention_mask = tf.reshape(
202
+ attention_mask,
203
+ [
204
+ -1,
205
+ tgt_num_blocks,
206
+ self._tgt_block_size,
207
+ ],
208
+ )
209
+ result = tf.einsum("BLQ,BLK->BLQK", attention_mask, attention_mask)
210
+ return result
167
211
 
168
212
  def _masked_softmax(self, attention_scores, attention_mask=None):
169
213
  # Normalize the attention scores to probabilities.
@@ -217,7 +261,7 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
217
261
  src_num_blocks = self._query_shape[-2] // self._src_block_size
218
262
  tgt_num_blocks = self._key_shape[-2] // self._tgt_block_size
219
263
 
220
- if src_num_blocks != tgt_num_blocks:
264
+ if src_num_blocks != tgt_num_blocks and tgt_num_blocks != 1:
221
265
  raise ValueError(
222
266
  "src_num_blocks must be equal to tgt_num_blocks."
223
267
  )
@@ -230,20 +274,37 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
230
274
  self._src_block_size,
231
275
  self._key_dim,
232
276
  ])
233
- key_blocks = tf.reshape(key, [
234
- -1,
235
- self._num_heads,
236
- tgt_num_blocks,
237
- self._tgt_block_size,
238
- self._key_dim,
239
- ])
240
- value_blocks = tf.reshape(value, [
241
- -1,
242
- self._num_heads,
243
- tgt_num_blocks,
244
- self._tgt_block_size,
245
- self._value_dim,
246
- ])
277
+ if tgt_num_blocks != 1 and self._num_kv_heads != 1:
278
+ key_blocks = tf.reshape(key, [
279
+ -1,
280
+ self._num_heads,
281
+ tgt_num_blocks,
282
+ self._tgt_block_size,
283
+ self._key_dim,
284
+ ])
285
+ value_blocks = tf.reshape(value, [
286
+ -1,
287
+ self._num_heads,
288
+ tgt_num_blocks,
289
+ self._tgt_block_size,
290
+ self._value_dim,
291
+ ])
292
+ elif tgt_num_blocks != 1 and self._num_kv_heads == 1:
293
+ key_blocks = tf.reshape(key, [
294
+ -1,
295
+ tgt_num_blocks,
296
+ self._tgt_block_size,
297
+ self._key_dim,
298
+ ])
299
+ value_blocks = tf.reshape(value, [
300
+ -1,
301
+ tgt_num_blocks,
302
+ self._tgt_block_size,
303
+ self._value_dim,
304
+ ])
305
+ else:
306
+ key_blocks = key
307
+ value_blocks = value
247
308
  if attention_mask is not None:
248
309
  attention_mask = self._block_diagonal_mask(attention_mask, key.dtype)
249
310
  # pytype: disable=attribute-error
@@ -27,9 +27,36 @@ class BlockSparseAttentionTest(tf.test.TestCase, parameterized.TestCase):
27
27
 
28
28
  @parameterized.named_parameters(
29
29
  ("key_value_same_proj", None, None, [40, 80]),
30
+ ("key_value_same_proj_mqa", None, None, [40, 80], False, 1),
31
+ ("key_value_same_proj_multi_query_blocks", None, None, [40, 80], True),
32
+ (
33
+ "key_value_same_proj_multi_query_blocks_mqa",
34
+ None,
35
+ None,
36
+ [40, 80],
37
+ True,
38
+ 1,
39
+ ),
30
40
  ("key_value_different_proj", 32, 60, [40, 60]),
41
+ ("key_value_different_proj_mqa", 32, 60, [40, 60], False, 1),
42
+ ("key_value_different_proj_multi_query_blocks", 32, 60, [40, 60], True),
43
+ (
44
+ "key_value_different_proj_multi_query_blocks_mqa",
45
+ 32,
46
+ 60,
47
+ [40, 60],
48
+ True,
49
+ 1,
50
+ ),
31
51
  )
32
- def test_non_masked_attention(self, value_dim, output_shape, output_dims):
52
+ def test_non_masked_attention(
53
+ self,
54
+ value_dim,
55
+ output_shape,
56
+ output_dims,
57
+ multi_query_blocks=False,
58
+ num_kv_heads=None,
59
+ ):
33
60
  """Test that the attention layer can be created without a mask tensor."""
34
61
  test_layer = block_sparse_attention.MultiHeadAttention(
35
62
  num_heads=12,
@@ -37,7 +64,8 @@ class BlockSparseAttentionTest(tf.test.TestCase, parameterized.TestCase):
37
64
  value_dim=value_dim,
38
65
  output_shape=output_shape,
39
66
  src_block_size=10,
40
- tgt_block_size=5,
67
+ tgt_block_size=20 if multi_query_blocks else 5,
68
+ num_kv_heads=num_kv_heads,
41
69
  )
42
70
  # Create a 3-dimensional input (the first dimension is implicit).
43
71
  query = tf_keras.Input(shape=(40, 80))
@@ -57,13 +85,24 @@ class BlockSparseAttentionTest(tf.test.TestCase, parameterized.TestCase):
57
85
 
58
86
  @parameterized.named_parameters(
59
87
  ("with_bias", True),
88
+ ("with_bias_mqa", True, False, False, 1),
89
+ ("with_bias_multi_query_blocks", True, False, True),
90
+ ("with_bias_multi_query_blocks_mqa", True, False, True, 1),
60
91
  ("no_bias", False),
92
+ ("no_bias_mqa", False, False, False, 1),
93
+ ("no_bias_multi_query_blocks", False, False, True),
94
+ ("no_bias_multi_query_blocks_mqa", False, False, True, 1),
61
95
  ("with_sigmoid_attn", True, True),
96
+ ("with_sigmoid_attn_mqa", True, True, False, 1),
97
+ ("with_sigmoid_attn_multi_query_blocks", True, True, True),
98
+ ("with_sigmoid_attn_multi_query_blocks_mqa", True, True, True, 1),
62
99
  )
63
100
  def test_masked_attention(
64
101
  self,
65
102
  use_bias,
66
103
  use_sigmoid_attn=False,
104
+ multi_query_blocks=False,
105
+ num_kv_heads=None,
67
106
  ):
68
107
  """Test with a mask tensor."""
69
108
  if use_sigmoid_attn:
@@ -75,9 +114,10 @@ class BlockSparseAttentionTest(tf.test.TestCase, parameterized.TestCase):
75
114
  key_dim=2,
76
115
  use_bias=use_bias,
77
116
  src_block_size=2,
78
- tgt_block_size=1,
117
+ tgt_block_size=2 if multi_query_blocks else 1,
79
118
  use_sigmoid_attn=use_sigmoid_attn,
80
119
  sigmoid_attn_bias=sigmoid_attn_bias,
120
+ num_kv_heads=num_kv_heads,
81
121
  )
82
122
  # Create a 3-dimensional input (the first dimension is implicit).
83
123
  batch_size = 3
@@ -238,22 +238,30 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
238
238
  self._sigmoid_attn_bias = sigmoid_attn_bias
239
239
  self._linformer_dim = linformer_dim
240
240
  self._linformer_shared_kv_projection = linformer_shared_kv_projection
241
- if self._num_kv_heads is not None and self._src_block_size is not None:
241
+ if (
242
+ self._src_block_size is not None
243
+ and self._num_kv_heads is not None
244
+ and self._num_kv_heads != 1
245
+ ):
242
246
  raise ValueError(
243
- "Block sparse attention does not support Multi-query attention."
244
- " Specify only one of them."
247
+ "Block sparse attention only supports Multi-query attention.Please"
248
+ " set num_kv_heads to 1 to enable MQA with block sparse attention."
245
249
  )
246
250
  if attention_initializer:
247
251
  self._attention_initializer = tf_keras.initializers.get(
248
- attention_initializer)
252
+ attention_initializer
253
+ )
249
254
  else:
250
255
  self._attention_initializer = tf_utils.clone_initializer(
251
- self._kernel_initializer)
256
+ self._kernel_initializer
257
+ )
252
258
  self._attention_axes = attention_axes
253
259
 
254
260
  if self._diff_q_kv_att_layer_norm and not self._norm_first:
255
- raise ValueError("Setting `diff_q_and_kv_attention_layer_norm` to True"
256
- "when `norm_first` is False is invalid.")
261
+ raise ValueError(
262
+ "Setting `diff_q_and_kv_attention_layer_norm` to True"
263
+ "when `norm_first` is False is invalid."
264
+ )
257
265
 
258
266
  def build(self, input_shape):
259
267
  if isinstance(input_shape, tf.TensorShape):
@@ -303,6 +311,7 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
303
311
  tgt_block_size=self._tgt_block_size,
304
312
  use_sigmoid_attn=self._use_sigmoid_attn,
305
313
  sigmoid_attn_bias=self._sigmoid_attn_bias,
314
+ num_kv_heads=self._num_kv_heads,
306
315
  name="block_sparse_attention",
307
316
  )
308
317
  attention_fn = block_sparse_attention.MultiHeadAttention
@@ -755,9 +755,11 @@ class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
755
755
 
756
756
  @parameterized.named_parameters(
757
757
  ('use_softmax_attn', False),
758
+ ('use_softmax_attn_mqa', False, 1),
758
759
  ('use_sigmoid_attn', True),
760
+ ('use_sigmoid_attn_mqa', True, 1),
759
761
  )
760
- def test_block_sparse_attention(self, use_sigmoid_attn):
762
+ def test_block_sparse_attention(self, use_sigmoid_attn, num_kv_heads=None):
761
763
  num_attention_heads = 8
762
764
  sequence_length = 21
763
765
  width = 80
@@ -771,6 +773,7 @@ class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
771
773
  return_attention_scores=True,
772
774
  src_block_size=src_block_size,
773
775
  tgt_block_size=tgt_block_size,
776
+ num_kv_heads=num_kv_heads,
774
777
  use_sigmoid_attn=use_sigmoid_attn,
775
778
  sigmoid_attn_bias=-math.log(sequence_length)
776
779
  if use_sigmoid_attn
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.18.0.dev20241021
3
+ Version: 2.18.0.dev20241022
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -305,8 +305,8 @@ official/nlp/modeling/layers/bigbird_attention.py,sha256=dzutgRoQt2DFsYMpMILv_QF
305
305
  official/nlp/modeling/layers/bigbird_attention_test.py,sha256=cBYwK5k1rnykZ0gif-n7VaByLIoElA-N0_svCRKASoU,2206
306
306
  official/nlp/modeling/layers/block_diag_feedforward.py,sha256=FDEt-J_QjOxwar3eT5yjMs4hR41Ppke1zj7iswsZR4M,7243
307
307
  official/nlp/modeling/layers/block_diag_feedforward_test.py,sha256=wcg8In6FIOCxcKqe5rucftjJ_kUWTi9Ei7eEmlVCYpE,4181
308
- official/nlp/modeling/layers/block_sparse_attention.py,sha256=eY6jkSI-TrnL0JkP_9B-0DCxzppZdK_c8qp6Uw6yiD0,9923
309
- official/nlp/modeling/layers/block_sparse_attention_test.py,sha256=KSQENNhRG7Y1qDpdW_O3Ws6nPC4se7zv1UcxF2o7blI,15037
308
+ official/nlp/modeling/layers/block_sparse_attention.py,sha256=8Jyinyf5XuU6nuxblGRfNnAviBEZIltMSoNJzAVjAYo,12233
309
+ official/nlp/modeling/layers/block_sparse_attention_test.py,sha256=9YiKtv4YCrKIyUbv27P2xcTXFohoaRxq2K6vsOUi4zU,16447
310
310
  official/nlp/modeling/layers/cls_head.py,sha256=0X_gdjnAt6TZVrH_xkDcQCpwLuVz5Pb7d04wEVN_Kn8,16208
311
311
  official/nlp/modeling/layers/cls_head_test.py,sha256=01oMmiuyp1lDEXBYa9r3krn6BtH-QuSedGOca9LViEc,8888
312
312
  official/nlp/modeling/layers/factorized_embedding.py,sha256=4oFRYJbpoaSxqv8hTWY2JPGPllp-zhniz99IyRtlzV8,2902
@@ -363,8 +363,8 @@ official/nlp/modeling/layers/tn_expand_condense_test.py,sha256=J52mXzoiuaXfR61kh
363
363
  official/nlp/modeling/layers/tn_transformer_expand_condense.py,sha256=gbGJOrgxJd1SyMGB6ME04FSxuZfHqsi94Xxt23l7368,11032
364
364
  official/nlp/modeling/layers/tn_transformer_test.py,sha256=Fh-EDRoAkhO7ccD3w3FsJHC51MnZySv8jBlHYnvKZMc,8893
365
365
  official/nlp/modeling/layers/transformer.py,sha256=yofIEOjZpcvDmHbcjBmkZrl5iSe6pLtMsetNbXmxDnY,20087
366
- official/nlp/modeling/layers/transformer_encoder_block.py,sha256=kiCQ4yGejmwRsJBKpmrwA1As4rFUekNYf9xGS052kyU,24766
367
- official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=cIunagl03W1tPkkt1BDVpGEpd-7ZwCqc3sPdzQOmpuc,32269
366
+ official/nlp/modeling/layers/transformer_encoder_block.py,sha256=14dgbg6z9xeXl2trEJkxsVyQPguQ9m7U20aDAmOVDQE,24930
367
+ official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=eTIDHGbTZobWIyMswPp9K_tgyzWTLFJ9j1ujXY3EXvY,32406
368
368
  official/nlp/modeling/layers/transformer_scaffold.py,sha256=m8TF4geBkm8-VJQiTpzMI6FSJZry6oa2vPO3FXCCClE,15704
369
369
  official/nlp/modeling/layers/transformer_scaffold_test.py,sha256=pqUGldhmAKROrd4eoCWmHNtKOdCO6PH_-EigcYnvIpE,19920
370
370
  official/nlp/modeling/layers/transformer_test.py,sha256=kC_9NcLbJnBbuTaE_7BW60EF8xG_QUoICj0t0gS7O4Q,5522
@@ -1222,9 +1222,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=nc6A9K53OGqF25xN5St8EiWvdVbda
1222
1222
  tensorflow_models/nlp/__init__.py,sha256=4tA5Pf4qaFwT-fIFOpX7x7FHJpnyJT-5UgOeFYTyMlc,807
1223
1223
  tensorflow_models/uplift/__init__.py,sha256=mqfa55gweOdpKoaQyid4A_4u7xw__FcQeSIF0k_pYmI,999
1224
1224
  tensorflow_models/vision/__init__.py,sha256=zBorY_v5xva1uI-qxhZO3Qh-Dii-Suq6wEYh6hKHDfc,833
1225
- tf_models_nightly-2.18.0.dev20241021.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1226
- tf_models_nightly-2.18.0.dev20241021.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1227
- tf_models_nightly-2.18.0.dev20241021.dist-info/METADATA,sha256=RwWDKM5onfUl4FjQtHMFuJAKRJ5avUVrgN1TTVUXxYU,1432
1228
- tf_models_nightly-2.18.0.dev20241021.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1229
- tf_models_nightly-2.18.0.dev20241021.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1230
- tf_models_nightly-2.18.0.dev20241021.dist-info/RECORD,,
1225
+ tf_models_nightly-2.18.0.dev20241022.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1226
+ tf_models_nightly-2.18.0.dev20241022.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1227
+ tf_models_nightly-2.18.0.dev20241022.dist-info/METADATA,sha256=Xq17BE4FrMRbkZCvlH9T5EFnFy7FIPjl2th1UY1LuUA,1432
1228
+ tf_models_nightly-2.18.0.dev20241022.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1229
+ tf_models_nightly-2.18.0.dev20241022.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1230
+ tf_models_nightly-2.18.0.dev20241022.dist-info/RECORD,,