easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
from tensorflow.python.keras.layers import Dense
|
|
7
|
+
from tensorflow.python.keras.layers import Layer
|
|
8
|
+
|
|
9
|
+
from easy_rec.python.layers.keras.attention import Attention
|
|
10
|
+
from easy_rec.python.layers.keras.blocks import MLP
|
|
11
|
+
from easy_rec.python.layers.utils import Parameter
|
|
12
|
+
from easy_rec.python.protos import seq_encoder_pb2
|
|
13
|
+
|
|
14
|
+
if tf.__version__ >= '2.0':
|
|
15
|
+
tf = tf.compat.v1
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MMoE(Layer):
|
|
19
|
+
"""Multi-gate Mixture-of-Experts model."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, params, name='MMoE', reuse=None, **kwargs):
|
|
22
|
+
super(MMoE, self).__init__(name=name, **kwargs)
|
|
23
|
+
params.check_required(['num_expert', 'num_task'])
|
|
24
|
+
self._reuse = reuse
|
|
25
|
+
self._num_expert = params.num_expert
|
|
26
|
+
self._num_task = params.num_task
|
|
27
|
+
if params.has_field('expert_mlp'):
|
|
28
|
+
expert_params = Parameter.make_from_pb(params.expert_mlp)
|
|
29
|
+
expert_params.l2_regularizer = params.l2_regularizer
|
|
30
|
+
self._has_experts = True
|
|
31
|
+
self._experts = [
|
|
32
|
+
MLP(expert_params, 'expert_%d' % i, reuse=reuse)
|
|
33
|
+
for i in range(self._num_expert)
|
|
34
|
+
]
|
|
35
|
+
else:
|
|
36
|
+
self._has_experts = False
|
|
37
|
+
|
|
38
|
+
self._gates = []
|
|
39
|
+
for task_id in range(self._num_task):
|
|
40
|
+
dense = Dense(
|
|
41
|
+
self._num_expert,
|
|
42
|
+
activation='softmax',
|
|
43
|
+
name='gate_%d' % task_id,
|
|
44
|
+
kernel_regularizer=params.l2_regularizer)
|
|
45
|
+
self._gates.append(dense)
|
|
46
|
+
|
|
47
|
+
def call(self, inputs, training=None, **kwargs):
|
|
48
|
+
if self._num_expert == 0:
|
|
49
|
+
logging.warning('num_expert of MMoE layer `%s` is 0' % self.name)
|
|
50
|
+
return inputs
|
|
51
|
+
if self._has_experts:
|
|
52
|
+
expert_fea_list = [
|
|
53
|
+
expert(inputs, training=training) for expert in self._experts
|
|
54
|
+
]
|
|
55
|
+
else:
|
|
56
|
+
expert_fea_list = inputs
|
|
57
|
+
experts_fea = tf.stack(expert_fea_list, axis=1)
|
|
58
|
+
# 不使用内置MLP作为expert时,gate的input使用最后一个额外的输入
|
|
59
|
+
gate_input = inputs if self._has_experts else inputs[self._num_expert]
|
|
60
|
+
task_input_list = []
|
|
61
|
+
for task_id in range(self._num_task):
|
|
62
|
+
gate = self._gates[task_id](gate_input)
|
|
63
|
+
gate = tf.expand_dims(gate, -1)
|
|
64
|
+
task_input = tf.multiply(experts_fea, gate)
|
|
65
|
+
task_input = tf.reduce_sum(task_input, axis=1)
|
|
66
|
+
task_input_list.append(task_input)
|
|
67
|
+
return task_input_list
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class AITMTower(Layer):
|
|
71
|
+
"""Adaptive Information Transfer Multi-task (AITM) Tower."""
|
|
72
|
+
|
|
73
|
+
def __init__(self, params, name='AITMTower', reuse=None, **kwargs):
|
|
74
|
+
super(AITMTower, self).__init__(name=name, **kwargs)
|
|
75
|
+
self.project_dim = params.get_or_default('project_dim', None)
|
|
76
|
+
self.stop_gradient = params.get_or_default('stop_gradient', True)
|
|
77
|
+
self.transfer = None
|
|
78
|
+
if params.has_field('transfer_mlp'):
|
|
79
|
+
mlp_cfg = Parameter.make_from_pb(params.transfer_mlp)
|
|
80
|
+
mlp_cfg.l2_regularizer = params.l2_regularizer
|
|
81
|
+
self.transfer = MLP(mlp_cfg, name='transfer')
|
|
82
|
+
self.queries = []
|
|
83
|
+
self.keys = []
|
|
84
|
+
self.values = []
|
|
85
|
+
self.attention = None
|
|
86
|
+
|
|
87
|
+
def build(self, input_shape):
|
|
88
|
+
if not isinstance(input_shape, (tuple, list)):
|
|
89
|
+
super(AITMTower, self).build(input_shape)
|
|
90
|
+
return
|
|
91
|
+
dim = self.project_dim if self.project_dim else int(input_shape[0][-1])
|
|
92
|
+
for i in range(len(input_shape)):
|
|
93
|
+
self.queries.append(Dense(dim, name='query_%d' % i))
|
|
94
|
+
self.keys.append(Dense(dim, name='key_%d' % i))
|
|
95
|
+
self.values.append(Dense(dim, name='value_%d' % i))
|
|
96
|
+
attn_cfg = seq_encoder_pb2.Attention()
|
|
97
|
+
attn_cfg.scale_by_dim = True
|
|
98
|
+
attn_params = Parameter.make_from_pb(attn_cfg)
|
|
99
|
+
self.attention = Attention(attn_params)
|
|
100
|
+
super(AITMTower, self).build(input_shape)
|
|
101
|
+
|
|
102
|
+
def call(self, inputs, training=None, **kwargs):
|
|
103
|
+
if not isinstance(inputs, (tuple, list)):
|
|
104
|
+
return inputs
|
|
105
|
+
|
|
106
|
+
queries = []
|
|
107
|
+
keys = []
|
|
108
|
+
values = []
|
|
109
|
+
for i, tower in enumerate(inputs):
|
|
110
|
+
if i == 0: # current tower
|
|
111
|
+
queries.append(self.queries[i](tower))
|
|
112
|
+
keys.append(self.keys[i](tower))
|
|
113
|
+
values.append(self.values[i](tower))
|
|
114
|
+
else:
|
|
115
|
+
dep = tf.stop_gradient(tower) if self.stop_gradient else tower
|
|
116
|
+
if self.transfer is not None:
|
|
117
|
+
dep = self.transfer(dep, training=training)
|
|
118
|
+
queries.append(self.queries[i](dep))
|
|
119
|
+
keys.append(self.keys[i](dep))
|
|
120
|
+
values.append(self.values[i](dep))
|
|
121
|
+
query = tf.stack(queries, axis=1)
|
|
122
|
+
key = tf.stack(keys, axis=1)
|
|
123
|
+
value = tf.stack(values, axis=1)
|
|
124
|
+
attn = self.attention([query, value, key])
|
|
125
|
+
return attn[:, 0, :]
|
|
@@ -0,0 +1,376 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
import math
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
from tensorflow.python.framework import ops
|
|
9
|
+
from tensorflow.python.keras.layers import Layer
|
|
10
|
+
|
|
11
|
+
from easy_rec.python.compat.array_ops import repeat
|
|
12
|
+
from easy_rec.python.utils.activation import get_activation
|
|
13
|
+
from easy_rec.python.utils.tf_utils import get_ps_num_from_tf_config
|
|
14
|
+
|
|
15
|
+
curr_dir, _ = os.path.split(__file__)
|
|
16
|
+
parent_dir = os.path.dirname(curr_dir)
|
|
17
|
+
ops_idr = os.path.dirname(parent_dir)
|
|
18
|
+
ops_dir = os.path.join(ops_idr, 'ops')
|
|
19
|
+
if 'PAI' in tf.__version__:
|
|
20
|
+
ops_dir = os.path.join(ops_dir, '1.12_pai')
|
|
21
|
+
elif tf.__version__.startswith('1.12'):
|
|
22
|
+
ops_dir = os.path.join(ops_dir, '1.12')
|
|
23
|
+
elif tf.__version__.startswith('1.15'):
|
|
24
|
+
if 'IS_ON_PAI' in os.environ:
|
|
25
|
+
ops_dir = os.path.join(ops_dir, 'DeepRec')
|
|
26
|
+
else:
|
|
27
|
+
ops_dir = os.path.join(ops_dir, '1.15')
|
|
28
|
+
elif tf.__version__.startswith('2.12'):
|
|
29
|
+
ops_dir = os.path.join(ops_dir, '2.12')
|
|
30
|
+
|
|
31
|
+
logging.info('ops_dir is %s' % ops_dir)
|
|
32
|
+
custom_op_path = os.path.join(ops_dir, 'libcustom_ops.so')
|
|
33
|
+
try:
|
|
34
|
+
custom_ops = tf.load_op_library(custom_op_path)
|
|
35
|
+
logging.info('load custom op from %s succeed' % custom_op_path)
|
|
36
|
+
except Exception as ex:
|
|
37
|
+
logging.warning('load custom op from %s failed: %s' %
|
|
38
|
+
(custom_op_path, str(ex)))
|
|
39
|
+
custom_ops = None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class NLinear(Layer):
|
|
43
|
+
"""N linear layers for N token (feature) embeddings.
|
|
44
|
+
|
|
45
|
+
To understand this module, let's revise `tf.layers.dense`. When `tf.layers.dense` is
|
|
46
|
+
applied to three-dimensional inputs of the shape
|
|
47
|
+
``(batch_size, n_tokens, d_embedding)``, then the same linear transformation is
|
|
48
|
+
applied to each of ``n_tokens`` token (feature) embeddings.
|
|
49
|
+
|
|
50
|
+
By contrast, `NLinear` allocates one linear layer per token (``n_tokens`` layers in total).
|
|
51
|
+
One such layer can be represented as ``tf.layers.dense(d_in, d_out)``.
|
|
52
|
+
So, the i-th linear transformation is applied to the i-th token embedding, as
|
|
53
|
+
illustrated in the following pseudocode::
|
|
54
|
+
|
|
55
|
+
layers = [tf.layers.dense(d_in, d_out) for _ in range(n_tokens)]
|
|
56
|
+
x = tf.random.normal(batch_size, n_tokens, d_in)
|
|
57
|
+
result = tf.stack([layers[i](x[:, i]) for i in range(n_tokens)], 1)
|
|
58
|
+
|
|
59
|
+
Examples:
|
|
60
|
+
.. testcode::
|
|
61
|
+
|
|
62
|
+
batch_size = 2
|
|
63
|
+
n_features = 3
|
|
64
|
+
d_embedding_in = 4
|
|
65
|
+
d_embedding_out = 5
|
|
66
|
+
x = tf.random.normal(batch_size, n_features, d_embedding_in)
|
|
67
|
+
m = NLinear(n_features, d_embedding_in, d_embedding_out)
|
|
68
|
+
assert m(x).shape == (batch_size, n_features, d_embedding_out)
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(self,
|
|
72
|
+
n_tokens,
|
|
73
|
+
d_in,
|
|
74
|
+
d_out,
|
|
75
|
+
bias=True,
|
|
76
|
+
name='nd_linear',
|
|
77
|
+
**kwargs):
|
|
78
|
+
"""Init with input shapes.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
n_tokens: the number of tokens (features)
|
|
82
|
+
d_in: the input dimension
|
|
83
|
+
d_out: the output dimension
|
|
84
|
+
bias: indicates if the underlying linear layers have biases
|
|
85
|
+
name: layer name
|
|
86
|
+
"""
|
|
87
|
+
super(NLinear, self).__init__(name=name, **kwargs)
|
|
88
|
+
self.weight = self.add_weight(
|
|
89
|
+
'weights', [1, n_tokens, d_in, d_out], dtype=tf.float32)
|
|
90
|
+
if bias:
|
|
91
|
+
initializer = tf.constant_initializer(0.0)
|
|
92
|
+
self.bias = self.add_weight(
|
|
93
|
+
'bias', [1, n_tokens, d_out],
|
|
94
|
+
dtype=tf.float32,
|
|
95
|
+
initializer=initializer)
|
|
96
|
+
else:
|
|
97
|
+
self.bias = None
|
|
98
|
+
|
|
99
|
+
def call(self, x, **kwargs):
|
|
100
|
+
if x.shape.ndims != 3:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
'The input must have three dimensions (batch_size, n_tokens, d_embedding)'
|
|
103
|
+
)
|
|
104
|
+
if x.shape[2] != self.weight.shape[2]:
|
|
105
|
+
raise ValueError('invalid input embedding dimension %d, expect %d' %
|
|
106
|
+
(int(x.shape[2]), int(self.weight.shape[2])))
|
|
107
|
+
|
|
108
|
+
x = x[..., None] * self.weight # [B, N, D, D_out]
|
|
109
|
+
x = tf.reduce_sum(x, axis=-2) # [B, N, D_out]
|
|
110
|
+
if self.bias is not None:
|
|
111
|
+
x = x + self.bias
|
|
112
|
+
return x
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class PeriodicEmbedding(Layer):
|
|
116
|
+
"""Periodic embeddings for numerical features described in [1].
|
|
117
|
+
|
|
118
|
+
References:
|
|
119
|
+
* [1] Yury Gorishniy, Ivan Rubachev, Artem Babenko,
|
|
120
|
+
"On Embeddings for Numerical Features in Tabular Deep Learning", 2022
|
|
121
|
+
https://arxiv.org/pdf/2203.05556.pdf
|
|
122
|
+
|
|
123
|
+
Attributes:
|
|
124
|
+
embedding_dim: the embedding size, must be an even positive integer.
|
|
125
|
+
sigma: the scale of the weight initialization.
|
|
126
|
+
**This is a super important parameter which significantly affects performance**.
|
|
127
|
+
Its optimal value can be dramatically different for different datasets, so
|
|
128
|
+
no "default value" can exist for this parameter, and it must be tuned for
|
|
129
|
+
each dataset. In the original paper, during hyperparameter tuning, this
|
|
130
|
+
parameter was sampled from the distribution ``LogUniform[1e-2, 1e2]``.
|
|
131
|
+
A similar grid would be ``[1e-2, 1e-1, 1e0, 1e1, 1e2]``.
|
|
132
|
+
If possible, add more intermediate values to this grid.
|
|
133
|
+
output_3d_tensor: whether to output a 3d tensor
|
|
134
|
+
output_tensor_list: whether to output the list of embedding
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
def __init__(self, params, name='periodic_embedding', reuse=None, **kwargs):
|
|
138
|
+
super(PeriodicEmbedding, self).__init__(name=name, **kwargs)
|
|
139
|
+
self.reuse = reuse
|
|
140
|
+
params.check_required(['embedding_dim', 'sigma'])
|
|
141
|
+
self.embedding_dim = int(params.embedding_dim)
|
|
142
|
+
if self.embedding_dim % 2:
|
|
143
|
+
raise ValueError('embedding_dim must be even')
|
|
144
|
+
sigma = params.sigma
|
|
145
|
+
self.initializer = tf.random_normal_initializer(stddev=sigma)
|
|
146
|
+
self.add_linear_layer = params.get_or_default('add_linear_layer', True)
|
|
147
|
+
self.linear_activation = params.get_or_default('linear_activation', 'relu')
|
|
148
|
+
self.output_tensor_list = params.get_or_default('output_tensor_list', False)
|
|
149
|
+
self.output_3d_tensor = params.get_or_default('output_3d_tensor', False)
|
|
150
|
+
|
|
151
|
+
def build(self, input_shape):
|
|
152
|
+
if input_shape.ndims != 2:
|
|
153
|
+
raise ValueError('inputs of AutoDisEmbedding must have 2 dimensions.')
|
|
154
|
+
self.num_features = int(input_shape[-1])
|
|
155
|
+
num_ps = get_ps_num_from_tf_config()
|
|
156
|
+
partitioner = None
|
|
157
|
+
if num_ps > 0:
|
|
158
|
+
partitioner = tf.fixed_size_partitioner(num_shards=num_ps)
|
|
159
|
+
emb_dim = self.embedding_dim // 2
|
|
160
|
+
self.coef = self.add_weight(
|
|
161
|
+
'coefficients',
|
|
162
|
+
shape=[1, self.num_features, emb_dim],
|
|
163
|
+
partitioner=partitioner,
|
|
164
|
+
initializer=self.initializer)
|
|
165
|
+
if self.add_linear_layer:
|
|
166
|
+
self.linear = NLinear(
|
|
167
|
+
self.num_features,
|
|
168
|
+
self.embedding_dim,
|
|
169
|
+
self.embedding_dim,
|
|
170
|
+
name='nd_linear')
|
|
171
|
+
super(PeriodicEmbedding, self).build(input_shape)
|
|
172
|
+
|
|
173
|
+
def call(self, inputs, **kwargs):
|
|
174
|
+
features = inputs[..., None] # [B, N, 1]
|
|
175
|
+
v = 2 * math.pi * self.coef * features # [B, N, E]
|
|
176
|
+
emb = tf.concat([tf.sin(v), tf.cos(v)], axis=-1) # [B, N, 2E]
|
|
177
|
+
|
|
178
|
+
dim = self.embedding_dim
|
|
179
|
+
if self.add_linear_layer:
|
|
180
|
+
emb = self.linear(emb)
|
|
181
|
+
act = get_activation(self.linear_activation)
|
|
182
|
+
if callable(act):
|
|
183
|
+
emb = act(emb)
|
|
184
|
+
output = tf.reshape(emb, [-1, self.num_features * dim])
|
|
185
|
+
|
|
186
|
+
if self.output_tensor_list:
|
|
187
|
+
return output, tf.unstack(emb, axis=1)
|
|
188
|
+
if self.output_3d_tensor:
|
|
189
|
+
return output, emb
|
|
190
|
+
return output
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class AutoDisEmbedding(Layer):
|
|
194
|
+
"""An Embedding Learning Framework for Numerical Features in CTR Prediction.
|
|
195
|
+
|
|
196
|
+
Refer: https://arxiv.org/pdf/2012.08986v2.pdf
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
def __init__(self, params, name='auto_dis_embedding', reuse=None, **kwargs):
|
|
200
|
+
super(AutoDisEmbedding, self).__init__(name=name, **kwargs)
|
|
201
|
+
self.reuse = reuse
|
|
202
|
+
params.check_required(['embedding_dim', 'num_bins', 'temperature'])
|
|
203
|
+
self.emb_dim = int(params.embedding_dim)
|
|
204
|
+
self.num_bins = int(params.num_bins)
|
|
205
|
+
self.temperature = params.temperature
|
|
206
|
+
self.keep_prob = params.get_or_default('keep_prob', 0.8)
|
|
207
|
+
self.output_tensor_list = params.get_or_default('output_tensor_list', False)
|
|
208
|
+
self.output_3d_tensor = params.get_or_default('output_3d_tensor', False)
|
|
209
|
+
|
|
210
|
+
def build(self, input_shape):
|
|
211
|
+
if input_shape.ndims != 2:
|
|
212
|
+
raise ValueError('inputs of AutoDisEmbedding must have 2 dimensions.')
|
|
213
|
+
self.num_features = int(input_shape[-1])
|
|
214
|
+
num_ps = get_ps_num_from_tf_config()
|
|
215
|
+
partitioner = None
|
|
216
|
+
if num_ps > 0:
|
|
217
|
+
partitioner = tf.fixed_size_partitioner(num_shards=num_ps)
|
|
218
|
+
self.meta_emb = self.add_weight(
|
|
219
|
+
'meta_embedding',
|
|
220
|
+
shape=[self.num_features, self.num_bins, self.emb_dim],
|
|
221
|
+
partitioner=partitioner)
|
|
222
|
+
self.proj_w = self.add_weight(
|
|
223
|
+
'project_w',
|
|
224
|
+
shape=[1, self.num_features, self.num_bins],
|
|
225
|
+
partitioner=partitioner)
|
|
226
|
+
self.proj_mat = self.add_weight(
|
|
227
|
+
'project_mat',
|
|
228
|
+
shape=[self.num_features, self.num_bins, self.num_bins],
|
|
229
|
+
partitioner=partitioner)
|
|
230
|
+
super(AutoDisEmbedding, self).build(input_shape)
|
|
231
|
+
|
|
232
|
+
def call(self, inputs, **kwargs):
|
|
233
|
+
x = tf.expand_dims(inputs, axis=-1) # [B, N, 1]
|
|
234
|
+
hidden = tf.nn.leaky_relu(self.proj_w * x) # [B, N, num_bin]
|
|
235
|
+
# 低版本的tf(1.12) matmul 不支持广播,所以改成 einsum
|
|
236
|
+
# y = tf.matmul(mat, hidden[..., None]) # [B, N, num_bin, 1]
|
|
237
|
+
# y = tf.squeeze(y, axis=3) # [B, N, num_bin]
|
|
238
|
+
y = tf.einsum('nik,bnk->bni', self.proj_mat, hidden) # [B, N, num_bin]
|
|
239
|
+
|
|
240
|
+
# keep_prob(float): if dropout_flag is True, keep_prob rate to keep connect
|
|
241
|
+
alpha = self.keep_prob
|
|
242
|
+
x_bar = y + alpha * hidden # [B, N, num_bin]
|
|
243
|
+
x_hat = tf.nn.softmax(x_bar / self.temperature) # [B, N, num_bin]
|
|
244
|
+
|
|
245
|
+
# emb = tf.matmul(x_hat[:, :, None, :], meta_emb) # [B, N, 1, D]
|
|
246
|
+
# emb = tf.squeeze(emb, axis=2) # [B, N, D]
|
|
247
|
+
emb = tf.einsum('bnk,nkd->bnd', x_hat, self.meta_emb)
|
|
248
|
+
output = tf.reshape(emb, [-1, self.emb_dim * self.num_features]) # [B, N*D]
|
|
249
|
+
|
|
250
|
+
if self.output_tensor_list:
|
|
251
|
+
return output, tf.unstack(emb, axis=1)
|
|
252
|
+
if self.output_3d_tensor:
|
|
253
|
+
return output, emb
|
|
254
|
+
return output
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class NaryDisEmbedding(Layer):
|
|
258
|
+
"""Numerical Feature Representation with Hybrid 𝑁 -ary Encoding, CIKM 2022..
|
|
259
|
+
|
|
260
|
+
Refer: https://dl.acm.org/doi/pdf/10.1145/3511808.3557090
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
def __init__(self, params, name='nary_dis_embedding', reuse=None, **kwargs):
|
|
264
|
+
super(NaryDisEmbedding, self).__init__(name=name, **kwargs)
|
|
265
|
+
self.reuse = reuse
|
|
266
|
+
self.nary_carry = custom_ops.nary_carry
|
|
267
|
+
params.check_required(['embedding_dim', 'carries'])
|
|
268
|
+
self.emb_dim = int(params.embedding_dim)
|
|
269
|
+
self.carries = params.get_or_default('carries', [2, 9])
|
|
270
|
+
self.num_replicas = params.get_or_default('num_replicas', 1)
|
|
271
|
+
assert self.num_replicas >= 1, 'num replicas must be >= 1'
|
|
272
|
+
self.lengths = list(map(self.max_length, self.carries))
|
|
273
|
+
self.vocab_size = int(sum(self.lengths))
|
|
274
|
+
self.multiplier = params.get_or_default('multiplier', 1.0)
|
|
275
|
+
self.intra_ary_pooling = params.get_or_default('intra_ary_pooling', 'sum')
|
|
276
|
+
self.output_3d_tensor = params.get_or_default('output_3d_tensor', False)
|
|
277
|
+
self.output_tensor_list = params.get_or_default('output_tensor_list', False)
|
|
278
|
+
logging.info(
|
|
279
|
+
'{} carries: {}, lengths: {}, vocab_size: {}, intra_ary: {}, replicas: {}, multiplier: {}'
|
|
280
|
+
.format(self.name, ','.join(map(str, self.carries)),
|
|
281
|
+
','.join(map(str, self.lengths)), self.vocab_size,
|
|
282
|
+
self.intra_ary_pooling, self.num_replicas, self.multiplier))
|
|
283
|
+
|
|
284
|
+
@staticmethod
|
|
285
|
+
def max_length(carry):
|
|
286
|
+
bits = math.log(4294967295, carry)
|
|
287
|
+
return (math.floor(bits) + 1) * carry
|
|
288
|
+
|
|
289
|
+
def build(self, input_shape):
|
|
290
|
+
assert isinstance(input_shape,
|
|
291
|
+
tf.TensorShape), 'NaryDisEmbedding only takes 1 input'
|
|
292
|
+
self.num_features = int(input_shape[-1])
|
|
293
|
+
logging.info('%s has %d input features', self.name, self.num_features)
|
|
294
|
+
vocab_size = self.num_features * self.vocab_size
|
|
295
|
+
emb_dim = self.emb_dim * self.num_replicas
|
|
296
|
+
num_ps = get_ps_num_from_tf_config()
|
|
297
|
+
partitioner = None
|
|
298
|
+
if num_ps > 0:
|
|
299
|
+
partitioner = tf.fixed_size_partitioner(num_shards=num_ps)
|
|
300
|
+
self.embedding_table = self.add_weight(
|
|
301
|
+
'embed_table', shape=[vocab_size, emb_dim], partitioner=partitioner)
|
|
302
|
+
super(NaryDisEmbedding, self).build(input_shape)
|
|
303
|
+
|
|
304
|
+
def call(self, inputs, **kwargs):
|
|
305
|
+
if inputs.shape.ndims != 2:
|
|
306
|
+
raise ValueError('inputs of NaryDisEmbedding must have 2 dimensions.')
|
|
307
|
+
if self.multiplier != 1.0:
|
|
308
|
+
inputs *= self.multiplier
|
|
309
|
+
inputs = tf.to_int32(inputs)
|
|
310
|
+
offset, emb_indices, emb_splits = 0, [], []
|
|
311
|
+
with ops.device('/CPU:0'):
|
|
312
|
+
for carry, length in zip(self.carries, self.lengths):
|
|
313
|
+
values, splits = self.nary_carry(inputs, carry=carry, offset=offset)
|
|
314
|
+
offset += length
|
|
315
|
+
emb_indices.append(values)
|
|
316
|
+
emb_splits.append(splits)
|
|
317
|
+
indices = tf.concat(emb_indices, axis=0)
|
|
318
|
+
splits = tf.concat(emb_splits, axis=0)
|
|
319
|
+
# embedding shape: [B*N*C, D]
|
|
320
|
+
embedding = tf.nn.embedding_lookup(self.embedding_table, indices)
|
|
321
|
+
|
|
322
|
+
total_length = tf.size(splits)
|
|
323
|
+
if self.intra_ary_pooling == 'sum':
|
|
324
|
+
if tf.__version__ >= '2.0':
|
|
325
|
+
segment_ids = tf.repeat(tf.range(total_length), repeats=splits)
|
|
326
|
+
else:
|
|
327
|
+
segment_ids = repeat(tf.range(total_length), repeats=splits)
|
|
328
|
+
embedding = tf.math.segment_sum(embedding, segment_ids)
|
|
329
|
+
elif self.intra_ary_pooling == 'mean':
|
|
330
|
+
if tf.__version__ >= '2.0':
|
|
331
|
+
segment_ids = tf.repeat(tf.range(total_length), repeats=splits)
|
|
332
|
+
else:
|
|
333
|
+
segment_ids = repeat(tf.range(total_length), repeats=splits)
|
|
334
|
+
embedding = tf.math.segment_mean(embedding, segment_ids)
|
|
335
|
+
else:
|
|
336
|
+
raise ValueError('Unsupported intra ary pooling method %s' %
|
|
337
|
+
self.intra_ary_pooling)
|
|
338
|
+
# B: batch size
|
|
339
|
+
# N: num features
|
|
340
|
+
# C: num carries
|
|
341
|
+
# D: embedding dimension
|
|
342
|
+
# R: num replicas
|
|
343
|
+
# shape of embedding: [B*N*C, R*D]
|
|
344
|
+
N = self.num_features
|
|
345
|
+
C = len(self.carries)
|
|
346
|
+
D = self.emb_dim
|
|
347
|
+
if self.num_replicas == 1:
|
|
348
|
+
embedding = tf.reshape(embedding, [C, -1, D]) # [C, B*N, D]
|
|
349
|
+
embedding = tf.transpose(embedding, perm=[1, 0, 2]) # [B*N, C, D]
|
|
350
|
+
embedding = tf.reshape(embedding, [-1, C * D]) # [B*N, C*D]
|
|
351
|
+
output = tf.reshape(embedding, [-1, N * C * D]) # [B, N*C*D]
|
|
352
|
+
if self.output_tensor_list:
|
|
353
|
+
return output, tf.split(embedding, N) # [B, C*D] * N
|
|
354
|
+
if self.output_3d_tensor:
|
|
355
|
+
embedding = tf.reshape(embedding, [-1, N, C * D]) # [B, N, C*D]
|
|
356
|
+
return output, embedding
|
|
357
|
+
return output
|
|
358
|
+
|
|
359
|
+
# self.num_replicas > 1:
|
|
360
|
+
replicas = tf.split(embedding, self.num_replicas, axis=1)
|
|
361
|
+
outputs = []
|
|
362
|
+
outputs2 = []
|
|
363
|
+
for replica in replicas:
|
|
364
|
+
# shape of replica: [B*N*C, D]
|
|
365
|
+
embedding = tf.reshape(replica, [C, -1, D]) # [C, B*N, D]
|
|
366
|
+
embedding = tf.transpose(embedding, perm=[1, 0, 2]) # [B*N, C, D]
|
|
367
|
+
embedding = tf.reshape(embedding, [-1, C * D]) # [B*N, C*D]
|
|
368
|
+
output = tf.reshape(embedding, [-1, N * C * D]) # [B, N*C*D]
|
|
369
|
+
outputs.append(output)
|
|
370
|
+
if self.output_tensor_list:
|
|
371
|
+
embedding = tf.split(embedding, N) # [B, C*D] * N
|
|
372
|
+
outputs2.append(embedding)
|
|
373
|
+
elif self.output_3d_tensor:
|
|
374
|
+
embedding = tf.reshape(embedding, [-1, N, C * D]) # [B, N, C*D]
|
|
375
|
+
outputs2.append(embedding)
|
|
376
|
+
return outputs + outputs2
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
"""Convenience blocks for building models."""
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
|
|
8
|
+
from easy_rec.python.layers.keras.activation import activation_layer
|
|
9
|
+
from easy_rec.python.utils.tf_utils import add_elements_to_collection
|
|
10
|
+
|
|
11
|
+
if tf.__version__ >= '2.0':
|
|
12
|
+
tf = tf.compat.v1
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class GateNN(tf.keras.layers.Layer):
|
|
16
|
+
|
|
17
|
+
def __init__(self,
|
|
18
|
+
params,
|
|
19
|
+
output_units=None,
|
|
20
|
+
name='gate_nn',
|
|
21
|
+
reuse=None,
|
|
22
|
+
**kwargs):
|
|
23
|
+
super(GateNN, self).__init__(name=name, **kwargs)
|
|
24
|
+
output_dim = output_units if output_units is not None else params.output_dim
|
|
25
|
+
hidden_dim = params.get_or_default('hidden_dim', output_dim)
|
|
26
|
+
initializer = params.get_or_default('initializer', 'he_uniform')
|
|
27
|
+
do_batch_norm = params.get_or_default('use_bn', False)
|
|
28
|
+
activation = params.get_or_default('activation', 'relu')
|
|
29
|
+
dropout_rate = params.get_or_default('dropout_rate', 0.0)
|
|
30
|
+
|
|
31
|
+
self._sub_layers = []
|
|
32
|
+
dense = tf.keras.layers.Dense(
|
|
33
|
+
units=hidden_dim,
|
|
34
|
+
use_bias=not do_batch_norm,
|
|
35
|
+
kernel_initializer=initializer)
|
|
36
|
+
self._sub_layers.append(dense)
|
|
37
|
+
|
|
38
|
+
if do_batch_norm:
|
|
39
|
+
bn = tf.keras.layers.BatchNormalization(trainable=True)
|
|
40
|
+
self._sub_layers.append(bn)
|
|
41
|
+
|
|
42
|
+
act_layer = activation_layer(activation)
|
|
43
|
+
self._sub_layers.append(act_layer)
|
|
44
|
+
|
|
45
|
+
if 0.0 < dropout_rate < 1.0:
|
|
46
|
+
dropout = tf.keras.layers.Dropout(dropout_rate)
|
|
47
|
+
self._sub_layers.append(dropout)
|
|
48
|
+
elif dropout_rate >= 1.0:
|
|
49
|
+
raise ValueError('invalid dropout_ratio: %.3f' % dropout_rate)
|
|
50
|
+
|
|
51
|
+
dense = tf.keras.layers.Dense(
|
|
52
|
+
units=output_dim,
|
|
53
|
+
activation='sigmoid',
|
|
54
|
+
use_bias=not do_batch_norm,
|
|
55
|
+
kernel_initializer=initializer,
|
|
56
|
+
name='weight')
|
|
57
|
+
self._sub_layers.append(dense)
|
|
58
|
+
self._sub_layers.append(lambda x: x * 2)
|
|
59
|
+
|
|
60
|
+
def call(self, x, training=None, **kwargs):
|
|
61
|
+
"""Performs the forward computation of the block."""
|
|
62
|
+
for layer in self._sub_layers:
|
|
63
|
+
cls = layer.__class__.__name__
|
|
64
|
+
if cls in ('Dropout', 'BatchNormalization', 'Dice'):
|
|
65
|
+
x = layer(x, training=training)
|
|
66
|
+
if cls in ('BatchNormalization', 'Dice') and training:
|
|
67
|
+
add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS)
|
|
68
|
+
else:
|
|
69
|
+
x = layer(x)
|
|
70
|
+
return x
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class PPNet(tf.keras.layers.Layer):
|
|
74
|
+
"""PEPNet: Parameter and Embedding Personalized Network for Infusing with Personalized Prior Information.
|
|
75
|
+
|
|
76
|
+
Attributes:
|
|
77
|
+
units: Sequential list of layer sizes.
|
|
78
|
+
use_bias: Whether to include a bias term.
|
|
79
|
+
activation: Type of activation to use on all except the last layer.
|
|
80
|
+
final_activation: Type of activation to use on last layer.
|
|
81
|
+
**kwargs: Extra args passed to the Keras Layer base class.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(self, params, name='ppnet', reuse=None, **kwargs):
|
|
85
|
+
super(PPNet, self).__init__(name=name, **kwargs)
|
|
86
|
+
params.check_required('mlp')
|
|
87
|
+
self.full_gate_input = params.get_or_default('full_gate_input', True)
|
|
88
|
+
mode = params.get_or_default('mode', 'lazy')
|
|
89
|
+
gate_params = params.gate_params
|
|
90
|
+
params = params.mlp
|
|
91
|
+
params.check_required('hidden_units')
|
|
92
|
+
use_bn = params.get_or_default('use_bn', True)
|
|
93
|
+
use_final_bn = params.get_or_default('use_final_bn', True)
|
|
94
|
+
use_bias = params.get_or_default('use_bias', False)
|
|
95
|
+
use_final_bias = params.get_or_default('use_final_bias', False)
|
|
96
|
+
dropout_rate = list(params.get_or_default('dropout_ratio', []))
|
|
97
|
+
activation = params.get_or_default('activation', 'relu')
|
|
98
|
+
initializer = params.get_or_default('initializer', 'he_uniform')
|
|
99
|
+
final_activation = params.get_or_default('final_activation', None)
|
|
100
|
+
use_bn_after_act = params.get_or_default('use_bn_after_activation', False)
|
|
101
|
+
units = list(params.hidden_units)
|
|
102
|
+
logging.info(
|
|
103
|
+
'MLP(%s) units: %s, dropout: %r, activate=%s, use_bn=%r, final_bn=%r,'
|
|
104
|
+
' final_activate=%s, bias=%r, initializer=%s, bn_after_activation=%r' %
|
|
105
|
+
(name, units, dropout_rate, activation, use_bn, use_final_bn,
|
|
106
|
+
final_activation, use_bias, initializer, use_bn_after_act))
|
|
107
|
+
assert len(units) > 0, 'MLP(%s) takes at least one hidden units' % name
|
|
108
|
+
self.reuse = reuse
|
|
109
|
+
|
|
110
|
+
num_dropout = len(dropout_rate)
|
|
111
|
+
self._sub_layers = []
|
|
112
|
+
|
|
113
|
+
if mode != 'lazy':
|
|
114
|
+
self._sub_layers.append(GateNN(gate_params, None, 'gate_0'))
|
|
115
|
+
for i, num_units in enumerate(units[:-1]):
|
|
116
|
+
name = 'layer_%d' % i
|
|
117
|
+
drop_rate = dropout_rate[i] if i < num_dropout else 0.0
|
|
118
|
+
self.add_rich_layer(num_units, use_bn, drop_rate, activation, initializer,
|
|
119
|
+
use_bias, use_bn_after_act, name,
|
|
120
|
+
params.l2_regularizer)
|
|
121
|
+
self._sub_layers.append(
|
|
122
|
+
GateNN(gate_params, num_units, 'gate_%d' % (i + 1)))
|
|
123
|
+
|
|
124
|
+
n = len(units) - 1
|
|
125
|
+
drop_rate = dropout_rate[n] if num_dropout > n else 0.0
|
|
126
|
+
name = 'layer_%d' % n
|
|
127
|
+
self.add_rich_layer(units[-1], use_final_bn, drop_rate, final_activation,
|
|
128
|
+
initializer, use_final_bias, use_bn_after_act, name,
|
|
129
|
+
params.l2_regularizer)
|
|
130
|
+
if mode == 'lazy':
|
|
131
|
+
self._sub_layers.append(
|
|
132
|
+
GateNN(gate_params, units[-1], 'gate_%d' % (n + 1)))
|
|
133
|
+
|
|
134
|
+
def add_rich_layer(self,
|
|
135
|
+
num_units,
|
|
136
|
+
use_bn,
|
|
137
|
+
dropout_rate,
|
|
138
|
+
activation,
|
|
139
|
+
initializer,
|
|
140
|
+
use_bias,
|
|
141
|
+
use_bn_after_activation,
|
|
142
|
+
name,
|
|
143
|
+
l2_reg=None):
|
|
144
|
+
act_layer = activation_layer(activation, name='%s/act' % name)
|
|
145
|
+
if use_bn and not use_bn_after_activation:
|
|
146
|
+
dense = tf.keras.layers.Dense(
|
|
147
|
+
units=num_units,
|
|
148
|
+
use_bias=use_bias,
|
|
149
|
+
kernel_initializer=initializer,
|
|
150
|
+
kernel_regularizer=l2_reg,
|
|
151
|
+
name='%s/dense' % name)
|
|
152
|
+
self._sub_layers.append(dense)
|
|
153
|
+
bn = tf.keras.layers.BatchNormalization(
|
|
154
|
+
name='%s/bn' % name, trainable=True)
|
|
155
|
+
self._sub_layers.append(bn)
|
|
156
|
+
self._sub_layers.append(act_layer)
|
|
157
|
+
else:
|
|
158
|
+
dense = tf.keras.layers.Dense(
|
|
159
|
+
num_units,
|
|
160
|
+
use_bias=use_bias,
|
|
161
|
+
kernel_initializer=initializer,
|
|
162
|
+
kernel_regularizer=l2_reg,
|
|
163
|
+
name='%s/dense' % name)
|
|
164
|
+
self._sub_layers.append(dense)
|
|
165
|
+
self._sub_layers.append(act_layer)
|
|
166
|
+
if use_bn and use_bn_after_activation:
|
|
167
|
+
bn = tf.keras.layers.BatchNormalization(name='%s/bn' % name)
|
|
168
|
+
self._sub_layers.append(bn)
|
|
169
|
+
|
|
170
|
+
if 0.0 < dropout_rate < 1.0:
|
|
171
|
+
dropout = tf.keras.layers.Dropout(dropout_rate, name='%s/dropout' % name)
|
|
172
|
+
self._sub_layers.append(dropout)
|
|
173
|
+
elif dropout_rate >= 1.0:
|
|
174
|
+
raise ValueError('invalid dropout_ratio: %.3f' % dropout_rate)
|
|
175
|
+
|
|
176
|
+
def call(self, inputs, training=None, **kwargs):
|
|
177
|
+
"""Performs the forward computation of the block."""
|
|
178
|
+
x, gate_input = inputs
|
|
179
|
+
if self.full_gate_input:
|
|
180
|
+
with tf.name_scope(self.name):
|
|
181
|
+
gate_input = tf.concat([tf.stop_gradient(x), gate_input], axis=-1)
|
|
182
|
+
|
|
183
|
+
for layer in self._sub_layers:
|
|
184
|
+
cls = layer.__class__.__name__
|
|
185
|
+
if cls == 'GateNN':
|
|
186
|
+
gate = layer(gate_input)
|
|
187
|
+
x *= gate
|
|
188
|
+
elif cls in ('Dropout', 'BatchNormalization', 'Dice'):
|
|
189
|
+
x = layer(x, training=training)
|
|
190
|
+
if cls in ('BatchNormalization', 'Dice') and training:
|
|
191
|
+
add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS)
|
|
192
|
+
else:
|
|
193
|
+
x = layer(x)
|
|
194
|
+
return x
|