easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,749 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
from __future__ import absolute_import
|
|
4
|
+
from __future__ import division
|
|
5
|
+
from __future__ import print_function
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
|
|
9
|
+
import tensorflow as tf
|
|
10
|
+
|
|
11
|
+
from easy_rec.python.compat.layers import layer_norm as tf_layer_norm
|
|
12
|
+
from easy_rec.python.utils.activation import gelu
|
|
13
|
+
from easy_rec.python.utils.shape_utils import get_shape_list
|
|
14
|
+
|
|
15
|
+
if tf.__version__ >= '2.0':
|
|
16
|
+
tf = tf.compat.v1
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def create_initializer(initializer_range=0.02):
|
|
20
|
+
"""Creates a `truncated_normal_initializer` with the given range."""
|
|
21
|
+
return tf.truncated_normal_initializer(stddev=initializer_range)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def dropout(input_tensor, dropout_prob):
|
|
25
|
+
"""Perform dropout.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
input_tensor: float Tensor.
|
|
29
|
+
dropout_prob: Python float. The probability of dropping out a value (NOT of
|
|
30
|
+
*keeping* a dimension as in `tf.nn.dropout`).
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
A version of `input_tensor` with dropout applied.
|
|
34
|
+
"""
|
|
35
|
+
if dropout_prob is None or dropout_prob == 0.0:
|
|
36
|
+
return input_tensor
|
|
37
|
+
|
|
38
|
+
output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
|
|
39
|
+
return output
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def attention_layer(from_tensor,
|
|
43
|
+
to_tensor,
|
|
44
|
+
size_per_head,
|
|
45
|
+
num_attention_heads=1,
|
|
46
|
+
attention_mask=None,
|
|
47
|
+
query_act=None,
|
|
48
|
+
key_act=None,
|
|
49
|
+
value_act=None,
|
|
50
|
+
attention_probs_dropout_prob=0.0,
|
|
51
|
+
initializer_range=0.02,
|
|
52
|
+
do_return_2d_tensor=False,
|
|
53
|
+
batch_size=None,
|
|
54
|
+
from_seq_length=None,
|
|
55
|
+
to_seq_length=None,
|
|
56
|
+
reuse=None):
|
|
57
|
+
"""Performs multi-headed attention from `from_tensor` to `to_tensor`.
|
|
58
|
+
|
|
59
|
+
This is an implementation of multi-headed attention based on "Attention is all you Need".
|
|
60
|
+
If `from_tensor` and `to_tensor` are the same, then this is self-attention.
|
|
61
|
+
Each timestep in `from_tensor` attends to the corresponding sequence in `to_tensor`,
|
|
62
|
+
and returns a fixed-width vector.
|
|
63
|
+
This function first projects `from_tensor` into a "query" tensor and `to_tensor` into "key" and "value" tensors.
|
|
64
|
+
These are (effectively) a list of tensors of length `num_attention_heads`, where each tensor is of shape:
|
|
65
|
+
[batch_size, seq_length, size_per_head].
|
|
66
|
+
Then, the query and key tensors are dot-producted and scaled. These are
|
|
67
|
+
softmaxed to obtain attention probabilities. The value tensors are then
|
|
68
|
+
interpolated by these probabilities, then concatenated back to a single
|
|
69
|
+
tensor and returned.
|
|
70
|
+
In practice, the multi-headed attention are done with transposes and
|
|
71
|
+
reshapes rather than actual separate tensors.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
from_tensor: float Tensor of shape [batch_size, from_seq_length,
|
|
75
|
+
from_width].
|
|
76
|
+
to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
|
|
77
|
+
size_per_head: int. Size of each attention head.
|
|
78
|
+
num_attention_heads: int. Number of attention heads.
|
|
79
|
+
attention_mask: (optional) int32 Tensor of shape [batch_size,
|
|
80
|
+
from_seq_length, to_seq_length]. The values should be 1 or 0. The
|
|
81
|
+
attention scores will effectively be set to -infinity for any positions in
|
|
82
|
+
the mask that are 0, and will be unchanged for positions that are 1.
|
|
83
|
+
query_act: (optional) Activation function for the query transform.
|
|
84
|
+
key_act: (optional) Activation function for the key transform.
|
|
85
|
+
value_act: (optional) Activation function for the value transform.
|
|
86
|
+
attention_probs_dropout_prob: (optional) float. Dropout probability of the
|
|
87
|
+
attention probabilities.
|
|
88
|
+
initializer_range: float. Range of the weight initializer.
|
|
89
|
+
do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
|
|
90
|
+
* from_seq_length, num_attention_heads * size_per_head]. If False, the
|
|
91
|
+
output will be of shape [batch_size, from_seq_length, num_attention_heads
|
|
92
|
+
* size_per_head].
|
|
93
|
+
batch_size: (Optional) int. If the input is 2D, this might be the batch size
|
|
94
|
+
of the 3D version of the `from_tensor` and `to_tensor`.
|
|
95
|
+
from_seq_length: (Optional) If the input is 2D, this might be the seq length
|
|
96
|
+
of the 3D version of the `from_tensor`.
|
|
97
|
+
to_seq_length: (Optional) If the input is 2D, this might be the seq length
|
|
98
|
+
of the 3D version of the `to_tensor`.
|
|
99
|
+
reuse: whether to reuse this layer
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
float Tensor of shape [batch_size, from_seq_length,
|
|
103
|
+
num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
|
|
104
|
+
true, this will be of shape [batch_size * from_seq_length,
|
|
105
|
+
num_attention_heads * size_per_head]).
|
|
106
|
+
|
|
107
|
+
Raises:
|
|
108
|
+
ValueError: Any of the arguments or tensor shapes are invalid.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
|
|
112
|
+
seq_length, width):
|
|
113
|
+
output_tensor = tf.reshape(
|
|
114
|
+
input_tensor, [batch_size, seq_length, num_attention_heads, width])
|
|
115
|
+
|
|
116
|
+
output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
|
|
117
|
+
return output_tensor
|
|
118
|
+
|
|
119
|
+
from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
|
|
120
|
+
to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
|
|
121
|
+
|
|
122
|
+
if len(from_shape) != len(to_shape):
|
|
123
|
+
raise ValueError(
|
|
124
|
+
'The rank of `from_tensor` must match the rank of `to_tensor`.')
|
|
125
|
+
|
|
126
|
+
if len(from_shape) == 3:
|
|
127
|
+
batch_size = from_shape[0]
|
|
128
|
+
from_seq_length = from_shape[1]
|
|
129
|
+
to_seq_length = to_shape[1]
|
|
130
|
+
elif len(from_shape) == 2:
|
|
131
|
+
if (batch_size is None or from_seq_length is None or to_seq_length is None):
|
|
132
|
+
raise ValueError(
|
|
133
|
+
'When passing in rank 2 tensors to attention_layer, the values '
|
|
134
|
+
'for `batch_size`, `from_seq_length`, and `to_seq_length` '
|
|
135
|
+
'must all be specified.')
|
|
136
|
+
|
|
137
|
+
# Scalar dimensions referenced here:
|
|
138
|
+
# B = batch size (number of sequences)
|
|
139
|
+
# F = `from_tensor` sequence length
|
|
140
|
+
# T = `to_tensor` sequence length
|
|
141
|
+
# N = `num_attention_heads`
|
|
142
|
+
# H = `size_per_head`
|
|
143
|
+
|
|
144
|
+
from_tensor_2d = reshape_to_matrix(from_tensor)
|
|
145
|
+
to_tensor_2d = reshape_to_matrix(to_tensor)
|
|
146
|
+
|
|
147
|
+
# `query_layer` = [B*F, N*H]
|
|
148
|
+
query_layer = tf.layers.dense(
|
|
149
|
+
from_tensor_2d,
|
|
150
|
+
num_attention_heads * size_per_head,
|
|
151
|
+
activation=query_act,
|
|
152
|
+
name='query',
|
|
153
|
+
kernel_initializer=create_initializer(initializer_range),
|
|
154
|
+
reuse=reuse)
|
|
155
|
+
|
|
156
|
+
# `key_layer` = [B*T, N*H]
|
|
157
|
+
key_layer = tf.layers.dense(
|
|
158
|
+
to_tensor_2d,
|
|
159
|
+
num_attention_heads * size_per_head,
|
|
160
|
+
activation=key_act,
|
|
161
|
+
name='key',
|
|
162
|
+
kernel_initializer=create_initializer(initializer_range),
|
|
163
|
+
reuse=reuse)
|
|
164
|
+
|
|
165
|
+
# `value_layer` = [B*T, N*H]
|
|
166
|
+
value_layer = tf.layers.dense(
|
|
167
|
+
to_tensor_2d,
|
|
168
|
+
num_attention_heads * size_per_head,
|
|
169
|
+
activation=value_act,
|
|
170
|
+
name='value',
|
|
171
|
+
kernel_initializer=create_initializer(initializer_range),
|
|
172
|
+
reuse=reuse)
|
|
173
|
+
|
|
174
|
+
# `query_layer` = [B, N, F, H]
|
|
175
|
+
query_layer = transpose_for_scores(query_layer, batch_size,
|
|
176
|
+
num_attention_heads, from_seq_length,
|
|
177
|
+
size_per_head)
|
|
178
|
+
|
|
179
|
+
# `key_layer` = [B, N, T, H]
|
|
180
|
+
key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
|
|
181
|
+
to_seq_length, size_per_head)
|
|
182
|
+
|
|
183
|
+
# Take the dot product between "query" and "key" to get the raw
|
|
184
|
+
# attention scores.
|
|
185
|
+
# `attention_scores` = [B, N, F, T]
|
|
186
|
+
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
|
187
|
+
attention_scores = tf.multiply(attention_scores,
|
|
188
|
+
1.0 / math.sqrt(float(size_per_head)))
|
|
189
|
+
|
|
190
|
+
if attention_mask is not None:
|
|
191
|
+
# `attention_mask` = [B, 1, F, T]
|
|
192
|
+
attention_mask = tf.expand_dims(attention_mask, axis=[1])
|
|
193
|
+
|
|
194
|
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
195
|
+
# masked positions, this operation will create a tensor which is 0.0 for
|
|
196
|
+
# positions we want to attend and -10000.0 for masked positions.
|
|
197
|
+
adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
|
|
198
|
+
|
|
199
|
+
# Since we are adding it to the raw scores before the softmax, this is
|
|
200
|
+
# effectively the same as removing these entirely.
|
|
201
|
+
attention_scores += adder
|
|
202
|
+
|
|
203
|
+
# Normalize the attention scores to probabilities.
|
|
204
|
+
# `attention_probs` = [B, N, F, T]
|
|
205
|
+
attention_probs = tf.nn.softmax(attention_scores)
|
|
206
|
+
|
|
207
|
+
# This is actually dropping out entire tokens to attend to, which might
|
|
208
|
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
209
|
+
attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
|
|
210
|
+
|
|
211
|
+
# `value_layer` = [B, T, N, H]
|
|
212
|
+
value_layer = tf.reshape(
|
|
213
|
+
value_layer,
|
|
214
|
+
[batch_size, to_seq_length, num_attention_heads, size_per_head])
|
|
215
|
+
|
|
216
|
+
# `value_layer` = [B, N, T, H]
|
|
217
|
+
value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
|
|
218
|
+
|
|
219
|
+
# `context_layer` = [B, N, F, H]
|
|
220
|
+
context_layer = tf.matmul(attention_probs, value_layer)
|
|
221
|
+
|
|
222
|
+
# `context_layer` = [B, F, N, H]
|
|
223
|
+
context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
|
|
224
|
+
|
|
225
|
+
if do_return_2d_tensor:
|
|
226
|
+
# `context_layer` = [B*F, N*H]
|
|
227
|
+
context_layer = tf.reshape(
|
|
228
|
+
context_layer,
|
|
229
|
+
[batch_size * from_seq_length, num_attention_heads * size_per_head])
|
|
230
|
+
else:
|
|
231
|
+
# `context_layer` = [B, F, N*H]
|
|
232
|
+
context_layer = tf.reshape(
|
|
233
|
+
context_layer,
|
|
234
|
+
[batch_size, from_seq_length, num_attention_heads * size_per_head])
|
|
235
|
+
|
|
236
|
+
return context_layer
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def transformer_encoder(input_tensor,
|
|
240
|
+
attention_mask=None,
|
|
241
|
+
hidden_size=768,
|
|
242
|
+
num_hidden_layers=12,
|
|
243
|
+
num_attention_heads=12,
|
|
244
|
+
intermediate_size=3072,
|
|
245
|
+
intermediate_act_fn=gelu,
|
|
246
|
+
hidden_dropout_prob=0.1,
|
|
247
|
+
attention_probs_dropout_prob=0.1,
|
|
248
|
+
initializer_range=0.02,
|
|
249
|
+
reuse=None,
|
|
250
|
+
name='transformer'):
|
|
251
|
+
"""Multi-headed, multi-layer Transformer from "Attention is All You Need".
|
|
252
|
+
|
|
253
|
+
This is almost an exact implementation of the original Transformer encoder.
|
|
254
|
+
See the original paper:
|
|
255
|
+
https://arxiv.org/abs/1706.03762
|
|
256
|
+
Args:
|
|
257
|
+
input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
|
|
258
|
+
attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
|
|
259
|
+
seq_length], with 1 for positions that can be attended to and 0 in
|
|
260
|
+
positions that should not be.
|
|
261
|
+
hidden_size: int. Hidden size of the Transformer.
|
|
262
|
+
num_hidden_layers: int. Number of layers (blocks) in the Transformer.
|
|
263
|
+
num_attention_heads: int. Number of attention heads in the Transformer.
|
|
264
|
+
intermediate_size: int. The size of the "intermediate" (a.k.a., feed
|
|
265
|
+
forward) layer.
|
|
266
|
+
intermediate_act_fn: function. The non-linear activation function to apply
|
|
267
|
+
to the output of the intermediate/feed-forward layer.
|
|
268
|
+
hidden_dropout_prob: float. Dropout probability for the hidden layers.
|
|
269
|
+
attention_probs_dropout_prob: float. Dropout probability of the attention
|
|
270
|
+
probabilities.
|
|
271
|
+
initializer_range: float. Range of the initializer (stddev of truncated
|
|
272
|
+
normal).
|
|
273
|
+
reuse: whether to reuse this encoder
|
|
274
|
+
name: scope name prefix
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
float Tensor of shape [batch_size, seq_length, hidden_size], the final
|
|
278
|
+
hidden layer of the Transformer.
|
|
279
|
+
|
|
280
|
+
Raises:
|
|
281
|
+
ValueError: A Tensor shape or parameter is invalid.
|
|
282
|
+
"""
|
|
283
|
+
if hidden_size % num_attention_heads != 0:
|
|
284
|
+
raise ValueError(
|
|
285
|
+
'The hidden size (%d) is not a multiple of the number of attention '
|
|
286
|
+
'heads (%d)' % (hidden_size, num_attention_heads))
|
|
287
|
+
|
|
288
|
+
attention_head_size = int(hidden_size / num_attention_heads)
|
|
289
|
+
input_shape = get_shape_list(input_tensor, expected_rank=3)
|
|
290
|
+
batch_size = input_shape[0]
|
|
291
|
+
seq_length = input_shape[1]
|
|
292
|
+
input_width = input_shape[2]
|
|
293
|
+
|
|
294
|
+
# The Transformer performs sum residuals on all layers so the input needs
|
|
295
|
+
# to be the same as the hidden size.
|
|
296
|
+
if input_width != hidden_size:
|
|
297
|
+
raise ValueError('The width of the input tensor (%d) != hidden size (%d)' %
|
|
298
|
+
(input_width, hidden_size))
|
|
299
|
+
|
|
300
|
+
# We keep the representation as a 2D tensor to avoid re-shaping it back and
|
|
301
|
+
# forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
|
|
302
|
+
# the GPU/CPU but may not be free on the TPU, so we want to minimize them to
|
|
303
|
+
# help the optimizer.
|
|
304
|
+
prev_output = reshape_to_matrix(input_tensor)
|
|
305
|
+
|
|
306
|
+
for layer_idx in range(num_hidden_layers):
|
|
307
|
+
with tf.variable_scope('%s_layer_%d' % (name, layer_idx)):
|
|
308
|
+
layer_input = prev_output
|
|
309
|
+
|
|
310
|
+
with tf.variable_scope('attention'):
|
|
311
|
+
with tf.variable_scope('self'):
|
|
312
|
+
# [batch_size * from_seq_length, num_attention_heads * size_per_head]
|
|
313
|
+
attention_output = attention_layer(
|
|
314
|
+
from_tensor=layer_input,
|
|
315
|
+
to_tensor=layer_input,
|
|
316
|
+
size_per_head=attention_head_size,
|
|
317
|
+
num_attention_heads=num_attention_heads,
|
|
318
|
+
attention_mask=attention_mask,
|
|
319
|
+
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
|
320
|
+
initializer_range=initializer_range,
|
|
321
|
+
do_return_2d_tensor=True,
|
|
322
|
+
batch_size=batch_size,
|
|
323
|
+
from_seq_length=seq_length,
|
|
324
|
+
to_seq_length=seq_length,
|
|
325
|
+
reuse=reuse)
|
|
326
|
+
|
|
327
|
+
# Run a linear projection of `hidden_size` then add a residual
|
|
328
|
+
# with `layer_input`.
|
|
329
|
+
with tf.variable_scope('output', reuse=reuse):
|
|
330
|
+
attention_output = tf.layers.dense(
|
|
331
|
+
attention_output,
|
|
332
|
+
hidden_size,
|
|
333
|
+
kernel_initializer=create_initializer(initializer_range))
|
|
334
|
+
attention_output = dropout(attention_output, hidden_dropout_prob)
|
|
335
|
+
attention_output = layer_norm(attention_output + layer_input)
|
|
336
|
+
|
|
337
|
+
# The activation is only applied to the "intermediate" hidden layer.
|
|
338
|
+
with tf.variable_scope('intermediate', reuse=reuse):
|
|
339
|
+
intermediate_output = tf.layers.dense(
|
|
340
|
+
attention_output,
|
|
341
|
+
intermediate_size,
|
|
342
|
+
activation=intermediate_act_fn,
|
|
343
|
+
kernel_initializer=create_initializer(initializer_range))
|
|
344
|
+
|
|
345
|
+
# Down-project back to `hidden_size` then add the residual.
|
|
346
|
+
with tf.variable_scope('output', reuse=reuse):
|
|
347
|
+
layer_output = tf.layers.dense(
|
|
348
|
+
intermediate_output,
|
|
349
|
+
hidden_size,
|
|
350
|
+
kernel_initializer=create_initializer(initializer_range))
|
|
351
|
+
layer_output = dropout(layer_output, hidden_dropout_prob)
|
|
352
|
+
layer_output = layer_norm(layer_output + attention_output)
|
|
353
|
+
prev_output = layer_output
|
|
354
|
+
|
|
355
|
+
final_output = reshape_from_matrix(prev_output, input_shape)
|
|
356
|
+
return final_output
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def cross_attention_block(from_tensor,
|
|
360
|
+
to_tensor,
|
|
361
|
+
layer_idx,
|
|
362
|
+
size_per_head,
|
|
363
|
+
cross_attention_mask=None,
|
|
364
|
+
self_attention_mask=None,
|
|
365
|
+
num_attention_heads=1,
|
|
366
|
+
intermediate_size=512,
|
|
367
|
+
hidden_dropout_prob=0.1,
|
|
368
|
+
attention_probs_dropout_prob=0.1,
|
|
369
|
+
initializer_range=0.02,
|
|
370
|
+
name=''):
|
|
371
|
+
"""Multi-headed cross attention block.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
from_tensor: float Tensor of shape [batch_size, from_seq_length,
|
|
375
|
+
from_width].
|
|
376
|
+
to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
|
|
377
|
+
layer_idx: int. layer id in the Transformer.
|
|
378
|
+
size_per_head: int. Size of each attention head.
|
|
379
|
+
cross_attention_mask: (optional) int32 Tensor of shape [batch_size, from_seq_length,
|
|
380
|
+
to_seq_length], with 1 for positions that can be attended to and 0 in
|
|
381
|
+
positions that should not be.
|
|
382
|
+
self_attention_mask: (optional) int32 Tensor of shape [batch_size, from_seq_length,
|
|
383
|
+
from_seq_length], with 1 for positions that can be attended to and 0 in
|
|
384
|
+
positions that should not be.
|
|
385
|
+
num_attention_heads: int. Number of attention heads in the Transformer.
|
|
386
|
+
intermediate_size: int. The size of the "intermediate" (a.k.a., feed
|
|
387
|
+
forward) layer.
|
|
388
|
+
hidden_dropout_prob: float. Dropout probability for the hidden layers.
|
|
389
|
+
attention_probs_dropout_prob: float. Dropout probability of the attention
|
|
390
|
+
probabilities.
|
|
391
|
+
initializer_range: float. Range of the initializer (stddev of truncated
|
|
392
|
+
normal).
|
|
393
|
+
name: scope name prefix
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
float Tensor of shape [batch_size, seq_length, hidden_size], the final
|
|
397
|
+
hidden layer of the Transformer.
|
|
398
|
+
|
|
399
|
+
Raises:
|
|
400
|
+
ValueError: A Tensor shape or parameter is invalid.
|
|
401
|
+
"""
|
|
402
|
+
input_shape = get_shape_list(from_tensor, expected_rank=3)
|
|
403
|
+
batch_size = input_shape[0]
|
|
404
|
+
from_seq_length = input_shape[1]
|
|
405
|
+
|
|
406
|
+
input_shape = get_shape_list(to_tensor, expected_rank=3)
|
|
407
|
+
to_seq_length = input_shape[1]
|
|
408
|
+
|
|
409
|
+
with tf.variable_scope('%scross_layer_%d' % (name, layer_idx)):
|
|
410
|
+
with tf.variable_scope('attention'):
|
|
411
|
+
with tf.variable_scope('cross'):
|
|
412
|
+
# [batch_size * from_seq_length, num_attention_heads * size_per_head]
|
|
413
|
+
cross_attention_output = attention_layer(
|
|
414
|
+
from_tensor=from_tensor,
|
|
415
|
+
to_tensor=to_tensor,
|
|
416
|
+
size_per_head=size_per_head,
|
|
417
|
+
num_attention_heads=num_attention_heads,
|
|
418
|
+
attention_mask=cross_attention_mask,
|
|
419
|
+
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
|
420
|
+
initializer_range=initializer_range,
|
|
421
|
+
do_return_2d_tensor=True,
|
|
422
|
+
batch_size=batch_size,
|
|
423
|
+
from_seq_length=from_seq_length,
|
|
424
|
+
to_seq_length=to_seq_length)
|
|
425
|
+
|
|
426
|
+
with tf.variable_scope('self'):
|
|
427
|
+
# [batch_size * from_seq_length, num_attention_heads * size_per_head]
|
|
428
|
+
self_attention_output = attention_layer(
|
|
429
|
+
from_tensor=cross_attention_output,
|
|
430
|
+
to_tensor=cross_attention_output,
|
|
431
|
+
size_per_head=size_per_head,
|
|
432
|
+
num_attention_heads=num_attention_heads,
|
|
433
|
+
attention_mask=self_attention_mask,
|
|
434
|
+
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
|
435
|
+
initializer_range=initializer_range,
|
|
436
|
+
do_return_2d_tensor=True,
|
|
437
|
+
batch_size=batch_size,
|
|
438
|
+
from_seq_length=from_seq_length,
|
|
439
|
+
to_seq_length=from_seq_length)
|
|
440
|
+
|
|
441
|
+
with tf.variable_scope('output'):
|
|
442
|
+
attention_output = dropout(self_attention_output, hidden_dropout_prob)
|
|
443
|
+
attention_output = layer_norm(attention_output + cross_attention_output)
|
|
444
|
+
|
|
445
|
+
# The activation is only applied to the "intermediate" hidden layer.
|
|
446
|
+
with tf.variable_scope('intermediate'):
|
|
447
|
+
intermediate_output = tf.layers.dense(
|
|
448
|
+
attention_output,
|
|
449
|
+
intermediate_size,
|
|
450
|
+
activation=tf.nn.relu,
|
|
451
|
+
kernel_initializer=create_initializer(initializer_range))
|
|
452
|
+
|
|
453
|
+
# Down-project back to `hidden_size` then add the residual.
|
|
454
|
+
with tf.variable_scope('output'):
|
|
455
|
+
layer_output = tf.layers.dense(
|
|
456
|
+
intermediate_output,
|
|
457
|
+
num_attention_heads * size_per_head,
|
|
458
|
+
kernel_initializer=create_initializer(initializer_range))
|
|
459
|
+
layer_output = dropout(layer_output, hidden_dropout_prob)
|
|
460
|
+
# [batch_size * from_seq_length, num_attention_heads * size_per_head]
|
|
461
|
+
layer_output = layer_norm(layer_output + attention_output)
|
|
462
|
+
|
|
463
|
+
final_output = reshape_from_matrix(
|
|
464
|
+
layer_output,
|
|
465
|
+
[batch_size, from_seq_length, num_attention_heads * size_per_head])
|
|
466
|
+
return final_output # [batch_size, from_seq_length, num_attention_heads * size_per_head]
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def cross_attention_tower(left_tensor,
|
|
470
|
+
right_tensor,
|
|
471
|
+
num_hidden_layers=1,
|
|
472
|
+
num_attention_heads=12,
|
|
473
|
+
left_size_per_head=64,
|
|
474
|
+
right_size_per_head=64,
|
|
475
|
+
left_intermediate_size=0,
|
|
476
|
+
right_intermediate_size=0,
|
|
477
|
+
left_input_mask=None,
|
|
478
|
+
right_input_mask=None,
|
|
479
|
+
hidden_dropout_prob=0.1,
|
|
480
|
+
attention_probs_dropout_prob=0.1,
|
|
481
|
+
initializer_range=0.02,
|
|
482
|
+
name=''):
|
|
483
|
+
"""Multi-headed, multi layer cross attention block.
|
|
484
|
+
|
|
485
|
+
Args:
|
|
486
|
+
left_tensor: float Tensor of shape [batch_size, left_seq_length,
|
|
487
|
+
from_width].
|
|
488
|
+
right_tensor: float Tensor of shape [batch_size, right_seq_length, to_width].
|
|
489
|
+
num_hidden_layers: int. Number of layers (blocks) in the Transformer.
|
|
490
|
+
num_attention_heads: int. Number of attention heads in the Transformer.
|
|
491
|
+
left_size_per_head: int. Size of each attention head of left tower.
|
|
492
|
+
right_size_per_head: int. Size of each attention head of right tower.
|
|
493
|
+
left intermediate_size: int. The size of the "intermediate" (a.k.a., feed
|
|
494
|
+
forward) layer of left tower. Less or equal to 0 means `num_attention_heads
|
|
495
|
+
* left_size_per_head`
|
|
496
|
+
right intermediate_size: int. The size of the "intermediate" (a.k.a., feed
|
|
497
|
+
forward) layer of right tower. Less or equal to 0 means `num_attention_heads
|
|
498
|
+
* right_size_per_head`
|
|
499
|
+
left_input_mask: the mask for `left_tensor`
|
|
500
|
+
right_input_mask: the mask for `right_tensor`
|
|
501
|
+
hidden_dropout_prob: float. Dropout probability for the hidden layers.
|
|
502
|
+
attention_probs_dropout_prob: float. Dropout probability of the attention
|
|
503
|
+
probabilities.
|
|
504
|
+
initializer_range: float. Range of the initializer (stddev of truncated
|
|
505
|
+
normal).
|
|
506
|
+
name: scope name prefix
|
|
507
|
+
|
|
508
|
+
Returns:
|
|
509
|
+
tuple of float Tensors of shape ([batch_size, left_seq_length, hidden_size],
|
|
510
|
+
[batch_size, right_seq_length, hidden_size]),
|
|
511
|
+
where hidden_size = num_attention_heads * size_per_head
|
|
512
|
+
|
|
513
|
+
Raises:
|
|
514
|
+
ValueError: A Tensor shape or parameter is invalid.
|
|
515
|
+
"""
|
|
516
|
+
if left_intermediate_size <= 0:
|
|
517
|
+
left_intermediate_size = num_attention_heads * left_size_per_head
|
|
518
|
+
if right_intermediate_size <= 0:
|
|
519
|
+
right_intermediate_size = num_attention_heads * right_size_per_head
|
|
520
|
+
|
|
521
|
+
left_attention_mask = None
|
|
522
|
+
if left_input_mask is not None:
|
|
523
|
+
left_attention_mask = create_attention_mask_from_input_mask(
|
|
524
|
+
left_tensor, left_attention_mask)
|
|
525
|
+
|
|
526
|
+
left_2_right_attention_mask = None
|
|
527
|
+
if right_input_mask is not None:
|
|
528
|
+
left_2_right_attention_mask = create_attention_mask_from_input_mask(
|
|
529
|
+
left_tensor, right_input_mask)
|
|
530
|
+
|
|
531
|
+
right_attention_mask = None
|
|
532
|
+
if right_input_mask is not None:
|
|
533
|
+
right_attention_mask = create_attention_mask_from_input_mask(
|
|
534
|
+
right_tensor, right_input_mask)
|
|
535
|
+
|
|
536
|
+
right_2_left_attention_mask = None
|
|
537
|
+
if left_input_mask is not None:
|
|
538
|
+
right_2_left_attention_mask = create_attention_mask_from_input_mask(
|
|
539
|
+
right_tensor, left_input_mask)
|
|
540
|
+
|
|
541
|
+
prev_left_output = left_tensor
|
|
542
|
+
prev_right_output = right_tensor
|
|
543
|
+
for layer_idx in range(num_hidden_layers):
|
|
544
|
+
left_output = cross_attention_block(
|
|
545
|
+
prev_left_output,
|
|
546
|
+
prev_right_output,
|
|
547
|
+
layer_idx,
|
|
548
|
+
num_attention_heads=num_attention_heads,
|
|
549
|
+
size_per_head=left_size_per_head,
|
|
550
|
+
intermediate_size=left_intermediate_size,
|
|
551
|
+
hidden_dropout_prob=hidden_dropout_prob,
|
|
552
|
+
cross_attention_mask=left_2_right_attention_mask,
|
|
553
|
+
self_attention_mask=left_attention_mask,
|
|
554
|
+
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
|
555
|
+
initializer_range=initializer_range,
|
|
556
|
+
name='%sleft_to_right_' % name)
|
|
557
|
+
right_output = cross_attention_block(
|
|
558
|
+
prev_right_output,
|
|
559
|
+
prev_left_output,
|
|
560
|
+
layer_idx,
|
|
561
|
+
num_attention_heads=num_attention_heads,
|
|
562
|
+
size_per_head=right_size_per_head,
|
|
563
|
+
intermediate_size=right_intermediate_size,
|
|
564
|
+
hidden_dropout_prob=hidden_dropout_prob,
|
|
565
|
+
cross_attention_mask=right_2_left_attention_mask,
|
|
566
|
+
self_attention_mask=right_attention_mask,
|
|
567
|
+
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
|
568
|
+
initializer_range=initializer_range,
|
|
569
|
+
name='%sright_to_left_' % name)
|
|
570
|
+
prev_left_output = left_output
|
|
571
|
+
prev_right_output = right_output
|
|
572
|
+
return prev_left_output, prev_right_output
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
def layer_norm(input_tensor, name=None):
|
|
576
|
+
"""Run layer normalization on the last dimension of the tensor."""
|
|
577
|
+
return tf_layer_norm(
|
|
578
|
+
inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
def reshape_to_matrix(input_tensor):
|
|
582
|
+
"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
|
|
583
|
+
ndims = input_tensor.shape.ndims
|
|
584
|
+
if ndims < 2:
|
|
585
|
+
raise ValueError('Input tensor must have at least rank 2. Shape = %s' %
|
|
586
|
+
(input_tensor.shape))
|
|
587
|
+
if ndims == 2:
|
|
588
|
+
return input_tensor
|
|
589
|
+
|
|
590
|
+
width = input_tensor.shape[-1]
|
|
591
|
+
output_tensor = tf.reshape(input_tensor, [-1, width])
|
|
592
|
+
return output_tensor
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def reshape_from_matrix(output_tensor, orig_shape_list):
|
|
596
|
+
"""Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
|
|
597
|
+
if len(orig_shape_list) == 2:
|
|
598
|
+
return output_tensor
|
|
599
|
+
|
|
600
|
+
output_shape = get_shape_list(output_tensor)
|
|
601
|
+
|
|
602
|
+
orig_dims = orig_shape_list[0:-1]
|
|
603
|
+
width = output_shape[-1]
|
|
604
|
+
|
|
605
|
+
return tf.reshape(output_tensor, orig_dims + [width])
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
def create_attention_mask_from_input_mask(from_tensor, to_mask):
|
|
609
|
+
"""Create 3D attention mask from a 2D tensor mask.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
|
|
613
|
+
to_mask: int32 Tensor of shape [batch_size, to_seq_length].
|
|
614
|
+
|
|
615
|
+
Returns:
|
|
616
|
+
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
|
|
617
|
+
"""
|
|
618
|
+
from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
|
|
619
|
+
batch_size = from_shape[0]
|
|
620
|
+
from_seq_length = from_shape[1]
|
|
621
|
+
|
|
622
|
+
to_shape = get_shape_list(to_mask, expected_rank=2)
|
|
623
|
+
to_seq_length = to_shape[1]
|
|
624
|
+
|
|
625
|
+
to_mask = tf.cast(
|
|
626
|
+
tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)
|
|
627
|
+
|
|
628
|
+
# We don't assume that `from_tensor` is a mask (although it could be). We
|
|
629
|
+
# don't actually care if we attend *from* padding tokens (only *to* padding)
|
|
630
|
+
# tokens so we create a tensor of all ones.
|
|
631
|
+
#
|
|
632
|
+
# `broadcast_ones` = [batch_size, from_seq_length, 1]
|
|
633
|
+
broadcast_ones = tf.ones(
|
|
634
|
+
shape=tf.stack([batch_size, from_seq_length, 1]), dtype=tf.float32)
|
|
635
|
+
|
|
636
|
+
# Here we broadcast along two dimensions to create the mask.
|
|
637
|
+
mask = broadcast_ones * to_mask
|
|
638
|
+
|
|
639
|
+
return mask
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
def embedding_postprocessor(input_tensor,
|
|
643
|
+
use_token_type=False,
|
|
644
|
+
token_type_ids=None,
|
|
645
|
+
token_type_vocab_size=16,
|
|
646
|
+
token_type_embedding_name='token_type_embeddings',
|
|
647
|
+
reuse_token_type=None,
|
|
648
|
+
use_position_embeddings=True,
|
|
649
|
+
position_embedding_name='position_embeddings',
|
|
650
|
+
reuse_position_embedding=None,
|
|
651
|
+
initializer_range=0.02,
|
|
652
|
+
max_position_embeddings=512,
|
|
653
|
+
dropout_prob=0.1):
|
|
654
|
+
"""Performs various post-processing on a word embedding tensor.
|
|
655
|
+
|
|
656
|
+
Args:
|
|
657
|
+
input_tensor: float Tensor of shape [batch_size, seq_length,
|
|
658
|
+
embedding_size].
|
|
659
|
+
use_token_type: bool. Whether to add embeddings for `token_type_ids`.
|
|
660
|
+
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
|
|
661
|
+
Must be specified if `use_token_type` is True.
|
|
662
|
+
token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
|
|
663
|
+
token_type_embedding_name: string. The name of the embedding table variable
|
|
664
|
+
for token type ids.
|
|
665
|
+
reuse_token_type: bool. Whether to reuse token type embedding variable.
|
|
666
|
+
use_position_embeddings: bool. Whether to add position embeddings for the
|
|
667
|
+
position of each token in the sequence.
|
|
668
|
+
position_embedding_name: string. The name of the embedding table variable
|
|
669
|
+
for positional embeddings.
|
|
670
|
+
reuse_position_embedding: bool. Whether to reuse position embedding variable.
|
|
671
|
+
initializer_range: float. Range of the weight initialization.
|
|
672
|
+
max_position_embeddings: int. Maximum sequence length that might ever be
|
|
673
|
+
used with this model. This can be longer than the sequence length of
|
|
674
|
+
input_tensor, but cannot be shorter.
|
|
675
|
+
dropout_prob: float. Dropout probability applied to the final output tensor.
|
|
676
|
+
|
|
677
|
+
Returns:
|
|
678
|
+
float tensor with same shape as `input_tensor`.
|
|
679
|
+
|
|
680
|
+
Raises:
|
|
681
|
+
ValueError: One of the tensor shapes or input values is invalid.
|
|
682
|
+
"""
|
|
683
|
+
input_shape = get_shape_list(input_tensor, expected_rank=3)
|
|
684
|
+
batch_size = input_shape[0]
|
|
685
|
+
seq_length = input_shape[1]
|
|
686
|
+
width = input_shape[2]
|
|
687
|
+
|
|
688
|
+
output = input_tensor
|
|
689
|
+
|
|
690
|
+
if use_token_type:
|
|
691
|
+
if token_type_ids is None:
|
|
692
|
+
raise ValueError('`token_type_ids` must be specified if'
|
|
693
|
+
'`use_token_type` is True.')
|
|
694
|
+
with tf.variable_scope('token_type', reuse=reuse_token_type):
|
|
695
|
+
token_type_table = tf.get_variable(
|
|
696
|
+
name=token_type_embedding_name,
|
|
697
|
+
shape=[token_type_vocab_size, width],
|
|
698
|
+
initializer=create_initializer(initializer_range))
|
|
699
|
+
# This vocab will be small so we always do one-hot here, since it is always
|
|
700
|
+
# faster for a small vocabulary.
|
|
701
|
+
flat_token_type_ids = tf.reshape(token_type_ids, [-1])
|
|
702
|
+
one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
|
|
703
|
+
token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
|
|
704
|
+
token_type_embeddings = tf.reshape(token_type_embeddings,
|
|
705
|
+
[batch_size, seq_length, width])
|
|
706
|
+
output += token_type_embeddings
|
|
707
|
+
|
|
708
|
+
if use_position_embeddings:
|
|
709
|
+
assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
|
|
710
|
+
with tf.control_dependencies([assert_op]):
|
|
711
|
+
with tf.variable_scope(
|
|
712
|
+
'position_embedding', reuse=reuse_position_embedding):
|
|
713
|
+
full_position_embeddings = tf.get_variable(
|
|
714
|
+
name=position_embedding_name,
|
|
715
|
+
shape=[max_position_embeddings, width],
|
|
716
|
+
initializer=create_initializer(initializer_range))
|
|
717
|
+
# Since the position embedding table is a learned variable, we create it
|
|
718
|
+
# using a (long) sequence length `max_position_embeddings`. The actual
|
|
719
|
+
# sequence length might be shorter than this, for faster training of
|
|
720
|
+
# tasks that do not have long sequences.
|
|
721
|
+
#
|
|
722
|
+
# So `full_position_embeddings` is effectively an embedding table
|
|
723
|
+
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
|
|
724
|
+
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
|
|
725
|
+
# perform a slice.
|
|
726
|
+
position_embeddings = tf.slice(full_position_embeddings, [0, 0],
|
|
727
|
+
[seq_length, -1])
|
|
728
|
+
num_dims = len(output.shape.as_list())
|
|
729
|
+
|
|
730
|
+
# Only the last two dimensions are relevant (`seq_length` and `width`), so
|
|
731
|
+
# we broadcast among the first dimensions, which is typically just
|
|
732
|
+
# the batch size.
|
|
733
|
+
position_broadcast_shape = []
|
|
734
|
+
for _ in range(num_dims - 2):
|
|
735
|
+
position_broadcast_shape.append(1)
|
|
736
|
+
position_broadcast_shape.extend([seq_length, width])
|
|
737
|
+
position_embeddings = tf.reshape(position_embeddings,
|
|
738
|
+
position_broadcast_shape)
|
|
739
|
+
output += position_embeddings
|
|
740
|
+
|
|
741
|
+
output = layer_norm_and_dropout(output, dropout_prob)
|
|
742
|
+
return output
|
|
743
|
+
|
|
744
|
+
|
|
745
|
+
def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
|
|
746
|
+
"""Runs layer normalization followed by dropout."""
|
|
747
|
+
output_tensor = layer_norm(input_tensor, name)
|
|
748
|
+
output_tensor = dropout(output_tensor, dropout_prob)
|
|
749
|
+
return output_tensor
|