easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
easy_rec/__init__.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import platform
|
|
7
|
+
import sys
|
|
8
|
+
|
|
9
|
+
from easy_rec.version import __version__
|
|
10
|
+
|
|
11
|
+
curr_dir, _ = os.path.split(__file__)
|
|
12
|
+
parent_dir = os.path.dirname(curr_dir)
|
|
13
|
+
sys.path.insert(0, parent_dir)
|
|
14
|
+
|
|
15
|
+
logging.basicConfig(
|
|
16
|
+
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
|
|
17
|
+
|
|
18
|
+
# Avoid import tensorflow which conflicts with the version used in EasyRecProcessor
|
|
19
|
+
if 'PROCESSOR_TEST' not in os.environ:
|
|
20
|
+
from tensorflow.python.platform import tf_logging
|
|
21
|
+
# In DeepRec, logger.propagate of tf_logging is False, should be True
|
|
22
|
+
tf_logging._logger.propagate = True
|
|
23
|
+
|
|
24
|
+
def get_ops_dir():
|
|
25
|
+
import tensorflow as tf
|
|
26
|
+
if platform.system() == 'Linux':
|
|
27
|
+
ops_dir = os.path.join(curr_dir, 'python/ops')
|
|
28
|
+
if 'PAI' in tf.__version__:
|
|
29
|
+
ops_dir = os.path.join(ops_dir, '1.12_pai')
|
|
30
|
+
elif tf.__version__.startswith('1.12'):
|
|
31
|
+
ops_dir = os.path.join(ops_dir, '1.12')
|
|
32
|
+
elif tf.__version__.startswith('1.15'):
|
|
33
|
+
if 'IS_ON_PAI' in os.environ:
|
|
34
|
+
ops_dir = os.path.join(ops_dir, 'DeepRec')
|
|
35
|
+
else:
|
|
36
|
+
ops_dir = os.path.join(ops_dir, '1.15')
|
|
37
|
+
else:
|
|
38
|
+
tmp_version = tf.__version__.split('.')
|
|
39
|
+
tmp_version = '.'.join(tmp_version[:2])
|
|
40
|
+
ops_dir = os.path.join(ops_dir, tmp_version)
|
|
41
|
+
return ops_dir
|
|
42
|
+
else:
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
ops_dir = get_ops_dir()
|
|
46
|
+
if ops_dir is not None and not os.path.exists(ops_dir):
|
|
47
|
+
logging.warning('ops_dir[%s] does not exist' % ops_dir)
|
|
48
|
+
ops_dir = None
|
|
49
|
+
|
|
50
|
+
from easy_rec.python.inference.predictor import Predictor # isort:skip # noqa: E402
|
|
51
|
+
from easy_rec.python.main import evaluate # isort:skip # noqa: E402
|
|
52
|
+
from easy_rec.python.main import distribute_evaluate # isort:skip # noqa: E402
|
|
53
|
+
from easy_rec.python.main import export # isort:skip # noqa: E402
|
|
54
|
+
from easy_rec.python.main import train_and_evaluate # isort:skip # noqa: E402
|
|
55
|
+
from easy_rec.python.main import export_checkpoint # isort:skip # noqa: E402
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
import tensorflow_io.oss
|
|
59
|
+
except Exception:
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
print('easy_rec version: %s' % __version__)
|
|
63
|
+
print('Usage: easy_rec.help()')
|
|
64
|
+
|
|
65
|
+
_global_config = {}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def help():
|
|
69
|
+
print("""
|
|
70
|
+
1 Train
|
|
71
|
+
1.1 Train 1gpu
|
|
72
|
+
CUDA_VISIBLE_DEVICES=0 python -m easy_rec.python.train_eval
|
|
73
|
+
--pipeline_config_path deepfm_combo_on_avazu_ctr.config
|
|
74
|
+
1.2 Train 2gpu
|
|
75
|
+
sh scripts/train_2gpu.sh deepfm_combo_on_avazu_ctr.config
|
|
76
|
+
2 Eval
|
|
77
|
+
CUDA_VISIBLE_DEVICES=0 python -m easy_rec.python.eval
|
|
78
|
+
--pipeline_config_path deepfm_combo_on_avazu_ctr.config
|
|
79
|
+
3 Export
|
|
80
|
+
CUDA_VISIBLE_DEVICES=""
|
|
81
|
+
python -m easy_rec.python.export
|
|
82
|
+
--pipeline_config_path deepfm_combo_on_avazu_ctr.config
|
|
83
|
+
--export_dir models/export
|
|
84
|
+
4 Create config from excel
|
|
85
|
+
python -m easy_rec.python.tools.create_config_from_excel
|
|
86
|
+
--excel_path dwd_avazu_ctr_multi_tower.xls
|
|
87
|
+
--output_path dwd_avazu_ctr_multi_tower.config
|
|
88
|
+
5. Inference:
|
|
89
|
+
# use list input
|
|
90
|
+
import csv
|
|
91
|
+
from easy_rec.python.inference.predictor import Predictor
|
|
92
|
+
predictor = Predictor(SAVED_MODEL_DIR)
|
|
93
|
+
with open(INPUT_CSV, 'r') as fin:
|
|
94
|
+
reader = csv.reader(fin)
|
|
95
|
+
inputs = []
|
|
96
|
+
for row in reader:
|
|
97
|
+
inputs.append(row[1:])
|
|
98
|
+
output_res = self._predictor.predict(inputs, batch_size=32)
|
|
99
|
+
|
|
100
|
+
# use dict input
|
|
101
|
+
import csv
|
|
102
|
+
from easy_rec.python.inference.predictor import Predictor
|
|
103
|
+
predictor = Predictor(SAVED_MODEL_DIR)
|
|
104
|
+
field_keys = [ "field1", "field2", "field3", "field4", "field5",
|
|
105
|
+
"field6", "field7", "field8", "field9", "field10",
|
|
106
|
+
"field11", "field12", "field13", "field14", "field15",
|
|
107
|
+
"field16", "field17", "field18", "field19", "field20" ]
|
|
108
|
+
with open(INPUT_CSV, 'r') as fin:
|
|
109
|
+
reader = csv.reader(fin)
|
|
110
|
+
inputs = []
|
|
111
|
+
for row in reader:
|
|
112
|
+
inputs.append({ f : row[fid+1] for fid, f in enumerate(field_keys) })
|
|
113
|
+
output_res = self._predictor.predict(inputs, batch_size=32)
|
|
114
|
+
""")
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
"""Builder function to construct tf-slim arg_scope for convolution, fc ops."""
|
|
17
|
+
import tensorflow as tf
|
|
18
|
+
|
|
19
|
+
from easy_rec.python.compat import regularizers
|
|
20
|
+
|
|
21
|
+
if tf.__version__ >= '2.0':
|
|
22
|
+
tf = tf.compat.v1
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def build_regularizer(regularizer):
|
|
26
|
+
"""Builds a tensorflow regularizer from config.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
regularizer: hyperparams_pb2.Hyperparams.regularizer proto.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
tensorflow regularizer.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
ValueError: On unknown regularizer.
|
|
36
|
+
"""
|
|
37
|
+
regularizer_oneof = regularizer.WhichOneof('regularizer_oneof')
|
|
38
|
+
if regularizer_oneof == 'l1_regularizer':
|
|
39
|
+
return regularizers.l1_regularizer(
|
|
40
|
+
scale=float(regularizer.l1_regularizer.scale))
|
|
41
|
+
if regularizer_oneof == 'l2_regularizer':
|
|
42
|
+
return regularizers.l2_regularizer(
|
|
43
|
+
scale=float(regularizer.l2_regularizer.scale))
|
|
44
|
+
if regularizer_oneof == 'l1_l2_regularizer':
|
|
45
|
+
return regularizers.l1_l2_regularizer(
|
|
46
|
+
scale_l1=float(regularizer.l1_l2_regularizer.scale_l1),
|
|
47
|
+
scale_l2=float(regularizer.l1_l2_regularizer.scale_l2))
|
|
48
|
+
|
|
49
|
+
raise ValueError('Unknown regularizer function: {}'.format(regularizer_oneof))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def build_initializer(initializer):
|
|
53
|
+
"""Build a tf initializer from config.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
initializer: hyperparams_pb2.Hyperparams.regularizer proto.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
tf initializer.
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
ValueError: On unknown initializer.
|
|
63
|
+
"""
|
|
64
|
+
initializer_oneof = initializer.WhichOneof('initializer_oneof')
|
|
65
|
+
if initializer_oneof == 'truncated_normal_initializer':
|
|
66
|
+
return tf.truncated_normal_initializer(
|
|
67
|
+
mean=initializer.truncated_normal_initializer.mean,
|
|
68
|
+
stddev=initializer.truncated_normal_initializer.stddev)
|
|
69
|
+
if initializer_oneof == 'random_normal_initializer':
|
|
70
|
+
return tf.random_normal_initializer(
|
|
71
|
+
mean=initializer.random_normal_initializer.mean,
|
|
72
|
+
stddev=initializer.random_normal_initializer.stddev)
|
|
73
|
+
if initializer_oneof == 'glorot_normal_initializer':
|
|
74
|
+
return tf.glorot_normal_initializer()
|
|
75
|
+
if initializer_oneof == 'constant_initializer':
|
|
76
|
+
return tf.constant_initializer(
|
|
77
|
+
[x for x in initializer.constant_initializer.consts])
|
|
78
|
+
raise ValueError('Unknown initializer function: {}'.format(initializer_oneof))
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
|
|
8
|
+
from easy_rec.python.loss.focal_loss import sigmoid_focal_loss_with_logits
|
|
9
|
+
from easy_rec.python.loss.jrc_loss import jrc_loss
|
|
10
|
+
from easy_rec.python.loss.listwise_loss import listwise_distill_loss
|
|
11
|
+
from easy_rec.python.loss.listwise_loss import listwise_rank_loss
|
|
12
|
+
from easy_rec.python.loss.pairwise_loss import pairwise_focal_loss
|
|
13
|
+
from easy_rec.python.loss.pairwise_loss import pairwise_hinge_loss
|
|
14
|
+
from easy_rec.python.loss.pairwise_loss import pairwise_logistic_loss
|
|
15
|
+
from easy_rec.python.loss.pairwise_loss import pairwise_loss
|
|
16
|
+
from easy_rec.python.protos.loss_pb2 import LossType
|
|
17
|
+
|
|
18
|
+
from easy_rec.python.loss.zero_inflated_lognormal import zero_inflated_lognormal_loss # NOQA
|
|
19
|
+
|
|
20
|
+
from easy_rec.python.loss.f1_reweight_loss import f1_reweight_sigmoid_cross_entropy # NOQA
|
|
21
|
+
|
|
22
|
+
if tf.__version__ >= '2.0':
|
|
23
|
+
tf = tf.compat.v1
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def build(loss_type,
|
|
27
|
+
label,
|
|
28
|
+
pred,
|
|
29
|
+
loss_weight=1.0,
|
|
30
|
+
num_class=1,
|
|
31
|
+
loss_param=None,
|
|
32
|
+
**kwargs):
|
|
33
|
+
loss_name = kwargs.pop('loss_name') if 'loss_name' in kwargs else 'unknown'
|
|
34
|
+
if loss_type == LossType.CLASSIFICATION:
|
|
35
|
+
if num_class == 1:
|
|
36
|
+
return tf.losses.sigmoid_cross_entropy(
|
|
37
|
+
label, logits=pred, weights=loss_weight, **kwargs)
|
|
38
|
+
else:
|
|
39
|
+
assert label.dtype in [tf.int32, tf.int64], \
|
|
40
|
+
'label.dtype must in [tf.int32, tf.int64] when use sparse_softmax_cross_entropy.'
|
|
41
|
+
return tf.losses.sparse_softmax_cross_entropy(
|
|
42
|
+
labels=label, logits=pred, weights=loss_weight, **kwargs)
|
|
43
|
+
elif loss_type == LossType.CROSS_ENTROPY_LOSS:
|
|
44
|
+
return tf.losses.log_loss(label, pred, weights=loss_weight, **kwargs)
|
|
45
|
+
elif loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS:
|
|
46
|
+
losses = tf.keras.backend.binary_crossentropy(label, pred, from_logits=True)
|
|
47
|
+
return tf.reduce_mean(losses)
|
|
48
|
+
elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
|
|
49
|
+
logging.info('%s is used' % LossType.Name(loss_type))
|
|
50
|
+
return tf.losses.mean_squared_error(
|
|
51
|
+
labels=label, predictions=pred, weights=loss_weight, **kwargs)
|
|
52
|
+
elif loss_type == LossType.ZILN_LOSS:
|
|
53
|
+
loss = zero_inflated_lognormal_loss(label, pred)
|
|
54
|
+
if np.isscalar(loss_weight) and loss_weight != 1.0:
|
|
55
|
+
return loss * loss_weight
|
|
56
|
+
return loss
|
|
57
|
+
elif loss_type == LossType.JRC_LOSS:
|
|
58
|
+
session = kwargs.get('session_ids', None)
|
|
59
|
+
if loss_param is None:
|
|
60
|
+
return jrc_loss(label, pred, session, name=loss_name)
|
|
61
|
+
return jrc_loss(
|
|
62
|
+
label,
|
|
63
|
+
pred,
|
|
64
|
+
session,
|
|
65
|
+
loss_param.alpha,
|
|
66
|
+
loss_weight_strategy=loss_param.loss_weight_strategy,
|
|
67
|
+
sample_weights=loss_weight,
|
|
68
|
+
same_label_loss=loss_param.same_label_loss,
|
|
69
|
+
name=loss_name)
|
|
70
|
+
elif loss_type == LossType.PAIR_WISE_LOSS:
|
|
71
|
+
session = kwargs.get('session_ids', None)
|
|
72
|
+
margin = 0 if loss_param is None else loss_param.margin
|
|
73
|
+
temp = 1.0 if loss_param is None else loss_param.temperature
|
|
74
|
+
return pairwise_loss(
|
|
75
|
+
label,
|
|
76
|
+
pred,
|
|
77
|
+
session_ids=session,
|
|
78
|
+
margin=margin,
|
|
79
|
+
temperature=temp,
|
|
80
|
+
weights=loss_weight,
|
|
81
|
+
name=loss_name)
|
|
82
|
+
elif loss_type == LossType.PAIRWISE_LOGISTIC_LOSS:
|
|
83
|
+
session = kwargs.get('session_ids', None)
|
|
84
|
+
temp = 1.0 if loss_param is None else loss_param.temperature
|
|
85
|
+
ohem_ratio = 1.0 if loss_param is None else loss_param.ohem_ratio
|
|
86
|
+
hinge_margin = None
|
|
87
|
+
if loss_param is not None and loss_param.HasField('hinge_margin'):
|
|
88
|
+
hinge_margin = loss_param.hinge_margin
|
|
89
|
+
lbl_margin = False if loss_param is None else loss_param.use_label_margin
|
|
90
|
+
return pairwise_logistic_loss(
|
|
91
|
+
label,
|
|
92
|
+
pred,
|
|
93
|
+
session_ids=session,
|
|
94
|
+
temperature=temp,
|
|
95
|
+
hinge_margin=hinge_margin,
|
|
96
|
+
ohem_ratio=ohem_ratio,
|
|
97
|
+
weights=loss_weight,
|
|
98
|
+
use_label_margin=lbl_margin,
|
|
99
|
+
name=loss_name)
|
|
100
|
+
elif loss_type == LossType.PAIRWISE_HINGE_LOSS:
|
|
101
|
+
session = kwargs.get('session_ids', None)
|
|
102
|
+
temp, ohem_ratio, margin = 1.0, 1.0, 1.0
|
|
103
|
+
label_is_logits, use_label_margin, use_exponent = True, True, False
|
|
104
|
+
if loss_param is not None:
|
|
105
|
+
temp = loss_param.temperature
|
|
106
|
+
ohem_ratio = loss_param.ohem_ratio
|
|
107
|
+
margin = loss_param.margin
|
|
108
|
+
label_is_logits = loss_param.label_is_logits
|
|
109
|
+
use_label_margin = loss_param.use_label_margin
|
|
110
|
+
use_exponent = loss_param.use_exponent
|
|
111
|
+
return pairwise_hinge_loss(
|
|
112
|
+
label,
|
|
113
|
+
pred,
|
|
114
|
+
session_ids=session,
|
|
115
|
+
temperature=temp,
|
|
116
|
+
margin=margin,
|
|
117
|
+
ohem_ratio=ohem_ratio,
|
|
118
|
+
weights=loss_weight,
|
|
119
|
+
label_is_logits=label_is_logits,
|
|
120
|
+
use_label_margin=use_label_margin,
|
|
121
|
+
use_exponent=use_exponent,
|
|
122
|
+
name=loss_name)
|
|
123
|
+
elif loss_type == LossType.PAIRWISE_FOCAL_LOSS:
|
|
124
|
+
session = kwargs.get('session_ids', None)
|
|
125
|
+
if loss_param is None:
|
|
126
|
+
return pairwise_focal_loss(
|
|
127
|
+
label, pred, session_ids=session, weights=loss_weight, name=loss_name)
|
|
128
|
+
hinge_margin = None
|
|
129
|
+
if loss_param.HasField('hinge_margin'):
|
|
130
|
+
hinge_margin = loss_param.hinge_margin
|
|
131
|
+
return pairwise_focal_loss(
|
|
132
|
+
label,
|
|
133
|
+
pred,
|
|
134
|
+
session_ids=session,
|
|
135
|
+
gamma=loss_param.gamma,
|
|
136
|
+
alpha=loss_param.alpha if loss_param.HasField('alpha') else None,
|
|
137
|
+
hinge_margin=hinge_margin,
|
|
138
|
+
ohem_ratio=loss_param.ohem_ratio,
|
|
139
|
+
temperature=loss_param.temperature,
|
|
140
|
+
weights=loss_weight,
|
|
141
|
+
name=loss_name)
|
|
142
|
+
elif loss_type == LossType.LISTWISE_RANK_LOSS:
|
|
143
|
+
session = kwargs.get('session_ids', None)
|
|
144
|
+
trans_fn, temp, label_is_logits, scale = None, 1.0, False, False
|
|
145
|
+
if loss_param is not None:
|
|
146
|
+
temp = loss_param.temperature
|
|
147
|
+
label_is_logits = loss_param.label_is_logits
|
|
148
|
+
scale = loss_param.scale_logits
|
|
149
|
+
if loss_param.HasField('transform_fn'):
|
|
150
|
+
trans_fn = loss_param.transform_fn
|
|
151
|
+
return listwise_rank_loss(
|
|
152
|
+
label,
|
|
153
|
+
pred,
|
|
154
|
+
session,
|
|
155
|
+
temperature=temp,
|
|
156
|
+
label_is_logits=label_is_logits,
|
|
157
|
+
transform_fn=trans_fn,
|
|
158
|
+
scale_logits=scale,
|
|
159
|
+
weights=loss_weight)
|
|
160
|
+
elif loss_type == LossType.LISTWISE_DISTILL_LOSS:
|
|
161
|
+
session = kwargs.get('session_ids', None)
|
|
162
|
+
trans_fn, temp, label_clip_max_value, scale = None, 1.0, 512.0, False
|
|
163
|
+
if loss_param is not None:
|
|
164
|
+
temp = loss_param.temperature
|
|
165
|
+
label_clip_max_value = loss_param.label_clip_max_value
|
|
166
|
+
scale = loss_param.scale_logits
|
|
167
|
+
if loss_param.HasField('transform_fn'):
|
|
168
|
+
trans_fn = loss_param.transform_fn
|
|
169
|
+
return listwise_distill_loss(
|
|
170
|
+
label,
|
|
171
|
+
pred,
|
|
172
|
+
session,
|
|
173
|
+
temperature=temp,
|
|
174
|
+
label_clip_max_value=label_clip_max_value,
|
|
175
|
+
transform_fn=trans_fn,
|
|
176
|
+
scale_logits=scale,
|
|
177
|
+
weights=loss_weight)
|
|
178
|
+
elif loss_type == LossType.F1_REWEIGHTED_LOSS:
|
|
179
|
+
f1_beta_square = 1.0 if loss_param is None else loss_param.f1_beta_square
|
|
180
|
+
label_smoothing = 0 if loss_param is None else loss_param.label_smoothing
|
|
181
|
+
return f1_reweight_sigmoid_cross_entropy(
|
|
182
|
+
label,
|
|
183
|
+
pred,
|
|
184
|
+
f1_beta_square,
|
|
185
|
+
weights=loss_weight,
|
|
186
|
+
label_smoothing=label_smoothing)
|
|
187
|
+
elif loss_type == LossType.BINARY_FOCAL_LOSS:
|
|
188
|
+
if loss_param is None:
|
|
189
|
+
return sigmoid_focal_loss_with_logits(
|
|
190
|
+
label, pred, sample_weights=loss_weight, name=loss_name)
|
|
191
|
+
gamma = loss_param.gamma
|
|
192
|
+
alpha = None
|
|
193
|
+
if loss_param.HasField('alpha'):
|
|
194
|
+
alpha = loss_param.alpha
|
|
195
|
+
return sigmoid_focal_loss_with_logits(
|
|
196
|
+
label,
|
|
197
|
+
pred,
|
|
198
|
+
gamma=gamma,
|
|
199
|
+
alpha=alpha,
|
|
200
|
+
ohem_ratio=loss_param.ohem_ratio,
|
|
201
|
+
sample_weights=loss_weight,
|
|
202
|
+
label_smoothing=loss_param.label_smoothing,
|
|
203
|
+
name=loss_name)
|
|
204
|
+
else:
|
|
205
|
+
raise ValueError('unsupported loss type: %s' % LossType.Name(loss_type))
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def build_kd_loss(kds, prediction_dict, label_dict, feature_dict):
|
|
209
|
+
"""Build knowledge distillation loss.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
kds: list of knowledge distillation object of type KD.
|
|
213
|
+
prediction_dict: dict of predict_name to predict tensors.
|
|
214
|
+
label_dict: ordered dict of label_name to label tensors.
|
|
215
|
+
feature_dict: dict of feature name to feature value
|
|
216
|
+
|
|
217
|
+
Return:
|
|
218
|
+
knowledge distillation loss will be add to loss_dict with key: kd_loss.
|
|
219
|
+
"""
|
|
220
|
+
loss_dict = {}
|
|
221
|
+
for kd in kds:
|
|
222
|
+
assert kd.pred_name in prediction_dict, \
|
|
223
|
+
'invalid predict_name: %s available ones: %s' % (
|
|
224
|
+
kd.pred_name, ','.join(prediction_dict.keys()))
|
|
225
|
+
|
|
226
|
+
loss_name = kd.loss_name
|
|
227
|
+
if not loss_name:
|
|
228
|
+
loss_name = 'kd_loss_' + kd.pred_name.replace('/', '_')
|
|
229
|
+
loss_name += '_' + kd.soft_label_name.replace('/', '_')
|
|
230
|
+
|
|
231
|
+
loss_weight = kd.loss_weight
|
|
232
|
+
if kd.HasField('task_space_indicator_name') and kd.HasField(
|
|
233
|
+
'task_space_indicator_value'):
|
|
234
|
+
in_task_space = tf.to_float(
|
|
235
|
+
tf.equal(feature_dict[kd.task_space_indicator_name],
|
|
236
|
+
kd.task_space_indicator_value))
|
|
237
|
+
loss_weight = loss_weight * (
|
|
238
|
+
kd.in_task_space_weight * in_task_space + kd.out_task_space_weight *
|
|
239
|
+
(1 - in_task_space))
|
|
240
|
+
|
|
241
|
+
label = label_dict[kd.soft_label_name]
|
|
242
|
+
pred = prediction_dict[kd.pred_name]
|
|
243
|
+
epsilon = tf.keras.backend.epsilon()
|
|
244
|
+
num_class = 1 if len(pred.get_shape()) < 2 else pred.get_shape()[-1]
|
|
245
|
+
|
|
246
|
+
if kd.loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS:
|
|
247
|
+
if not kd.label_is_logits: # label is prob
|
|
248
|
+
label = tf.clip_by_value(label, epsilon, 1 - epsilon)
|
|
249
|
+
label = tf.log(label / (1 - label))
|
|
250
|
+
if not kd.pred_is_logits:
|
|
251
|
+
pred = tf.clip_by_value(pred, epsilon, 1 - epsilon)
|
|
252
|
+
pred = tf.log(pred / (1 - pred))
|
|
253
|
+
if kd.temperature > 0:
|
|
254
|
+
label = label / kd.temperature
|
|
255
|
+
pred = pred / kd.temperature
|
|
256
|
+
label = tf.nn.sigmoid(label) # convert to prob
|
|
257
|
+
elif kd.loss_type == LossType.KL_DIVERGENCE_LOSS:
|
|
258
|
+
if not kd.label_is_logits: # label is prob
|
|
259
|
+
if num_class == 1: # for binary classification
|
|
260
|
+
label = tf.clip_by_value(label, epsilon, 1 - epsilon)
|
|
261
|
+
label = tf.log(label / (1 - label))
|
|
262
|
+
else:
|
|
263
|
+
label = tf.math.log(label + epsilon)
|
|
264
|
+
label -= tf.reduce_max(label)
|
|
265
|
+
if not kd.pred_is_logits:
|
|
266
|
+
if num_class == 1: # for binary classification
|
|
267
|
+
pred = tf.clip_by_value(pred, epsilon, 1 - epsilon)
|
|
268
|
+
pred = tf.log(pred / (1 - pred))
|
|
269
|
+
else:
|
|
270
|
+
pred = tf.math.log(pred + epsilon)
|
|
271
|
+
pred -= tf.reduce_max(pred)
|
|
272
|
+
if kd.temperature > 0:
|
|
273
|
+
label = label / kd.temperature
|
|
274
|
+
pred = pred / kd.temperature
|
|
275
|
+
if num_class > 1:
|
|
276
|
+
label = tf.nn.softmax(label)
|
|
277
|
+
pred = tf.nn.softmax(pred)
|
|
278
|
+
else:
|
|
279
|
+
label = tf.nn.sigmoid(label) # convert to prob
|
|
280
|
+
pred = tf.nn.sigmoid(pred) # convert to prob
|
|
281
|
+
elif kd.loss_type == LossType.CROSS_ENTROPY_LOSS:
|
|
282
|
+
if not kd.label_is_logits:
|
|
283
|
+
label = tf.math.log(label + epsilon)
|
|
284
|
+
if not kd.pred_is_logits:
|
|
285
|
+
pred = tf.math.log(pred + epsilon)
|
|
286
|
+
if kd.temperature > 0:
|
|
287
|
+
label = label / kd.temperature
|
|
288
|
+
pred = pred / kd.temperature
|
|
289
|
+
if num_class > 1:
|
|
290
|
+
label = tf.nn.softmax(label)
|
|
291
|
+
pred = tf.nn.softmax(pred)
|
|
292
|
+
elif num_class == 1:
|
|
293
|
+
label = tf.nn.sigmoid(label)
|
|
294
|
+
pred = tf.nn.sigmoid(pred)
|
|
295
|
+
|
|
296
|
+
if kd.loss_type == LossType.KL_DIVERGENCE_LOSS:
|
|
297
|
+
if num_class == 1:
|
|
298
|
+
label = tf.expand_dims(label, 1) # [B, 1]
|
|
299
|
+
labels = tf.concat([1 - label, label], axis=1) # [B, 2]
|
|
300
|
+
pred = tf.expand_dims(pred, 1) # [B, 1]
|
|
301
|
+
preds = tf.concat([1 - pred, pred], axis=1) # [B, 2]
|
|
302
|
+
else:
|
|
303
|
+
labels = label
|
|
304
|
+
preds = pred
|
|
305
|
+
losses = tf.keras.losses.KLD(labels, preds)
|
|
306
|
+
loss_dict[loss_name] = tf.reduce_mean(
|
|
307
|
+
losses, name=loss_name) * loss_weight
|
|
308
|
+
elif kd.loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS:
|
|
309
|
+
losses = tf.keras.backend.binary_crossentropy(
|
|
310
|
+
label, pred, from_logits=True)
|
|
311
|
+
loss_dict[loss_name] = tf.reduce_mean(
|
|
312
|
+
losses, name=loss_name) * loss_weight
|
|
313
|
+
elif kd.loss_type == LossType.CROSS_ENTROPY_LOSS:
|
|
314
|
+
loss_dict[loss_name] = tf.losses.log_loss(
|
|
315
|
+
label, pred, weights=loss_weight)
|
|
316
|
+
elif kd.loss_type == LossType.L2_LOSS:
|
|
317
|
+
loss_dict[loss_name] = tf.losses.mean_squared_error(
|
|
318
|
+
labels=label, predictions=pred, weights=loss_weight)
|
|
319
|
+
else:
|
|
320
|
+
loss_param = kd.WhichOneof('loss_param')
|
|
321
|
+
kwargs = {}
|
|
322
|
+
if loss_param is not None:
|
|
323
|
+
loss_param = getattr(kd, loss_param)
|
|
324
|
+
if hasattr(loss_param, 'session_name'):
|
|
325
|
+
kwargs['session_ids'] = feature_dict[loss_param.session_name]
|
|
326
|
+
loss_dict[loss_name] = build(
|
|
327
|
+
kd.loss_type,
|
|
328
|
+
label,
|
|
329
|
+
pred,
|
|
330
|
+
loss_weight=loss_weight,
|
|
331
|
+
loss_param=loss_param,
|
|
332
|
+
**kwargs)
|
|
333
|
+
return loss_dict
|