easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
"""Attention layers that can be used in sequence DNN/CNN models.
|
|
4
|
+
|
|
5
|
+
This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2.
|
|
6
|
+
Attention is formed by three tensors: Query, Key and Value.
|
|
7
|
+
"""
|
|
8
|
+
import tensorflow as tf
|
|
9
|
+
from tensorflow.python.keras.layers import Layer
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Attention(Layer):
|
|
13
|
+
"""Dot-product attention layer, a.k.a. Luong-style attention.
|
|
14
|
+
|
|
15
|
+
Inputs are a list with 2 or 3 elements:
|
|
16
|
+
1. A `query` tensor of shape `(batch_size, Tq, dim)`.
|
|
17
|
+
2. A `value` tensor of shape `(batch_size, Tv, dim)`.
|
|
18
|
+
3. A optional `key` tensor of shape `(batch_size, Tv, dim)`. If none
|
|
19
|
+
supplied, `value` will be used as a `key`.
|
|
20
|
+
|
|
21
|
+
The calculation follows the steps:
|
|
22
|
+
1. Calculate attention scores using `query` and `key` with shape
|
|
23
|
+
`(batch_size, Tq, Tv)`.
|
|
24
|
+
2. Use scores to calculate a softmax distribution with shape
|
|
25
|
+
`(batch_size, Tq, Tv)`.
|
|
26
|
+
3. Use the softmax distribution to create a linear combination of `value`
|
|
27
|
+
with shape `(batch_size, Tq, dim)`.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
use_scale: If `True`, will create a scalar variable to scale the
|
|
31
|
+
attention scores.
|
|
32
|
+
dropout: Float between 0 and 1. Fraction of the units to drop for the
|
|
33
|
+
attention scores. Defaults to `0.0`.
|
|
34
|
+
seed: A Python integer to use as random seed in case of `dropout`.
|
|
35
|
+
score_mode: Function to use to compute attention scores, one of
|
|
36
|
+
`{"dot", "concat"}`. `"dot"` refers to the dot product between the
|
|
37
|
+
query and key vectors. `"concat"` refers to the hyperbolic tangent
|
|
38
|
+
of the concatenation of the `query` and `key` vectors.
|
|
39
|
+
|
|
40
|
+
Call Args:
|
|
41
|
+
inputs: List of the following tensors:
|
|
42
|
+
- `query`: Query tensor of shape `(batch_size, Tq, dim)`.
|
|
43
|
+
- `value`: Value tensor of shape `(batch_size, Tv, dim)`.
|
|
44
|
+
- `key`: Optional key tensor of shape `(batch_size, Tv, dim)`. If
|
|
45
|
+
not given, will use `value` for both `key` and `value`, which is
|
|
46
|
+
the most common case.
|
|
47
|
+
mask: List of the following tensors:
|
|
48
|
+
- `query_mask`: A boolean mask tensor of shape `(batch_size, Tq)`.
|
|
49
|
+
If given, the output will be zero at the positions where
|
|
50
|
+
`mask==False`.
|
|
51
|
+
- `value_mask`: A boolean mask tensor of shape `(batch_size, Tv)`.
|
|
52
|
+
If given, will apply the mask such that values at positions
|
|
53
|
+
where `mask==False` do not contribute to the result.
|
|
54
|
+
return_attention_scores: bool, it `True`, returns the attention scores
|
|
55
|
+
(after masking and softmax) as an additional output argument.
|
|
56
|
+
training: Python boolean indicating whether the layer should behave in
|
|
57
|
+
training mode (adding dropout) or in inference mode (no dropout).
|
|
58
|
+
use_causal_mask: Boolean. Set to `True` for decoder self-attention. Adds
|
|
59
|
+
a mask such that position `i` cannot attend to positions `j > i`.
|
|
60
|
+
This prevents the flow of information from the future towards the
|
|
61
|
+
past. Defaults to `False`.
|
|
62
|
+
|
|
63
|
+
Output:
|
|
64
|
+
Attention outputs of shape `(batch_size, Tq, dim)`.
|
|
65
|
+
(Optional) Attention scores after masking and softmax with shape
|
|
66
|
+
`(batch_size, Tq, Tv)`.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(self, params, name='attention', reuse=None, **kwargs):
|
|
70
|
+
super(Attention, self).__init__(name=name, **kwargs)
|
|
71
|
+
self.use_scale = params.get_or_default('use_scale', False)
|
|
72
|
+
self.scale_by_dim = params.get_or_default('scale_by_dim', False)
|
|
73
|
+
self.score_mode = params.get_or_default('score_mode', 'dot')
|
|
74
|
+
if self.score_mode not in ['dot', 'concat']:
|
|
75
|
+
raise ValueError('Invalid value for argument score_mode. '
|
|
76
|
+
"Expected one of {'dot', 'concat'}. "
|
|
77
|
+
'Received: score_mode=%s' % self.score_mode)
|
|
78
|
+
self.dropout = params.get_or_default('dropout', 0.0)
|
|
79
|
+
self.seed = params.get_or_default('seed', None)
|
|
80
|
+
self.scale = None
|
|
81
|
+
self.concat_score_weight = None
|
|
82
|
+
self._return_attention_scores = params.get_or_default(
|
|
83
|
+
'return_attention_scores', False)
|
|
84
|
+
self.use_causal_mask = params.get_or_default('use_causal_mask', False)
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def return_attention_scores(self):
|
|
88
|
+
return self._return_attention_scores
|
|
89
|
+
|
|
90
|
+
def build(self, input_shape):
|
|
91
|
+
self._validate_inputs(input_shape)
|
|
92
|
+
if self.use_scale:
|
|
93
|
+
self.scale = self.add_weight(
|
|
94
|
+
name='scale',
|
|
95
|
+
shape=(),
|
|
96
|
+
initializer='ones',
|
|
97
|
+
dtype=self.dtype,
|
|
98
|
+
trainable=True,
|
|
99
|
+
)
|
|
100
|
+
if self.score_mode == 'concat':
|
|
101
|
+
self.concat_score_weight = self.add_weight(
|
|
102
|
+
name='concat_score_weight',
|
|
103
|
+
shape=(),
|
|
104
|
+
initializer='ones',
|
|
105
|
+
dtype=self.dtype,
|
|
106
|
+
trainable=True,
|
|
107
|
+
)
|
|
108
|
+
super(Attention, self).build(input_shape) # Be sure to call this somewhere!
|
|
109
|
+
|
|
110
|
+
def _calculate_scores(self, query, key):
|
|
111
|
+
"""Calculates attention scores as a query-key dot product.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
query: Query tensor of shape `(batch_size, Tq, dim)`.
|
|
115
|
+
key: Key tensor of shape `(batch_size, Tv, dim)`.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Tensor of shape `(batch_size, Tq, Tv)`.
|
|
119
|
+
"""
|
|
120
|
+
if self.score_mode == 'dot':
|
|
121
|
+
scores = tf.matmul(query, tf.transpose(key, [0, 2, 1]))
|
|
122
|
+
if self.scale is not None:
|
|
123
|
+
scores *= self.scale
|
|
124
|
+
elif self.scale_by_dim:
|
|
125
|
+
dk = tf.cast(tf.shape(key)[-1], tf.float32)
|
|
126
|
+
scores /= tf.math.sqrt(dk)
|
|
127
|
+
elif self.score_mode == 'concat':
|
|
128
|
+
# Reshape tensors to enable broadcasting.
|
|
129
|
+
# Reshape into [batch_size, Tq, 1, dim].
|
|
130
|
+
q_reshaped = tf.expand_dims(query, axis=-2)
|
|
131
|
+
# Reshape into [batch_size, 1, Tv, dim].
|
|
132
|
+
k_reshaped = tf.expand_dims(key, axis=-3)
|
|
133
|
+
if self.scale is not None:
|
|
134
|
+
scores = self.concat_score_weight * tf.reduce_sum(
|
|
135
|
+
tf.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1)
|
|
136
|
+
else:
|
|
137
|
+
scores = self.concat_score_weight * tf.reduce_sum(
|
|
138
|
+
tf.tanh(q_reshaped + k_reshaped), axis=-1)
|
|
139
|
+
return scores
|
|
140
|
+
|
|
141
|
+
def _apply_scores(self, scores, value, scores_mask=None, training=False):
|
|
142
|
+
"""Applies attention scores to the given value tensor.
|
|
143
|
+
|
|
144
|
+
To use this method in your attention layer, follow the steps:
|
|
145
|
+
|
|
146
|
+
* Use `query` tensor of shape `(batch_size, Tq)` and `key` tensor of
|
|
147
|
+
shape `(batch_size, Tv)` to calculate the attention `scores`.
|
|
148
|
+
* Pass `scores` and `value` tensors to this method. The method applies
|
|
149
|
+
`scores_mask`, calculates
|
|
150
|
+
`attention_distribution = softmax(scores)`, then returns
|
|
151
|
+
`matmul(attention_distribution, value).
|
|
152
|
+
* Apply `query_mask` and return the result.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
scores: Scores float tensor of shape `(batch_size, Tq, Tv)`.
|
|
156
|
+
value: Value tensor of shape `(batch_size, Tv, dim)`.
|
|
157
|
+
scores_mask: A boolean mask tensor of shape `(batch_size, 1, Tv)`
|
|
158
|
+
or `(batch_size, Tq, Tv)`. If given, scores at positions where
|
|
159
|
+
`scores_mask==False` do not contribute to the result. It must
|
|
160
|
+
contain at least one `True` value in each line along the last
|
|
161
|
+
dimension.
|
|
162
|
+
training: Python boolean indicating whether the layer should behave
|
|
163
|
+
in training mode (adding dropout) or in inference mode
|
|
164
|
+
(no dropout).
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Tensor of shape `(batch_size, Tq, dim)`.
|
|
168
|
+
Attention scores after masking and softmax with shape
|
|
169
|
+
`(batch_size, Tq, Tv)`.
|
|
170
|
+
"""
|
|
171
|
+
if scores_mask is not None:
|
|
172
|
+
padding_mask = tf.logical_not(scores_mask)
|
|
173
|
+
# Bias so padding positions do not contribute to attention
|
|
174
|
+
# distribution. Note 65504. is the max float16 value.
|
|
175
|
+
max_value = 65504.0 if scores.dtype == 'float16' else 1.0e9
|
|
176
|
+
scores -= max_value * tf.cast(padding_mask, dtype=scores.dtype)
|
|
177
|
+
|
|
178
|
+
weights = tf.nn.softmax(scores, axis=-1)
|
|
179
|
+
if training and self.dropout > 0:
|
|
180
|
+
weights = tf.nn.dropout(weights, 1.0 - self.dropout, seed=self.seed)
|
|
181
|
+
return tf.matmul(weights, value), weights
|
|
182
|
+
|
|
183
|
+
def _calculate_score_mask(self, scores, v_mask, use_causal_mask):
|
|
184
|
+
if use_causal_mask:
|
|
185
|
+
# Creates a lower triangular mask, so position i cannot attend to
|
|
186
|
+
# positions j > i. This prevents the flow of information from the
|
|
187
|
+
# future into the past.
|
|
188
|
+
score_shape = tf.shape(scores)
|
|
189
|
+
# causal_mask_shape = [1, Tq, Tv].
|
|
190
|
+
mask_shape = (1, score_shape[-2], score_shape[-1])
|
|
191
|
+
ones_mask = tf.ones(shape=mask_shape, dtype='int32')
|
|
192
|
+
row_index = tf.cumsum(ones_mask, axis=-2)
|
|
193
|
+
col_index = tf.cumsum(ones_mask, axis=-1)
|
|
194
|
+
causal_mask = tf.greater_equal(row_index, col_index)
|
|
195
|
+
|
|
196
|
+
if v_mask is not None:
|
|
197
|
+
# Mask of shape [batch_size, 1, Tv].
|
|
198
|
+
v_mask = tf.expand_dims(v_mask, axis=-2)
|
|
199
|
+
return tf.logical_and(v_mask, causal_mask)
|
|
200
|
+
return causal_mask
|
|
201
|
+
else:
|
|
202
|
+
# If not using causal mask, return the value mask as is,
|
|
203
|
+
# or None if the value mask is not provided.
|
|
204
|
+
return v_mask
|
|
205
|
+
|
|
206
|
+
def call(self, inputs, mask=None, training=False, **kwargs):
|
|
207
|
+
self._validate_inputs(inputs=inputs, mask=mask)
|
|
208
|
+
q = inputs[0]
|
|
209
|
+
v = inputs[1]
|
|
210
|
+
k = inputs[2] if len(inputs) > 2 else v
|
|
211
|
+
q_mask = mask[0] if mask else None
|
|
212
|
+
v_mask = mask[1] if mask else None
|
|
213
|
+
scores = self._calculate_scores(query=q, key=k)
|
|
214
|
+
scores_mask = self._calculate_score_mask(scores, v_mask,
|
|
215
|
+
self.use_causal_mask)
|
|
216
|
+
result, attention_scores = self._apply_scores(
|
|
217
|
+
scores=scores, value=v, scores_mask=scores_mask, training=training)
|
|
218
|
+
if q_mask is not None:
|
|
219
|
+
# Mask of shape [batch_size, Tq, 1].
|
|
220
|
+
q_mask = tf.expand_dims(q_mask, axis=-1)
|
|
221
|
+
result *= tf.cast(q_mask, dtype=result.dtype)
|
|
222
|
+
if self._return_attention_scores:
|
|
223
|
+
return result, attention_scores
|
|
224
|
+
return result
|
|
225
|
+
|
|
226
|
+
def compute_mask(self, inputs, mask=None):
|
|
227
|
+
self._validate_inputs(inputs=inputs, mask=mask)
|
|
228
|
+
if mask is None or mask[0] is None:
|
|
229
|
+
return None
|
|
230
|
+
return tf.convert_to_tensor(mask[0])
|
|
231
|
+
|
|
232
|
+
def compute_output_shape(self, input_shape):
|
|
233
|
+
"""Returns shape of value tensor dim, but for query tensor length."""
|
|
234
|
+
return list(input_shape[0][:-1]), input_shape[1][-1]
|
|
235
|
+
|
|
236
|
+
def _validate_inputs(self, inputs, mask=None):
|
|
237
|
+
"""Validates arguments of the call method."""
|
|
238
|
+
class_name = self.__class__.__name__
|
|
239
|
+
if not isinstance(inputs, list):
|
|
240
|
+
raise ValueError('{class_name} layer must be called on a list of inputs, '
|
|
241
|
+
'namely [query, value] or [query, value, key]. '
|
|
242
|
+
'Received: inputs={inputs}.'.format(
|
|
243
|
+
class_name=class_name, inputs=inputs))
|
|
244
|
+
if len(inputs) < 2 or len(inputs) > 3:
|
|
245
|
+
raise ValueError('%s layer accepts inputs list of length 2 or 3, '
|
|
246
|
+
'namely [query, value] or [query, value, key]. '
|
|
247
|
+
'Received length: %d.' % (class_name, len(inputs)))
|
|
248
|
+
if mask is not None:
|
|
249
|
+
if not isinstance(mask, list):
|
|
250
|
+
raise ValueError(
|
|
251
|
+
'{class_name} layer mask must be a list, '
|
|
252
|
+
'namely [query_mask, value_mask]. Received: mask={mask}.'.format(
|
|
253
|
+
class_name=class_name, mask=mask))
|
|
254
|
+
if len(mask) < 2 or len(mask) > 3:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
'{class_name} layer accepts mask list of length 2 or 3. '
|
|
257
|
+
'Received: inputs={inputs}, mask={mask}.'.format(
|
|
258
|
+
class_name=class_name, inputs=inputs, mask=mask))
|
|
259
|
+
|
|
260
|
+
def get_config(self):
|
|
261
|
+
base_config = super(Attention, self).get_config()
|
|
262
|
+
config = {
|
|
263
|
+
'use_scale': self.use_scale,
|
|
264
|
+
'score_mode': self.score_mode,
|
|
265
|
+
'dropout': self.dropout,
|
|
266
|
+
}
|
|
267
|
+
return dict(list(base_config.items()) + list(config.items()))
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
|
|
7
|
+
from easy_rec.python.loss import contrastive_loss
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AuxiliaryLoss(tf.keras.layers.Layer):
|
|
11
|
+
"""Compute auxiliary loss, usually use for contrastive learning."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, params, name='auxiliary_loss', reuse=None, **kwargs):
|
|
14
|
+
super(AuxiliaryLoss, self).__init__(name=name, **kwargs)
|
|
15
|
+
params.check_required('loss_type')
|
|
16
|
+
self.loss_type = params.get_or_default('loss_type', None)
|
|
17
|
+
self.loss_weight = params.get_or_default('loss_weight', 1.0)
|
|
18
|
+
logging.info('init layer `%s` with loss type: %s and weight: %f' %
|
|
19
|
+
(self.name, self.loss_type, self.loss_weight))
|
|
20
|
+
self.temperature = params.get_or_default('temperature', 0.1)
|
|
21
|
+
|
|
22
|
+
def call(self, inputs, training=None, **kwargs):
|
|
23
|
+
if self.loss_type is None:
|
|
24
|
+
logging.warning('loss_type is None in auxiliary loss layer')
|
|
25
|
+
return 0
|
|
26
|
+
|
|
27
|
+
loss_dict = kwargs['loss_dict']
|
|
28
|
+
loss_value = 0
|
|
29
|
+
|
|
30
|
+
if self.loss_type == 'l2_loss':
|
|
31
|
+
x1, x2 = inputs
|
|
32
|
+
loss = contrastive_loss.l2_loss(x1, x2)
|
|
33
|
+
loss_value = loss if self.loss_weight == 1.0 else loss * self.loss_weight
|
|
34
|
+
loss_dict['%s_l2_loss' % self.name] = loss_value
|
|
35
|
+
elif self.loss_type == 'info_nce':
|
|
36
|
+
query, positive = inputs
|
|
37
|
+
loss = contrastive_loss.info_nce_loss(
|
|
38
|
+
query, positive, temperature=self.temperature)
|
|
39
|
+
loss_value = loss if self.loss_weight == 1.0 else loss * self.loss_weight
|
|
40
|
+
loss_dict['%s_info_nce_loss' % self.name] = loss_value
|
|
41
|
+
elif self.loss_type == 'nce_loss':
|
|
42
|
+
x1, x2 = inputs
|
|
43
|
+
loss = contrastive_loss.nce_loss(x1, x2, temperature=self.temperature)
|
|
44
|
+
loss_value = loss if self.loss_weight == 1.0 else loss * self.loss_weight
|
|
45
|
+
loss_dict['%s_nce_loss' % self.name] = loss_value
|
|
46
|
+
|
|
47
|
+
return loss_value
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
"""Convenience blocks for building models."""
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
from tensorflow.python.keras.initializers import Constant
|
|
8
|
+
from tensorflow.python.keras.layers import Dense
|
|
9
|
+
from tensorflow.python.keras.layers import Dropout
|
|
10
|
+
from tensorflow.python.keras.layers import Lambda
|
|
11
|
+
from tensorflow.python.keras.layers import Layer
|
|
12
|
+
|
|
13
|
+
from easy_rec.python.layers.keras.activation import activation_layer
|
|
14
|
+
from easy_rec.python.layers.utils import Parameter
|
|
15
|
+
from easy_rec.python.utils.shape_utils import pad_or_truncate_sequence
|
|
16
|
+
from easy_rec.python.utils.tf_utils import add_elements_to_collection
|
|
17
|
+
|
|
18
|
+
if tf.__version__ >= '2.0':
|
|
19
|
+
tf = tf.compat.v1
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MLP(Layer):
|
|
23
|
+
"""Sequential multi-layer perceptron (MLP) block.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
units: Sequential list of layer sizes.
|
|
27
|
+
use_bias: Whether to include a bias term.
|
|
28
|
+
activation: Type of activation to use on all except the last layer.
|
|
29
|
+
final_activation: Type of activation to use on last layer.
|
|
30
|
+
**kwargs: Extra args passed to the Keras Layer base class.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, params, name='mlp', reuse=None, **kwargs):
|
|
34
|
+
super(MLP, self).__init__(name=name, **kwargs)
|
|
35
|
+
self.layer_name = name # for add to output
|
|
36
|
+
params.check_required('hidden_units')
|
|
37
|
+
use_bn = params.get_or_default('use_bn', True)
|
|
38
|
+
use_final_bn = params.get_or_default('use_final_bn', True)
|
|
39
|
+
use_bias = params.get_or_default('use_bias', False)
|
|
40
|
+
use_final_bias = params.get_or_default('use_final_bias', False)
|
|
41
|
+
dropout_rate = list(params.get_or_default('dropout_ratio', []))
|
|
42
|
+
activation = params.get_or_default('activation', 'relu')
|
|
43
|
+
initializer = params.get_or_default('initializer', 'he_uniform')
|
|
44
|
+
final_activation = params.get_or_default('final_activation', None)
|
|
45
|
+
use_bn_after_act = params.get_or_default('use_bn_after_activation', False)
|
|
46
|
+
units = list(params.hidden_units)
|
|
47
|
+
logging.info(
|
|
48
|
+
'MLP(%s) units: %s, dropout: %r, activate=%s, use_bn=%r, final_bn=%r,'
|
|
49
|
+
' final_activate=%s, bias=%r, initializer=%s, bn_after_activation=%r' %
|
|
50
|
+
(name, units, dropout_rate, activation, use_bn, use_final_bn,
|
|
51
|
+
final_activation, use_bias, initializer, use_bn_after_act))
|
|
52
|
+
assert len(units) > 0, 'MLP(%s) takes at least one hidden units' % name
|
|
53
|
+
self.reuse = reuse
|
|
54
|
+
self.add_to_outputs = params.get_or_default('add_to_outputs', False)
|
|
55
|
+
|
|
56
|
+
num_dropout = len(dropout_rate)
|
|
57
|
+
self._sub_layers = []
|
|
58
|
+
for i, num_units in enumerate(units[:-1]):
|
|
59
|
+
name = 'layer_%d' % i
|
|
60
|
+
drop_rate = dropout_rate[i] if i < num_dropout else 0.0
|
|
61
|
+
self.add_rich_layer(num_units, use_bn, drop_rate, activation, initializer,
|
|
62
|
+
use_bias, use_bn_after_act, name,
|
|
63
|
+
params.l2_regularizer)
|
|
64
|
+
|
|
65
|
+
n = len(units) - 1
|
|
66
|
+
drop_rate = dropout_rate[n] if num_dropout > n else 0.0
|
|
67
|
+
name = 'layer_%d' % n
|
|
68
|
+
self.add_rich_layer(units[-1], use_final_bn, drop_rate, final_activation,
|
|
69
|
+
initializer, use_final_bias, use_bn_after_act, name,
|
|
70
|
+
params.l2_regularizer)
|
|
71
|
+
|
|
72
|
+
def add_rich_layer(self,
|
|
73
|
+
num_units,
|
|
74
|
+
use_bn,
|
|
75
|
+
dropout_rate,
|
|
76
|
+
activation,
|
|
77
|
+
initializer,
|
|
78
|
+
use_bias,
|
|
79
|
+
use_bn_after_activation,
|
|
80
|
+
name,
|
|
81
|
+
l2_reg=None):
|
|
82
|
+
act_layer = activation_layer(activation, name='%s/act' % name)
|
|
83
|
+
if use_bn and not use_bn_after_activation:
|
|
84
|
+
dense = Dense(
|
|
85
|
+
units=num_units,
|
|
86
|
+
use_bias=use_bias,
|
|
87
|
+
kernel_initializer=initializer,
|
|
88
|
+
kernel_regularizer=l2_reg,
|
|
89
|
+
name='%s/dense' % name)
|
|
90
|
+
self._sub_layers.append(dense)
|
|
91
|
+
bn = tf.keras.layers.BatchNormalization(
|
|
92
|
+
name='%s/bn' % name, trainable=True)
|
|
93
|
+
self._sub_layers.append(bn)
|
|
94
|
+
self._sub_layers.append(act_layer)
|
|
95
|
+
else:
|
|
96
|
+
dense = Dense(
|
|
97
|
+
num_units,
|
|
98
|
+
use_bias=use_bias,
|
|
99
|
+
kernel_initializer=initializer,
|
|
100
|
+
kernel_regularizer=l2_reg,
|
|
101
|
+
name='%s/dense' % name)
|
|
102
|
+
self._sub_layers.append(dense)
|
|
103
|
+
self._sub_layers.append(act_layer)
|
|
104
|
+
if use_bn and use_bn_after_activation:
|
|
105
|
+
bn = tf.keras.layers.BatchNormalization(name='%s/bn' % name)
|
|
106
|
+
self._sub_layers.append(bn)
|
|
107
|
+
|
|
108
|
+
if 0.0 < dropout_rate < 1.0:
|
|
109
|
+
dropout = Dropout(dropout_rate, name='%s/dropout' % name)
|
|
110
|
+
self._sub_layers.append(dropout)
|
|
111
|
+
elif dropout_rate >= 1.0:
|
|
112
|
+
raise ValueError('invalid dropout_ratio: %.3f' % dropout_rate)
|
|
113
|
+
|
|
114
|
+
def call(self, x, training=None, **kwargs):
|
|
115
|
+
"""Performs the forward computation of the block."""
|
|
116
|
+
for layer in self._sub_layers:
|
|
117
|
+
cls = layer.__class__.__name__
|
|
118
|
+
if cls in ('Dropout', 'BatchNormalization', 'Dice'):
|
|
119
|
+
x = layer(x, training=training)
|
|
120
|
+
if cls in ('BatchNormalization', 'Dice') and training:
|
|
121
|
+
add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS)
|
|
122
|
+
else:
|
|
123
|
+
x = layer(x)
|
|
124
|
+
if self.add_to_outputs and 'prediction_dict' in kwargs:
|
|
125
|
+
outputs = kwargs['prediction_dict']
|
|
126
|
+
outputs[self.layer_name] = tf.squeeze(x, axis=1)
|
|
127
|
+
logging.info('add `%s` to model outputs' % self.layer_name)
|
|
128
|
+
return x
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class Highway(Layer):
|
|
132
|
+
|
|
133
|
+
def __init__(self, params, name='highway', reuse=None, **kwargs):
|
|
134
|
+
super(Highway, self).__init__(name=name, **kwargs)
|
|
135
|
+
self.emb_size = params.get_or_default('emb_size', None)
|
|
136
|
+
self.num_layers = params.get_or_default('num_layers', 1)
|
|
137
|
+
self.activation = params.get_or_default('activation', 'relu')
|
|
138
|
+
self.dropout_rate = params.get_or_default('dropout_rate', 0.0)
|
|
139
|
+
self.init_gate_bias = params.get_or_default('init_gate_bias', -3.0)
|
|
140
|
+
self.act_layer = activation_layer(self.activation)
|
|
141
|
+
self.dropout_layer = Dropout(
|
|
142
|
+
self.dropout_rate) if self.dropout_rate > 0.0 else None
|
|
143
|
+
self.project_layer = None
|
|
144
|
+
self.gate_bias_initializer = Constant(self.init_gate_bias)
|
|
145
|
+
self.gates = [] # T
|
|
146
|
+
self.transforms = [] # H
|
|
147
|
+
self.multiply_layer = tf.keras.layers.Multiply()
|
|
148
|
+
self.add_layer = tf.keras.layers.Add()
|
|
149
|
+
|
|
150
|
+
def build(self, input_shape):
|
|
151
|
+
dim = input_shape[-1]
|
|
152
|
+
if self.emb_size is not None and dim != self.emb_size:
|
|
153
|
+
self.project_layer = Dense(self.emb_size, name='input_projection')
|
|
154
|
+
dim = self.emb_size
|
|
155
|
+
self.carry_gate = Lambda(lambda x: 1.0 - x, output_shape=(dim,))
|
|
156
|
+
for i in range(self.num_layers):
|
|
157
|
+
gate = Dense(
|
|
158
|
+
units=dim,
|
|
159
|
+
bias_initializer=self.gate_bias_initializer,
|
|
160
|
+
activation='sigmoid',
|
|
161
|
+
name='gate_%d' % i)
|
|
162
|
+
self.gates.append(gate)
|
|
163
|
+
self.transforms.append(Dense(units=dim))
|
|
164
|
+
|
|
165
|
+
def call(self, inputs, training=None, **kwargs):
|
|
166
|
+
value = inputs
|
|
167
|
+
if self.project_layer is not None:
|
|
168
|
+
value = self.project_layer(inputs)
|
|
169
|
+
for i in range(self.num_layers):
|
|
170
|
+
gate = self.gates[i](value)
|
|
171
|
+
transformed = self.act_layer(self.transforms[i](value))
|
|
172
|
+
if self.dropout_layer is not None:
|
|
173
|
+
transformed = self.dropout_layer(transformed, training=training)
|
|
174
|
+
transformed_gated = self.multiply_layer([gate, transformed])
|
|
175
|
+
identity_gated = self.multiply_layer([self.carry_gate(gate), value])
|
|
176
|
+
value = self.add_layer([transformed_gated, identity_gated])
|
|
177
|
+
return value
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class Gate(Layer):
|
|
181
|
+
"""Weighted sum gate."""
|
|
182
|
+
|
|
183
|
+
def __init__(self, params, name='gate', reuse=None, **kwargs):
|
|
184
|
+
super(Gate, self).__init__(name=name, **kwargs)
|
|
185
|
+
self.weight_index = params.get_or_default('weight_index', 0)
|
|
186
|
+
if params.has_field('mlp'):
|
|
187
|
+
mlp_cfg = Parameter.make_from_pb(params.mlp)
|
|
188
|
+
mlp_cfg.l2_regularizer = params.l2_regularizer
|
|
189
|
+
self.top_mlp = MLP(mlp_cfg, name='top_mlp')
|
|
190
|
+
else:
|
|
191
|
+
self.top_mlp = None
|
|
192
|
+
|
|
193
|
+
def call(self, inputs, training=None, **kwargs):
|
|
194
|
+
assert len(
|
|
195
|
+
inputs
|
|
196
|
+
) > 1, 'input of Gate layer must be a list containing at least 2 elements'
|
|
197
|
+
weights = inputs[self.weight_index]
|
|
198
|
+
j = 0
|
|
199
|
+
for i, x in enumerate(inputs):
|
|
200
|
+
if i == self.weight_index:
|
|
201
|
+
continue
|
|
202
|
+
if j == 0:
|
|
203
|
+
output = weights[:, j, None] * x
|
|
204
|
+
else:
|
|
205
|
+
output += weights[:, j, None] * x
|
|
206
|
+
j += 1
|
|
207
|
+
if self.top_mlp is not None:
|
|
208
|
+
output = self.top_mlp(output, training=training)
|
|
209
|
+
return output
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class TextCNN(Layer):
|
|
213
|
+
"""Text CNN Model.
|
|
214
|
+
|
|
215
|
+
References
|
|
216
|
+
- [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/abs/1408.5882)
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
def __init__(self, params, name='text_cnn', reuse=None, **kwargs):
|
|
220
|
+
super(TextCNN, self).__init__(name=name, **kwargs)
|
|
221
|
+
self.config = params.get_pb_config()
|
|
222
|
+
self.pad_seq_length = self.config.pad_sequence_length
|
|
223
|
+
if self.pad_seq_length <= 0:
|
|
224
|
+
logging.warning(
|
|
225
|
+
'run text cnn with pad_sequence_length <= 0, the predict of model may be unstable'
|
|
226
|
+
)
|
|
227
|
+
self.conv_layers = []
|
|
228
|
+
self.pool_layer = tf.keras.layers.GlobalMaxPool1D()
|
|
229
|
+
self.concat_layer = tf.keras.layers.Concatenate(axis=-1)
|
|
230
|
+
for size, filters in zip(self.config.filter_sizes, self.config.num_filters):
|
|
231
|
+
conv = tf.keras.layers.Conv1D(
|
|
232
|
+
filters=int(filters),
|
|
233
|
+
kernel_size=int(size),
|
|
234
|
+
activation=self.config.activation)
|
|
235
|
+
self.conv_layers.append(conv)
|
|
236
|
+
if self.config.HasField('mlp'):
|
|
237
|
+
p = Parameter.make_from_pb(self.config.mlp)
|
|
238
|
+
p.l2_regularizer = params.l2_regularizer
|
|
239
|
+
self.mlp = MLP(p, name='mlp', reuse=reuse)
|
|
240
|
+
else:
|
|
241
|
+
self.mlp = None
|
|
242
|
+
|
|
243
|
+
def call(self, inputs, training=None, **kwargs):
|
|
244
|
+
"""Input shape: 3D tensor with shape: `(batch_size, steps, input_dim)."""
|
|
245
|
+
assert isinstance(inputs, (list, tuple))
|
|
246
|
+
assert len(inputs) >= 2
|
|
247
|
+
seq_emb, seq_len = inputs[:2]
|
|
248
|
+
|
|
249
|
+
if self.pad_seq_length > 0:
|
|
250
|
+
seq_emb, seq_len = pad_or_truncate_sequence(seq_emb, seq_len,
|
|
251
|
+
self.pad_seq_length)
|
|
252
|
+
pooled_outputs = []
|
|
253
|
+
for layer in self.conv_layers:
|
|
254
|
+
conv = layer(seq_emb)
|
|
255
|
+
pooled = self.pool_layer(conv)
|
|
256
|
+
pooled_outputs.append(pooled)
|
|
257
|
+
net = self.concat_layer(pooled_outputs)
|
|
258
|
+
if self.mlp is not None:
|
|
259
|
+
output = self.mlp(net, training=training)
|
|
260
|
+
else:
|
|
261
|
+
output = net
|
|
262
|
+
return output
|