tf-models-nightly 2.20.0.dev20250721__py2.py3-none-any.whl → 2.20.0.dev20250722__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tf-models-nightly might be problematic. Click here for more details.

@@ -82,41 +82,44 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
82
82
  Understanding](https://arxiv.org/abs/1810.04805)
83
83
  """
84
84
 
85
- def __init__(self,
86
- num_attention_heads,
87
- inner_dim,
88
- inner_activation,
89
- output_range=None,
90
- kernel_initializer="glorot_uniform",
91
- bias_initializer="zeros",
92
- kernel_regularizer=None,
93
- bias_regularizer=None,
94
- activity_regularizer=None,
95
- kernel_constraint=None,
96
- bias_constraint=None,
97
- use_bias=True,
98
- norm_first=False,
99
- norm_epsilon=1e-12,
100
- use_rms_norm=False,
101
- output_dropout=0.0,
102
- attention_dropout=0.0,
103
- inner_dropout=0.0,
104
- attention_initializer=None,
105
- attention_axes=None,
106
- use_query_residual=True,
107
- key_dim=None,
108
- value_dim=None,
109
- output_last_dim=None,
110
- diff_q_kv_att_layer_norm=False,
111
- return_attention_scores=False,
112
- num_kv_heads=None,
113
- src_block_size=None,
114
- tgt_block_size=None,
115
- use_sigmoid_attn=False,
116
- sigmoid_attn_bias=None,
117
- linformer_dim=None,
118
- linformer_shared_kv_projection=True,
119
- **kwargs):
85
+ def __init__(
86
+ self,
87
+ num_attention_heads,
88
+ inner_dim,
89
+ inner_activation,
90
+ output_range=None,
91
+ kernel_initializer="glorot_uniform",
92
+ bias_initializer="zeros",
93
+ kernel_regularizer=None,
94
+ bias_regularizer=None,
95
+ activity_regularizer=None,
96
+ kernel_constraint=None,
97
+ bias_constraint=None,
98
+ use_bias=True,
99
+ norm_first=False,
100
+ norm_epsilon=1e-12,
101
+ use_rms_norm=False,
102
+ output_dropout=0.0,
103
+ attention_dropout=0.0,
104
+ inner_dropout=0.0,
105
+ attention_initializer=None,
106
+ attention_axes=None,
107
+ use_query_residual=True,
108
+ key_dim=None,
109
+ value_dim=None,
110
+ output_last_dim=None,
111
+ diff_q_kv_att_layer_norm=False,
112
+ return_attention_scores=False,
113
+ num_kv_heads=None,
114
+ src_block_size=None,
115
+ tgt_block_size=None,
116
+ use_sigmoid_attn=False,
117
+ sigmoid_attn_bias=None,
118
+ linformer_dim=None,
119
+ linformer_shared_kv_projection=True,
120
+ lowrank_query_seq_proj_dim=None,
121
+ **kwargs,
122
+ ):
120
123
  """Initializes `TransformerEncoderBlock`.
121
124
 
122
125
  Note: If `output_last_dim` is used and `use_query_residual` is `True`, the
@@ -197,6 +200,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
197
200
  https://arxiv.org/pdf/2006.04768.
198
201
  linformer_shared_kv_projection: If set, projection layer is shared for
199
202
  keys and values.
203
+ lowrank_query_seq_proj_dim: If set, applies a projection layer on query
204
+ sequence to the given dimension. go/constformer-doc
200
205
  **kwargs: keyword arguments.
201
206
  """
202
207
  util.filter_kwargs(kwargs)
@@ -238,6 +243,7 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
238
243
  self._sigmoid_attn_bias = sigmoid_attn_bias
239
244
  self._linformer_dim = linformer_dim
240
245
  self._linformer_shared_kv_projection = linformer_shared_kv_projection
246
+ self._lowrank_query_seq_proj_dim = lowrank_query_seq_proj_dim
241
247
  if (
242
248
  self._src_block_size is not None
243
249
  and self._num_kv_heads is not None
@@ -410,6 +416,21 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
410
416
  name="lowrank_kv_projection",
411
417
  **common_kwargs,
412
418
  )
419
+ if self._lowrank_query_seq_proj_dim is not None:
420
+ self._lowrank_query_seq_projection = tf_keras.layers.EinsumDense(
421
+ # Squash the sequence-length dimension; keep embedding as is.
422
+ "...ij,ik->...kj",
423
+ output_shape=(
424
+ self._lowrank_query_seq_proj_dim,
425
+ hidden_size,
426
+ ),
427
+ kernel_initializer=tf_utils.clone_initializer(
428
+ self._kernel_initializer
429
+ ),
430
+ bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
431
+ name="constformer_projection",
432
+ **common_kwargs,
433
+ )
413
434
  super().build(input_shape)
414
435
 
415
436
  def get_config(self):
@@ -461,10 +482,66 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
461
482
  "sigmoid_attn_bias": self._sigmoid_attn_bias,
462
483
  "linformer_dim": self._linformer_dim,
463
484
  "linformer_shared_kv_projection": self._linformer_shared_kv_projection,
485
+ "lowrank_query_seq_proj_dim": self._lowrank_query_seq_proj_dim,
464
486
  }
465
487
  base_config = super().get_config()
466
488
  return dict(list(base_config.items()) + list(config.items()))
467
489
 
490
+ def _apply_lowrank_query_projection(
491
+ self,
492
+ query: tf.Tensor,
493
+ attention_mask: tf.Tensor | None,
494
+ ):
495
+ """Applies constformer projection to the source tensor."""
496
+
497
+ # Don't project the source tensor if the `lowrank_query_seq_projection`
498
+ # (constformer) dimension is the same as the input
499
+ # sequence dimension.
500
+ if (
501
+ self._lowrank_query_seq_proj_dim is None
502
+ or query.shape[1] == self._lowrank_query_seq_proj_dim
503
+ ):
504
+ return query
505
+ # Don't overwrite the attention mask.
506
+ query = self._apply_query_mask(attention_mask, query)
507
+ dtype = query.dtype
508
+ query = self._lowrank_query_seq_projection(query)
509
+ query = tf.cast(query, dtype)
510
+ return query
511
+
512
+ def _apply_query_mask(
513
+ self,
514
+ attention_mask: tf.Tensor | None,
515
+ query: tf.Tensor,
516
+ ):
517
+ """Applying mask before the low rank factorization so that padding is accounted for.
518
+
519
+ Applies mask to query only if the dimension of query matches the mask. This
520
+ is to avoid the projection from happening multiple times while stacking
521
+ the transformer layers.
522
+
523
+ Args:
524
+ attention_mask: The attention_mask tensor.
525
+ query: The query tensor.
526
+
527
+ Returns:
528
+ query: The query tensor after applying the mask.
529
+ """
530
+ if attention_mask is None:
531
+ return query
532
+ if attention_mask.shape[1] != query.shape[1]:
533
+ # Skip the mask application for query.
534
+ logging.info(
535
+ "Skipping mask application on query. Shape mismatch: %s vs %s",
536
+ attention_mask.shape,
537
+ query.shape,
538
+ )
539
+ return query
540
+
541
+ query_mask = tf.cast(attention_mask[:, :, 0], dtype=query.dtype)
542
+ query = query * tf.expand_dims(query_mask, axis=-1)
543
+ return query
544
+
468
545
  def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any:
469
546
  """Transformer self-attention encoder block call.
470
547
 
@@ -499,6 +576,12 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
499
576
  if output_range:
500
577
  if self._norm_first:
501
578
  source_tensor = input_tensor[:, 0:output_range, :]
579
+ if self._use_query_residual:
580
+ # `source_tensor` is only used for the residual connection.
581
+ source_tensor = self._apply_lowrank_query_projection(
582
+ source_tensor, attention_mask
583
+ )
584
+
502
585
  input_tensor = self._attention_layer_norm(input_tensor)
503
586
  if key_value is not None:
504
587
  key_value = self._attention_layer_norm_kv(key_value)
@@ -508,11 +591,21 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
508
591
  else:
509
592
  if self._norm_first:
510
593
  source_tensor = input_tensor
594
+ if self._use_query_residual:
595
+ # `source_tensor` is only used for the residual connection.
596
+ source_tensor = self._apply_lowrank_query_projection(
597
+ source_tensor, attention_mask
598
+ )
511
599
  input_tensor = self._attention_layer_norm(input_tensor)
512
600
  if key_value is not None:
513
601
  key_value = self._attention_layer_norm_kv(key_value)
514
602
  target_tensor = input_tensor
515
603
 
604
+ # Project the query to the constformer dimension.
605
+ target_tensor = self._apply_lowrank_query_projection(
606
+ target_tensor, attention_mask
607
+ )
608
+
516
609
  if key_value is None:
517
610
  key_value = input_tensor
518
611
 
@@ -523,7 +616,8 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
523
616
  # Applying mask before the low rank factorization so that padding is
524
617
  # accounted for.
525
618
  query_mask = tf.cast(attention_mask[:, :, 0], dtype=target_tensor.dtype)
526
- target_tensor = target_tensor * tf.expand_dims(query_mask, axis=-1)
619
+ if self._lowrank_query_seq_proj_dim is None:
620
+ target_tensor = target_tensor * tf.expand_dims(query_mask, axis=-1)
527
621
  key_mask = tf.cast(attention_mask[:, 0, :], dtype=target_tensor.dtype)
528
622
  key_value = key_value * tf.expand_dims(key_mask, axis=-1)
529
623
  attention_mask = None
@@ -534,8 +628,9 @@ class TransformerEncoderBlock(tf_keras.layers.Layer):
534
628
  key = key_value
535
629
  value = key_value
536
630
  else:
537
- key = tf.transpose(key_value[:, :, :self._linformer_dim], [0, 2, 1])
538
- value = tf.transpose(key_value[:, :, self._linformer_dim:], [0, 2, 1])
631
+ key = tf.transpose(key_value[:, :, : self._linformer_dim], [0, 2, 1])
632
+ value = tf.transpose(key_value[:, :, self._linformer_dim :], [0, 2, 1])
633
+
539
634
  if self._return_attention_scores:
540
635
  attention_output, attention_scores = self._attention_layer(
541
636
  query=target_tensor,
@@ -844,6 +844,90 @@ class TransformerArgumentTest(tf.test.TestCase, parameterized.TestCase):
844
844
  output_tensor[1].shape.as_list(), expected_attention_scores_shape
845
845
  )
846
846
 
847
+ def test_low_rank_attention_with_constformer(self):
848
+ num_attention_heads = 8
849
+ sequence_length = 21
850
+ linformer_dim = 7
851
+ lowrank_query_seq_proj_dim = 10
852
+ width = 80
853
+ shared_kv_projection = False
854
+
855
+ test_layer = TransformerEncoderBlock(
856
+ num_attention_heads=num_attention_heads,
857
+ inner_dim=2048,
858
+ inner_activation='relu',
859
+ return_attention_scores=True,
860
+ linformer_dim=linformer_dim,
861
+ linformer_shared_kv_projection=shared_kv_projection,
862
+ lowrank_query_seq_proj_dim=lowrank_query_seq_proj_dim,
863
+ )
864
+ # Create a 3-dimensional input (the first dimension is implicit).
865
+ data_tensor = tf_keras.Input(shape=(sequence_length, width))
866
+ output_tensor = test_layer(data_tensor)
867
+
868
+ # The output from constformer has bottlenecked sequence length.
869
+ expected_layer_output_shape = [None, lowrank_query_seq_proj_dim, width]
870
+ # Note that attentions scores with Constformer don't have same
871
+ # interpretation as the original attention scores, since the sequence
872
+ # length is squashed.
873
+ expected_attention_scores_shape = [
874
+ None,
875
+ num_attention_heads,
876
+ lowrank_query_seq_proj_dim,
877
+ linformer_dim,
878
+ ]
879
+
880
+ self.assertIsInstance(output_tensor, tuple)
881
+ self.assertLen(output_tensor, 2)
882
+ # First is the standard output.
883
+ self.assertEqual(
884
+ output_tensor[0].shape.as_list(), expected_layer_output_shape
885
+ )
886
+ # Second is the attention scores.
887
+ self.assertEqual(
888
+ output_tensor[1].shape.as_list(), expected_attention_scores_shape
889
+ )
890
+
891
+ def test_low_rank_attention_with_constformer_no_linformer(self):
892
+ num_attention_heads = 8
893
+ sequence_length = 21
894
+ lowrank_query_seq_proj_dim = 10
895
+ width = 80
896
+
897
+ test_layer = TransformerEncoderBlock(
898
+ num_attention_heads=num_attention_heads,
899
+ inner_dim=2048,
900
+ inner_activation='relu',
901
+ return_attention_scores=True,
902
+ lowrank_query_seq_proj_dim=lowrank_query_seq_proj_dim,
903
+ )
904
+ # Create a 3-dimensional input (the first dimension is implicit).
905
+ data_tensor = tf_keras.Input(shape=(sequence_length, width))
906
+ output_tensor = test_layer(data_tensor)
907
+
908
+ # The output from constformer has bottlenecked sequence length.
909
+ expected_layer_output_shape = [None, lowrank_query_seq_proj_dim, width]
910
+ # Note that attentions scores with Constformer don't have same
911
+ # interpretation as the original attention scores, since the sequence
912
+ # length is squashed.
913
+ expected_attention_scores_shape = [
914
+ None,
915
+ num_attention_heads,
916
+ lowrank_query_seq_proj_dim,
917
+ sequence_length,
918
+ ]
919
+
920
+ self.assertIsInstance(output_tensor, tuple)
921
+ self.assertLen(output_tensor, 2)
922
+ # First is the standard output.
923
+ self.assertEqual(
924
+ output_tensor[0].shape.as_list(), expected_layer_output_shape
925
+ )
926
+ # Second is the attention scores.
927
+ self.assertEqual(
928
+ output_tensor[1].shape.as_list(), expected_attention_scores_shape
929
+ )
930
+
847
931
 
848
932
  if __name__ == '__main__':
849
933
  tf.test.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf-models-nightly
3
- Version: 2.20.0.dev20250721
3
+ Version: 2.20.0.dev20250722
4
4
  Summary: TensorFlow Official Models
5
5
  Home-page: https://github.com/tensorflow/models
6
6
  Author: Google Inc.
@@ -363,8 +363,8 @@ official/nlp/modeling/layers/tn_expand_condense_test.py,sha256=QWq1dJqQUPe5n69K3
363
363
  official/nlp/modeling/layers/tn_transformer_expand_condense.py,sha256=omzTkCBEk2TOkHEYDEBwve6WsOitX7IIJHzeKXdqDq0,11012
364
364
  official/nlp/modeling/layers/tn_transformer_test.py,sha256=pSCONEZRI4J9_6QLTJ3g_ynUYLrRXsJ1c2YMSiOV_60,8893
365
365
  official/nlp/modeling/layers/transformer.py,sha256=VjUO-gVj_PnavbT_vSrg5NDKMr0SRSiqSg5ktd42m5M,20087
366
- official/nlp/modeling/layers/transformer_encoder_block.py,sha256=fsYdA40A5kh8KvrEUyBLmv8UkDkV3eLdQb9mleocoM0,24930
367
- official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=PA_XJ4epjJLcgXjgVRoWnOvqszn_a2RKfFmbo5ow724,32406
366
+ official/nlp/modeling/layers/transformer_encoder_block.py,sha256=E-WeoxsjByL-lkPAyEjDbvt1_3ghcIpCXnoLSeCDKFQ,27953
367
+ official/nlp/modeling/layers/transformer_encoder_block_test.py,sha256=Y1Byz7RgF6puMIU3WRyoWghp5a3rYemgTRFRDohfy2Q,35402
368
368
  official/nlp/modeling/layers/transformer_scaffold.py,sha256=qmzhCJvbbFVF9zDqnfO4Zs2JDXwKhK7iEBOhsU6-KpQ,15704
369
369
  official/nlp/modeling/layers/transformer_scaffold_test.py,sha256=dRJwesTBKm-mF5mDHrHfVpVNnxa-Wx-fj_4ZHDPTpE0,19920
370
370
  official/nlp/modeling/layers/transformer_test.py,sha256=-pk9cdz9UlMpCIkGRkCKsMmjdRGi0seySaaB_2dwmXw,5522
@@ -1248,9 +1248,9 @@ tensorflow_models/tensorflow_models_test.py,sha256=yiAneltAW3NHSj3fUSvHNBjfq0MGZ
1248
1248
  tensorflow_models/nlp/__init__.py,sha256=8uQd4wI6Zc4IJMPjtQifMeWVbPFkTxqYh66wfivCOL4,807
1249
1249
  tensorflow_models/uplift/__init__.py,sha256=NzaweFf4ZmhRb2l_fuV6bP-2N8oSO3xu6xJqVb1UmpY,999
1250
1250
  tensorflow_models/vision/__init__.py,sha256=ks420Ooqzi0hU7HnQpM5rylLaE-YcJdJkBx_umVaXlE,833
1251
- tf_models_nightly-2.20.0.dev20250721.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1252
- tf_models_nightly-2.20.0.dev20250721.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1253
- tf_models_nightly-2.20.0.dev20250721.dist-info/METADATA,sha256=xZHsVrpfd06nYob1JYmhxn2mNk9GqGntpwFHCUJIjqs,1432
1254
- tf_models_nightly-2.20.0.dev20250721.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1255
- tf_models_nightly-2.20.0.dev20250721.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1256
- tf_models_nightly-2.20.0.dev20250721.dist-info/RECORD,,
1251
+ tf_models_nightly-2.20.0.dev20250722.dist-info/AUTHORS,sha256=1dG3fXVu9jlo7bul8xuix5F5vOnczMk7_yWn4y70uw0,337
1252
+ tf_models_nightly-2.20.0.dev20250722.dist-info/LICENSE,sha256=WxeBS_DejPZQabxtfMOM_xn8qoZNJDQjrT7z2wG1I4U,11512
1253
+ tf_models_nightly-2.20.0.dev20250722.dist-info/METADATA,sha256=g7qc5gIL8nhGzTxMTsPABwW7C6sJwZBvjKjgDtWNxok,1432
1254
+ tf_models_nightly-2.20.0.dev20250722.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
1255
+ tf_models_nightly-2.20.0.dev20250722.dist-info/top_level.txt,sha256=gum2FfO5R4cvjl2-QtP-S1aNmsvIZaFFT6VFzU0f4-g,33
1256
+ tf_models_nightly-2.20.0.dev20250722.dist-info/RECORD,,