tf-models-nightly 2.19.0.dev20250108__py2.py3-none-any.whl → 2.19.0.dev20250109__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tf-models-nightly might be problematic. Click here for more details.

Files changed (32) hide show
  1. official/projects/detr/__init__.py +14 -0
  2. official/projects/detr/configs/__init__.py +14 -0
  3. official/projects/detr/configs/detr.py +277 -0
  4. official/projects/detr/configs/detr_test.py +51 -0
  5. official/projects/detr/dataloaders/__init__.py +14 -0
  6. official/projects/detr/dataloaders/coco.py +157 -0
  7. official/projects/detr/dataloaders/coco_test.py +111 -0
  8. official/projects/detr/dataloaders/detr_input.py +175 -0
  9. official/projects/detr/experiments/__init__.py +14 -0
  10. official/projects/detr/modeling/__init__.py +14 -0
  11. official/projects/detr/modeling/detr.py +345 -0
  12. official/projects/detr/modeling/detr_test.py +70 -0
  13. official/projects/detr/modeling/transformer.py +849 -0
  14. official/projects/detr/modeling/transformer_test.py +263 -0
  15. official/projects/detr/ops/__init__.py +14 -0
  16. official/projects/detr/ops/matchers.py +489 -0
  17. official/projects/detr/ops/matchers_test.py +95 -0
  18. official/projects/detr/optimization.py +151 -0
  19. official/projects/detr/serving/__init__.py +14 -0
  20. official/projects/detr/serving/export_module.py +103 -0
  21. official/projects/detr/serving/export_module_test.py +98 -0
  22. official/projects/detr/serving/export_saved_model.py +109 -0
  23. official/projects/detr/tasks/__init__.py +14 -0
  24. official/projects/detr/tasks/detection.py +421 -0
  25. official/projects/detr/tasks/detection_test.py +203 -0
  26. official/projects/detr/train.py +70 -0
  27. {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/METADATA +1 -1
  28. {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/RECORD +32 -6
  29. {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/AUTHORS +0 -0
  30. {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/LICENSE +0 -0
  31. {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/WHEEL +0 -0
  32. {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,849 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Specialized Transformers for DETR.
16
+
17
+ the position embeddings are added to the query and key for every self- and
18
+ cross-attention layer.
19
+ """
20
+
21
+ import tensorflow as tf, tf_keras
22
+
23
+ from official.modeling import tf_utils
24
+ from official.nlp.modeling import layers
25
+ from official.nlp.modeling import models
26
+
27
+
28
+ class TransformerEncoder(tf_keras.layers.Layer):
29
+ """Transformer encoder.
30
+
31
+ Transformer encoder is made up of N identical layers. Each layer is composed
32
+ of the sublayers:
33
+ 1. Self-attention layer
34
+ 2. Feedforward network (which is 2 fully-connected layers)
35
+ """
36
+
37
+ def __init__(self,
38
+ num_layers=6,
39
+ num_attention_heads=8,
40
+ intermediate_size=2048,
41
+ activation="relu",
42
+ dropout_rate=0.0,
43
+ attention_dropout_rate=0.0,
44
+ use_bias=False,
45
+ norm_first=True,
46
+ norm_epsilon=1e-6,
47
+ intermediate_dropout=0.0,
48
+ **kwargs):
49
+ """Initialize a Transformer encoder.
50
+
51
+ Args:
52
+ num_layers: Number of layers.
53
+ num_attention_heads: Number of attention heads.
54
+ intermediate_size: Size of the intermediate (Feedforward) layer.
55
+ activation: Activation for the intermediate layer.
56
+ dropout_rate: Dropout probability.
57
+ attention_dropout_rate: Dropout probability for attention layers.
58
+ use_bias: Whether to enable use_bias in attention layer. If set False,
59
+ use_bias in attention layer is disabled.
60
+ norm_first: Whether to normalize inputs to attention and intermediate
61
+ dense layers. If set False, output of attention and intermediate dense
62
+ layers is normalized.
63
+ norm_epsilon: Epsilon value to initialize normalization layers.
64
+ intermediate_dropout: Dropout probability for intermediate_dropout_layer.
65
+ **kwargs: key word arguemnts passed to tf_keras.layers.Layer.
66
+ """
67
+
68
+ super(TransformerEncoder, self).__init__(**kwargs)
69
+ self.num_layers = num_layers
70
+ self.num_attention_heads = num_attention_heads
71
+ self._intermediate_size = intermediate_size
72
+ self._activation = activation
73
+ self._dropout_rate = dropout_rate
74
+ self._attention_dropout_rate = attention_dropout_rate
75
+ self._use_bias = use_bias
76
+ self._norm_first = norm_first
77
+ self._norm_epsilon = norm_epsilon
78
+ self._intermediate_dropout = intermediate_dropout
79
+
80
+ def build(self, input_shape):
81
+ """Implements build() for the layer."""
82
+ self.encoder_layers = []
83
+ for i in range(self.num_layers):
84
+ self.encoder_layers.append(
85
+ TransformerEncoderBlock(
86
+ num_attention_heads=self.num_attention_heads,
87
+ inner_dim=self._intermediate_size,
88
+ inner_activation=self._activation,
89
+ output_dropout=self._dropout_rate,
90
+ attention_dropout=self._attention_dropout_rate,
91
+ use_bias=self._use_bias,
92
+ norm_first=self._norm_first,
93
+ norm_epsilon=self._norm_epsilon,
94
+ inner_dropout=self._intermediate_dropout,
95
+ attention_initializer=tf_utils.clone_initializer(
96
+ models.seq2seq_transformer.attention_initializer(
97
+ input_shape[2])),
98
+ name=("layer_%d" % i)))
99
+ self.output_normalization = tf_keras.layers.LayerNormalization(
100
+ epsilon=self._norm_epsilon, dtype="float32")
101
+ super(TransformerEncoder, self).build(input_shape)
102
+
103
+ def get_config(self):
104
+ config = {
105
+ "num_layers": self.num_layers,
106
+ "num_attention_heads": self.num_attention_heads,
107
+ "intermediate_size": self._intermediate_size,
108
+ "activation": self._activation,
109
+ "dropout_rate": self._dropout_rate,
110
+ "attention_dropout_rate": self._attention_dropout_rate,
111
+ "use_bias": self._use_bias,
112
+ "norm_first": self._norm_first,
113
+ "norm_epsilon": self._norm_epsilon,
114
+ "intermediate_dropout": self._intermediate_dropout
115
+ }
116
+ base_config = super(TransformerEncoder, self).get_config()
117
+ return dict(list(base_config.items()) + list(config.items()))
118
+
119
+ def call(self, encoder_inputs, attention_mask=None, pos_embed=None):
120
+ """Return the output of the encoder.
121
+
122
+ Args:
123
+ encoder_inputs: A tensor with shape `(batch_size, input_length,
124
+ hidden_size)`.
125
+ attention_mask: A mask for the encoder self-attention layer with shape
126
+ `(batch_size, input_length, input_length)`.
127
+ pos_embed: Position embedding to add to every encoder layer.
128
+
129
+ Returns:
130
+ Output of encoder which is a `float32` tensor with shape
131
+ `(batch_size, input_length, hidden_size)`.
132
+ """
133
+ for layer_idx in range(self.num_layers):
134
+ encoder_inputs = self.encoder_layers[layer_idx](
135
+ [encoder_inputs, attention_mask, pos_embed])
136
+
137
+ output_tensor = encoder_inputs
138
+ output_tensor = self.output_normalization(output_tensor)
139
+
140
+ return output_tensor
141
+
142
+
143
+ class TransformerEncoderBlock(tf_keras.layers.Layer):
144
+ """TransformerEncoderBlock layer.
145
+
146
+ This layer implements the Transformer Encoder from
147
+ "Attention Is All You Need". (https://arxiv.org/abs/1706.03762),
148
+ which combines a `tf_keras.layers.MultiHeadAttention` layer with a
149
+ two-layer feedforward network. The only difference: position embedding is
150
+ added to the query and key of self-attention.
151
+
152
+ References:
153
+ [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
154
+ [BERT: Pre-training of Deep Bidirectional Transformers for Language
155
+ Understanding](https://arxiv.org/abs/1810.04805)
156
+ """
157
+
158
+ def __init__(self,
159
+ num_attention_heads,
160
+ inner_dim,
161
+ inner_activation,
162
+ output_range=None,
163
+ kernel_initializer="glorot_uniform",
164
+ bias_initializer="zeros",
165
+ kernel_regularizer=None,
166
+ bias_regularizer=None,
167
+ activity_regularizer=None,
168
+ kernel_constraint=None,
169
+ bias_constraint=None,
170
+ use_bias=True,
171
+ norm_first=False,
172
+ norm_epsilon=1e-12,
173
+ output_dropout=0.0,
174
+ attention_dropout=0.0,
175
+ inner_dropout=0.0,
176
+ attention_initializer=None,
177
+ attention_axes=None,
178
+ **kwargs):
179
+ """Initializes `TransformerEncoderBlock`.
180
+
181
+ Args:
182
+ num_attention_heads: Number of attention heads.
183
+ inner_dim: The output dimension of the first Dense layer in a two-layer
184
+ feedforward network.
185
+ inner_activation: The activation for the first Dense layer in a two-layer
186
+ feedforward network.
187
+ output_range: the sequence output range, [0, output_range) for slicing the
188
+ target sequence. `None` means the target sequence is not sliced.
189
+ kernel_initializer: Initializer for dense layer kernels.
190
+ bias_initializer: Initializer for dense layer biases.
191
+ kernel_regularizer: Regularizer for dense layer kernels.
192
+ bias_regularizer: Regularizer for dense layer biases.
193
+ activity_regularizer: Regularizer for dense layer activity.
194
+ kernel_constraint: Constraint for dense layer kernels.
195
+ bias_constraint: Constraint for dense layer kernels.
196
+ use_bias: Whether to enable use_bias in attention layer. If set False,
197
+ use_bias in attention layer is disabled.
198
+ norm_first: Whether to normalize inputs to attention and intermediate
199
+ dense layers. If set False, output of attention and intermediate dense
200
+ layers is normalized.
201
+ norm_epsilon: Epsilon value to initialize normalization layers.
202
+ output_dropout: Dropout probability for the post-attention and output
203
+ dropout.
204
+ attention_dropout: Dropout probability for within the attention layer.
205
+ inner_dropout: Dropout probability for the first Dense layer in a
206
+ two-layer feedforward network.
207
+ attention_initializer: Initializer for kernels of attention layers. If set
208
+ `None`, attention layers use kernel_initializer as initializer for
209
+ kernel.
210
+ attention_axes: axes over which the attention is applied. `None` means
211
+ attention over all axes, but batch, heads, and features.
212
+ **kwargs: keyword arguments/
213
+ """
214
+ super().__init__(**kwargs)
215
+
216
+ self._num_heads = num_attention_heads
217
+ self._inner_dim = inner_dim
218
+ self._inner_activation = inner_activation
219
+ self._attention_dropout = attention_dropout
220
+ self._attention_dropout_rate = attention_dropout
221
+ self._output_dropout = output_dropout
222
+ self._output_dropout_rate = output_dropout
223
+ self._output_range = output_range
224
+ self._kernel_initializer = tf_keras.initializers.get(kernel_initializer)
225
+ self._bias_initializer = tf_keras.initializers.get(bias_initializer)
226
+ self._kernel_regularizer = tf_keras.regularizers.get(kernel_regularizer)
227
+ self._bias_regularizer = tf_keras.regularizers.get(bias_regularizer)
228
+ self._activity_regularizer = tf_keras.regularizers.get(activity_regularizer)
229
+ self._kernel_constraint = tf_keras.constraints.get(kernel_constraint)
230
+ self._bias_constraint = tf_keras.constraints.get(bias_constraint)
231
+ self._use_bias = use_bias
232
+ self._norm_first = norm_first
233
+ self._norm_epsilon = norm_epsilon
234
+ self._inner_dropout = inner_dropout
235
+ if attention_initializer:
236
+ self._attention_initializer = tf_keras.initializers.get(
237
+ attention_initializer)
238
+ else:
239
+ self._attention_initializer = tf_utils.clone_initializer(
240
+ self._kernel_initializer)
241
+ self._attention_axes = attention_axes
242
+
243
+ def build(self, input_shape):
244
+ if isinstance(input_shape, tf.TensorShape):
245
+ input_tensor_shape = input_shape
246
+ elif isinstance(input_shape, (list, tuple)):
247
+ input_tensor_shape = tf.TensorShape(input_shape[0])
248
+ else:
249
+ raise ValueError(
250
+ "The type of input shape argument is not supported, got: %s" %
251
+ type(input_shape))
252
+ einsum_equation = "abc,cd->abd"
253
+ if len(input_tensor_shape.as_list()) > 3:
254
+ einsum_equation = "...bc,cd->...bd"
255
+ hidden_size = input_tensor_shape[-1]
256
+ if hidden_size % self._num_heads != 0:
257
+ raise ValueError(
258
+ "The input size (%d) is not a multiple of the number of attention "
259
+ "heads (%d)" % (hidden_size, self._num_heads))
260
+ self._attention_head_size = int(hidden_size // self._num_heads)
261
+ common_kwargs = dict(
262
+ bias_initializer=self._bias_initializer,
263
+ kernel_regularizer=self._kernel_regularizer,
264
+ bias_regularizer=self._bias_regularizer,
265
+ activity_regularizer=self._activity_regularizer,
266
+ kernel_constraint=self._kernel_constraint,
267
+ bias_constraint=self._bias_constraint)
268
+ self._attention_layer = tf_keras.layers.MultiHeadAttention(
269
+ num_heads=self._num_heads,
270
+ key_dim=self._attention_head_size,
271
+ dropout=self._attention_dropout,
272
+ use_bias=self._use_bias,
273
+ kernel_initializer=self._attention_initializer,
274
+ attention_axes=self._attention_axes,
275
+ name="self_attention",
276
+ **common_kwargs)
277
+ self._attention_dropout = tf_keras.layers.Dropout(rate=self._output_dropout)
278
+ # Use float32 in layernorm for numeric stability.
279
+ # It is probably safe in mixed_float16, but we haven't validated this yet.
280
+ self._attention_layer_norm = (
281
+ tf_keras.layers.LayerNormalization(
282
+ name="self_attention_layer_norm",
283
+ axis=-1,
284
+ epsilon=self._norm_epsilon,
285
+ dtype=tf.float32))
286
+ self._intermediate_dense = tf_keras.layers.EinsumDense(
287
+ einsum_equation,
288
+ output_shape=(None, self._inner_dim),
289
+ bias_axes="d",
290
+ kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
291
+ name="intermediate",
292
+ **common_kwargs)
293
+ policy = tf_keras.mixed_precision.global_policy()
294
+ if policy.name == "mixed_bfloat16":
295
+ # bfloat16 causes BERT with the LAMB optimizer to not converge
296
+ # as well, so we use float32.
297
+ # TODO(b/154538392): Investigate this.
298
+ policy = tf.float32
299
+ self._intermediate_activation_layer = tf_keras.layers.Activation(
300
+ self._inner_activation, dtype=policy)
301
+ self._inner_dropout_layer = tf_keras.layers.Dropout(
302
+ rate=self._inner_dropout)
303
+ self._output_dense = tf_keras.layers.EinsumDense(
304
+ einsum_equation,
305
+ output_shape=(None, hidden_size),
306
+ bias_axes="d",
307
+ name="output",
308
+ kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
309
+ **common_kwargs)
310
+ self._output_dropout = tf_keras.layers.Dropout(rate=self._output_dropout)
311
+ # Use float32 in layernorm for numeric stability.
312
+ self._output_layer_norm = tf_keras.layers.LayerNormalization(
313
+ name="output_layer_norm",
314
+ axis=-1,
315
+ epsilon=self._norm_epsilon,
316
+ dtype=tf.float32)
317
+
318
+ super(TransformerEncoderBlock, self).build(input_shape)
319
+
320
+ def get_config(self):
321
+ config = {
322
+ "num_attention_heads": self._num_heads,
323
+ "inner_dim": self._inner_dim,
324
+ "inner_activation": self._inner_activation,
325
+ "output_dropout": self._output_dropout_rate,
326
+ "attention_dropout": self._attention_dropout_rate,
327
+ "output_range": self._output_range,
328
+ "kernel_initializer": tf_utils.serialize_initializer(
329
+ self._kernel_initializer, use_legacy_format=True
330
+ ),
331
+ "bias_initializer": tf_utils.serialize_initializer(
332
+ self._bias_initializer, use_legacy_format=True
333
+ ),
334
+ "kernel_regularizer": tf_utils.serialize_regularizer(
335
+ self._kernel_regularizer, use_legacy_format=True
336
+ ),
337
+ "bias_regularizer": tf_utils.serialize_regularizer(
338
+ self._bias_regularizer, use_legacy_format=True
339
+ ),
340
+ "activity_regularizer": tf_utils.serialize_regularizer(
341
+ self._activity_regularizer, use_legacy_format=True
342
+ ),
343
+ "kernel_constraint": tf_utils.serialize_constraint(
344
+ self._kernel_constraint, use_legacy_format=True
345
+ ),
346
+ "bias_constraint": tf_utils.serialize_constraint(
347
+ self._bias_constraint, use_legacy_format=True
348
+ ),
349
+ "use_bias": self._use_bias,
350
+ "norm_first": self._norm_first,
351
+ "norm_epsilon": self._norm_epsilon,
352
+ "inner_dropout": self._inner_dropout,
353
+ "attention_initializer": tf_utils.serialize_initializer(
354
+ self._attention_initializer, use_legacy_format=True
355
+ ),
356
+ "attention_axes": self._attention_axes,
357
+ }
358
+ base_config = super(TransformerEncoderBlock, self).get_config()
359
+ return dict(list(base_config.items()) + list(config.items()))
360
+
361
+ def call(self, inputs):
362
+ """Transformer self-attention encoder block call.
363
+
364
+ Args:
365
+ inputs: a single tensor or a list of tensors. `input tensor` as the single
366
+ sequence of embeddings. [`input tensor`, `attention mask`] to have the
367
+ additional attention mask. [`input tensor`, `attention mask`, `query
368
+ embed`] to have an additional position embedding to add.
369
+
370
+ Returns:
371
+ An output tensor with the same dimensions as input/query tensor.
372
+ """
373
+ input_tensor, attention_mask, pos_embed = inputs
374
+
375
+ key_value = None
376
+
377
+ if self._output_range:
378
+ if self._norm_first:
379
+ source_tensor = input_tensor[:, 0:self._output_range, :]
380
+ input_tensor = self._attention_layer_norm(input_tensor)
381
+ if key_value is not None:
382
+ key_value = self._attention_layer_norm(key_value)
383
+ target_tensor = input_tensor[:, 0:self._output_range, :]
384
+ if attention_mask is not None:
385
+ attention_mask = attention_mask[:, 0:self._output_range, :]
386
+ else:
387
+ if self._norm_first:
388
+ source_tensor = input_tensor
389
+ input_tensor = self._attention_layer_norm(input_tensor)
390
+ if key_value is not None:
391
+ key_value = self._attention_layer_norm(key_value)
392
+ target_tensor = input_tensor
393
+
394
+ if key_value is None:
395
+ key_value = input_tensor
396
+ attention_output = self._attention_layer(
397
+ query=target_tensor + pos_embed,
398
+ key=key_value + pos_embed,
399
+ value=key_value,
400
+ attention_mask=attention_mask)
401
+ attention_output = self._attention_dropout(attention_output)
402
+ if self._norm_first:
403
+ attention_output = source_tensor + attention_output
404
+ else:
405
+ attention_output = self._attention_layer_norm(target_tensor +
406
+ attention_output)
407
+ if self._norm_first:
408
+ source_attention_output = attention_output
409
+ attention_output = self._output_layer_norm(attention_output)
410
+ inner_output = self._intermediate_dense(attention_output)
411
+ inner_output = self._intermediate_activation_layer(inner_output)
412
+ inner_output = self._inner_dropout_layer(inner_output)
413
+ layer_output = self._output_dense(inner_output)
414
+ layer_output = self._output_dropout(layer_output)
415
+
416
+ if self._norm_first:
417
+ return source_attention_output + layer_output
418
+
419
+ # During mixed precision training, layer norm output is always fp32 for now.
420
+ # Casts fp32 for the subsequent add.
421
+ layer_output = tf.cast(layer_output, tf.float32)
422
+ return self._output_layer_norm(layer_output + attention_output)
423
+
424
+
425
+ class TransformerDecoder(tf_keras.layers.Layer):
426
+ """Transformer decoder.
427
+
428
+ Like the encoder, the decoder is made up of N identical layers.
429
+ Each layer is composed of the sublayers:
430
+ 1. Self-attention layer
431
+ 2. Multi-headed attention layer combining encoder outputs with results from
432
+ the previous self-attention layer.
433
+ 3. Feedforward network (2 fully-connected layers)
434
+ """
435
+
436
+ def __init__(self,
437
+ num_layers=6,
438
+ num_attention_heads=8,
439
+ intermediate_size=2048,
440
+ activation="relu",
441
+ dropout_rate=0.0,
442
+ attention_dropout_rate=0.0,
443
+ use_bias=False,
444
+ norm_first=True,
445
+ norm_epsilon=1e-6,
446
+ intermediate_dropout=0.0,
447
+ **kwargs):
448
+ """Initialize a Transformer decoder.
449
+
450
+ Args:
451
+ num_layers: Number of layers.
452
+ num_attention_heads: Number of attention heads.
453
+ intermediate_size: Size of the intermediate (Feedforward) layer.
454
+ activation: Activation for the intermediate layer.
455
+ dropout_rate: Dropout probability.
456
+ attention_dropout_rate: Dropout probability for attention layers.
457
+ use_bias: Whether to enable use_bias in attention layer. If set `False`,
458
+ use_bias in attention layer is disabled.
459
+ norm_first: Whether to normalize inputs to attention and intermediate
460
+ dense layers. If set `False`, output of attention and intermediate dense
461
+ layers is normalized.
462
+ norm_epsilon: Epsilon value to initialize normalization layers.
463
+ intermediate_dropout: Dropout probability for intermediate_dropout_layer.
464
+ **kwargs: key word arguemnts passed to tf_keras.layers.Layer.
465
+ """
466
+ super(TransformerDecoder, self).__init__(**kwargs)
467
+ self.num_layers = num_layers
468
+ self.num_attention_heads = num_attention_heads
469
+ self._intermediate_size = intermediate_size
470
+ self._activation = activation
471
+ self._dropout_rate = dropout_rate
472
+ self._attention_dropout_rate = attention_dropout_rate
473
+ self._use_bias = use_bias
474
+ self._norm_first = norm_first
475
+ self._norm_epsilon = norm_epsilon
476
+ self._intermediate_dropout = intermediate_dropout
477
+
478
+ def build(self, input_shape):
479
+ """Implements build() for the layer."""
480
+ self.decoder_layers = []
481
+ for i in range(self.num_layers):
482
+ self.decoder_layers.append(
483
+ TransformerDecoderBlock(
484
+ num_attention_heads=self.num_attention_heads,
485
+ intermediate_size=self._intermediate_size,
486
+ intermediate_activation=self._activation,
487
+ dropout_rate=self._dropout_rate,
488
+ attention_dropout_rate=self._attention_dropout_rate,
489
+ use_bias=self._use_bias,
490
+ norm_first=self._norm_first,
491
+ norm_epsilon=self._norm_epsilon,
492
+ intermediate_dropout=self._intermediate_dropout,
493
+ attention_initializer=tf_utils.clone_initializer(
494
+ models.seq2seq_transformer.attention_initializer(
495
+ input_shape[2])),
496
+ name=("layer_%d" % i)))
497
+ self.output_normalization = tf_keras.layers.LayerNormalization(
498
+ epsilon=self._norm_epsilon, dtype="float32")
499
+ super(TransformerDecoder, self).build(input_shape)
500
+
501
+ def get_config(self):
502
+ config = {
503
+ "num_layers": self.num_layers,
504
+ "num_attention_heads": self.num_attention_heads,
505
+ "intermediate_size": self._intermediate_size,
506
+ "activation": self._activation,
507
+ "dropout_rate": self._dropout_rate,
508
+ "attention_dropout_rate": self._attention_dropout_rate,
509
+ "use_bias": self._use_bias,
510
+ "norm_first": self._norm_first,
511
+ "norm_epsilon": self._norm_epsilon,
512
+ "intermediate_dropout": self._intermediate_dropout
513
+ }
514
+ base_config = super(TransformerDecoder, self).get_config()
515
+ return dict(list(base_config.items()) + list(config.items()))
516
+
517
+ def call(self,
518
+ target,
519
+ memory,
520
+ self_attention_mask=None,
521
+ cross_attention_mask=None,
522
+ cache=None,
523
+ decode_loop_step=None,
524
+ return_all_decoder_outputs=False,
525
+ input_pos_embed=None,
526
+ memory_pos_embed=None):
527
+ """Return the output of the decoder layer stacks.
528
+
529
+ Args:
530
+ target: A tensor with shape `(batch_size, target_length, hidden_size)`.
531
+ memory: A tensor with shape `(batch_size, input_length, hidden_size)`.
532
+ self_attention_mask: A tensor with shape `(batch_size, target_len,
533
+ target_length)`, the mask for decoder self-attention layer.
534
+ cross_attention_mask: A tensor with shape `(batch_size, target_length,
535
+ input_length)` which is the mask for encoder-decoder attention layer.
536
+ cache: (Used for fast decoding) A nested dictionary storing previous
537
+ decoder self-attention values. The items are:
538
+ {layer_n: {"k": A tensor with shape `(batch_size, i, key_channels)`,
539
+ "v": A tensor with shape `(batch_size, i, value_channels)`},
540
+ ...}
541
+ decode_loop_step: An integer, the step number of the decoding loop. Used
542
+ only for autoregressive inference on TPU.
543
+ return_all_decoder_outputs: Return all decoder layer outputs. Note that
544
+ the outputs are layer normed. This is useful when introducing per layer
545
+ auxiliary loss.
546
+ input_pos_embed: A tensor that is added to the query and key of the
547
+ self-attention layer.
548
+ memory_pos_embed: A tensor that is added to the query and key of the
549
+ cross-attention layer.
550
+
551
+ Returns:
552
+ Output of decoder.
553
+ float32 tensor with shape `(batch_size, target_length, hidden_size`).
554
+ """
555
+
556
+ output_tensor = target
557
+ decoder_outputs = []
558
+ for layer_idx in range(self.num_layers):
559
+ transformer_inputs = [
560
+ output_tensor, memory, cross_attention_mask, self_attention_mask,
561
+ input_pos_embed, memory_pos_embed
562
+ ]
563
+ # Gets the cache for decoding.
564
+ if cache is None:
565
+ output_tensor, _ = self.decoder_layers[layer_idx](transformer_inputs)
566
+ else:
567
+ cache_layer_idx = str(layer_idx)
568
+ output_tensor, cache[cache_layer_idx] = self.decoder_layers[layer_idx](
569
+ transformer_inputs,
570
+ cache=cache[cache_layer_idx],
571
+ decode_loop_step=decode_loop_step)
572
+ if return_all_decoder_outputs:
573
+ decoder_outputs.append(self.output_normalization(output_tensor))
574
+
575
+ if return_all_decoder_outputs:
576
+ return decoder_outputs
577
+ else:
578
+ return self.output_normalization(output_tensor)
579
+
580
+
581
+ class TransformerDecoderBlock(tf_keras.layers.Layer):
582
+ """Single transformer layer for decoder.
583
+
584
+ It has three sub-layers:
585
+ (1) a multi-head self-attention mechanism.
586
+ (2) a encoder-decoder attention.
587
+ (3) a positionwise fully connected feed-forward network.
588
+ """
589
+
590
+ def __init__(self,
591
+ num_attention_heads,
592
+ intermediate_size,
593
+ intermediate_activation,
594
+ dropout_rate=0.0,
595
+ attention_dropout_rate=0.0,
596
+ kernel_initializer="glorot_uniform",
597
+ bias_initializer="zeros",
598
+ kernel_regularizer=None,
599
+ bias_regularizer=None,
600
+ activity_regularizer=None,
601
+ kernel_constraint=None,
602
+ bias_constraint=None,
603
+ use_bias=True,
604
+ norm_first=False,
605
+ norm_epsilon=1e-12,
606
+ intermediate_dropout=0.0,
607
+ attention_initializer=None,
608
+ **kwargs):
609
+ """Initialize a Transformer decoder block.
610
+
611
+ Args:
612
+ num_attention_heads: Number of attention heads.
613
+ intermediate_size: Size of the intermediate layer.
614
+ intermediate_activation: Activation for the intermediate layer.
615
+ dropout_rate: Dropout probability for the post-attention and output
616
+ dropout.
617
+ attention_dropout_rate: Dropout probability for within the attention
618
+ layer.
619
+ kernel_initializer: Initializer for dense layer kernels.
620
+ bias_initializer: Initializer for dense layer biases.
621
+ kernel_regularizer: Regularizer for dense layer kernels.
622
+ bias_regularizer: Regularizer for dense layer biases.
623
+ activity_regularizer: Regularizer for dense layer activity.
624
+ kernel_constraint: Constraint for dense layer kernels.
625
+ bias_constraint: Constraint for dense layer kernels.
626
+ use_bias: Whether to enable use_bias in attention layer. If set False,
627
+ use_bias in attention layer is disabled.
628
+ norm_first: Whether to normalize inputs to attention and intermediate
629
+ dense layers. If set False, output of attention and intermediate dense
630
+ layers is normalized.
631
+ norm_epsilon: Epsilon value to initialize normalization layers.
632
+ intermediate_dropout: Dropout probability for intermediate_dropout_layer.
633
+ attention_initializer: Initializer for kernels of attention layers. If set
634
+ `None`, attention layers use kernel_initializer as initializer for
635
+ kernel.
636
+ **kwargs: key word arguemnts passed to tf_keras.layers.Layer.
637
+ """
638
+ super().__init__(**kwargs)
639
+ self.num_attention_heads = num_attention_heads
640
+ self.intermediate_size = intermediate_size
641
+ self.intermediate_activation = tf_keras.activations.get(
642
+ intermediate_activation)
643
+ self.dropout_rate = dropout_rate
644
+ self.attention_dropout_rate = attention_dropout_rate
645
+ self._kernel_initializer = tf_keras.initializers.get(kernel_initializer)
646
+ self._bias_initializer = tf_keras.initializers.get(bias_initializer)
647
+ self._kernel_regularizer = tf_keras.regularizers.get(kernel_regularizer)
648
+ self._bias_regularizer = tf_keras.regularizers.get(bias_regularizer)
649
+ self._activity_regularizer = tf_keras.regularizers.get(activity_regularizer)
650
+ self._kernel_constraint = tf_keras.constraints.get(kernel_constraint)
651
+ self._bias_constraint = tf_keras.constraints.get(bias_constraint)
652
+ self._use_bias = use_bias
653
+ self._norm_first = norm_first
654
+ self._norm_epsilon = norm_epsilon
655
+ self._intermediate_dropout = intermediate_dropout
656
+ if attention_initializer:
657
+ self._attention_initializer = tf_keras.initializers.get(
658
+ attention_initializer)
659
+ else:
660
+ self._attention_initializer = tf_utils.clone_initializer(
661
+ self._kernel_initializer)
662
+ self._cross_attention_cls = layers.attention.MultiHeadAttention
663
+
664
+ def build(self, input_shape):
665
+ target_tensor_shape = tf.TensorShape(input_shape[0])
666
+ if len(target_tensor_shape.as_list()) != 3:
667
+ raise ValueError("TransformerLayer expects a three-dimensional input of "
668
+ "shape [batch, sequence, width].")
669
+ hidden_size = target_tensor_shape[2]
670
+ if hidden_size % self.num_attention_heads != 0:
671
+ raise ValueError(
672
+ "The hidden size (%d) is not a multiple of the number of attention "
673
+ "heads (%d)" % (hidden_size, self.num_attention_heads))
674
+ self.attention_head_size = int(hidden_size) // self.num_attention_heads
675
+ common_kwargs = dict(
676
+ bias_initializer=self._bias_initializer,
677
+ kernel_regularizer=self._kernel_regularizer,
678
+ bias_regularizer=self._bias_regularizer,
679
+ activity_regularizer=self._activity_regularizer,
680
+ kernel_constraint=self._kernel_constraint,
681
+ bias_constraint=self._bias_constraint)
682
+ # Self attention.
683
+ self.self_attention = layers.attention.CachedAttention(
684
+ num_heads=self.num_attention_heads,
685
+ key_dim=self.attention_head_size,
686
+ dropout=self.attention_dropout_rate,
687
+ use_bias=self._use_bias,
688
+ kernel_initializer=self._attention_initializer,
689
+ name="self_attention",
690
+ **common_kwargs)
691
+ self.self_attention_output_dense = tf_keras.layers.EinsumDense(
692
+ "abc,cd->abd",
693
+ output_shape=(None, hidden_size),
694
+ bias_axes="d",
695
+ kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
696
+ name="output",
697
+ **common_kwargs)
698
+ self.self_attention_dropout = tf_keras.layers.Dropout(
699
+ rate=self.dropout_rate)
700
+ self.self_attention_layer_norm = (
701
+ tf_keras.layers.LayerNormalization(
702
+ name="self_attention_layer_norm",
703
+ axis=-1,
704
+ epsilon=self._norm_epsilon,
705
+ dtype="float32"))
706
+ # Encoder-decoder attention.
707
+ self.encdec_attention = self._cross_attention_cls(
708
+ num_heads=self.num_attention_heads,
709
+ key_dim=self.attention_head_size,
710
+ dropout=self.attention_dropout_rate,
711
+ output_shape=hidden_size,
712
+ use_bias=self._use_bias,
713
+ kernel_initializer=self._attention_initializer,
714
+ name="attention/encdec",
715
+ **common_kwargs)
716
+
717
+ self.encdec_attention_dropout = tf_keras.layers.Dropout(
718
+ rate=self.dropout_rate)
719
+ self.encdec_attention_layer_norm = (
720
+ tf_keras.layers.LayerNormalization(
721
+ name="attention/encdec_output_layer_norm",
722
+ axis=-1,
723
+ epsilon=self._norm_epsilon,
724
+ dtype="float32"))
725
+
726
+ # Feed-forward projection.
727
+ self.intermediate_dense = tf_keras.layers.EinsumDense(
728
+ "abc,cd->abd",
729
+ output_shape=(None, self.intermediate_size),
730
+ bias_axes="d",
731
+ kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
732
+ name="intermediate",
733
+ **common_kwargs)
734
+ self.intermediate_activation_layer = tf_keras.layers.Activation(
735
+ self.intermediate_activation)
736
+ self._intermediate_dropout_layer = tf_keras.layers.Dropout(
737
+ rate=self._intermediate_dropout)
738
+ self.output_dense = tf_keras.layers.EinsumDense(
739
+ "abc,cd->abd",
740
+ output_shape=(None, hidden_size),
741
+ bias_axes="d",
742
+ kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
743
+ name="output",
744
+ **common_kwargs)
745
+ self.output_dropout = tf_keras.layers.Dropout(rate=self.dropout_rate)
746
+ self.output_layer_norm = tf_keras.layers.LayerNormalization(
747
+ name="output_layer_norm",
748
+ axis=-1,
749
+ epsilon=self._norm_epsilon,
750
+ dtype="float32")
751
+ super().build(input_shape)
752
+
753
+ def get_config(self):
754
+ config = {
755
+ "num_attention_heads": self.num_attention_heads,
756
+ "intermediate_size": self.intermediate_size,
757
+ "intermediate_activation": tf_utils.serialize_activation(
758
+ self.intermediate_activation, use_legacy_format=True
759
+ ),
760
+ "dropout_rate": self.dropout_rate,
761
+ "attention_dropout_rate": self.attention_dropout_rate,
762
+ "kernel_initializer": tf_utils.serialize_initializer(
763
+ self._kernel_initializer, use_legacy_format=True
764
+ ),
765
+ "bias_initializer": tf_utils.serialize_initializer(
766
+ self._bias_initializer, use_legacy_format=True
767
+ ),
768
+ "kernel_regularizer": tf_utils.serialize_regularizer(
769
+ self._kernel_regularizer, use_legacy_format=True
770
+ ),
771
+ "bias_regularizer": tf_utils.serialize_regularizer(
772
+ self._bias_regularizer, use_legacy_format=True
773
+ ),
774
+ "activity_regularizer": tf_utils.serialize_regularizer(
775
+ self._activity_regularizer, use_legacy_format=True
776
+ ),
777
+ "kernel_constraint": tf_utils.serialize_constraint(
778
+ self._kernel_constraint, use_legacy_format=True
779
+ ),
780
+ "bias_constraint": tf_utils.serialize_constraint(
781
+ self._bias_constraint, use_legacy_format=True
782
+ ),
783
+ "use_bias": self._use_bias,
784
+ "norm_first": self._norm_first,
785
+ "norm_epsilon": self._norm_epsilon,
786
+ "intermediate_dropout": self._intermediate_dropout,
787
+ "attention_initializer": tf_utils.serialize_initializer(
788
+ self._attention_initializer, use_legacy_format=True
789
+ ),
790
+ }
791
+ base_config = super().get_config()
792
+ return dict(list(base_config.items()) + list(config.items()))
793
+
794
+ def common_layers_with_encoder(self):
795
+ """Gets layer objects that can make a Transformer encoder block."""
796
+ return [
797
+ self.self_attention, self.self_attention_layer_norm,
798
+ self.intermediate_dense, self.output_dense, self.output_layer_norm
799
+ ]
800
+
801
+ def call(self, inputs, cache=None, decode_loop_step=None):
802
+ input_tensor, memory, attention_mask, self_attention_mask, input_pos_embed, memory_pos_embed = inputs
803
+ source_tensor = input_tensor
804
+ if self._norm_first:
805
+ input_tensor = self.self_attention_layer_norm(input_tensor)
806
+ self_attention_output, cache = self.self_attention(
807
+ query=input_tensor + input_pos_embed,
808
+ key=input_tensor + input_pos_embed,
809
+ value=input_tensor,
810
+ attention_mask=self_attention_mask,
811
+ cache=cache,
812
+ decode_loop_step=decode_loop_step)
813
+ self_attention_output = self.self_attention_dropout(self_attention_output)
814
+ if self._norm_first:
815
+ self_attention_output = source_tensor + self_attention_output
816
+ else:
817
+ self_attention_output = self.self_attention_layer_norm(
818
+ input_tensor + self_attention_output)
819
+ if self._norm_first:
820
+ source_self_attention_output = self_attention_output
821
+ self_attention_output = self.encdec_attention_layer_norm(
822
+ self_attention_output)
823
+ cross_attn_inputs = dict(
824
+ query=self_attention_output + input_pos_embed,
825
+ key=memory + memory_pos_embed,
826
+ value=memory,
827
+ attention_mask=attention_mask)
828
+ attention_output = self.encdec_attention(**cross_attn_inputs)
829
+ attention_output = self.encdec_attention_dropout(attention_output)
830
+ if self._norm_first:
831
+ attention_output = source_self_attention_output + attention_output
832
+ else:
833
+ attention_output = self.encdec_attention_layer_norm(
834
+ self_attention_output + attention_output)
835
+ if self._norm_first:
836
+ source_attention_output = attention_output
837
+ attention_output = self.output_layer_norm(attention_output)
838
+
839
+ intermediate_output = self.intermediate_dense(attention_output)
840
+ intermediate_output = self.intermediate_activation_layer(
841
+ intermediate_output)
842
+ intermediate_output = self._intermediate_dropout_layer(intermediate_output)
843
+ layer_output = self.output_dense(intermediate_output)
844
+ layer_output = self.output_dropout(layer_output)
845
+ if self._norm_first:
846
+ layer_output = source_attention_output + layer_output
847
+ else:
848
+ layer_output = self.output_layer_norm(layer_output + attention_output)
849
+ return layer_output, cache