tf-models-nightly 2.19.0.dev20250108__py2.py3-none-any.whl → 2.19.0.dev20250109__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tf-models-nightly might be problematic. Click here for more details.
- official/projects/detr/__init__.py +14 -0
- official/projects/detr/configs/__init__.py +14 -0
- official/projects/detr/configs/detr.py +277 -0
- official/projects/detr/configs/detr_test.py +51 -0
- official/projects/detr/dataloaders/__init__.py +14 -0
- official/projects/detr/dataloaders/coco.py +157 -0
- official/projects/detr/dataloaders/coco_test.py +111 -0
- official/projects/detr/dataloaders/detr_input.py +175 -0
- official/projects/detr/experiments/__init__.py +14 -0
- official/projects/detr/modeling/__init__.py +14 -0
- official/projects/detr/modeling/detr.py +345 -0
- official/projects/detr/modeling/detr_test.py +70 -0
- official/projects/detr/modeling/transformer.py +849 -0
- official/projects/detr/modeling/transformer_test.py +263 -0
- official/projects/detr/ops/__init__.py +14 -0
- official/projects/detr/ops/matchers.py +489 -0
- official/projects/detr/ops/matchers_test.py +95 -0
- official/projects/detr/optimization.py +151 -0
- official/projects/detr/serving/__init__.py +14 -0
- official/projects/detr/serving/export_module.py +103 -0
- official/projects/detr/serving/export_module_test.py +98 -0
- official/projects/detr/serving/export_saved_model.py +109 -0
- official/projects/detr/tasks/__init__.py +14 -0
- official/projects/detr/tasks/detection.py +421 -0
- official/projects/detr/tasks/detection_test.py +203 -0
- official/projects/detr/train.py +70 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/METADATA +1 -1
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/RECORD +32 -6
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/AUTHORS +0 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/LICENSE +0 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/WHEEL +0 -0
- {tf_models_nightly-2.19.0.dev20250108.dist-info → tf_models_nightly-2.19.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,849 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Specialized Transformers for DETR.
|
16
|
+
|
17
|
+
the position embeddings are added to the query and key for every self- and
|
18
|
+
cross-attention layer.
|
19
|
+
"""
|
20
|
+
|
21
|
+
import tensorflow as tf, tf_keras
|
22
|
+
|
23
|
+
from official.modeling import tf_utils
|
24
|
+
from official.nlp.modeling import layers
|
25
|
+
from official.nlp.modeling import models
|
26
|
+
|
27
|
+
|
28
|
+
class TransformerEncoder(tf_keras.layers.Layer):
|
29
|
+
"""Transformer encoder.
|
30
|
+
|
31
|
+
Transformer encoder is made up of N identical layers. Each layer is composed
|
32
|
+
of the sublayers:
|
33
|
+
1. Self-attention layer
|
34
|
+
2. Feedforward network (which is 2 fully-connected layers)
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __init__(self,
|
38
|
+
num_layers=6,
|
39
|
+
num_attention_heads=8,
|
40
|
+
intermediate_size=2048,
|
41
|
+
activation="relu",
|
42
|
+
dropout_rate=0.0,
|
43
|
+
attention_dropout_rate=0.0,
|
44
|
+
use_bias=False,
|
45
|
+
norm_first=True,
|
46
|
+
norm_epsilon=1e-6,
|
47
|
+
intermediate_dropout=0.0,
|
48
|
+
**kwargs):
|
49
|
+
"""Initialize a Transformer encoder.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
num_layers: Number of layers.
|
53
|
+
num_attention_heads: Number of attention heads.
|
54
|
+
intermediate_size: Size of the intermediate (Feedforward) layer.
|
55
|
+
activation: Activation for the intermediate layer.
|
56
|
+
dropout_rate: Dropout probability.
|
57
|
+
attention_dropout_rate: Dropout probability for attention layers.
|
58
|
+
use_bias: Whether to enable use_bias in attention layer. If set False,
|
59
|
+
use_bias in attention layer is disabled.
|
60
|
+
norm_first: Whether to normalize inputs to attention and intermediate
|
61
|
+
dense layers. If set False, output of attention and intermediate dense
|
62
|
+
layers is normalized.
|
63
|
+
norm_epsilon: Epsilon value to initialize normalization layers.
|
64
|
+
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
|
65
|
+
**kwargs: key word arguemnts passed to tf_keras.layers.Layer.
|
66
|
+
"""
|
67
|
+
|
68
|
+
super(TransformerEncoder, self).__init__(**kwargs)
|
69
|
+
self.num_layers = num_layers
|
70
|
+
self.num_attention_heads = num_attention_heads
|
71
|
+
self._intermediate_size = intermediate_size
|
72
|
+
self._activation = activation
|
73
|
+
self._dropout_rate = dropout_rate
|
74
|
+
self._attention_dropout_rate = attention_dropout_rate
|
75
|
+
self._use_bias = use_bias
|
76
|
+
self._norm_first = norm_first
|
77
|
+
self._norm_epsilon = norm_epsilon
|
78
|
+
self._intermediate_dropout = intermediate_dropout
|
79
|
+
|
80
|
+
def build(self, input_shape):
|
81
|
+
"""Implements build() for the layer."""
|
82
|
+
self.encoder_layers = []
|
83
|
+
for i in range(self.num_layers):
|
84
|
+
self.encoder_layers.append(
|
85
|
+
TransformerEncoderBlock(
|
86
|
+
num_attention_heads=self.num_attention_heads,
|
87
|
+
inner_dim=self._intermediate_size,
|
88
|
+
inner_activation=self._activation,
|
89
|
+
output_dropout=self._dropout_rate,
|
90
|
+
attention_dropout=self._attention_dropout_rate,
|
91
|
+
use_bias=self._use_bias,
|
92
|
+
norm_first=self._norm_first,
|
93
|
+
norm_epsilon=self._norm_epsilon,
|
94
|
+
inner_dropout=self._intermediate_dropout,
|
95
|
+
attention_initializer=tf_utils.clone_initializer(
|
96
|
+
models.seq2seq_transformer.attention_initializer(
|
97
|
+
input_shape[2])),
|
98
|
+
name=("layer_%d" % i)))
|
99
|
+
self.output_normalization = tf_keras.layers.LayerNormalization(
|
100
|
+
epsilon=self._norm_epsilon, dtype="float32")
|
101
|
+
super(TransformerEncoder, self).build(input_shape)
|
102
|
+
|
103
|
+
def get_config(self):
|
104
|
+
config = {
|
105
|
+
"num_layers": self.num_layers,
|
106
|
+
"num_attention_heads": self.num_attention_heads,
|
107
|
+
"intermediate_size": self._intermediate_size,
|
108
|
+
"activation": self._activation,
|
109
|
+
"dropout_rate": self._dropout_rate,
|
110
|
+
"attention_dropout_rate": self._attention_dropout_rate,
|
111
|
+
"use_bias": self._use_bias,
|
112
|
+
"norm_first": self._norm_first,
|
113
|
+
"norm_epsilon": self._norm_epsilon,
|
114
|
+
"intermediate_dropout": self._intermediate_dropout
|
115
|
+
}
|
116
|
+
base_config = super(TransformerEncoder, self).get_config()
|
117
|
+
return dict(list(base_config.items()) + list(config.items()))
|
118
|
+
|
119
|
+
def call(self, encoder_inputs, attention_mask=None, pos_embed=None):
|
120
|
+
"""Return the output of the encoder.
|
121
|
+
|
122
|
+
Args:
|
123
|
+
encoder_inputs: A tensor with shape `(batch_size, input_length,
|
124
|
+
hidden_size)`.
|
125
|
+
attention_mask: A mask for the encoder self-attention layer with shape
|
126
|
+
`(batch_size, input_length, input_length)`.
|
127
|
+
pos_embed: Position embedding to add to every encoder layer.
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
Output of encoder which is a `float32` tensor with shape
|
131
|
+
`(batch_size, input_length, hidden_size)`.
|
132
|
+
"""
|
133
|
+
for layer_idx in range(self.num_layers):
|
134
|
+
encoder_inputs = self.encoder_layers[layer_idx](
|
135
|
+
[encoder_inputs, attention_mask, pos_embed])
|
136
|
+
|
137
|
+
output_tensor = encoder_inputs
|
138
|
+
output_tensor = self.output_normalization(output_tensor)
|
139
|
+
|
140
|
+
return output_tensor
|
141
|
+
|
142
|
+
|
143
|
+
class TransformerEncoderBlock(tf_keras.layers.Layer):
|
144
|
+
"""TransformerEncoderBlock layer.
|
145
|
+
|
146
|
+
This layer implements the Transformer Encoder from
|
147
|
+
"Attention Is All You Need". (https://arxiv.org/abs/1706.03762),
|
148
|
+
which combines a `tf_keras.layers.MultiHeadAttention` layer with a
|
149
|
+
two-layer feedforward network. The only difference: position embedding is
|
150
|
+
added to the query and key of self-attention.
|
151
|
+
|
152
|
+
References:
|
153
|
+
[Attention Is All You Need](https://arxiv.org/abs/1706.03762)
|
154
|
+
[BERT: Pre-training of Deep Bidirectional Transformers for Language
|
155
|
+
Understanding](https://arxiv.org/abs/1810.04805)
|
156
|
+
"""
|
157
|
+
|
158
|
+
def __init__(self,
|
159
|
+
num_attention_heads,
|
160
|
+
inner_dim,
|
161
|
+
inner_activation,
|
162
|
+
output_range=None,
|
163
|
+
kernel_initializer="glorot_uniform",
|
164
|
+
bias_initializer="zeros",
|
165
|
+
kernel_regularizer=None,
|
166
|
+
bias_regularizer=None,
|
167
|
+
activity_regularizer=None,
|
168
|
+
kernel_constraint=None,
|
169
|
+
bias_constraint=None,
|
170
|
+
use_bias=True,
|
171
|
+
norm_first=False,
|
172
|
+
norm_epsilon=1e-12,
|
173
|
+
output_dropout=0.0,
|
174
|
+
attention_dropout=0.0,
|
175
|
+
inner_dropout=0.0,
|
176
|
+
attention_initializer=None,
|
177
|
+
attention_axes=None,
|
178
|
+
**kwargs):
|
179
|
+
"""Initializes `TransformerEncoderBlock`.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
num_attention_heads: Number of attention heads.
|
183
|
+
inner_dim: The output dimension of the first Dense layer in a two-layer
|
184
|
+
feedforward network.
|
185
|
+
inner_activation: The activation for the first Dense layer in a two-layer
|
186
|
+
feedforward network.
|
187
|
+
output_range: the sequence output range, [0, output_range) for slicing the
|
188
|
+
target sequence. `None` means the target sequence is not sliced.
|
189
|
+
kernel_initializer: Initializer for dense layer kernels.
|
190
|
+
bias_initializer: Initializer for dense layer biases.
|
191
|
+
kernel_regularizer: Regularizer for dense layer kernels.
|
192
|
+
bias_regularizer: Regularizer for dense layer biases.
|
193
|
+
activity_regularizer: Regularizer for dense layer activity.
|
194
|
+
kernel_constraint: Constraint for dense layer kernels.
|
195
|
+
bias_constraint: Constraint for dense layer kernels.
|
196
|
+
use_bias: Whether to enable use_bias in attention layer. If set False,
|
197
|
+
use_bias in attention layer is disabled.
|
198
|
+
norm_first: Whether to normalize inputs to attention and intermediate
|
199
|
+
dense layers. If set False, output of attention and intermediate dense
|
200
|
+
layers is normalized.
|
201
|
+
norm_epsilon: Epsilon value to initialize normalization layers.
|
202
|
+
output_dropout: Dropout probability for the post-attention and output
|
203
|
+
dropout.
|
204
|
+
attention_dropout: Dropout probability for within the attention layer.
|
205
|
+
inner_dropout: Dropout probability for the first Dense layer in a
|
206
|
+
two-layer feedforward network.
|
207
|
+
attention_initializer: Initializer for kernels of attention layers. If set
|
208
|
+
`None`, attention layers use kernel_initializer as initializer for
|
209
|
+
kernel.
|
210
|
+
attention_axes: axes over which the attention is applied. `None` means
|
211
|
+
attention over all axes, but batch, heads, and features.
|
212
|
+
**kwargs: keyword arguments/
|
213
|
+
"""
|
214
|
+
super().__init__(**kwargs)
|
215
|
+
|
216
|
+
self._num_heads = num_attention_heads
|
217
|
+
self._inner_dim = inner_dim
|
218
|
+
self._inner_activation = inner_activation
|
219
|
+
self._attention_dropout = attention_dropout
|
220
|
+
self._attention_dropout_rate = attention_dropout
|
221
|
+
self._output_dropout = output_dropout
|
222
|
+
self._output_dropout_rate = output_dropout
|
223
|
+
self._output_range = output_range
|
224
|
+
self._kernel_initializer = tf_keras.initializers.get(kernel_initializer)
|
225
|
+
self._bias_initializer = tf_keras.initializers.get(bias_initializer)
|
226
|
+
self._kernel_regularizer = tf_keras.regularizers.get(kernel_regularizer)
|
227
|
+
self._bias_regularizer = tf_keras.regularizers.get(bias_regularizer)
|
228
|
+
self._activity_regularizer = tf_keras.regularizers.get(activity_regularizer)
|
229
|
+
self._kernel_constraint = tf_keras.constraints.get(kernel_constraint)
|
230
|
+
self._bias_constraint = tf_keras.constraints.get(bias_constraint)
|
231
|
+
self._use_bias = use_bias
|
232
|
+
self._norm_first = norm_first
|
233
|
+
self._norm_epsilon = norm_epsilon
|
234
|
+
self._inner_dropout = inner_dropout
|
235
|
+
if attention_initializer:
|
236
|
+
self._attention_initializer = tf_keras.initializers.get(
|
237
|
+
attention_initializer)
|
238
|
+
else:
|
239
|
+
self._attention_initializer = tf_utils.clone_initializer(
|
240
|
+
self._kernel_initializer)
|
241
|
+
self._attention_axes = attention_axes
|
242
|
+
|
243
|
+
def build(self, input_shape):
|
244
|
+
if isinstance(input_shape, tf.TensorShape):
|
245
|
+
input_tensor_shape = input_shape
|
246
|
+
elif isinstance(input_shape, (list, tuple)):
|
247
|
+
input_tensor_shape = tf.TensorShape(input_shape[0])
|
248
|
+
else:
|
249
|
+
raise ValueError(
|
250
|
+
"The type of input shape argument is not supported, got: %s" %
|
251
|
+
type(input_shape))
|
252
|
+
einsum_equation = "abc,cd->abd"
|
253
|
+
if len(input_tensor_shape.as_list()) > 3:
|
254
|
+
einsum_equation = "...bc,cd->...bd"
|
255
|
+
hidden_size = input_tensor_shape[-1]
|
256
|
+
if hidden_size % self._num_heads != 0:
|
257
|
+
raise ValueError(
|
258
|
+
"The input size (%d) is not a multiple of the number of attention "
|
259
|
+
"heads (%d)" % (hidden_size, self._num_heads))
|
260
|
+
self._attention_head_size = int(hidden_size // self._num_heads)
|
261
|
+
common_kwargs = dict(
|
262
|
+
bias_initializer=self._bias_initializer,
|
263
|
+
kernel_regularizer=self._kernel_regularizer,
|
264
|
+
bias_regularizer=self._bias_regularizer,
|
265
|
+
activity_regularizer=self._activity_regularizer,
|
266
|
+
kernel_constraint=self._kernel_constraint,
|
267
|
+
bias_constraint=self._bias_constraint)
|
268
|
+
self._attention_layer = tf_keras.layers.MultiHeadAttention(
|
269
|
+
num_heads=self._num_heads,
|
270
|
+
key_dim=self._attention_head_size,
|
271
|
+
dropout=self._attention_dropout,
|
272
|
+
use_bias=self._use_bias,
|
273
|
+
kernel_initializer=self._attention_initializer,
|
274
|
+
attention_axes=self._attention_axes,
|
275
|
+
name="self_attention",
|
276
|
+
**common_kwargs)
|
277
|
+
self._attention_dropout = tf_keras.layers.Dropout(rate=self._output_dropout)
|
278
|
+
# Use float32 in layernorm for numeric stability.
|
279
|
+
# It is probably safe in mixed_float16, but we haven't validated this yet.
|
280
|
+
self._attention_layer_norm = (
|
281
|
+
tf_keras.layers.LayerNormalization(
|
282
|
+
name="self_attention_layer_norm",
|
283
|
+
axis=-1,
|
284
|
+
epsilon=self._norm_epsilon,
|
285
|
+
dtype=tf.float32))
|
286
|
+
self._intermediate_dense = tf_keras.layers.EinsumDense(
|
287
|
+
einsum_equation,
|
288
|
+
output_shape=(None, self._inner_dim),
|
289
|
+
bias_axes="d",
|
290
|
+
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
|
291
|
+
name="intermediate",
|
292
|
+
**common_kwargs)
|
293
|
+
policy = tf_keras.mixed_precision.global_policy()
|
294
|
+
if policy.name == "mixed_bfloat16":
|
295
|
+
# bfloat16 causes BERT with the LAMB optimizer to not converge
|
296
|
+
# as well, so we use float32.
|
297
|
+
# TODO(b/154538392): Investigate this.
|
298
|
+
policy = tf.float32
|
299
|
+
self._intermediate_activation_layer = tf_keras.layers.Activation(
|
300
|
+
self._inner_activation, dtype=policy)
|
301
|
+
self._inner_dropout_layer = tf_keras.layers.Dropout(
|
302
|
+
rate=self._inner_dropout)
|
303
|
+
self._output_dense = tf_keras.layers.EinsumDense(
|
304
|
+
einsum_equation,
|
305
|
+
output_shape=(None, hidden_size),
|
306
|
+
bias_axes="d",
|
307
|
+
name="output",
|
308
|
+
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
|
309
|
+
**common_kwargs)
|
310
|
+
self._output_dropout = tf_keras.layers.Dropout(rate=self._output_dropout)
|
311
|
+
# Use float32 in layernorm for numeric stability.
|
312
|
+
self._output_layer_norm = tf_keras.layers.LayerNormalization(
|
313
|
+
name="output_layer_norm",
|
314
|
+
axis=-1,
|
315
|
+
epsilon=self._norm_epsilon,
|
316
|
+
dtype=tf.float32)
|
317
|
+
|
318
|
+
super(TransformerEncoderBlock, self).build(input_shape)
|
319
|
+
|
320
|
+
def get_config(self):
|
321
|
+
config = {
|
322
|
+
"num_attention_heads": self._num_heads,
|
323
|
+
"inner_dim": self._inner_dim,
|
324
|
+
"inner_activation": self._inner_activation,
|
325
|
+
"output_dropout": self._output_dropout_rate,
|
326
|
+
"attention_dropout": self._attention_dropout_rate,
|
327
|
+
"output_range": self._output_range,
|
328
|
+
"kernel_initializer": tf_utils.serialize_initializer(
|
329
|
+
self._kernel_initializer, use_legacy_format=True
|
330
|
+
),
|
331
|
+
"bias_initializer": tf_utils.serialize_initializer(
|
332
|
+
self._bias_initializer, use_legacy_format=True
|
333
|
+
),
|
334
|
+
"kernel_regularizer": tf_utils.serialize_regularizer(
|
335
|
+
self._kernel_regularizer, use_legacy_format=True
|
336
|
+
),
|
337
|
+
"bias_regularizer": tf_utils.serialize_regularizer(
|
338
|
+
self._bias_regularizer, use_legacy_format=True
|
339
|
+
),
|
340
|
+
"activity_regularizer": tf_utils.serialize_regularizer(
|
341
|
+
self._activity_regularizer, use_legacy_format=True
|
342
|
+
),
|
343
|
+
"kernel_constraint": tf_utils.serialize_constraint(
|
344
|
+
self._kernel_constraint, use_legacy_format=True
|
345
|
+
),
|
346
|
+
"bias_constraint": tf_utils.serialize_constraint(
|
347
|
+
self._bias_constraint, use_legacy_format=True
|
348
|
+
),
|
349
|
+
"use_bias": self._use_bias,
|
350
|
+
"norm_first": self._norm_first,
|
351
|
+
"norm_epsilon": self._norm_epsilon,
|
352
|
+
"inner_dropout": self._inner_dropout,
|
353
|
+
"attention_initializer": tf_utils.serialize_initializer(
|
354
|
+
self._attention_initializer, use_legacy_format=True
|
355
|
+
),
|
356
|
+
"attention_axes": self._attention_axes,
|
357
|
+
}
|
358
|
+
base_config = super(TransformerEncoderBlock, self).get_config()
|
359
|
+
return dict(list(base_config.items()) + list(config.items()))
|
360
|
+
|
361
|
+
def call(self, inputs):
|
362
|
+
"""Transformer self-attention encoder block call.
|
363
|
+
|
364
|
+
Args:
|
365
|
+
inputs: a single tensor or a list of tensors. `input tensor` as the single
|
366
|
+
sequence of embeddings. [`input tensor`, `attention mask`] to have the
|
367
|
+
additional attention mask. [`input tensor`, `attention mask`, `query
|
368
|
+
embed`] to have an additional position embedding to add.
|
369
|
+
|
370
|
+
Returns:
|
371
|
+
An output tensor with the same dimensions as input/query tensor.
|
372
|
+
"""
|
373
|
+
input_tensor, attention_mask, pos_embed = inputs
|
374
|
+
|
375
|
+
key_value = None
|
376
|
+
|
377
|
+
if self._output_range:
|
378
|
+
if self._norm_first:
|
379
|
+
source_tensor = input_tensor[:, 0:self._output_range, :]
|
380
|
+
input_tensor = self._attention_layer_norm(input_tensor)
|
381
|
+
if key_value is not None:
|
382
|
+
key_value = self._attention_layer_norm(key_value)
|
383
|
+
target_tensor = input_tensor[:, 0:self._output_range, :]
|
384
|
+
if attention_mask is not None:
|
385
|
+
attention_mask = attention_mask[:, 0:self._output_range, :]
|
386
|
+
else:
|
387
|
+
if self._norm_first:
|
388
|
+
source_tensor = input_tensor
|
389
|
+
input_tensor = self._attention_layer_norm(input_tensor)
|
390
|
+
if key_value is not None:
|
391
|
+
key_value = self._attention_layer_norm(key_value)
|
392
|
+
target_tensor = input_tensor
|
393
|
+
|
394
|
+
if key_value is None:
|
395
|
+
key_value = input_tensor
|
396
|
+
attention_output = self._attention_layer(
|
397
|
+
query=target_tensor + pos_embed,
|
398
|
+
key=key_value + pos_embed,
|
399
|
+
value=key_value,
|
400
|
+
attention_mask=attention_mask)
|
401
|
+
attention_output = self._attention_dropout(attention_output)
|
402
|
+
if self._norm_first:
|
403
|
+
attention_output = source_tensor + attention_output
|
404
|
+
else:
|
405
|
+
attention_output = self._attention_layer_norm(target_tensor +
|
406
|
+
attention_output)
|
407
|
+
if self._norm_first:
|
408
|
+
source_attention_output = attention_output
|
409
|
+
attention_output = self._output_layer_norm(attention_output)
|
410
|
+
inner_output = self._intermediate_dense(attention_output)
|
411
|
+
inner_output = self._intermediate_activation_layer(inner_output)
|
412
|
+
inner_output = self._inner_dropout_layer(inner_output)
|
413
|
+
layer_output = self._output_dense(inner_output)
|
414
|
+
layer_output = self._output_dropout(layer_output)
|
415
|
+
|
416
|
+
if self._norm_first:
|
417
|
+
return source_attention_output + layer_output
|
418
|
+
|
419
|
+
# During mixed precision training, layer norm output is always fp32 for now.
|
420
|
+
# Casts fp32 for the subsequent add.
|
421
|
+
layer_output = tf.cast(layer_output, tf.float32)
|
422
|
+
return self._output_layer_norm(layer_output + attention_output)
|
423
|
+
|
424
|
+
|
425
|
+
class TransformerDecoder(tf_keras.layers.Layer):
|
426
|
+
"""Transformer decoder.
|
427
|
+
|
428
|
+
Like the encoder, the decoder is made up of N identical layers.
|
429
|
+
Each layer is composed of the sublayers:
|
430
|
+
1. Self-attention layer
|
431
|
+
2. Multi-headed attention layer combining encoder outputs with results from
|
432
|
+
the previous self-attention layer.
|
433
|
+
3. Feedforward network (2 fully-connected layers)
|
434
|
+
"""
|
435
|
+
|
436
|
+
def __init__(self,
|
437
|
+
num_layers=6,
|
438
|
+
num_attention_heads=8,
|
439
|
+
intermediate_size=2048,
|
440
|
+
activation="relu",
|
441
|
+
dropout_rate=0.0,
|
442
|
+
attention_dropout_rate=0.0,
|
443
|
+
use_bias=False,
|
444
|
+
norm_first=True,
|
445
|
+
norm_epsilon=1e-6,
|
446
|
+
intermediate_dropout=0.0,
|
447
|
+
**kwargs):
|
448
|
+
"""Initialize a Transformer decoder.
|
449
|
+
|
450
|
+
Args:
|
451
|
+
num_layers: Number of layers.
|
452
|
+
num_attention_heads: Number of attention heads.
|
453
|
+
intermediate_size: Size of the intermediate (Feedforward) layer.
|
454
|
+
activation: Activation for the intermediate layer.
|
455
|
+
dropout_rate: Dropout probability.
|
456
|
+
attention_dropout_rate: Dropout probability for attention layers.
|
457
|
+
use_bias: Whether to enable use_bias in attention layer. If set `False`,
|
458
|
+
use_bias in attention layer is disabled.
|
459
|
+
norm_first: Whether to normalize inputs to attention and intermediate
|
460
|
+
dense layers. If set `False`, output of attention and intermediate dense
|
461
|
+
layers is normalized.
|
462
|
+
norm_epsilon: Epsilon value to initialize normalization layers.
|
463
|
+
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
|
464
|
+
**kwargs: key word arguemnts passed to tf_keras.layers.Layer.
|
465
|
+
"""
|
466
|
+
super(TransformerDecoder, self).__init__(**kwargs)
|
467
|
+
self.num_layers = num_layers
|
468
|
+
self.num_attention_heads = num_attention_heads
|
469
|
+
self._intermediate_size = intermediate_size
|
470
|
+
self._activation = activation
|
471
|
+
self._dropout_rate = dropout_rate
|
472
|
+
self._attention_dropout_rate = attention_dropout_rate
|
473
|
+
self._use_bias = use_bias
|
474
|
+
self._norm_first = norm_first
|
475
|
+
self._norm_epsilon = norm_epsilon
|
476
|
+
self._intermediate_dropout = intermediate_dropout
|
477
|
+
|
478
|
+
def build(self, input_shape):
|
479
|
+
"""Implements build() for the layer."""
|
480
|
+
self.decoder_layers = []
|
481
|
+
for i in range(self.num_layers):
|
482
|
+
self.decoder_layers.append(
|
483
|
+
TransformerDecoderBlock(
|
484
|
+
num_attention_heads=self.num_attention_heads,
|
485
|
+
intermediate_size=self._intermediate_size,
|
486
|
+
intermediate_activation=self._activation,
|
487
|
+
dropout_rate=self._dropout_rate,
|
488
|
+
attention_dropout_rate=self._attention_dropout_rate,
|
489
|
+
use_bias=self._use_bias,
|
490
|
+
norm_first=self._norm_first,
|
491
|
+
norm_epsilon=self._norm_epsilon,
|
492
|
+
intermediate_dropout=self._intermediate_dropout,
|
493
|
+
attention_initializer=tf_utils.clone_initializer(
|
494
|
+
models.seq2seq_transformer.attention_initializer(
|
495
|
+
input_shape[2])),
|
496
|
+
name=("layer_%d" % i)))
|
497
|
+
self.output_normalization = tf_keras.layers.LayerNormalization(
|
498
|
+
epsilon=self._norm_epsilon, dtype="float32")
|
499
|
+
super(TransformerDecoder, self).build(input_shape)
|
500
|
+
|
501
|
+
def get_config(self):
|
502
|
+
config = {
|
503
|
+
"num_layers": self.num_layers,
|
504
|
+
"num_attention_heads": self.num_attention_heads,
|
505
|
+
"intermediate_size": self._intermediate_size,
|
506
|
+
"activation": self._activation,
|
507
|
+
"dropout_rate": self._dropout_rate,
|
508
|
+
"attention_dropout_rate": self._attention_dropout_rate,
|
509
|
+
"use_bias": self._use_bias,
|
510
|
+
"norm_first": self._norm_first,
|
511
|
+
"norm_epsilon": self._norm_epsilon,
|
512
|
+
"intermediate_dropout": self._intermediate_dropout
|
513
|
+
}
|
514
|
+
base_config = super(TransformerDecoder, self).get_config()
|
515
|
+
return dict(list(base_config.items()) + list(config.items()))
|
516
|
+
|
517
|
+
def call(self,
|
518
|
+
target,
|
519
|
+
memory,
|
520
|
+
self_attention_mask=None,
|
521
|
+
cross_attention_mask=None,
|
522
|
+
cache=None,
|
523
|
+
decode_loop_step=None,
|
524
|
+
return_all_decoder_outputs=False,
|
525
|
+
input_pos_embed=None,
|
526
|
+
memory_pos_embed=None):
|
527
|
+
"""Return the output of the decoder layer stacks.
|
528
|
+
|
529
|
+
Args:
|
530
|
+
target: A tensor with shape `(batch_size, target_length, hidden_size)`.
|
531
|
+
memory: A tensor with shape `(batch_size, input_length, hidden_size)`.
|
532
|
+
self_attention_mask: A tensor with shape `(batch_size, target_len,
|
533
|
+
target_length)`, the mask for decoder self-attention layer.
|
534
|
+
cross_attention_mask: A tensor with shape `(batch_size, target_length,
|
535
|
+
input_length)` which is the mask for encoder-decoder attention layer.
|
536
|
+
cache: (Used for fast decoding) A nested dictionary storing previous
|
537
|
+
decoder self-attention values. The items are:
|
538
|
+
{layer_n: {"k": A tensor with shape `(batch_size, i, key_channels)`,
|
539
|
+
"v": A tensor with shape `(batch_size, i, value_channels)`},
|
540
|
+
...}
|
541
|
+
decode_loop_step: An integer, the step number of the decoding loop. Used
|
542
|
+
only for autoregressive inference on TPU.
|
543
|
+
return_all_decoder_outputs: Return all decoder layer outputs. Note that
|
544
|
+
the outputs are layer normed. This is useful when introducing per layer
|
545
|
+
auxiliary loss.
|
546
|
+
input_pos_embed: A tensor that is added to the query and key of the
|
547
|
+
self-attention layer.
|
548
|
+
memory_pos_embed: A tensor that is added to the query and key of the
|
549
|
+
cross-attention layer.
|
550
|
+
|
551
|
+
Returns:
|
552
|
+
Output of decoder.
|
553
|
+
float32 tensor with shape `(batch_size, target_length, hidden_size`).
|
554
|
+
"""
|
555
|
+
|
556
|
+
output_tensor = target
|
557
|
+
decoder_outputs = []
|
558
|
+
for layer_idx in range(self.num_layers):
|
559
|
+
transformer_inputs = [
|
560
|
+
output_tensor, memory, cross_attention_mask, self_attention_mask,
|
561
|
+
input_pos_embed, memory_pos_embed
|
562
|
+
]
|
563
|
+
# Gets the cache for decoding.
|
564
|
+
if cache is None:
|
565
|
+
output_tensor, _ = self.decoder_layers[layer_idx](transformer_inputs)
|
566
|
+
else:
|
567
|
+
cache_layer_idx = str(layer_idx)
|
568
|
+
output_tensor, cache[cache_layer_idx] = self.decoder_layers[layer_idx](
|
569
|
+
transformer_inputs,
|
570
|
+
cache=cache[cache_layer_idx],
|
571
|
+
decode_loop_step=decode_loop_step)
|
572
|
+
if return_all_decoder_outputs:
|
573
|
+
decoder_outputs.append(self.output_normalization(output_tensor))
|
574
|
+
|
575
|
+
if return_all_decoder_outputs:
|
576
|
+
return decoder_outputs
|
577
|
+
else:
|
578
|
+
return self.output_normalization(output_tensor)
|
579
|
+
|
580
|
+
|
581
|
+
class TransformerDecoderBlock(tf_keras.layers.Layer):
|
582
|
+
"""Single transformer layer for decoder.
|
583
|
+
|
584
|
+
It has three sub-layers:
|
585
|
+
(1) a multi-head self-attention mechanism.
|
586
|
+
(2) a encoder-decoder attention.
|
587
|
+
(3) a positionwise fully connected feed-forward network.
|
588
|
+
"""
|
589
|
+
|
590
|
+
def __init__(self,
|
591
|
+
num_attention_heads,
|
592
|
+
intermediate_size,
|
593
|
+
intermediate_activation,
|
594
|
+
dropout_rate=0.0,
|
595
|
+
attention_dropout_rate=0.0,
|
596
|
+
kernel_initializer="glorot_uniform",
|
597
|
+
bias_initializer="zeros",
|
598
|
+
kernel_regularizer=None,
|
599
|
+
bias_regularizer=None,
|
600
|
+
activity_regularizer=None,
|
601
|
+
kernel_constraint=None,
|
602
|
+
bias_constraint=None,
|
603
|
+
use_bias=True,
|
604
|
+
norm_first=False,
|
605
|
+
norm_epsilon=1e-12,
|
606
|
+
intermediate_dropout=0.0,
|
607
|
+
attention_initializer=None,
|
608
|
+
**kwargs):
|
609
|
+
"""Initialize a Transformer decoder block.
|
610
|
+
|
611
|
+
Args:
|
612
|
+
num_attention_heads: Number of attention heads.
|
613
|
+
intermediate_size: Size of the intermediate layer.
|
614
|
+
intermediate_activation: Activation for the intermediate layer.
|
615
|
+
dropout_rate: Dropout probability for the post-attention and output
|
616
|
+
dropout.
|
617
|
+
attention_dropout_rate: Dropout probability for within the attention
|
618
|
+
layer.
|
619
|
+
kernel_initializer: Initializer for dense layer kernels.
|
620
|
+
bias_initializer: Initializer for dense layer biases.
|
621
|
+
kernel_regularizer: Regularizer for dense layer kernels.
|
622
|
+
bias_regularizer: Regularizer for dense layer biases.
|
623
|
+
activity_regularizer: Regularizer for dense layer activity.
|
624
|
+
kernel_constraint: Constraint for dense layer kernels.
|
625
|
+
bias_constraint: Constraint for dense layer kernels.
|
626
|
+
use_bias: Whether to enable use_bias in attention layer. If set False,
|
627
|
+
use_bias in attention layer is disabled.
|
628
|
+
norm_first: Whether to normalize inputs to attention and intermediate
|
629
|
+
dense layers. If set False, output of attention and intermediate dense
|
630
|
+
layers is normalized.
|
631
|
+
norm_epsilon: Epsilon value to initialize normalization layers.
|
632
|
+
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
|
633
|
+
attention_initializer: Initializer for kernels of attention layers. If set
|
634
|
+
`None`, attention layers use kernel_initializer as initializer for
|
635
|
+
kernel.
|
636
|
+
**kwargs: key word arguemnts passed to tf_keras.layers.Layer.
|
637
|
+
"""
|
638
|
+
super().__init__(**kwargs)
|
639
|
+
self.num_attention_heads = num_attention_heads
|
640
|
+
self.intermediate_size = intermediate_size
|
641
|
+
self.intermediate_activation = tf_keras.activations.get(
|
642
|
+
intermediate_activation)
|
643
|
+
self.dropout_rate = dropout_rate
|
644
|
+
self.attention_dropout_rate = attention_dropout_rate
|
645
|
+
self._kernel_initializer = tf_keras.initializers.get(kernel_initializer)
|
646
|
+
self._bias_initializer = tf_keras.initializers.get(bias_initializer)
|
647
|
+
self._kernel_regularizer = tf_keras.regularizers.get(kernel_regularizer)
|
648
|
+
self._bias_regularizer = tf_keras.regularizers.get(bias_regularizer)
|
649
|
+
self._activity_regularizer = tf_keras.regularizers.get(activity_regularizer)
|
650
|
+
self._kernel_constraint = tf_keras.constraints.get(kernel_constraint)
|
651
|
+
self._bias_constraint = tf_keras.constraints.get(bias_constraint)
|
652
|
+
self._use_bias = use_bias
|
653
|
+
self._norm_first = norm_first
|
654
|
+
self._norm_epsilon = norm_epsilon
|
655
|
+
self._intermediate_dropout = intermediate_dropout
|
656
|
+
if attention_initializer:
|
657
|
+
self._attention_initializer = tf_keras.initializers.get(
|
658
|
+
attention_initializer)
|
659
|
+
else:
|
660
|
+
self._attention_initializer = tf_utils.clone_initializer(
|
661
|
+
self._kernel_initializer)
|
662
|
+
self._cross_attention_cls = layers.attention.MultiHeadAttention
|
663
|
+
|
664
|
+
def build(self, input_shape):
|
665
|
+
target_tensor_shape = tf.TensorShape(input_shape[0])
|
666
|
+
if len(target_tensor_shape.as_list()) != 3:
|
667
|
+
raise ValueError("TransformerLayer expects a three-dimensional input of "
|
668
|
+
"shape [batch, sequence, width].")
|
669
|
+
hidden_size = target_tensor_shape[2]
|
670
|
+
if hidden_size % self.num_attention_heads != 0:
|
671
|
+
raise ValueError(
|
672
|
+
"The hidden size (%d) is not a multiple of the number of attention "
|
673
|
+
"heads (%d)" % (hidden_size, self.num_attention_heads))
|
674
|
+
self.attention_head_size = int(hidden_size) // self.num_attention_heads
|
675
|
+
common_kwargs = dict(
|
676
|
+
bias_initializer=self._bias_initializer,
|
677
|
+
kernel_regularizer=self._kernel_regularizer,
|
678
|
+
bias_regularizer=self._bias_regularizer,
|
679
|
+
activity_regularizer=self._activity_regularizer,
|
680
|
+
kernel_constraint=self._kernel_constraint,
|
681
|
+
bias_constraint=self._bias_constraint)
|
682
|
+
# Self attention.
|
683
|
+
self.self_attention = layers.attention.CachedAttention(
|
684
|
+
num_heads=self.num_attention_heads,
|
685
|
+
key_dim=self.attention_head_size,
|
686
|
+
dropout=self.attention_dropout_rate,
|
687
|
+
use_bias=self._use_bias,
|
688
|
+
kernel_initializer=self._attention_initializer,
|
689
|
+
name="self_attention",
|
690
|
+
**common_kwargs)
|
691
|
+
self.self_attention_output_dense = tf_keras.layers.EinsumDense(
|
692
|
+
"abc,cd->abd",
|
693
|
+
output_shape=(None, hidden_size),
|
694
|
+
bias_axes="d",
|
695
|
+
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
|
696
|
+
name="output",
|
697
|
+
**common_kwargs)
|
698
|
+
self.self_attention_dropout = tf_keras.layers.Dropout(
|
699
|
+
rate=self.dropout_rate)
|
700
|
+
self.self_attention_layer_norm = (
|
701
|
+
tf_keras.layers.LayerNormalization(
|
702
|
+
name="self_attention_layer_norm",
|
703
|
+
axis=-1,
|
704
|
+
epsilon=self._norm_epsilon,
|
705
|
+
dtype="float32"))
|
706
|
+
# Encoder-decoder attention.
|
707
|
+
self.encdec_attention = self._cross_attention_cls(
|
708
|
+
num_heads=self.num_attention_heads,
|
709
|
+
key_dim=self.attention_head_size,
|
710
|
+
dropout=self.attention_dropout_rate,
|
711
|
+
output_shape=hidden_size,
|
712
|
+
use_bias=self._use_bias,
|
713
|
+
kernel_initializer=self._attention_initializer,
|
714
|
+
name="attention/encdec",
|
715
|
+
**common_kwargs)
|
716
|
+
|
717
|
+
self.encdec_attention_dropout = tf_keras.layers.Dropout(
|
718
|
+
rate=self.dropout_rate)
|
719
|
+
self.encdec_attention_layer_norm = (
|
720
|
+
tf_keras.layers.LayerNormalization(
|
721
|
+
name="attention/encdec_output_layer_norm",
|
722
|
+
axis=-1,
|
723
|
+
epsilon=self._norm_epsilon,
|
724
|
+
dtype="float32"))
|
725
|
+
|
726
|
+
# Feed-forward projection.
|
727
|
+
self.intermediate_dense = tf_keras.layers.EinsumDense(
|
728
|
+
"abc,cd->abd",
|
729
|
+
output_shape=(None, self.intermediate_size),
|
730
|
+
bias_axes="d",
|
731
|
+
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
|
732
|
+
name="intermediate",
|
733
|
+
**common_kwargs)
|
734
|
+
self.intermediate_activation_layer = tf_keras.layers.Activation(
|
735
|
+
self.intermediate_activation)
|
736
|
+
self._intermediate_dropout_layer = tf_keras.layers.Dropout(
|
737
|
+
rate=self._intermediate_dropout)
|
738
|
+
self.output_dense = tf_keras.layers.EinsumDense(
|
739
|
+
"abc,cd->abd",
|
740
|
+
output_shape=(None, hidden_size),
|
741
|
+
bias_axes="d",
|
742
|
+
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
|
743
|
+
name="output",
|
744
|
+
**common_kwargs)
|
745
|
+
self.output_dropout = tf_keras.layers.Dropout(rate=self.dropout_rate)
|
746
|
+
self.output_layer_norm = tf_keras.layers.LayerNormalization(
|
747
|
+
name="output_layer_norm",
|
748
|
+
axis=-1,
|
749
|
+
epsilon=self._norm_epsilon,
|
750
|
+
dtype="float32")
|
751
|
+
super().build(input_shape)
|
752
|
+
|
753
|
+
def get_config(self):
|
754
|
+
config = {
|
755
|
+
"num_attention_heads": self.num_attention_heads,
|
756
|
+
"intermediate_size": self.intermediate_size,
|
757
|
+
"intermediate_activation": tf_utils.serialize_activation(
|
758
|
+
self.intermediate_activation, use_legacy_format=True
|
759
|
+
),
|
760
|
+
"dropout_rate": self.dropout_rate,
|
761
|
+
"attention_dropout_rate": self.attention_dropout_rate,
|
762
|
+
"kernel_initializer": tf_utils.serialize_initializer(
|
763
|
+
self._kernel_initializer, use_legacy_format=True
|
764
|
+
),
|
765
|
+
"bias_initializer": tf_utils.serialize_initializer(
|
766
|
+
self._bias_initializer, use_legacy_format=True
|
767
|
+
),
|
768
|
+
"kernel_regularizer": tf_utils.serialize_regularizer(
|
769
|
+
self._kernel_regularizer, use_legacy_format=True
|
770
|
+
),
|
771
|
+
"bias_regularizer": tf_utils.serialize_regularizer(
|
772
|
+
self._bias_regularizer, use_legacy_format=True
|
773
|
+
),
|
774
|
+
"activity_regularizer": tf_utils.serialize_regularizer(
|
775
|
+
self._activity_regularizer, use_legacy_format=True
|
776
|
+
),
|
777
|
+
"kernel_constraint": tf_utils.serialize_constraint(
|
778
|
+
self._kernel_constraint, use_legacy_format=True
|
779
|
+
),
|
780
|
+
"bias_constraint": tf_utils.serialize_constraint(
|
781
|
+
self._bias_constraint, use_legacy_format=True
|
782
|
+
),
|
783
|
+
"use_bias": self._use_bias,
|
784
|
+
"norm_first": self._norm_first,
|
785
|
+
"norm_epsilon": self._norm_epsilon,
|
786
|
+
"intermediate_dropout": self._intermediate_dropout,
|
787
|
+
"attention_initializer": tf_utils.serialize_initializer(
|
788
|
+
self._attention_initializer, use_legacy_format=True
|
789
|
+
),
|
790
|
+
}
|
791
|
+
base_config = super().get_config()
|
792
|
+
return dict(list(base_config.items()) + list(config.items()))
|
793
|
+
|
794
|
+
def common_layers_with_encoder(self):
|
795
|
+
"""Gets layer objects that can make a Transformer encoder block."""
|
796
|
+
return [
|
797
|
+
self.self_attention, self.self_attention_layer_norm,
|
798
|
+
self.intermediate_dense, self.output_dense, self.output_layer_norm
|
799
|
+
]
|
800
|
+
|
801
|
+
def call(self, inputs, cache=None, decode_loop_step=None):
|
802
|
+
input_tensor, memory, attention_mask, self_attention_mask, input_pos_embed, memory_pos_embed = inputs
|
803
|
+
source_tensor = input_tensor
|
804
|
+
if self._norm_first:
|
805
|
+
input_tensor = self.self_attention_layer_norm(input_tensor)
|
806
|
+
self_attention_output, cache = self.self_attention(
|
807
|
+
query=input_tensor + input_pos_embed,
|
808
|
+
key=input_tensor + input_pos_embed,
|
809
|
+
value=input_tensor,
|
810
|
+
attention_mask=self_attention_mask,
|
811
|
+
cache=cache,
|
812
|
+
decode_loop_step=decode_loop_step)
|
813
|
+
self_attention_output = self.self_attention_dropout(self_attention_output)
|
814
|
+
if self._norm_first:
|
815
|
+
self_attention_output = source_tensor + self_attention_output
|
816
|
+
else:
|
817
|
+
self_attention_output = self.self_attention_layer_norm(
|
818
|
+
input_tensor + self_attention_output)
|
819
|
+
if self._norm_first:
|
820
|
+
source_self_attention_output = self_attention_output
|
821
|
+
self_attention_output = self.encdec_attention_layer_norm(
|
822
|
+
self_attention_output)
|
823
|
+
cross_attn_inputs = dict(
|
824
|
+
query=self_attention_output + input_pos_embed,
|
825
|
+
key=memory + memory_pos_embed,
|
826
|
+
value=memory,
|
827
|
+
attention_mask=attention_mask)
|
828
|
+
attention_output = self.encdec_attention(**cross_attn_inputs)
|
829
|
+
attention_output = self.encdec_attention_dropout(attention_output)
|
830
|
+
if self._norm_first:
|
831
|
+
attention_output = source_self_attention_output + attention_output
|
832
|
+
else:
|
833
|
+
attention_output = self.encdec_attention_layer_norm(
|
834
|
+
self_attention_output + attention_output)
|
835
|
+
if self._norm_first:
|
836
|
+
source_attention_output = attention_output
|
837
|
+
attention_output = self.output_layer_norm(attention_output)
|
838
|
+
|
839
|
+
intermediate_output = self.intermediate_dense(attention_output)
|
840
|
+
intermediate_output = self.intermediate_activation_layer(
|
841
|
+
intermediate_output)
|
842
|
+
intermediate_output = self._intermediate_dropout_layer(intermediate_output)
|
843
|
+
layer_output = self.output_dense(intermediate_output)
|
844
|
+
layer_output = self.output_dropout(layer_output)
|
845
|
+
if self._norm_first:
|
846
|
+
layer_output = source_attention_output + layer_output
|
847
|
+
else:
|
848
|
+
layer_output = self.output_layer_norm(layer_output + attention_output)
|
849
|
+
return layer_output, cache
|