easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,467 @@
|
|
|
1
|
+
# -*- encoding: utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
from abc import abstractmethod
|
|
8
|
+
|
|
9
|
+
import six
|
|
10
|
+
import tensorflow as tf
|
|
11
|
+
from tensorflow.python.framework import ops
|
|
12
|
+
from tensorflow.python.framework import tensor_shape
|
|
13
|
+
from tensorflow.python.ops import variables
|
|
14
|
+
from tensorflow.python.platform import gfile
|
|
15
|
+
|
|
16
|
+
from easy_rec.python.compat import regularizers
|
|
17
|
+
from easy_rec.python.layers import input_layer
|
|
18
|
+
from easy_rec.python.layers.backbone import Backbone
|
|
19
|
+
from easy_rec.python.utils import constant
|
|
20
|
+
from easy_rec.python.utils import estimator_utils
|
|
21
|
+
from easy_rec.python.utils import restore_filter
|
|
22
|
+
from easy_rec.python.utils.load_class import get_register_class_meta
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
import horovod.tensorflow as hvd
|
|
26
|
+
from sparse_operation_kit.experiment import raw_ops as dynamic_variable_ops
|
|
27
|
+
from sparse_operation_kit import experiment as sok
|
|
28
|
+
except Exception:
|
|
29
|
+
dynamic_variable_ops = None
|
|
30
|
+
sok = None
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
from tensorflow.python.framework.load_library import load_op_library
|
|
34
|
+
import easy_rec
|
|
35
|
+
load_embed_lib_path = os.path.join(easy_rec.ops_dir, 'libload_embed.so')
|
|
36
|
+
load_embed_lib = load_op_library(load_embed_lib_path)
|
|
37
|
+
except Exception as ex:
|
|
38
|
+
logging.warning('load libload_embed.so failed: %s' % str(ex))
|
|
39
|
+
load_embed_lib = None
|
|
40
|
+
|
|
41
|
+
if tf.__version__ >= '2.0':
|
|
42
|
+
tf = tf.compat.v1
|
|
43
|
+
|
|
44
|
+
_EASY_REC_MODEL_CLASS_MAP = {}
|
|
45
|
+
_meta_type = get_register_class_meta(
|
|
46
|
+
_EASY_REC_MODEL_CLASS_MAP, have_abstract_class=True)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class EasyRecModel(six.with_metaclass(_meta_type, object)):
|
|
50
|
+
|
|
51
|
+
def __init__(self,
|
|
52
|
+
model_config,
|
|
53
|
+
feature_configs,
|
|
54
|
+
features,
|
|
55
|
+
labels=None,
|
|
56
|
+
is_training=False):
|
|
57
|
+
self._base_model_config = model_config
|
|
58
|
+
self._model_config = model_config
|
|
59
|
+
self._is_training = is_training
|
|
60
|
+
self._is_predicting = labels is None
|
|
61
|
+
self._feature_dict = features
|
|
62
|
+
|
|
63
|
+
# embedding variable parameters
|
|
64
|
+
self._global_ev_params = None
|
|
65
|
+
if model_config.HasField('ev_params'):
|
|
66
|
+
self._global_ev_params = model_config.ev_params
|
|
67
|
+
|
|
68
|
+
if self.embedding_regularization > 0:
|
|
69
|
+
self._emb_reg = regularizers.l2_regularizer(self.embedding_regularization)
|
|
70
|
+
else:
|
|
71
|
+
self._emb_reg = None
|
|
72
|
+
|
|
73
|
+
if self.l2_regularization > 0:
|
|
74
|
+
self._l2_reg = regularizers.l2_regularizer(self.l2_regularization)
|
|
75
|
+
else:
|
|
76
|
+
self._l2_reg = None
|
|
77
|
+
|
|
78
|
+
# only used by model with wide feature groups, e.g. WideAndDeep
|
|
79
|
+
self._wide_output_dim = -1
|
|
80
|
+
if self.has_backbone:
|
|
81
|
+
wide_dim = Backbone.wide_embed_dim(model_config.backbone)
|
|
82
|
+
if wide_dim:
|
|
83
|
+
self._wide_output_dim = wide_dim
|
|
84
|
+
logging.info('set `wide_output_dim` to %d' % wide_dim)
|
|
85
|
+
|
|
86
|
+
self._feature_configs = feature_configs
|
|
87
|
+
self.build_input_layer(model_config, feature_configs)
|
|
88
|
+
|
|
89
|
+
self._labels = labels
|
|
90
|
+
self._prediction_dict = {}
|
|
91
|
+
self._loss_dict = {}
|
|
92
|
+
self._metric_dict = {}
|
|
93
|
+
|
|
94
|
+
# add sample weight from inputs
|
|
95
|
+
self._sample_weight = 1.0
|
|
96
|
+
if constant.SAMPLE_WEIGHT in features:
|
|
97
|
+
self._sample_weight = features[constant.SAMPLE_WEIGHT]
|
|
98
|
+
|
|
99
|
+
self._backbone_output = None
|
|
100
|
+
self._backbone_net = self.build_backbone_network()
|
|
101
|
+
|
|
102
|
+
def build_backbone_network(self):
|
|
103
|
+
if self.has_backbone:
|
|
104
|
+
return Backbone(
|
|
105
|
+
self._base_model_config.backbone,
|
|
106
|
+
self._feature_dict,
|
|
107
|
+
input_layer=self._input_layer,
|
|
108
|
+
l2_reg=self._l2_reg)
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def has_backbone(self):
|
|
113
|
+
return self._base_model_config.HasField('backbone')
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def backbone(self):
|
|
117
|
+
if self._backbone_output:
|
|
118
|
+
return self._backbone_output
|
|
119
|
+
if self._backbone_net:
|
|
120
|
+
kwargs = {
|
|
121
|
+
'loss_dict': self._loss_dict,
|
|
122
|
+
'metric_dict': self._metric_dict,
|
|
123
|
+
'prediction_dict': self._prediction_dict,
|
|
124
|
+
'labels': self._labels,
|
|
125
|
+
constant.SAMPLE_WEIGHT: self._sample_weight
|
|
126
|
+
}
|
|
127
|
+
return self._backbone_net(self._is_training, **kwargs)
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def embedding_regularization(self):
|
|
132
|
+
return self._base_model_config.embedding_regularization
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def kd(self):
|
|
136
|
+
return self._base_model_config.kd
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def feature_groups(self):
|
|
140
|
+
return self._base_model_config.feature_groups
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def l2_regularization(self):
|
|
144
|
+
model_config = getattr(self._base_model_config,
|
|
145
|
+
self._base_model_config.WhichOneof('model'))
|
|
146
|
+
l2_regularization = 0.0
|
|
147
|
+
if hasattr(model_config, 'dense_regularization') and \
|
|
148
|
+
model_config.HasField('dense_regularization'):
|
|
149
|
+
# backward compatibility
|
|
150
|
+
logging.warn(
|
|
151
|
+
'dense_regularization is deprecated, please use l2_regularization')
|
|
152
|
+
l2_regularization = model_config.dense_regularization
|
|
153
|
+
elif hasattr(model_config, 'l2_regularization'):
|
|
154
|
+
l2_regularization = model_config.l2_regularization
|
|
155
|
+
return l2_regularization
|
|
156
|
+
|
|
157
|
+
def build_input_layer(self, model_config, feature_configs):
|
|
158
|
+
self._input_layer = input_layer.InputLayer(
|
|
159
|
+
feature_configs,
|
|
160
|
+
model_config.feature_groups,
|
|
161
|
+
wide_output_dim=self._wide_output_dim,
|
|
162
|
+
ev_params=self._global_ev_params,
|
|
163
|
+
embedding_regularizer=self._emb_reg,
|
|
164
|
+
kernel_regularizer=self._l2_reg,
|
|
165
|
+
variational_dropout_config=model_config.variational_dropout
|
|
166
|
+
if model_config.HasField('variational_dropout') else None,
|
|
167
|
+
is_training=self._is_training,
|
|
168
|
+
is_predicting=self._is_predicting)
|
|
169
|
+
|
|
170
|
+
@abstractmethod
|
|
171
|
+
def build_predict_graph(self):
|
|
172
|
+
pass
|
|
173
|
+
|
|
174
|
+
@abstractmethod
|
|
175
|
+
def build_loss_graph(self):
|
|
176
|
+
pass
|
|
177
|
+
|
|
178
|
+
def build_metric_graph(self, eval_config):
|
|
179
|
+
return self._metric_dict
|
|
180
|
+
|
|
181
|
+
@abstractmethod
|
|
182
|
+
def get_outputs(self):
|
|
183
|
+
pass
|
|
184
|
+
|
|
185
|
+
def build_output_dict(self):
|
|
186
|
+
"""For exporting: get standard output nodes."""
|
|
187
|
+
outputs = {}
|
|
188
|
+
for name in self.get_outputs():
|
|
189
|
+
if name not in self._prediction_dict:
|
|
190
|
+
raise KeyError(
|
|
191
|
+
'output node {} not in prediction_dict, can not be exported'.format(
|
|
192
|
+
name))
|
|
193
|
+
outputs[name] = self._prediction_dict[name]
|
|
194
|
+
return outputs
|
|
195
|
+
|
|
196
|
+
def build_feature_output_dict(self):
|
|
197
|
+
"""For exporting: get output feature nodes."""
|
|
198
|
+
outputs = {}
|
|
199
|
+
for feature_name in self._feature_dict:
|
|
200
|
+
out_name = 'feature_' + feature_name
|
|
201
|
+
feature_value = self._feature_dict[feature_name]
|
|
202
|
+
if isinstance(feature_value, tf.SparseTensor):
|
|
203
|
+
sparse_values = feature_value.values
|
|
204
|
+
if sparse_values.dtype != tf.string:
|
|
205
|
+
sparse_values = tf.as_string(sparse_values)
|
|
206
|
+
feature_value = tf.sparse_to_dense(feature_value.indices,
|
|
207
|
+
feature_value.dense_shape,
|
|
208
|
+
sparse_values, '')
|
|
209
|
+
elif feature_value.dtype != tf.string:
|
|
210
|
+
feature_value = tf.as_string(feature_value)
|
|
211
|
+
feature_value = tf.reduce_join(feature_value, axis=-1, separator=',')
|
|
212
|
+
outputs[out_name] = feature_value
|
|
213
|
+
return outputs
|
|
214
|
+
|
|
215
|
+
def build_rtp_output_dict(self):
|
|
216
|
+
"""For exporting: get output nodes for RTP infering."""
|
|
217
|
+
return {}
|
|
218
|
+
|
|
219
|
+
def restore(self,
|
|
220
|
+
ckpt_path,
|
|
221
|
+
include_global_step=False,
|
|
222
|
+
ckpt_var_map_path='',
|
|
223
|
+
force_restore_shape_compatible=False):
|
|
224
|
+
"""Restore variables from ckpt_path.
|
|
225
|
+
|
|
226
|
+
steps:
|
|
227
|
+
1. list the variables in graph that need to be restored
|
|
228
|
+
2. inspect checkpoint and find the variables that could restore from checkpoint
|
|
229
|
+
substitute scope names in case necessary
|
|
230
|
+
3. call tf.train.init_from_checkpoint to restore the variables
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
ckpt_path: checkpoint path to restore from
|
|
234
|
+
include_global_step: whether to restore global_step variable
|
|
235
|
+
ckpt_var_map_path: variable map from graph variables to variables in a checkpoint
|
|
236
|
+
each line consists of: variable name in graph variable name in ckpt
|
|
237
|
+
force_restore_shape_compatible: if variable shape is incompatible, clip or pad
|
|
238
|
+
variables in checkpoint, and then restore
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
IncompatibleShapeRestoreHook if force_shape_compatible else None
|
|
242
|
+
"""
|
|
243
|
+
name2var_map = self._get_restore_vars(ckpt_var_map_path)
|
|
244
|
+
logging.info('start to restore from %s' % ckpt_path)
|
|
245
|
+
|
|
246
|
+
ckpt_reader = tf.train.NewCheckpointReader(ckpt_path)
|
|
247
|
+
ckpt_var2shape_map = ckpt_reader.get_variable_to_shape_map()
|
|
248
|
+
if not include_global_step:
|
|
249
|
+
ckpt_var2shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
|
|
250
|
+
|
|
251
|
+
vars_in_ckpt = {}
|
|
252
|
+
incompatible_shape_var_map = {}
|
|
253
|
+
fail_restore_vars = []
|
|
254
|
+
for variable_name, variable in sorted(name2var_map.items()):
|
|
255
|
+
if variable_name in ckpt_var2shape_map:
|
|
256
|
+
print('restore %s' % variable_name)
|
|
257
|
+
ckpt_var_shape = ckpt_var2shape_map[variable_name]
|
|
258
|
+
if type(variable) == list:
|
|
259
|
+
shape_arr = [x.get_shape() for x in variable]
|
|
260
|
+
var_shape = list(shape_arr[0])
|
|
261
|
+
for x in shape_arr[1:]:
|
|
262
|
+
var_shape[0] += x[0]
|
|
263
|
+
var_shape = tensor_shape.TensorShape(var_shape)
|
|
264
|
+
variable = variables.PartitionedVariable(
|
|
265
|
+
variable_name,
|
|
266
|
+
var_shape,
|
|
267
|
+
variable[0].dtype,
|
|
268
|
+
variable,
|
|
269
|
+
partitions=[len(variable)] + [1] * (len(var_shape) - 1))
|
|
270
|
+
else:
|
|
271
|
+
var_shape = variable.shape.as_list()
|
|
272
|
+
if ckpt_var_shape == var_shape:
|
|
273
|
+
vars_in_ckpt[variable_name] = list(variable) if isinstance(
|
|
274
|
+
variable, variables.PartitionedVariable) else variable
|
|
275
|
+
elif len(ckpt_var_shape) == len(var_shape):
|
|
276
|
+
if force_restore_shape_compatible:
|
|
277
|
+
# create a variable compatible with checkpoint to restore
|
|
278
|
+
dtype = variable[0].dtype if isinstance(variable,
|
|
279
|
+
list) else variable.dtype
|
|
280
|
+
with tf.variable_scope('incompatible_shape_restore'):
|
|
281
|
+
tmp_var = tf.get_variable(
|
|
282
|
+
name=variable_name + '_T_E_M_P',
|
|
283
|
+
shape=ckpt_var_shape,
|
|
284
|
+
trainable=False,
|
|
285
|
+
# add to a special collection for easy reference
|
|
286
|
+
# by tf.get_collection('T_E_M_P_RESTROE')
|
|
287
|
+
collections=['T_E_M_P_RESTROE'],
|
|
288
|
+
dtype=dtype)
|
|
289
|
+
vars_in_ckpt[variable_name] = tmp_var
|
|
290
|
+
incompatible_shape_var_map[variable] = tmp_var
|
|
291
|
+
print('incompatible restore %s[%s, %s]' %
|
|
292
|
+
(variable_name, str(var_shape), str(ckpt_var_shape)))
|
|
293
|
+
else:
|
|
294
|
+
logging.warning(
|
|
295
|
+
'Variable [%s] is available in checkpoint, but '
|
|
296
|
+
'incompatible shape with model variable.', variable_name)
|
|
297
|
+
else:
|
|
298
|
+
logging.warning(
|
|
299
|
+
'Variable [%s] is available in checkpoint, but '
|
|
300
|
+
'incompatible shape dims with model variable.', variable_name)
|
|
301
|
+
elif 'EmbeddingVariable' in str(type(variable)):
|
|
302
|
+
if '%s-keys' % variable_name not in ckpt_var2shape_map:
|
|
303
|
+
continue
|
|
304
|
+
print('restore embedding_variable %s' % variable_name)
|
|
305
|
+
from tensorflow.python.training import saver
|
|
306
|
+
names_to_saveables = saver.BaseSaverBuilder.OpListToDict([variable])
|
|
307
|
+
saveable_objects = []
|
|
308
|
+
for name, op in names_to_saveables.items():
|
|
309
|
+
for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name):
|
|
310
|
+
saveable_objects.append(s)
|
|
311
|
+
init_op = saveable_objects[0].restore([ckpt_path], None)
|
|
312
|
+
variable._initializer_op = init_op
|
|
313
|
+
elif type(variable) == list and 'EmbeddingVariable' in str(
|
|
314
|
+
type(variable[0])):
|
|
315
|
+
if '%s/part_0-keys' % variable_name not in ckpt_var2shape_map:
|
|
316
|
+
continue
|
|
317
|
+
print('restore partitioned embedding_variable %s' % variable_name)
|
|
318
|
+
from tensorflow.python.training import saver
|
|
319
|
+
for part_var in variable:
|
|
320
|
+
names_to_saveables = saver.BaseSaverBuilder.OpListToDict([part_var])
|
|
321
|
+
saveable_objects = []
|
|
322
|
+
for name, op in names_to_saveables.items():
|
|
323
|
+
for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name):
|
|
324
|
+
saveable_objects.append(s)
|
|
325
|
+
init_op = saveable_objects[0].restore([ckpt_path], None)
|
|
326
|
+
part_var._initializer_op = init_op
|
|
327
|
+
elif sok is not None and isinstance(variable, sok.DynamicVariable):
|
|
328
|
+
print('restore dynamic_variable %s' % variable_name)
|
|
329
|
+
keys, vals = load_embed_lib.load_kv_embed(
|
|
330
|
+
task_index=hvd.rank(),
|
|
331
|
+
task_num=hvd.size(),
|
|
332
|
+
embed_dim=variable._dimension,
|
|
333
|
+
var_name='embed-' + variable.name.replace('/', '__'),
|
|
334
|
+
ckpt_path=ckpt_path)
|
|
335
|
+
with ops.control_dependencies([variable._initializer_op]):
|
|
336
|
+
variable._initializer_op = dynamic_variable_ops.dummy_var_assign(
|
|
337
|
+
variable.handle, keys, vals)
|
|
338
|
+
else:
|
|
339
|
+
fail_restore_vars.append(variable_name)
|
|
340
|
+
for variable_name in fail_restore_vars:
|
|
341
|
+
if 'Momentum' not in variable_name:
|
|
342
|
+
logging.warning('Variable [%s] is not available in checkpoint',
|
|
343
|
+
variable_name)
|
|
344
|
+
|
|
345
|
+
tf.train.init_from_checkpoint(ckpt_path, vars_in_ckpt)
|
|
346
|
+
|
|
347
|
+
if force_restore_shape_compatible:
|
|
348
|
+
return estimator_utils.IncompatibleShapeRestoreHook(
|
|
349
|
+
incompatible_shape_var_map)
|
|
350
|
+
else:
|
|
351
|
+
return None
|
|
352
|
+
|
|
353
|
+
def _get_restore_vars(self, ckpt_var_map_path):
|
|
354
|
+
"""Restore by specify variable map between graph variables and ckpt variables.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
ckpt_var_map_path: variable map from graph variables to variables in a checkpoint
|
|
358
|
+
each line consists of: variable name in graph variable name in ckpt
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
the list of variables which need to restore from checkpoint
|
|
362
|
+
"""
|
|
363
|
+
# here must use global_variables, because variables such as moving_mean
|
|
364
|
+
# and moving_variance is usually not trainable in detection models
|
|
365
|
+
all_vars = tf.global_variables()
|
|
366
|
+
PARTITION_PATTERN = '/part_[0-9]+'
|
|
367
|
+
VAR_SUFIX_PATTERN = ':[0-9]$'
|
|
368
|
+
|
|
369
|
+
name2var = {}
|
|
370
|
+
for one_var in all_vars:
|
|
371
|
+
var_name = re.sub(VAR_SUFIX_PATTERN, '', one_var.name)
|
|
372
|
+
if re.search(PARTITION_PATTERN,
|
|
373
|
+
var_name) and one_var._save_slice_info is not None:
|
|
374
|
+
var_name = re.sub(PARTITION_PATTERN, '', var_name)
|
|
375
|
+
is_part = True
|
|
376
|
+
else:
|
|
377
|
+
is_part = False
|
|
378
|
+
if var_name in name2var:
|
|
379
|
+
assert is_part, 'multiple vars: %s' % var_name
|
|
380
|
+
name2var[var_name].append(one_var)
|
|
381
|
+
else:
|
|
382
|
+
name2var[var_name] = [one_var] if is_part else one_var
|
|
383
|
+
|
|
384
|
+
if ckpt_var_map_path != '':
|
|
385
|
+
if not gfile.Exists(ckpt_var_map_path):
|
|
386
|
+
logging.warning('%s not exist' % ckpt_var_map_path)
|
|
387
|
+
return name2var
|
|
388
|
+
|
|
389
|
+
# load var map
|
|
390
|
+
name_map = {}
|
|
391
|
+
with gfile.GFile(ckpt_var_map_path, 'r') as fin:
|
|
392
|
+
for one_line in fin:
|
|
393
|
+
one_line = one_line.strip()
|
|
394
|
+
line_tok = [x for x in one_line.split() if x != '']
|
|
395
|
+
if len(line_tok) != 2:
|
|
396
|
+
logging.warning('Failed to process: %s' % one_line)
|
|
397
|
+
continue
|
|
398
|
+
name_map[line_tok[0]] = line_tok[1]
|
|
399
|
+
update_map = {}
|
|
400
|
+
old_keys = []
|
|
401
|
+
for var_name in name2var:
|
|
402
|
+
if var_name in name_map:
|
|
403
|
+
in_ckpt_name = name_map[var_name]
|
|
404
|
+
update_map[in_ckpt_name] = name2var[var_name]
|
|
405
|
+
old_keys.append(var_name)
|
|
406
|
+
for tmp_key in old_keys:
|
|
407
|
+
del name2var[tmp_key]
|
|
408
|
+
name2var.update(update_map)
|
|
409
|
+
return name2var
|
|
410
|
+
else:
|
|
411
|
+
var_filter, scope_update = self.get_restore_filter()
|
|
412
|
+
if var_filter is not None:
|
|
413
|
+
name2var = {
|
|
414
|
+
var_name: name2var[var_name]
|
|
415
|
+
for var in name2var
|
|
416
|
+
if var_filter.keep(var.name)
|
|
417
|
+
}
|
|
418
|
+
# drop scope prefix if necessary
|
|
419
|
+
if scope_update is not None:
|
|
420
|
+
name2var = {
|
|
421
|
+
scope_update(var_name): name2var[var_name] for var_name in name2var
|
|
422
|
+
}
|
|
423
|
+
return name2var
|
|
424
|
+
|
|
425
|
+
def get_restore_filter(self):
|
|
426
|
+
"""Get restore variable filter.
|
|
427
|
+
|
|
428
|
+
Return:
|
|
429
|
+
filter: type of Filter in restore_filter.py
|
|
430
|
+
scope_drop: type of ScopeDrop in restore_filter.py
|
|
431
|
+
"""
|
|
432
|
+
if len(self._base_model_config.restore_filters) == 0:
|
|
433
|
+
return None, None
|
|
434
|
+
|
|
435
|
+
for x in self._base_model_config.restore_filters:
|
|
436
|
+
logging.info('restore will filter out pattern %s' % x)
|
|
437
|
+
|
|
438
|
+
all_filters = [
|
|
439
|
+
restore_filter.KeywordFilter(x, True)
|
|
440
|
+
for x in self._base_model_config.restore_filters
|
|
441
|
+
]
|
|
442
|
+
|
|
443
|
+
return restore_filter.CombineFilter(all_filters,
|
|
444
|
+
restore_filter.Logical.AND), None
|
|
445
|
+
|
|
446
|
+
def get_grouped_vars(self, opt_num):
|
|
447
|
+
"""Group the vars into different optimization groups.
|
|
448
|
+
|
|
449
|
+
Each group will be optimized by a separate optimizer.
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
opt_num: number of optimizers from easyrec config.
|
|
453
|
+
|
|
454
|
+
Return:
|
|
455
|
+
list of list of variables.
|
|
456
|
+
"""
|
|
457
|
+
assert opt_num == 2, 'could only support 2 optimizers, one for embedding, one for the other layers'
|
|
458
|
+
|
|
459
|
+
embedding_vars = []
|
|
460
|
+
deep_vars = []
|
|
461
|
+
for tmp_var in variables.trainable_variables():
|
|
462
|
+
if tmp_var.name.startswith(
|
|
463
|
+
'input_layer') or '/embedding_weights' in tmp_var.name:
|
|
464
|
+
embedding_vars.append(tmp_var)
|
|
465
|
+
else:
|
|
466
|
+
deep_vars.append(tmp_var)
|
|
467
|
+
return [embedding_vars, deep_vars]
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
|
|
7
|
+
from easy_rec.python.layers import dnn
|
|
8
|
+
from easy_rec.python.model.multi_task_model import MultiTaskModel
|
|
9
|
+
from easy_rec.python.protos.esmm_pb2 import ESMM as ESMMConfig
|
|
10
|
+
from easy_rec.python.protos.loss_pb2 import LossType
|
|
11
|
+
|
|
12
|
+
if tf.__version__ >= '2.0':
|
|
13
|
+
tf = tf.compat.v1
|
|
14
|
+
losses = tf.losses
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ESMM(MultiTaskModel):
|
|
18
|
+
|
|
19
|
+
def __init__(self,
|
|
20
|
+
model_config,
|
|
21
|
+
feature_configs,
|
|
22
|
+
features,
|
|
23
|
+
labels=None,
|
|
24
|
+
is_training=False):
|
|
25
|
+
super(ESMM, self).__init__(model_config, feature_configs, features, labels,
|
|
26
|
+
is_training)
|
|
27
|
+
assert self._model_config.WhichOneof('model') == 'esmm', \
|
|
28
|
+
'invalid model config: %s' % self._model_config.WhichOneof('model')
|
|
29
|
+
self._model_config = self._model_config.esmm
|
|
30
|
+
assert isinstance(self._model_config, ESMMConfig)
|
|
31
|
+
|
|
32
|
+
self._group_num = len(self._model_config.groups)
|
|
33
|
+
self._group_features = []
|
|
34
|
+
if self.has_backbone:
|
|
35
|
+
logging.info('use bottom backbone network')
|
|
36
|
+
elif self._group_num > 0:
|
|
37
|
+
logging.info('group_num: {0}'.format(self._group_num))
|
|
38
|
+
for group_id in range(self._group_num):
|
|
39
|
+
group = self._model_config.groups[group_id]
|
|
40
|
+
group_feature, _ = self._input_layer(self._feature_dict, group.input)
|
|
41
|
+
self._group_features.append(group_feature)
|
|
42
|
+
else:
|
|
43
|
+
group_feature, _ = self._input_layer(self._feature_dict, 'all')
|
|
44
|
+
self._group_features.append(group_feature)
|
|
45
|
+
|
|
46
|
+
# This model only supports two tasks (cvr+ctr or playtime+ctr).
|
|
47
|
+
# In order to be consistent with the paper,
|
|
48
|
+
# we call these two towers cvr_tower (main tower) and ctr_tower (aux tower).
|
|
49
|
+
self._cvr_tower_cfg = self._model_config.cvr_tower
|
|
50
|
+
self._ctr_tower_cfg = self._model_config.ctr_tower
|
|
51
|
+
self._init_towers([self._cvr_tower_cfg, self._ctr_tower_cfg])
|
|
52
|
+
|
|
53
|
+
assert self._model_config.ctr_tower.loss_type == LossType.CLASSIFICATION, \
|
|
54
|
+
'ctr tower must be binary classification.'
|
|
55
|
+
for task_tower_cfg in self._task_towers:
|
|
56
|
+
assert task_tower_cfg.num_class == 1, 'Does not support multiclass classification problem'
|
|
57
|
+
|
|
58
|
+
def build_loss_graph(self):
|
|
59
|
+
"""Build loss graph.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
self._loss_dict: Weighted loss of ctr and cvr.
|
|
63
|
+
"""
|
|
64
|
+
cvr_tower_name = self._cvr_tower_cfg.tower_name
|
|
65
|
+
ctr_tower_name = self._ctr_tower_cfg.tower_name
|
|
66
|
+
cvr_label_name = self._label_name_dict[cvr_tower_name]
|
|
67
|
+
ctr_label_name = self._label_name_dict[ctr_tower_name]
|
|
68
|
+
if self._cvr_tower_cfg.loss_type == LossType.CLASSIFICATION:
|
|
69
|
+
ctcvr_label = tf.cast(
|
|
70
|
+
self._labels[cvr_label_name] * self._labels[ctr_label_name],
|
|
71
|
+
tf.float32)
|
|
72
|
+
cvr_losses = tf.keras.backend.binary_crossentropy(
|
|
73
|
+
ctcvr_label, self._prediction_dict['probs_ctcvr'])
|
|
74
|
+
cvr_loss = tf.reduce_sum(cvr_losses, name='ctcvr_loss')
|
|
75
|
+
# The weight defaults to 1.
|
|
76
|
+
self._loss_dict['weighted_cross_entropy_loss_%s' %
|
|
77
|
+
cvr_tower_name] = self._cvr_tower_cfg.weight * cvr_loss
|
|
78
|
+
|
|
79
|
+
elif self._cvr_tower_cfg.loss_type == LossType.L2_LOSS:
|
|
80
|
+
logging.info('l2 loss is used')
|
|
81
|
+
cvr_dtype = self._labels[cvr_label_name].dtype
|
|
82
|
+
ctcvr_label = self._labels[cvr_label_name] * tf.cast(
|
|
83
|
+
self._labels[ctr_label_name], cvr_dtype)
|
|
84
|
+
cvr_loss = tf.losses.mean_squared_error(
|
|
85
|
+
labels=ctcvr_label,
|
|
86
|
+
predictions=self._prediction_dict['y_ctcvr'],
|
|
87
|
+
weights=self._sample_weight)
|
|
88
|
+
self._loss_dict['weighted_l2_loss_%s' %
|
|
89
|
+
cvr_tower_name] = self._cvr_tower_cfg.weight * cvr_loss
|
|
90
|
+
_labels = tf.cast(self._labels[ctr_label_name], tf.float32)
|
|
91
|
+
_logits = self._prediction_dict['logits_%s' % ctr_tower_name]
|
|
92
|
+
cross = tf.nn.sigmoid_cross_entropy_with_logits(
|
|
93
|
+
labels=_labels, logits=_logits, name='ctr_loss')
|
|
94
|
+
ctr_loss = tf.reduce_sum(cross)
|
|
95
|
+
self._loss_dict['weighted_cross_entropy_loss_%s' %
|
|
96
|
+
ctr_tower_name] = self._ctr_tower_cfg.weight * ctr_loss
|
|
97
|
+
return self._loss_dict
|
|
98
|
+
|
|
99
|
+
def build_metric_graph(self, eval_config):
|
|
100
|
+
"""Build metric graph.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
eval_config: Evaluation configuration.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
metric_dict: Calculate AUC of ctr, cvr and ctrvr.
|
|
107
|
+
"""
|
|
108
|
+
metric_dict = {}
|
|
109
|
+
|
|
110
|
+
cvr_tower_name = self._cvr_tower_cfg.tower_name
|
|
111
|
+
ctr_tower_name = self._ctr_tower_cfg.tower_name
|
|
112
|
+
cvr_label_name = self._label_name_dict[cvr_tower_name]
|
|
113
|
+
ctr_label_name = self._label_name_dict[ctr_tower_name]
|
|
114
|
+
for metric in self._cvr_tower_cfg.metrics_set:
|
|
115
|
+
# CTCVR metric
|
|
116
|
+
ctcvr_label_name = cvr_label_name + '_ctcvr'
|
|
117
|
+
cvr_dtype = self._labels[cvr_label_name].dtype
|
|
118
|
+
self._labels[ctcvr_label_name] = self._labels[cvr_label_name] * tf.cast(
|
|
119
|
+
self._labels[ctr_label_name], cvr_dtype)
|
|
120
|
+
metric_dict.update(
|
|
121
|
+
self._build_metric_impl(
|
|
122
|
+
metric,
|
|
123
|
+
loss_type=self._cvr_tower_cfg.loss_type,
|
|
124
|
+
label_name=ctcvr_label_name,
|
|
125
|
+
num_class=self._cvr_tower_cfg.num_class,
|
|
126
|
+
suffix='_ctcvr'))
|
|
127
|
+
|
|
128
|
+
# CVR metric
|
|
129
|
+
cvr_label_masked_name = cvr_label_name + '_masked'
|
|
130
|
+
ctr_mask = self._labels[ctr_label_name] > 0
|
|
131
|
+
self._labels[cvr_label_masked_name] = tf.boolean_mask(
|
|
132
|
+
self._labels[cvr_label_name], ctr_mask)
|
|
133
|
+
pred_prefix = 'probs' if self._cvr_tower_cfg.loss_type == LossType.CLASSIFICATION else 'y'
|
|
134
|
+
pred_name = '%s_%s' % (pred_prefix, cvr_tower_name)
|
|
135
|
+
self._prediction_dict[pred_name + '_masked'] = tf.boolean_mask(
|
|
136
|
+
self._prediction_dict[pred_name], ctr_mask)
|
|
137
|
+
metric_dict.update(
|
|
138
|
+
self._build_metric_impl(
|
|
139
|
+
metric,
|
|
140
|
+
loss_type=self._cvr_tower_cfg.loss_type,
|
|
141
|
+
label_name=cvr_label_masked_name,
|
|
142
|
+
num_class=self._cvr_tower_cfg.num_class,
|
|
143
|
+
suffix='_%s_masked' % cvr_tower_name))
|
|
144
|
+
|
|
145
|
+
for metric in self._ctr_tower_cfg.metrics_set:
|
|
146
|
+
# CTR metric
|
|
147
|
+
metric_dict.update(
|
|
148
|
+
self._build_metric_impl(
|
|
149
|
+
metric,
|
|
150
|
+
loss_type=self._ctr_tower_cfg.loss_type,
|
|
151
|
+
label_name=ctr_label_name,
|
|
152
|
+
num_class=self._ctr_tower_cfg.num_class,
|
|
153
|
+
suffix='_%s' % ctr_tower_name))
|
|
154
|
+
return metric_dict
|
|
155
|
+
|
|
156
|
+
def _add_to_prediction_dict(self, output):
|
|
157
|
+
super(ESMM, self)._add_to_prediction_dict(output)
|
|
158
|
+
if self._cvr_tower_cfg.loss_type == LossType.CLASSIFICATION:
|
|
159
|
+
prob = tf.multiply(
|
|
160
|
+
self._prediction_dict['probs_%s' % self._cvr_tower_cfg.tower_name],
|
|
161
|
+
self._prediction_dict['probs_%s' % self._ctr_tower_cfg.tower_name])
|
|
162
|
+
# pctcvr = pctr * pcvr
|
|
163
|
+
self._prediction_dict['probs_ctcvr'] = prob
|
|
164
|
+
|
|
165
|
+
else:
|
|
166
|
+
prob = tf.multiply(
|
|
167
|
+
self._prediction_dict['y_%s' % self._cvr_tower_cfg.tower_name],
|
|
168
|
+
self._prediction_dict['probs_%s' % self._ctr_tower_cfg.tower_name])
|
|
169
|
+
# pctcvr = pctr * pcvr
|
|
170
|
+
self._prediction_dict['y_ctcvr'] = prob
|
|
171
|
+
|
|
172
|
+
def build_predict_graph(self):
|
|
173
|
+
"""Forward function.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
self._prediction_dict: Prediction result of two tasks.
|
|
177
|
+
"""
|
|
178
|
+
if self.has_backbone:
|
|
179
|
+
all_fea = self.backbone
|
|
180
|
+
elif self._group_num > 0:
|
|
181
|
+
group_fea_arr = []
|
|
182
|
+
# Both towers share the underlying network.
|
|
183
|
+
for group_id in range(self._group_num):
|
|
184
|
+
group_fea = self._group_features[group_id]
|
|
185
|
+
group = self._model_config.groups[group_id]
|
|
186
|
+
group_name = group.input
|
|
187
|
+
dnn_model = dnn.DNN(group.dnn, self._l2_reg, group_name,
|
|
188
|
+
self._is_training)
|
|
189
|
+
group_fea = dnn_model(group_fea)
|
|
190
|
+
group_fea_arr.append(group_fea)
|
|
191
|
+
all_fea = tf.concat(group_fea_arr, axis=1)
|
|
192
|
+
else:
|
|
193
|
+
all_fea = self._group_features[0]
|
|
194
|
+
|
|
195
|
+
cvr_tower_name = self._cvr_tower_cfg.tower_name
|
|
196
|
+
dnn_model = dnn.DNN(
|
|
197
|
+
self._cvr_tower_cfg.dnn,
|
|
198
|
+
self._l2_reg,
|
|
199
|
+
name=cvr_tower_name,
|
|
200
|
+
is_training=self._is_training)
|
|
201
|
+
cvr_tower_output = dnn_model(all_fea)
|
|
202
|
+
cvr_tower_output = tf.layers.dense(
|
|
203
|
+
inputs=cvr_tower_output,
|
|
204
|
+
units=1,
|
|
205
|
+
kernel_regularizer=self._l2_reg,
|
|
206
|
+
name='%s/dnn_output' % cvr_tower_name)
|
|
207
|
+
|
|
208
|
+
ctr_tower_name = self._ctr_tower_cfg.tower_name
|
|
209
|
+
dnn_model = dnn.DNN(
|
|
210
|
+
self._ctr_tower_cfg.dnn,
|
|
211
|
+
self._l2_reg,
|
|
212
|
+
name=ctr_tower_name,
|
|
213
|
+
is_training=self._is_training)
|
|
214
|
+
ctr_tower_output = dnn_model(all_fea)
|
|
215
|
+
ctr_tower_output = tf.layers.dense(
|
|
216
|
+
inputs=ctr_tower_output,
|
|
217
|
+
units=1,
|
|
218
|
+
kernel_regularizer=self._l2_reg,
|
|
219
|
+
name='%s/dnn_output' % ctr_tower_name)
|
|
220
|
+
|
|
221
|
+
tower_outputs = {
|
|
222
|
+
cvr_tower_name: cvr_tower_output,
|
|
223
|
+
ctr_tower_name: ctr_tower_output
|
|
224
|
+
}
|
|
225
|
+
self._add_to_prediction_dict(tower_outputs)
|
|
226
|
+
return self._prediction_dict
|
|
227
|
+
|
|
228
|
+
def get_outputs(self):
|
|
229
|
+
"""Get model outputs.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
outputs: The list of tensor names output by the model.
|
|
233
|
+
"""
|
|
234
|
+
outputs = super(ESMM, self).get_outputs()
|
|
235
|
+
if self._cvr_tower_cfg.loss_type == LossType.CLASSIFICATION:
|
|
236
|
+
outputs.append('probs_ctcvr')
|
|
237
|
+
elif self._cvr_tower_cfg.loss_type == LossType.L2_LOSS:
|
|
238
|
+
outputs.append('y_ctcvr')
|
|
239
|
+
else:
|
|
240
|
+
raise ValueError('invalid cvr_tower loss type: %s' %
|
|
241
|
+
str(self._cvr_tower_cfg.loss_type))
|
|
242
|
+
return outputs
|