easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
from tensorflow.python.keras.layers import Dense
|
|
8
|
+
from tensorflow.python.keras.layers import Dropout
|
|
9
|
+
from tensorflow.python.keras.layers import Embedding
|
|
10
|
+
from tensorflow.python.keras.layers import Layer
|
|
11
|
+
|
|
12
|
+
from easy_rec.python.layers.keras import MultiHeadAttention
|
|
13
|
+
from easy_rec.python.layers.keras.layer_norm import LayerNormalization
|
|
14
|
+
from easy_rec.python.layers.utils import Parameter
|
|
15
|
+
from easy_rec.python.protos import seq_encoder_pb2
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TransformerBlock(Layer):
|
|
19
|
+
"""A transformer block combines multi-head attention and feed-forward networks with layer normalization and dropout.
|
|
20
|
+
|
|
21
|
+
Purpose: Combines attention and feed-forward layers with residual connections and normalization.
|
|
22
|
+
Components: Multi-head attention, feed-forward network, dropout, and layer normalization.
|
|
23
|
+
Output: Enhanced representation after applying attention and feed-forward layers.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, params, name='transformer_block', reuse=None, **kwargs):
|
|
27
|
+
super(TransformerBlock, self).__init__(name=name, **kwargs)
|
|
28
|
+
d_model = params.hidden_size
|
|
29
|
+
num_heads = params.num_attention_heads
|
|
30
|
+
mha_cfg = seq_encoder_pb2.MultiHeadAttention()
|
|
31
|
+
mha_cfg.num_heads = num_heads
|
|
32
|
+
mha_cfg.key_dim = d_model // num_heads
|
|
33
|
+
mha_cfg.dropout = params.get_or_default('attention_probs_dropout_prob', 0.0)
|
|
34
|
+
mha_cfg.return_attention_scores = False
|
|
35
|
+
args = Parameter.make_from_pb(mha_cfg)
|
|
36
|
+
self.mha = MultiHeadAttention(args, 'multi_head_attn')
|
|
37
|
+
dropout_rate = params.get_or_default('hidden_dropout_prob', 0.1)
|
|
38
|
+
ffn_units = params.get_or_default('intermediate_size', d_model)
|
|
39
|
+
ffn_act = params.get_or_default('hidden_act', 'relu')
|
|
40
|
+
self.ffn_dense1 = Dense(ffn_units, activation=ffn_act)
|
|
41
|
+
self.ffn_dense2 = Dense(d_model)
|
|
42
|
+
if tf.__version__ >= '2.0':
|
|
43
|
+
self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
|
44
|
+
self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
|
45
|
+
else:
|
|
46
|
+
self.layer_norm1 = LayerNormalization(epsilon=1e-6)
|
|
47
|
+
self.layer_norm2 = LayerNormalization(epsilon=1e-6)
|
|
48
|
+
self.dropout1 = Dropout(dropout_rate)
|
|
49
|
+
self.dropout2 = Dropout(dropout_rate)
|
|
50
|
+
|
|
51
|
+
def call(self, inputs, training=None, **kwargs):
|
|
52
|
+
x, mask = inputs
|
|
53
|
+
attn_output = self.mha([x, x, x], mask=mask, training=training)
|
|
54
|
+
attn_output = self.dropout1(attn_output, training=training)
|
|
55
|
+
out1 = self.layer_norm1(x + attn_output)
|
|
56
|
+
ffn_mid = self.ffn_dense1(out1)
|
|
57
|
+
ffn_output = self.ffn_dense2(ffn_mid)
|
|
58
|
+
ffn_output = self.dropout2(ffn_output, training=training)
|
|
59
|
+
out2 = self.layer_norm2(out1 + ffn_output)
|
|
60
|
+
return out2
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# Positional Encoding, https://www.tensorflow.org/text/tutorials/transformer
|
|
64
|
+
def positional_encoding(length, depth):
|
|
65
|
+
depth = depth / 2
|
|
66
|
+
positions = np.arange(length)[:, np.newaxis] # (seq, 1)
|
|
67
|
+
depths = np.arange(depth)[np.newaxis, :] / depth # (1, depth)
|
|
68
|
+
angle_rates = 1 / (10000**depths) # (1, depth)
|
|
69
|
+
angle_rads = positions * angle_rates # (pos, depth)
|
|
70
|
+
pos_encoding = np.concatenate(
|
|
71
|
+
[np.sin(angle_rads), np.cos(angle_rads)], axis=-1)
|
|
72
|
+
return tf.cast(pos_encoding, dtype=tf.float32)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class PositionalEmbedding(Layer):
|
|
76
|
+
|
|
77
|
+
def __init__(self, vocab_size, d_model, max_position, name='pos_embedding'):
|
|
78
|
+
super(PositionalEmbedding, self).__init__(name=name)
|
|
79
|
+
self.d_model = d_model
|
|
80
|
+
self.embedding = Embedding(vocab_size, d_model)
|
|
81
|
+
self.pos_encoding = positional_encoding(length=max_position, depth=d_model)
|
|
82
|
+
|
|
83
|
+
def call(self, x, training=None):
|
|
84
|
+
length = tf.shape(x)[1]
|
|
85
|
+
x = self.embedding(x)
|
|
86
|
+
# This factor sets the relative scale of the embedding and positional_encoding.
|
|
87
|
+
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
|
|
88
|
+
x = x + self.pos_encoding[tf.newaxis, :length, :]
|
|
89
|
+
return x
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class TransformerEncoder(Layer):
|
|
93
|
+
"""The encoder consists of a stack of encoder layers.
|
|
94
|
+
|
|
95
|
+
It converts the input sequence into a set of embeddings enriched with positional information.
|
|
96
|
+
Purpose: Encodes the input sequence into a set of embeddings.
|
|
97
|
+
Components: Embedding layer, positional encoding, and a stack of transformer blocks.
|
|
98
|
+
Output: Encoded representation of the input sequence.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(self, params, name='transformer_encoder', reuse=None, **kwargs):
|
|
102
|
+
super(TransformerEncoder, self).__init__(name=name, **kwargs)
|
|
103
|
+
d_model = params.hidden_size
|
|
104
|
+
dropout_rate = params.get_or_default('hidden_dropout_prob', 0.1)
|
|
105
|
+
max_position = params.get_or_default('max_position_embeddings', 512)
|
|
106
|
+
num_layers = params.get_or_default('num_hidden_layers', 1)
|
|
107
|
+
vocab_size = params.vocab_size
|
|
108
|
+
logging.info('vocab size of TransformerEncoder(%s) is %d', name, vocab_size)
|
|
109
|
+
self.output_all = params.get_or_default('output_all_token_embeddings', True)
|
|
110
|
+
self.pos_encoding = PositionalEmbedding(vocab_size, d_model, max_position)
|
|
111
|
+
self.dropout = Dropout(dropout_rate)
|
|
112
|
+
self.enc_layers = [
|
|
113
|
+
TransformerBlock(params, 'layer_%d' % i) for i in range(num_layers)
|
|
114
|
+
]
|
|
115
|
+
self._vocab_size = vocab_size
|
|
116
|
+
self._max_position = max_position
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def vocab_size(self):
|
|
120
|
+
return self._vocab_size
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def max_position(self):
|
|
124
|
+
return self._max_position
|
|
125
|
+
|
|
126
|
+
def call(self, inputs, training=None, **kwargs):
|
|
127
|
+
x, mask = inputs
|
|
128
|
+
# `x` is token-IDs shape: (batch, seq_len)
|
|
129
|
+
x = self.pos_encoding(x) # Shape `(batch_size, seq_len, d_model)`.
|
|
130
|
+
x = self.dropout(x, training=training)
|
|
131
|
+
for block in self.enc_layers:
|
|
132
|
+
x = block([x, mask], training)
|
|
133
|
+
# x Shape `(batch_size, seq_len, d_model)`.
|
|
134
|
+
return x if self.output_all else x[:, 0, :]
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class TextEncoder(Layer):
|
|
138
|
+
|
|
139
|
+
def __init__(self, params, name='text_encoder', reuse=None, **kwargs):
|
|
140
|
+
super(TextEncoder, self).__init__(name=name, **kwargs)
|
|
141
|
+
self.separator = params.get_or_default('separator', ' ')
|
|
142
|
+
self.cls_token = '[CLS]' + self.separator
|
|
143
|
+
self.sep_token = self.separator + '[SEP]' + self.separator
|
|
144
|
+
params.transformer.output_all_token_embeddings = False
|
|
145
|
+
trans_params = Parameter.make_from_pb(params.transformer)
|
|
146
|
+
vocab_file = params.get_or_default('vocab_file', None)
|
|
147
|
+
self.vocab = None
|
|
148
|
+
self.default_token_id = params.get_or_default('default_token_id', 0)
|
|
149
|
+
if vocab_file is not None:
|
|
150
|
+
self.vocab = tf.feature_column.categorical_column_with_vocabulary_file(
|
|
151
|
+
'tokens',
|
|
152
|
+
vocabulary_file=vocab_file,
|
|
153
|
+
default_value=self.default_token_id)
|
|
154
|
+
logging.info('vocab file of TextEncoder(%s) is %s', name, vocab_file)
|
|
155
|
+
trans_params.vocab_size = self.vocab.vocabulary_size
|
|
156
|
+
self.encoder = TransformerEncoder(trans_params, name='transformer')
|
|
157
|
+
|
|
158
|
+
def call(self, inputs, training=None, **kwargs):
|
|
159
|
+
if type(inputs) not in (tuple, list):
|
|
160
|
+
inputs = [inputs]
|
|
161
|
+
inputs = [tf.squeeze(text) for text in inputs]
|
|
162
|
+
batch_size = tf.shape(inputs[0])
|
|
163
|
+
cls = tf.fill(batch_size, self.cls_token)
|
|
164
|
+
sep = tf.fill(batch_size, self.sep_token)
|
|
165
|
+
sentences = [cls]
|
|
166
|
+
for sentence in inputs:
|
|
167
|
+
sentences.append(sentence)
|
|
168
|
+
sentences.append(sep)
|
|
169
|
+
text = tf.strings.join(sentences)
|
|
170
|
+
tokens = tf.strings.split(text, self.separator)
|
|
171
|
+
if self.vocab is not None:
|
|
172
|
+
features = {'tokens': tokens}
|
|
173
|
+
token_ids = self.vocab._transform_feature(features)
|
|
174
|
+
token_ids = tf.sparse.to_dense(
|
|
175
|
+
token_ids, default_value=self.default_token_id, name='token_ids')
|
|
176
|
+
length = tf.shape(token_ids)[-1]
|
|
177
|
+
token_ids = tf.cond(
|
|
178
|
+
tf.less_equal(length, self.encoder.max_position), lambda: token_ids,
|
|
179
|
+
lambda: tf.slice(token_ids, [0, 0], [-1, self.encoder.max_position]))
|
|
180
|
+
mask = tf.not_equal(token_ids, self.default_token_id, name='mask')
|
|
181
|
+
else:
|
|
182
|
+
tokens = tf.sparse.to_dense(tokens, default_value='')
|
|
183
|
+
length = tf.shape(tokens)[-1]
|
|
184
|
+
tokens = tf.cond(
|
|
185
|
+
tf.less_equal(length, self.encoder.max_position), lambda: tokens,
|
|
186
|
+
lambda: tf.slice(tokens, [0, 0], [-1, self.encoder.max_position]))
|
|
187
|
+
token_ids = tf.string_to_hash_bucket_fast(
|
|
188
|
+
tokens, self.encoder.vocab_size, name='token_ids')
|
|
189
|
+
mask = tf.not_equal(tokens, '', name='mask')
|
|
190
|
+
|
|
191
|
+
encoding = self.encoder([token_ids, mask], training=training)
|
|
192
|
+
return encoding
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
# -*- encoding: utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
|
|
5
|
+
if tf.__version__ >= '2.0':
|
|
6
|
+
tf = tf.compat.v1
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LayerNormalization(tf.layers.Layer):
|
|
10
|
+
"""Layer normalization for BTC format: supports L2(default) and L1 modes."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, hidden_size, params={}):
|
|
13
|
+
super(LayerNormalization, self).__init__()
|
|
14
|
+
self.hidden_size = hidden_size
|
|
15
|
+
self.norm_type = params.get('type', 'layernorm_L2')
|
|
16
|
+
self.epsilon = params.get('epsilon', 1e-6)
|
|
17
|
+
|
|
18
|
+
def build(self, _):
|
|
19
|
+
self.scale = tf.get_variable(
|
|
20
|
+
'layer_norm_scale', [self.hidden_size],
|
|
21
|
+
initializer=tf.keras.initializers.Ones(),
|
|
22
|
+
dtype=tf.float32)
|
|
23
|
+
self.bias = tf.get_variable(
|
|
24
|
+
'layer_norm_bias', [self.hidden_size],
|
|
25
|
+
initializer=tf.keras.initializers.Zeros(),
|
|
26
|
+
dtype=tf.float32)
|
|
27
|
+
self.built = True
|
|
28
|
+
|
|
29
|
+
def call(self, x):
|
|
30
|
+
if self.norm_type == 'layernorm_L2':
|
|
31
|
+
epsilon = self.epsilon
|
|
32
|
+
dtype = x.dtype
|
|
33
|
+
x = tf.cast(x=x, dtype=tf.float32)
|
|
34
|
+
mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
|
|
35
|
+
variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True)
|
|
36
|
+
norm_x = (x - mean) * tf.rsqrt(variance + epsilon)
|
|
37
|
+
result = norm_x * self.scale + self.bias
|
|
38
|
+
return tf.cast(x=result, dtype=dtype)
|
|
39
|
+
|
|
40
|
+
else:
|
|
41
|
+
dtype = x.dtype
|
|
42
|
+
if dtype == tf.float16:
|
|
43
|
+
x = tf.cast(x, dtype=tf.float32)
|
|
44
|
+
mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
|
|
45
|
+
x = x - mean
|
|
46
|
+
variance = tf.reduce_mean(tf.abs(x), axis=[-1], keepdims=True)
|
|
47
|
+
norm_x = tf.div(x, variance + self.epsilon)
|
|
48
|
+
y = norm_x * self.scale + self.bias
|
|
49
|
+
if dtype == tf.float16:
|
|
50
|
+
y = tf.saturate_cast(y, dtype)
|
|
51
|
+
return y
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
|
|
7
|
+
from easy_rec.python.layers import dnn
|
|
8
|
+
|
|
9
|
+
if tf.__version__ >= '2.0':
|
|
10
|
+
tf = tf.compat.v1
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MMOE:
|
|
14
|
+
|
|
15
|
+
def __init__(self,
|
|
16
|
+
expert_dnn_config,
|
|
17
|
+
l2_reg,
|
|
18
|
+
num_task,
|
|
19
|
+
num_expert=None,
|
|
20
|
+
name='mmoe',
|
|
21
|
+
is_training=False):
|
|
22
|
+
"""Initializes a `DNN` Layer.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
expert_dnn_config: a instance or a list of easy_rec.python.protos.dnn_pb2.DNN,
|
|
26
|
+
if it is a list of configs, the param `num_expert` will be ignored,
|
|
27
|
+
if it is a single config, the number of experts will be specified by num_expert.
|
|
28
|
+
l2_reg: l2 regularizer.
|
|
29
|
+
num_task: number of tasks
|
|
30
|
+
num_expert: number of experts, default is the list length of expert_dnn_configs
|
|
31
|
+
name: scope of the DNN, so that the parameters could be separated from other dnns
|
|
32
|
+
is_training: train phase or not, impact batchnorm and dropout
|
|
33
|
+
"""
|
|
34
|
+
if isinstance(expert_dnn_config, list):
|
|
35
|
+
self._expert_dnn_configs = expert_dnn_config
|
|
36
|
+
self._num_expert = len(expert_dnn_config)
|
|
37
|
+
else:
|
|
38
|
+
assert num_expert is not None and num_expert > 0, \
|
|
39
|
+
'param `num_expert` must be large than zero, when expert_dnn_config is not a list'
|
|
40
|
+
self._expert_dnn_configs = [expert_dnn_config] * num_expert
|
|
41
|
+
self._num_expert = num_expert
|
|
42
|
+
logging.info('num_expert: {0}'.format(self._num_expert))
|
|
43
|
+
|
|
44
|
+
self._num_task = num_task
|
|
45
|
+
self._l2_reg = l2_reg
|
|
46
|
+
self._name = name
|
|
47
|
+
self._is_training = is_training
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def num_expert(self):
|
|
51
|
+
return self._num_expert
|
|
52
|
+
|
|
53
|
+
def gate(self, unit, deep_fea, name):
|
|
54
|
+
fea = tf.layers.dense(
|
|
55
|
+
inputs=deep_fea,
|
|
56
|
+
units=unit,
|
|
57
|
+
kernel_regularizer=self._l2_reg,
|
|
58
|
+
name='%s/dnn' % name)
|
|
59
|
+
fea = tf.nn.softmax(fea, axis=1)
|
|
60
|
+
return fea
|
|
61
|
+
|
|
62
|
+
def __call__(self, deep_fea):
|
|
63
|
+
expert_fea_list = []
|
|
64
|
+
for expert_id in range(self._num_expert):
|
|
65
|
+
expert_dnn_config = self._expert_dnn_configs[expert_id]
|
|
66
|
+
expert_dnn = dnn.DNN(
|
|
67
|
+
expert_dnn_config,
|
|
68
|
+
self._l2_reg,
|
|
69
|
+
name='%s/expert_%d' % (self._name, expert_id),
|
|
70
|
+
is_training=self._is_training)
|
|
71
|
+
expert_fea = expert_dnn(deep_fea)
|
|
72
|
+
expert_fea_list.append(expert_fea)
|
|
73
|
+
experts_fea = tf.stack(expert_fea_list, axis=1)
|
|
74
|
+
|
|
75
|
+
task_input_list = []
|
|
76
|
+
for task_id in range(self._num_task):
|
|
77
|
+
gate = self.gate(
|
|
78
|
+
self._num_expert, deep_fea, name='%s/gate_%d' % (self._name, task_id))
|
|
79
|
+
gate = tf.expand_dims(gate, -1)
|
|
80
|
+
task_input = tf.multiply(experts_fea, gate)
|
|
81
|
+
task_input = tf.reduce_sum(task_input, axis=1)
|
|
82
|
+
task_input_list.append(task_input)
|
|
83
|
+
return task_input_list
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
|
|
5
|
+
if tf.__version__ >= '2.0':
|
|
6
|
+
tf = tf.compat.v1
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MultiHeadAttention:
|
|
10
|
+
|
|
11
|
+
def __init__(self, head_num, head_size, l2_reg, use_res=False, name=''):
|
|
12
|
+
"""Initializes a `MultiHeadAttention` Layer.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
head_num: The number of heads
|
|
16
|
+
head_size: The dimension of a head
|
|
17
|
+
l2_reg: l2 regularizer
|
|
18
|
+
use_res: Whether to use residual connections before output.
|
|
19
|
+
name: scope of the MultiHeadAttention, so that the parameters could be separated from other MultiHeadAttention
|
|
20
|
+
"""
|
|
21
|
+
self._head_num = head_num
|
|
22
|
+
self._head_size = head_size
|
|
23
|
+
self._l2_reg = l2_reg
|
|
24
|
+
self._use_res = use_res
|
|
25
|
+
self._name = name
|
|
26
|
+
|
|
27
|
+
def _split_multihead_qkv(self, q, k, v):
|
|
28
|
+
"""Split multiple heads.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
q: Query matrix of shape [bs, feature_num, head_num * head_size].
|
|
32
|
+
k: Key matrix of shape [bs, feature_num, head_num * head_size].
|
|
33
|
+
v: Value matrix of shape [bs, feature_num, head_num * head_size].
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
q: Query matrix of shape [bs, head_num, feature_num, head_size].
|
|
37
|
+
k: Key matrix of shape [bs, head_num, feature_num, head_size].
|
|
38
|
+
v: Value matrix of shape [bs, head_num, feature_num, head_size].
|
|
39
|
+
"""
|
|
40
|
+
reshaped_q = tf.reshape(
|
|
41
|
+
q, shape=[-1, q.shape[1], self._head_num, self._head_size])
|
|
42
|
+
q = tf.transpose(reshaped_q, perm=[0, 2, 1, 3])
|
|
43
|
+
reshaped_k = tf.reshape(
|
|
44
|
+
k, shape=[-1, k.shape[1], self._head_num, self._head_size])
|
|
45
|
+
k = tf.transpose(reshaped_k, perm=[0, 2, 1, 3])
|
|
46
|
+
reshaped_v = tf.reshape(
|
|
47
|
+
v, shape=[-1, v.shape[1], self._head_num, self._head_size])
|
|
48
|
+
v = tf.transpose(reshaped_v, perm=[0, 2, 1, 3])
|
|
49
|
+
return q, k, v
|
|
50
|
+
|
|
51
|
+
def _scaled_dot_product_attention(self, q, k, v):
|
|
52
|
+
"""Calculate scaled dot product attention by q, k and v.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
q: Query matrix of shape [bs, head_num, feature_num, head_size].
|
|
56
|
+
k: Key matrix of shape [bs, head_num, feature_num, head_size].
|
|
57
|
+
v: Value matrix of shape [bs, head_num, feature_num, head_size].
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
q: Query matrix of shape [bs, head_num, feature_num, head_size].
|
|
61
|
+
k: Key matrix of shape [bs, head_num, feature_num, head_size].
|
|
62
|
+
v: Value matrix of shape [bs, head_num, feature_num, head_size].
|
|
63
|
+
"""
|
|
64
|
+
product = tf.linalg.matmul(
|
|
65
|
+
a=q, b=k, transpose_b=True) / (
|
|
66
|
+
self._head_size**-0.5)
|
|
67
|
+
weights = tf.nn.softmax(product)
|
|
68
|
+
out = tf.linalg.matmul(weights, v)
|
|
69
|
+
return out
|
|
70
|
+
|
|
71
|
+
def _compute_qkv(self, q, k, v):
|
|
72
|
+
"""Calculate q, k and v matrices.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
q: Query matrix of shape [bs, feature_num, d_model].
|
|
76
|
+
k: Key matrix of shape [bs, feature_num, d_model].
|
|
77
|
+
v: Value matrix of shape [bs, feature_num, d_model].
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
q: Query matrix of shape [bs, feature_num, head_size * n_head].
|
|
81
|
+
k: Key matrix of shape [bs, feature_num, head_size * n_head].
|
|
82
|
+
v: Value matrix of shape [bs, feature_num, head_size * n_head].
|
|
83
|
+
"""
|
|
84
|
+
q = tf.layers.dense(
|
|
85
|
+
q,
|
|
86
|
+
self._head_num * self._head_size,
|
|
87
|
+
use_bias=False,
|
|
88
|
+
kernel_regularizer=self._l2_reg,
|
|
89
|
+
name='%s/%s/dnn' % (self._name, 'query'))
|
|
90
|
+
k = tf.layers.dense(
|
|
91
|
+
k,
|
|
92
|
+
self._head_num * self._head_size,
|
|
93
|
+
use_bias=False,
|
|
94
|
+
kernel_regularizer=self._l2_reg,
|
|
95
|
+
name='%s/%s/dnn' % (self._name, 'key'))
|
|
96
|
+
v = tf.layers.dense(
|
|
97
|
+
v,
|
|
98
|
+
self._head_num * self._head_size,
|
|
99
|
+
use_bias=False,
|
|
100
|
+
kernel_regularizer=self._l2_reg,
|
|
101
|
+
name='%s/%s/dnn' % (self._name, 'value'))
|
|
102
|
+
return q, k, v
|
|
103
|
+
|
|
104
|
+
def _combine_heads(self, multi_head_tensor):
|
|
105
|
+
"""Combine the results of multiple heads.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
multi_head_tensor: Result matrix of shape [bs, head_num, feature_num, head_size].
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
out: Result matrix of shape [bs, feature_num, head_num * head_size].
|
|
112
|
+
"""
|
|
113
|
+
x = tf.transpose(multi_head_tensor, perm=[0, 2, 1, 3])
|
|
114
|
+
out = tf.reshape(x, shape=[-1, x.shape[1], x.shape[2] * x.shape[3]])
|
|
115
|
+
return out
|
|
116
|
+
|
|
117
|
+
def _multi_head_attention(self, attention_input):
|
|
118
|
+
"""Build multiple heads attention layer.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
attention_input: The input of interacting layer, has a shape of [bs, feature_num, d_model].
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
out: The output of multi head attention layer, has a shape of [bs, feature_num, head_num * head_size].
|
|
125
|
+
"""
|
|
126
|
+
if isinstance(attention_input, list):
|
|
127
|
+
assert len(attention_input) == 3 or len(attention_input) == 1, \
|
|
128
|
+
'If the input of multi_head_attention is a list, the length must be 1 or 3.'
|
|
129
|
+
|
|
130
|
+
if len(attention_input) == 3:
|
|
131
|
+
ori_q = attention_input[0]
|
|
132
|
+
ori_k = attention_input[1]
|
|
133
|
+
ori_v = attention_input[2]
|
|
134
|
+
else:
|
|
135
|
+
ori_q = attention_input[0]
|
|
136
|
+
ori_k = attention_input[0]
|
|
137
|
+
ori_v = attention_input[0]
|
|
138
|
+
else:
|
|
139
|
+
ori_q = attention_input
|
|
140
|
+
ori_k = attention_input
|
|
141
|
+
ori_v = attention_input
|
|
142
|
+
|
|
143
|
+
q, k, v = self._compute_qkv(ori_q, ori_k, ori_v)
|
|
144
|
+
q, k, v = self._split_multihead_qkv(q, k, v)
|
|
145
|
+
multi_head_tensor = self._scaled_dot_product_attention(q, k, v)
|
|
146
|
+
out = self._combine_heads(multi_head_tensor)
|
|
147
|
+
|
|
148
|
+
if self._use_res:
|
|
149
|
+
W_0_x = tf.layers.dense(
|
|
150
|
+
ori_v,
|
|
151
|
+
out.shape[2],
|
|
152
|
+
use_bias=False,
|
|
153
|
+
kernel_regularizer=self._l2_reg,
|
|
154
|
+
name='%s/dnn' % (self._name))
|
|
155
|
+
res_out = tf.nn.relu(out + W_0_x)
|
|
156
|
+
return res_out
|
|
157
|
+
else:
|
|
158
|
+
return out
|
|
159
|
+
|
|
160
|
+
def __call__(self, deep_fea):
|
|
161
|
+
deep_fea = self._multi_head_attention(deep_fea)
|
|
162
|
+
return deep_fea
|