easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import itertools
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
from tensorflow.python.keras.layers import Dense
|
|
8
|
+
from tensorflow.python.keras.layers import Layer
|
|
9
|
+
|
|
10
|
+
from easy_rec.python.layers.keras.blocks import MLP
|
|
11
|
+
from easy_rec.python.layers.keras.layer_norm import LayerNormalization
|
|
12
|
+
from easy_rec.python.layers.utils import Parameter
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SENet(Layer):
|
|
16
|
+
"""SENET Layer used in FiBiNET.
|
|
17
|
+
|
|
18
|
+
Input shape
|
|
19
|
+
- A list of 2D tensor with shape: ``(batch_size,embedding_size)``.
|
|
20
|
+
The ``embedding_size`` of each field can have different value.
|
|
21
|
+
|
|
22
|
+
Output shape
|
|
23
|
+
- A 2D tensor with shape: ``(batch_size,sum_of_embedding_size)``.
|
|
24
|
+
|
|
25
|
+
References:
|
|
26
|
+
1. [FiBiNET](https://arxiv.org/pdf/1905.09433.pdf)
|
|
27
|
+
Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction
|
|
28
|
+
2. [FiBiNet++](https://arxiv.org/pdf/2209.05016.pdf)
|
|
29
|
+
Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, params, name='SENet', reuse=None, **kwargs):
|
|
33
|
+
super(SENet, self).__init__(name=name, **kwargs)
|
|
34
|
+
self.config = params.get_pb_config()
|
|
35
|
+
self.reuse = reuse
|
|
36
|
+
if tf.__version__ >= '2.0':
|
|
37
|
+
self.layer_norm = tf.keras.layers.LayerNormalization(name='output_ln')
|
|
38
|
+
else:
|
|
39
|
+
self.layer_norm = LayerNormalization(name='output_ln')
|
|
40
|
+
|
|
41
|
+
def build(self, input_shape):
|
|
42
|
+
g = self.config.num_squeeze_group
|
|
43
|
+
emb_size = 0
|
|
44
|
+
for shape in input_shape:
|
|
45
|
+
assert shape.ndims == 2, 'field embeddings must be rank 2 tensors'
|
|
46
|
+
dim = int(shape[-1])
|
|
47
|
+
assert dim >= g and dim % g == 0, 'field embedding dimension %d must be divisible by %d' % (
|
|
48
|
+
dim, g)
|
|
49
|
+
emb_size += dim
|
|
50
|
+
|
|
51
|
+
r = self.config.reduction_ratio
|
|
52
|
+
field_size = len(input_shape)
|
|
53
|
+
reduction_size = max(1, field_size * g * 2 // r)
|
|
54
|
+
self.reduce_layer = Dense(
|
|
55
|
+
units=reduction_size,
|
|
56
|
+
activation='relu',
|
|
57
|
+
kernel_initializer='he_normal',
|
|
58
|
+
name='W1')
|
|
59
|
+
self.excite_layer = Dense(
|
|
60
|
+
units=emb_size, kernel_initializer='glorot_normal', name='W2')
|
|
61
|
+
super(SENet, self).build(input_shape) # Be sure to call this somewhere!
|
|
62
|
+
|
|
63
|
+
def call(self, inputs, **kwargs):
|
|
64
|
+
g = self.config.num_squeeze_group
|
|
65
|
+
|
|
66
|
+
# Squeeze
|
|
67
|
+
# embedding dimension 必须能被 g 整除
|
|
68
|
+
group_embs = [
|
|
69
|
+
tf.reshape(emb, [-1, g, int(emb.shape[-1]) // g]) for emb in inputs
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
squeezed = []
|
|
73
|
+
for emb in group_embs:
|
|
74
|
+
squeezed.append(tf.reduce_max(emb, axis=-1)) # [B, g]
|
|
75
|
+
squeezed.append(tf.reduce_mean(emb, axis=-1)) # [B, g]
|
|
76
|
+
z = tf.concat(squeezed, axis=1) # [bs, field_size * num_groups * 2]
|
|
77
|
+
|
|
78
|
+
# Excitation
|
|
79
|
+
a1 = self.reduce_layer(z)
|
|
80
|
+
weights = self.excite_layer(a1)
|
|
81
|
+
|
|
82
|
+
# Re-weight
|
|
83
|
+
inputs = tf.concat(inputs, axis=-1)
|
|
84
|
+
output = inputs * weights
|
|
85
|
+
|
|
86
|
+
# Fuse, add skip-connection
|
|
87
|
+
if self.config.use_skip_connection:
|
|
88
|
+
output += inputs
|
|
89
|
+
|
|
90
|
+
# Layer Normalization
|
|
91
|
+
if self.config.use_output_layer_norm:
|
|
92
|
+
output = self.layer_norm(output)
|
|
93
|
+
return output
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _full_interaction(v_i, v_j):
|
|
97
|
+
# [bs, 1, dim] x [bs, dim, 1] = [bs, 1]
|
|
98
|
+
interaction = tf.matmul(
|
|
99
|
+
tf.expand_dims(v_i, axis=1), tf.expand_dims(v_j, axis=-1))
|
|
100
|
+
return tf.squeeze(interaction, axis=1)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class BiLinear(Layer):
|
|
104
|
+
"""BilinearInteraction Layer used in FiBiNET.
|
|
105
|
+
|
|
106
|
+
Input shape
|
|
107
|
+
- A list of 2D tensor with shape: ``(batch_size,embedding_size)``.
|
|
108
|
+
Its length is ``filed_size``.
|
|
109
|
+
The ``embedding_size`` of each field can have different value.
|
|
110
|
+
|
|
111
|
+
Output shape
|
|
112
|
+
- 2D tensor with shape: ``(batch_size,output_size)``.
|
|
113
|
+
|
|
114
|
+
Attributes:
|
|
115
|
+
num_output_units: the number of output units
|
|
116
|
+
type: ['all', 'each', 'interaction'], types of bilinear functions used in this layer
|
|
117
|
+
use_plus: whether to use bi-linear+
|
|
118
|
+
|
|
119
|
+
References:
|
|
120
|
+
1. [FiBiNET](https://arxiv.org/pdf/1905.09433.pdf)
|
|
121
|
+
Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction
|
|
122
|
+
2. [FiBiNet++](https://arxiv.org/pdf/2209.05016.pdf)
|
|
123
|
+
Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def __init__(self, params, name='bilinear', reuse=None, **kwargs):
|
|
127
|
+
super(BiLinear, self).__init__(name=name, **kwargs)
|
|
128
|
+
self.reuse = reuse
|
|
129
|
+
params.check_required(['num_output_units'])
|
|
130
|
+
bilinear_plus = params.get_or_default('use_plus', True)
|
|
131
|
+
self.output_size = params.num_output_units
|
|
132
|
+
self.bilinear_type = params.get_or_default('type', 'interaction').lower()
|
|
133
|
+
if self.bilinear_type not in ['all', 'each', 'interaction']:
|
|
134
|
+
raise NotImplementedError(
|
|
135
|
+
"bilinear_type only support: ['all', 'each', 'interaction']")
|
|
136
|
+
if bilinear_plus:
|
|
137
|
+
self.func = _full_interaction
|
|
138
|
+
else:
|
|
139
|
+
self.func = tf.multiply
|
|
140
|
+
self.output_layer = Dense(self.output_size, name='output')
|
|
141
|
+
|
|
142
|
+
def build(self, input_shape):
|
|
143
|
+
if type(input_shape) not in (tuple, list):
|
|
144
|
+
raise TypeError('input of BiLinear layer must be a list')
|
|
145
|
+
field_num = len(input_shape)
|
|
146
|
+
logging.info('Bilinear Layer with %d inputs' % field_num)
|
|
147
|
+
if field_num > 200:
|
|
148
|
+
logging.warning('Too many inputs for bilinear layer: %d' % field_num)
|
|
149
|
+
equal_dim = True
|
|
150
|
+
_dim = input_shape[0][-1]
|
|
151
|
+
for shape in input_shape:
|
|
152
|
+
assert shape.ndims == 2, 'field embeddings must be rank 2 tensors'
|
|
153
|
+
if shape[-1] != _dim:
|
|
154
|
+
equal_dim = False
|
|
155
|
+
if not equal_dim and self.bilinear_type != 'interaction':
|
|
156
|
+
raise ValueError(
|
|
157
|
+
'all embedding dimensions must be same when not use bilinear type: interaction'
|
|
158
|
+
)
|
|
159
|
+
dim = int(_dim)
|
|
160
|
+
|
|
161
|
+
if self.bilinear_type == 'all':
|
|
162
|
+
self.dot_layer = Dense(dim, name='all')
|
|
163
|
+
elif self.bilinear_type == 'each':
|
|
164
|
+
self.dot_layers = [
|
|
165
|
+
Dense(dim, name='each_%d' % i) for i in range(field_num - 1)
|
|
166
|
+
]
|
|
167
|
+
else: # interaction
|
|
168
|
+
self.dot_layers = [
|
|
169
|
+
Dense(
|
|
170
|
+
units=int(input_shape[j][-1]), name='interaction_%d_%d' % (i, j))
|
|
171
|
+
for i, j in itertools.combinations(range(field_num), 2)
|
|
172
|
+
]
|
|
173
|
+
super(BiLinear, self).build(input_shape) # Be sure to call this somewhere!
|
|
174
|
+
|
|
175
|
+
def call(self, inputs, **kwargs):
|
|
176
|
+
embeddings = inputs
|
|
177
|
+
field_num = len(embeddings)
|
|
178
|
+
|
|
179
|
+
# bi-linear+: dimension of `p` is [bs, f*(f-1)/2]
|
|
180
|
+
# bi-linear:
|
|
181
|
+
# - when equal_dim=True, dimension of `p` is [bs, f*(f-1)/2*k], k is embedding size
|
|
182
|
+
# - when equal_dim=False, dimension of `p` is [bs, (k_2+k_3+...+k_f)+...+(k_i+k_{i+1}+...+k_f)+...+k_f],
|
|
183
|
+
# - where k_i is the embedding size of the ith field
|
|
184
|
+
if self.bilinear_type == 'all':
|
|
185
|
+
v_dot = [self.dot_layer(v_i) for v_i in embeddings[:-1]]
|
|
186
|
+
p = [
|
|
187
|
+
self.func(v_dot[i], embeddings[j])
|
|
188
|
+
for i, j in itertools.combinations(range(field_num), 2)
|
|
189
|
+
]
|
|
190
|
+
elif self.bilinear_type == 'each':
|
|
191
|
+
v_dot = [self.dot_layers[i](v_i) for i, v_i in enumerate(embeddings[:-1])]
|
|
192
|
+
p = [
|
|
193
|
+
self.func(v_dot[i], embeddings[j])
|
|
194
|
+
for i, j in itertools.combinations(range(field_num), 2)
|
|
195
|
+
]
|
|
196
|
+
else: # interaction
|
|
197
|
+
p = [
|
|
198
|
+
self.func(self.dot_layers[i * field_num + j](embeddings[i]),
|
|
199
|
+
embeddings[j])
|
|
200
|
+
for i, j in itertools.combinations(range(field_num), 2)
|
|
201
|
+
]
|
|
202
|
+
|
|
203
|
+
return self.output_layer(tf.concat(p, axis=-1))
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class FiBiNet(Layer):
|
|
207
|
+
"""FiBiNet++:Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction.
|
|
208
|
+
|
|
209
|
+
References:
|
|
210
|
+
- [FiBiNet++](https://arxiv.org/pdf/2209.05016.pdf)
|
|
211
|
+
Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
def __init__(self, params, name='fibinet', reuse=None, **kwargs):
|
|
215
|
+
super(FiBiNet, self).__init__(name=name, **kwargs)
|
|
216
|
+
self.reuse = reuse
|
|
217
|
+
self._config = params.get_pb_config()
|
|
218
|
+
|
|
219
|
+
se_params = Parameter.make_from_pb(self._config.senet)
|
|
220
|
+
self.senet_layer = SENet(
|
|
221
|
+
se_params, name=self.name + '/senet', reuse=self.reuse)
|
|
222
|
+
|
|
223
|
+
if self._config.HasField('bilinear'):
|
|
224
|
+
bi_params = Parameter.make_from_pb(self._config.bilinear)
|
|
225
|
+
self.bilinear_layer = BiLinear(
|
|
226
|
+
bi_params, name=self.name + '/bilinear', reuse=self.reuse)
|
|
227
|
+
|
|
228
|
+
if self._config.HasField('mlp'):
|
|
229
|
+
p = Parameter.make_from_pb(self._config.mlp)
|
|
230
|
+
p.l2_regularizer = params.l2_regularizer
|
|
231
|
+
self.final_mlp = MLP(p, name=self.name + '/mlp', reuse=reuse)
|
|
232
|
+
else:
|
|
233
|
+
self.final_mlp = None
|
|
234
|
+
|
|
235
|
+
def call(self, inputs, training=None, **kwargs):
|
|
236
|
+
feature_list = []
|
|
237
|
+
|
|
238
|
+
senet_output = self.senet_layer(inputs)
|
|
239
|
+
feature_list.append(senet_output)
|
|
240
|
+
|
|
241
|
+
if self._config.HasField('bilinear'):
|
|
242
|
+
feature_list.append(self.bilinear_layer(inputs))
|
|
243
|
+
|
|
244
|
+
if len(feature_list) > 1:
|
|
245
|
+
feature = tf.concat(feature_list, axis=-1)
|
|
246
|
+
else:
|
|
247
|
+
feature = feature_list[0]
|
|
248
|
+
|
|
249
|
+
if self.final_mlp is not None:
|
|
250
|
+
feature = self.final_mlp(feature, training=training)
|
|
251
|
+
return feature
|
|
@@ -0,0 +1,416 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
|
|
5
|
+
from easy_rec.python.utils.activation import get_activation
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class FM(tf.keras.layers.Layer):
|
|
9
|
+
"""Factorization Machine models pairwise (order-2) feature interactions without linear term and bias.
|
|
10
|
+
|
|
11
|
+
References
|
|
12
|
+
- [Factorization Machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf)
|
|
13
|
+
Input shape.
|
|
14
|
+
- List of 2D tensor with shape: ``(batch_size,embedding_size)``.
|
|
15
|
+
- Or a 3D tensor with shape: ``(batch_size,field_size,embedding_size)``
|
|
16
|
+
Output shape
|
|
17
|
+
- 2D tensor with shape: ``(batch_size, 1)``.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, params, name='fm', reuse=None, **kwargs):
|
|
21
|
+
super(FM, self).__init__(name=name, **kwargs)
|
|
22
|
+
self.use_variant = params.get_or_default('use_variant', False)
|
|
23
|
+
|
|
24
|
+
def call(self, inputs, **kwargs):
|
|
25
|
+
if type(inputs) == list:
|
|
26
|
+
emb_dims = set(map(lambda x: int(x.shape[-1]), inputs))
|
|
27
|
+
if len(emb_dims) != 1:
|
|
28
|
+
dims = ','.join([str(d) for d in emb_dims])
|
|
29
|
+
raise ValueError('all embedding dim must be equal in FM layer:' + dims)
|
|
30
|
+
with tf.name_scope(self.name):
|
|
31
|
+
fea = tf.stack(inputs, axis=1)
|
|
32
|
+
else:
|
|
33
|
+
assert inputs.shape.ndims == 3, 'input of FM layer must be a 3D tensor or a list of 2D tensors'
|
|
34
|
+
fea = inputs
|
|
35
|
+
|
|
36
|
+
with tf.name_scope(self.name):
|
|
37
|
+
square_of_sum = tf.square(tf.reduce_sum(fea, axis=1))
|
|
38
|
+
sum_of_square = tf.reduce_sum(tf.square(fea), axis=1)
|
|
39
|
+
cross_term = tf.subtract(square_of_sum, sum_of_square)
|
|
40
|
+
if self.use_variant:
|
|
41
|
+
cross_term = 0.5 * cross_term
|
|
42
|
+
else:
|
|
43
|
+
cross_term = 0.5 * tf.reduce_sum(cross_term, axis=-1, keepdims=True)
|
|
44
|
+
return cross_term
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class DotInteraction(tf.keras.layers.Layer):
|
|
48
|
+
"""Dot interaction layer of DLRM model..
|
|
49
|
+
|
|
50
|
+
See theory in the DLRM paper: https://arxiv.org/pdf/1906.00091.pdf,
|
|
51
|
+
section 2.1.3. Sparse activations and dense activations are combined.
|
|
52
|
+
Dot interaction is applied to a batch of input Tensors [e1,...,e_k] of the
|
|
53
|
+
same dimension and the output is a batch of Tensors with all distinct pairwise
|
|
54
|
+
dot products of the form dot(e_i, e_j) for i <= j if self self_interaction is
|
|
55
|
+
True, otherwise dot(e_i, e_j) i < j.
|
|
56
|
+
|
|
57
|
+
Attributes:
|
|
58
|
+
self_interaction: Boolean indicating if features should self-interact.
|
|
59
|
+
If it is True, then the diagonal entries of the interaction metric are
|
|
60
|
+
also taken.
|
|
61
|
+
skip_gather: An optimization flag. If it's set then the upper triangle part
|
|
62
|
+
of the dot interaction matrix dot(e_i, e_j) is set to 0. The resulting
|
|
63
|
+
activations will be of dimension [num_features * num_features] from which
|
|
64
|
+
half will be zeros. Otherwise activations will be only lower triangle part
|
|
65
|
+
of the interaction matrix. The later saves space but is much slower.
|
|
66
|
+
name: String name of the layer.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(self, params, name=None, reuse=None, **kwargs):
|
|
70
|
+
super(DotInteraction, self).__init__(name=name, **kwargs)
|
|
71
|
+
self._self_interaction = params.get_or_default('self_interaction', False)
|
|
72
|
+
self._skip_gather = params.get_or_default('skip_gather', False)
|
|
73
|
+
|
|
74
|
+
def call(self, inputs, **kwargs):
|
|
75
|
+
"""Performs the interaction operation on the tensors in the list.
|
|
76
|
+
|
|
77
|
+
The tensors represent as transformed dense features and embedded categorical
|
|
78
|
+
features.
|
|
79
|
+
Pre-condition: The tensors should all have the same shape.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
inputs: List of features with shapes [batch_size, feature_dim].
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
activations: Tensor representing interacted features. It has a dimension
|
|
86
|
+
`num_features * num_features` if skip_gather is True, otherside
|
|
87
|
+
`num_features * (num_features + 1) / 2` if self_interaction is True and
|
|
88
|
+
`num_features * (num_features - 1) / 2` if self_interaction is False.
|
|
89
|
+
"""
|
|
90
|
+
if isinstance(inputs, (list, tuple)):
|
|
91
|
+
# concat_features shape: batch_size, num_features, feature_dim
|
|
92
|
+
try:
|
|
93
|
+
concat_features = tf.stack(inputs, axis=1)
|
|
94
|
+
except (ValueError, tf.errors.InvalidArgumentError) as e:
|
|
95
|
+
raise ValueError('Input tensors` dimensions must be equal, original'
|
|
96
|
+
'error message: {}'.format(e))
|
|
97
|
+
else:
|
|
98
|
+
assert inputs.shape.ndims == 3, 'input of dot func must be a 3D tensor or a list of 2D tensors'
|
|
99
|
+
concat_features = inputs
|
|
100
|
+
|
|
101
|
+
batch_size = tf.shape(concat_features)[0]
|
|
102
|
+
|
|
103
|
+
# Interact features, select lower-triangular portion, and re-shape.
|
|
104
|
+
xactions = tf.matmul(concat_features, concat_features, transpose_b=True)
|
|
105
|
+
num_features = xactions.shape[-1]
|
|
106
|
+
ones = tf.ones_like(xactions)
|
|
107
|
+
if self._self_interaction:
|
|
108
|
+
# Selecting lower-triangular portion including the diagonal.
|
|
109
|
+
lower_tri_mask = tf.linalg.band_part(ones, -1, 0)
|
|
110
|
+
upper_tri_mask = ones - lower_tri_mask
|
|
111
|
+
out_dim = num_features * (num_features + 1) // 2
|
|
112
|
+
else:
|
|
113
|
+
# Selecting lower-triangular portion not included the diagonal.
|
|
114
|
+
upper_tri_mask = tf.linalg.band_part(ones, 0, -1)
|
|
115
|
+
lower_tri_mask = ones - upper_tri_mask
|
|
116
|
+
out_dim = num_features * (num_features - 1) // 2
|
|
117
|
+
|
|
118
|
+
if self._skip_gather:
|
|
119
|
+
# Setting upper triangle part of the interaction matrix to zeros.
|
|
120
|
+
activations = tf.where(
|
|
121
|
+
condition=tf.cast(upper_tri_mask, tf.bool),
|
|
122
|
+
x=tf.zeros_like(xactions),
|
|
123
|
+
y=xactions)
|
|
124
|
+
out_dim = num_features * num_features
|
|
125
|
+
else:
|
|
126
|
+
activations = tf.boolean_mask(xactions, lower_tri_mask)
|
|
127
|
+
activations = tf.reshape(activations, (batch_size, out_dim))
|
|
128
|
+
return activations
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class Cross(tf.keras.layers.Layer):
|
|
132
|
+
"""Cross Layer in Deep & Cross Network to learn explicit feature interactions.
|
|
133
|
+
|
|
134
|
+
A layer that creates explicit and bounded-degree feature interactions
|
|
135
|
+
efficiently. The `call` method accepts `inputs` as a tuple of size 2
|
|
136
|
+
tensors. The first input `x0` is the base layer that contains the original
|
|
137
|
+
features (usually the embedding layer); the second input `xi` is the output
|
|
138
|
+
of the previous `Cross` layer in the stack, i.e., the i-th `Cross`
|
|
139
|
+
layer. For the first `Cross` layer in the stack, x0 = xi.
|
|
140
|
+
|
|
141
|
+
The output is x_{i+1} = x0 .* (W * xi + bias + diag_scale * xi) + xi,
|
|
142
|
+
where .* designates elementwise multiplication, W could be a full-rank
|
|
143
|
+
matrix, or a low-rank matrix U*V to reduce the computational cost, and
|
|
144
|
+
diag_scale increases the diagonal of W to improve training stability (
|
|
145
|
+
especially for the low-rank case).
|
|
146
|
+
|
|
147
|
+
References:
|
|
148
|
+
1. [R. Wang et al.](https://arxiv.org/pdf/2008.13535.pdf)
|
|
149
|
+
See Eq. (1) for full-rank and Eq. (2) for low-rank version.
|
|
150
|
+
2. [R. Wang et al.](https://arxiv.org/pdf/1708.05123.pdf)
|
|
151
|
+
|
|
152
|
+
Example:
|
|
153
|
+
|
|
154
|
+
```python
|
|
155
|
+
# after embedding layer in a functional model:
|
|
156
|
+
input = tf.keras.Input(shape=(None,), name='index', dtype=tf.int64)
|
|
157
|
+
x0 = tf.keras.layers.Embedding(input_dim=32, output_dim=6)
|
|
158
|
+
x1 = Cross()(x0, x0)
|
|
159
|
+
x2 = Cross()(x0, x1)
|
|
160
|
+
logits = tf.keras.layers.Dense(units=10)(x2)
|
|
161
|
+
model = tf.keras.Model(input, logits)
|
|
162
|
+
```
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
projection_dim: project dimension to reduce the computational cost.
|
|
166
|
+
Default is `None` such that a full (`input_dim` by `input_dim`) matrix
|
|
167
|
+
W is used. If enabled, a low-rank matrix W = U*V will be used, where U
|
|
168
|
+
is of size `input_dim` by `projection_dim` and V is of size
|
|
169
|
+
`projection_dim` by `input_dim`. `projection_dim` need to be smaller
|
|
170
|
+
than `input_dim`/2 to improve the model efficiency. In practice, we've
|
|
171
|
+
observed that `projection_dim` = d/4 consistently preserved the
|
|
172
|
+
accuracy of a full-rank version.
|
|
173
|
+
diag_scale: a non-negative float used to increase the diagonal of the
|
|
174
|
+
kernel W by `diag_scale`, that is, W + diag_scale * I, where I is an
|
|
175
|
+
identity matrix.
|
|
176
|
+
use_bias: whether to add a bias term for this layer. If set to False,
|
|
177
|
+
no bias term will be used.
|
|
178
|
+
preactivation: Activation applied to output matrix of the layer, before
|
|
179
|
+
multiplication with the input. Can be used to control the scale of the
|
|
180
|
+
layer's outputs and improve stability.
|
|
181
|
+
kernel_initializer: Initializer to use on the kernel matrix.
|
|
182
|
+
bias_initializer: Initializer to use on the bias vector.
|
|
183
|
+
kernel_regularizer: Regularizer to use on the kernel matrix.
|
|
184
|
+
bias_regularizer: Regularizer to use on bias vector.
|
|
185
|
+
|
|
186
|
+
Input shape: A tuple of 2 (batch_size, `input_dim`) dimensional inputs.
|
|
187
|
+
Output shape: A single (batch_size, `input_dim`) dimensional output.
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
def __init__(self, params, name='cross', reuse=None, **kwargs):
|
|
191
|
+
super(Cross, self).__init__(name=name, **kwargs)
|
|
192
|
+
self._projection_dim = params.get_or_default('projection_dim', None)
|
|
193
|
+
self._diag_scale = params.get_or_default('diag_scale', 0.0)
|
|
194
|
+
self._use_bias = params.get_or_default('use_bias', True)
|
|
195
|
+
preactivation = params.get_or_default('preactivation', None)
|
|
196
|
+
preact = get_activation(preactivation)
|
|
197
|
+
self._preactivation = tf.keras.activations.get(preact)
|
|
198
|
+
kernel_initializer = params.get_or_default('kernel_initializer',
|
|
199
|
+
'truncated_normal')
|
|
200
|
+
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
|
|
201
|
+
bias_initializer = params.get_or_default('bias_initializer', 'zeros')
|
|
202
|
+
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
|
|
203
|
+
kernel_regularizer = params.get_or_default('kernel_regularizer', None)
|
|
204
|
+
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
|
|
205
|
+
bias_regularizer = params.get_or_default('bias_regularizer', None)
|
|
206
|
+
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
|
|
207
|
+
self._input_dim = None
|
|
208
|
+
self._supports_masking = True
|
|
209
|
+
|
|
210
|
+
if self._diag_scale < 0: # pytype: disable=unsupported-operands
|
|
211
|
+
raise ValueError(
|
|
212
|
+
'`diag_scale` should be non-negative. Got `diag_scale` = {}'.format(
|
|
213
|
+
self._diag_scale))
|
|
214
|
+
|
|
215
|
+
def build(self, input_shape):
|
|
216
|
+
last_dim = input_shape[0][-1]
|
|
217
|
+
|
|
218
|
+
if self._projection_dim is None:
|
|
219
|
+
self._dense = tf.keras.layers.Dense(
|
|
220
|
+
last_dim,
|
|
221
|
+
kernel_initializer=_clone_initializer(self._kernel_initializer),
|
|
222
|
+
bias_initializer=self._bias_initializer,
|
|
223
|
+
kernel_regularizer=self._kernel_regularizer,
|
|
224
|
+
bias_regularizer=self._bias_regularizer,
|
|
225
|
+
use_bias=self._use_bias,
|
|
226
|
+
dtype=self.dtype,
|
|
227
|
+
activation=self._preactivation,
|
|
228
|
+
)
|
|
229
|
+
else:
|
|
230
|
+
self._dense_u = tf.keras.layers.Dense(
|
|
231
|
+
self._projection_dim,
|
|
232
|
+
kernel_initializer=_clone_initializer(self._kernel_initializer),
|
|
233
|
+
kernel_regularizer=self._kernel_regularizer,
|
|
234
|
+
use_bias=False,
|
|
235
|
+
dtype=self.dtype,
|
|
236
|
+
)
|
|
237
|
+
self._dense_v = tf.keras.layers.Dense(
|
|
238
|
+
last_dim,
|
|
239
|
+
kernel_initializer=_clone_initializer(self._kernel_initializer),
|
|
240
|
+
bias_initializer=self._bias_initializer,
|
|
241
|
+
kernel_regularizer=self._kernel_regularizer,
|
|
242
|
+
bias_regularizer=self._bias_regularizer,
|
|
243
|
+
use_bias=self._use_bias,
|
|
244
|
+
dtype=self.dtype,
|
|
245
|
+
activation=self._preactivation,
|
|
246
|
+
)
|
|
247
|
+
super(Cross, self).build(input_shape) # Be sure to call this somewhere!
|
|
248
|
+
|
|
249
|
+
def call(self, inputs, **kwargs):
|
|
250
|
+
"""Computes the feature cross.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
inputs: The input tensor(x0, x)
|
|
254
|
+
- x0: The input tensor
|
|
255
|
+
- x: Optional second input tensor. If provided, the layer will compute
|
|
256
|
+
crosses between x0 and x; if not provided, the layer will compute
|
|
257
|
+
crosses between x0 and itself.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Tensor of crosses.
|
|
261
|
+
"""
|
|
262
|
+
if isinstance(inputs, (list, tuple)):
|
|
263
|
+
x0, x = inputs
|
|
264
|
+
else:
|
|
265
|
+
x0, x = inputs, inputs
|
|
266
|
+
|
|
267
|
+
if not self.built:
|
|
268
|
+
self.build(x0.shape)
|
|
269
|
+
|
|
270
|
+
if x0.shape[-1] != x.shape[-1]:
|
|
271
|
+
raise ValueError(
|
|
272
|
+
'`x0` and `x` dimension mismatch! Got `x0` dimension {}, and x '
|
|
273
|
+
'dimension {}. This case is not supported yet.'.format(
|
|
274
|
+
x0.shape[-1], x.shape[-1]))
|
|
275
|
+
|
|
276
|
+
if self._projection_dim is None:
|
|
277
|
+
prod_output = self._dense(x)
|
|
278
|
+
else:
|
|
279
|
+
prod_output = self._dense_v(self._dense_u(x))
|
|
280
|
+
|
|
281
|
+
# prod_output = tf.cast(prod_output, self.compute_dtype)
|
|
282
|
+
|
|
283
|
+
if self._diag_scale:
|
|
284
|
+
prod_output = prod_output + self._diag_scale * x
|
|
285
|
+
|
|
286
|
+
return x0 * prod_output + x
|
|
287
|
+
|
|
288
|
+
def get_config(self):
|
|
289
|
+
config = {
|
|
290
|
+
'projection_dim':
|
|
291
|
+
self._projection_dim,
|
|
292
|
+
'diag_scale':
|
|
293
|
+
self._diag_scale,
|
|
294
|
+
'use_bias':
|
|
295
|
+
self._use_bias,
|
|
296
|
+
'preactivation':
|
|
297
|
+
tf.keras.activations.serialize(self._preactivation),
|
|
298
|
+
'kernel_initializer':
|
|
299
|
+
tf.keras.initializers.serialize(self._kernel_initializer),
|
|
300
|
+
'bias_initializer':
|
|
301
|
+
tf.keras.initializers.serialize(self._bias_initializer),
|
|
302
|
+
'kernel_regularizer':
|
|
303
|
+
tf.keras.regularizers.serialize(self._kernel_regularizer),
|
|
304
|
+
'bias_regularizer':
|
|
305
|
+
tf.keras.regularizers.serialize(self._bias_regularizer),
|
|
306
|
+
}
|
|
307
|
+
base_config = super(Cross, self).get_config()
|
|
308
|
+
return dict(list(base_config.items()) + list(config.items()))
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class CIN(tf.keras.layers.Layer):
|
|
312
|
+
"""Compressed Interaction Network(CIN) module in xDeepFM model.
|
|
313
|
+
|
|
314
|
+
CIN layer is aimed at achieving high-order feature interactions at
|
|
315
|
+
vector-wise level rather than bit-wise level.
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
Reference:
|
|
319
|
+
[xDeepFM](https://arxiv.org/pdf/1803.05170)
|
|
320
|
+
xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
def __init__(self, params, name='cin', reuse=None, **kwargs):
|
|
324
|
+
super(CIN, self).__init__(name=name, **kwargs)
|
|
325
|
+
self._name = name
|
|
326
|
+
self._hidden_feature_sizes = list(
|
|
327
|
+
params.get_or_default('hidden_feature_sizes', []))
|
|
328
|
+
|
|
329
|
+
assert isinstance(self._hidden_feature_sizes, list) and len(
|
|
330
|
+
self._hidden_feature_sizes
|
|
331
|
+
) > 0, 'parameter hidden_feature_sizes must be a list of int with length greater than 0'
|
|
332
|
+
|
|
333
|
+
kernel_regularizer = params.get_or_default('kernel_regularizer', None)
|
|
334
|
+
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
|
|
335
|
+
bias_regularizer = params.get_or_default('bias_regularizer', None)
|
|
336
|
+
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
|
|
337
|
+
|
|
338
|
+
def build(self, input_shape):
|
|
339
|
+
if len(input_shape) != 3:
|
|
340
|
+
raise ValueError(
|
|
341
|
+
'Unexpected inputs dimensions %d, expect to be 3 dimensions' %
|
|
342
|
+
(len(input_shape)))
|
|
343
|
+
|
|
344
|
+
hidden_feature_sizes = [input_shape[1]
|
|
345
|
+
] + [h for h in self._hidden_feature_sizes]
|
|
346
|
+
tfv1 = tf.compat.v1 if tf.__version__ >= '2.0' else tf
|
|
347
|
+
with tfv1.variable_scope(self._name):
|
|
348
|
+
self.kernel_list = [
|
|
349
|
+
tfv1.get_variable(
|
|
350
|
+
name='cin_kernel_%d' % i,
|
|
351
|
+
shape=[
|
|
352
|
+
hidden_feature_sizes[i + 1], hidden_feature_sizes[i],
|
|
353
|
+
hidden_feature_sizes[0]
|
|
354
|
+
],
|
|
355
|
+
initializer=tf.initializers.he_normal(),
|
|
356
|
+
regularizer=self._kernel_regularizer,
|
|
357
|
+
trainable=True) for i in range(len(self._hidden_feature_sizes))
|
|
358
|
+
]
|
|
359
|
+
self.bias_list = [
|
|
360
|
+
tfv1.get_variable(
|
|
361
|
+
name='cin_bias_%d' % i,
|
|
362
|
+
shape=[hidden_feature_sizes[i + 1]],
|
|
363
|
+
initializer=tf.keras.initializers.Zeros,
|
|
364
|
+
regularizer=self._bias_regularizer,
|
|
365
|
+
trainable=True) for i in range(len(self._hidden_feature_sizes))
|
|
366
|
+
]
|
|
367
|
+
|
|
368
|
+
super(CIN, self).build(input_shape)
|
|
369
|
+
|
|
370
|
+
def call(self, input, **kwargs):
|
|
371
|
+
"""Computes the compressed feature maps.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
input: The 3D input tensor with shape (b, h0, d), where b is batch_size,
|
|
375
|
+
h0 is the number of features, d is the feature embedding dimension.
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
2D tensor of compressed feature map with shape (b, featuremap_num),
|
|
379
|
+
where b is the batch_size, featuremap_num is sum of the hidden layer sizes
|
|
380
|
+
"""
|
|
381
|
+
x_0 = input
|
|
382
|
+
x_i = input
|
|
383
|
+
x_0_expanded = tf.expand_dims(x_0, 1)
|
|
384
|
+
pooled_feature_map_list = []
|
|
385
|
+
for i in range(len(self._hidden_feature_sizes)):
|
|
386
|
+
hk = self._hidden_feature_sizes[i]
|
|
387
|
+
|
|
388
|
+
x_i_expanded = tf.expand_dims(x_i, 2)
|
|
389
|
+
intermediate_tensor = tf.multiply(x_0_expanded, x_i_expanded)
|
|
390
|
+
|
|
391
|
+
intermediate_tensor_expanded = tf.expand_dims(intermediate_tensor, 1)
|
|
392
|
+
intermediate_tensor_expanded = tf.tile(intermediate_tensor_expanded,
|
|
393
|
+
[1, hk, 1, 1, 1])
|
|
394
|
+
|
|
395
|
+
feature_map_elementwise = tf.multiply(
|
|
396
|
+
intermediate_tensor_expanded,
|
|
397
|
+
tf.expand_dims(tf.expand_dims(self.kernel_list[i], -1), 0))
|
|
398
|
+
feature_map = tf.reduce_sum(
|
|
399
|
+
tf.reduce_sum(feature_map_elementwise, axis=3), axis=2)
|
|
400
|
+
|
|
401
|
+
feature_map = tf.add(
|
|
402
|
+
feature_map,
|
|
403
|
+
tf.expand_dims(tf.expand_dims(self.bias_list[i], axis=-1), axis=0))
|
|
404
|
+
feature_map = tf.nn.relu(feature_map)
|
|
405
|
+
|
|
406
|
+
x_i = feature_map
|
|
407
|
+
pooled_feature_map_list.append(tf.reduce_sum(feature_map, axis=-1))
|
|
408
|
+
return tf.concat(
|
|
409
|
+
pooled_feature_map_list, axis=-1) # shape = (b, h1 + ... + hk)
|
|
410
|
+
|
|
411
|
+
def get_config(self):
|
|
412
|
+
pass
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def _clone_initializer(initializer):
|
|
416
|
+
return initializer.__class__.from_config(initializer.get_config())
|