easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
from tensorflow.python.ops.losses.losses_impl import compute_weighted_loss
|
|
7
|
+
|
|
8
|
+
from easy_rec.python.loss.focal_loss import sigmoid_focal_loss_with_logits
|
|
9
|
+
from easy_rec.python.utils.shape_utils import get_shape_list
|
|
10
|
+
|
|
11
|
+
if tf.__version__ >= '2.0':
|
|
12
|
+
tf = tf.compat.v1
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def pairwise_loss(labels,
|
|
16
|
+
logits,
|
|
17
|
+
session_ids=None,
|
|
18
|
+
margin=0,
|
|
19
|
+
temperature=1.0,
|
|
20
|
+
weights=1.0,
|
|
21
|
+
name=''):
|
|
22
|
+
"""Deprecated Pairwise loss. Also see `pairwise_logistic_loss` below.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
labels: a `Tensor` with shape [batch_size]. e.g. click or not click in the session.
|
|
26
|
+
logits: a `Tensor` with shape [batch_size]. e.g. the value of last neuron before activation.
|
|
27
|
+
session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
|
|
28
|
+
margin: the margin between positive and negative sample pair
|
|
29
|
+
temperature: (Optional) The temperature to use for scaling the logits.
|
|
30
|
+
weights: sample weights
|
|
31
|
+
name: the name of loss
|
|
32
|
+
"""
|
|
33
|
+
logging.warning(
|
|
34
|
+
'The old `pairwise_loss` is being deprecated. '
|
|
35
|
+
'Please use the new `pairwise_logistic_loss` or `pairwise_focal_loss`')
|
|
36
|
+
loss_name = name if name else 'pairwise_loss'
|
|
37
|
+
logging.info('[{}] margin: {}, temperature: {}'.format(
|
|
38
|
+
loss_name, margin, temperature))
|
|
39
|
+
|
|
40
|
+
if temperature != 1.0:
|
|
41
|
+
logits /= temperature
|
|
42
|
+
pairwise_logits = tf.math.subtract(
|
|
43
|
+
tf.expand_dims(logits, -1), tf.expand_dims(logits, 0)) - margin
|
|
44
|
+
pairwise_mask = tf.greater(
|
|
45
|
+
tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
|
|
46
|
+
if session_ids is not None:
|
|
47
|
+
logging.info('[%s] use session ids' % loss_name)
|
|
48
|
+
group_equal = tf.equal(
|
|
49
|
+
tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0))
|
|
50
|
+
pairwise_mask = tf.logical_and(pairwise_mask, group_equal)
|
|
51
|
+
|
|
52
|
+
pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
|
|
53
|
+
num_pair = tf.size(pairwise_logits)
|
|
54
|
+
tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair)
|
|
55
|
+
|
|
56
|
+
if tf.is_numeric_tensor(weights):
|
|
57
|
+
logging.info('[%s] use sample weight' % loss_name)
|
|
58
|
+
weights = tf.expand_dims(tf.cast(weights, tf.float32), -1)
|
|
59
|
+
batch_size, _ = get_shape_list(weights, 2)
|
|
60
|
+
pairwise_weights = tf.tile(weights, tf.stack([1, batch_size]))
|
|
61
|
+
pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask)
|
|
62
|
+
else:
|
|
63
|
+
pairwise_weights = weights
|
|
64
|
+
|
|
65
|
+
pairwise_pseudo_labels = tf.ones_like(pairwise_logits)
|
|
66
|
+
loss = tf.losses.sigmoid_cross_entropy(
|
|
67
|
+
pairwise_pseudo_labels, pairwise_logits, weights=pairwise_weights)
|
|
68
|
+
# set rank loss to zero if a batch has no positive sample.
|
|
69
|
+
# loss = tf.where(tf.is_nan(loss), tf.zeros_like(loss), loss)
|
|
70
|
+
return loss
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def pairwise_focal_loss(labels,
|
|
74
|
+
logits,
|
|
75
|
+
session_ids=None,
|
|
76
|
+
hinge_margin=None,
|
|
77
|
+
gamma=2,
|
|
78
|
+
alpha=None,
|
|
79
|
+
ohem_ratio=1.0,
|
|
80
|
+
temperature=1.0,
|
|
81
|
+
weights=1.0,
|
|
82
|
+
name=''):
|
|
83
|
+
loss_name = name if name else 'pairwise_focal_loss'
|
|
84
|
+
assert 0 < ohem_ratio <= 1.0, loss_name + ' ohem_ratio must be in (0, 1]'
|
|
85
|
+
logging.info(
|
|
86
|
+
'[{}] hinge margin: {}, gamma: {}, alpha: {}, ohem_ratio: {}, temperature: {}'
|
|
87
|
+
.format(loss_name, hinge_margin, gamma, alpha, ohem_ratio, temperature))
|
|
88
|
+
|
|
89
|
+
if temperature != 1.0:
|
|
90
|
+
logits /= temperature
|
|
91
|
+
pairwise_logits = tf.expand_dims(logits, -1) - tf.expand_dims(logits, 0)
|
|
92
|
+
|
|
93
|
+
pairwise_mask = tf.greater(
|
|
94
|
+
tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
|
|
95
|
+
if hinge_margin is not None:
|
|
96
|
+
hinge_mask = tf.less(pairwise_logits, hinge_margin)
|
|
97
|
+
pairwise_mask = tf.logical_and(pairwise_mask, hinge_mask)
|
|
98
|
+
if session_ids is not None:
|
|
99
|
+
logging.info('[%s] use session ids' % loss_name)
|
|
100
|
+
group_equal = tf.equal(
|
|
101
|
+
tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0))
|
|
102
|
+
pairwise_mask = tf.logical_and(pairwise_mask, group_equal)
|
|
103
|
+
|
|
104
|
+
pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
|
|
105
|
+
num_pair = tf.size(pairwise_logits)
|
|
106
|
+
tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair)
|
|
107
|
+
|
|
108
|
+
if tf.is_numeric_tensor(weights):
|
|
109
|
+
logging.info('[%s] use sample weight' % loss_name)
|
|
110
|
+
weights = tf.expand_dims(tf.cast(weights, tf.float32), -1)
|
|
111
|
+
batch_size, _ = get_shape_list(weights, 2)
|
|
112
|
+
pairwise_weights = tf.tile(weights, tf.stack([1, batch_size]))
|
|
113
|
+
pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask)
|
|
114
|
+
else:
|
|
115
|
+
pairwise_weights = weights
|
|
116
|
+
|
|
117
|
+
pairwise_pseudo_labels = tf.ones_like(pairwise_logits)
|
|
118
|
+
loss = sigmoid_focal_loss_with_logits(
|
|
119
|
+
pairwise_pseudo_labels,
|
|
120
|
+
pairwise_logits,
|
|
121
|
+
gamma=gamma,
|
|
122
|
+
alpha=alpha,
|
|
123
|
+
ohem_ratio=ohem_ratio,
|
|
124
|
+
sample_weights=pairwise_weights)
|
|
125
|
+
return loss
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def pairwise_logistic_loss(labels,
|
|
129
|
+
logits,
|
|
130
|
+
session_ids=None,
|
|
131
|
+
temperature=1.0,
|
|
132
|
+
hinge_margin=None,
|
|
133
|
+
weights=1.0,
|
|
134
|
+
ohem_ratio=1.0,
|
|
135
|
+
use_label_margin=False,
|
|
136
|
+
name=''):
|
|
137
|
+
r"""Computes pairwise logistic loss between `labels` and `logits`, equivalent to RankNet loss.
|
|
138
|
+
|
|
139
|
+
Definition:
|
|
140
|
+
$$
|
|
141
|
+
\mathcal{L}(\{y\}, \{s\}) =
|
|
142
|
+
\sum_i \sum_j I[y_i > y_j] \log(1 + \exp(-(s_i - s_j)))
|
|
143
|
+
$$
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
labels: A `Tensor` of the same shape as `logits` representing graded
|
|
147
|
+
relevance.
|
|
148
|
+
logits: A `Tensor` with shape [batch_size].
|
|
149
|
+
session_ids: a `Tensor` with shape [batch_size]. Session ids of each
|
|
150
|
+
sample, used to max GAUC metric. e.g. user_id
|
|
151
|
+
temperature: (Optional) The temperature to use for scaling the logits.
|
|
152
|
+
hinge_margin: the margin between positive and negative logits
|
|
153
|
+
weights: A scalar, a `Tensor` with shape [batch_size] for each sample
|
|
154
|
+
ohem_ratio: the percent of hard examples to be mined
|
|
155
|
+
use_label_margin: whether to use the diff `label[i]-label[j]` as margin
|
|
156
|
+
name: the name of loss
|
|
157
|
+
"""
|
|
158
|
+
loss_name = name if name else 'pairwise_logistic_loss'
|
|
159
|
+
assert 0 < ohem_ratio <= 1.0, loss_name + ' ohem_ratio must be in (0, 1]'
|
|
160
|
+
logging.info('[{}] hinge margin: {}, ohem_ratio: {}, temperature: {}'.format(
|
|
161
|
+
loss_name, hinge_margin, ohem_ratio, temperature))
|
|
162
|
+
|
|
163
|
+
if temperature != 1.0:
|
|
164
|
+
logits /= temperature
|
|
165
|
+
if use_label_margin:
|
|
166
|
+
labels /= temperature
|
|
167
|
+
|
|
168
|
+
pairwise_logits = tf.math.subtract(
|
|
169
|
+
tf.expand_dims(logits, -1), tf.expand_dims(logits, 0))
|
|
170
|
+
if use_label_margin:
|
|
171
|
+
pairwise_logits -= tf.math.subtract(
|
|
172
|
+
tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
|
|
173
|
+
elif hinge_margin is not None:
|
|
174
|
+
pairwise_logits -= hinge_margin
|
|
175
|
+
|
|
176
|
+
pairwise_mask = tf.greater(
|
|
177
|
+
tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
|
|
178
|
+
if session_ids is not None:
|
|
179
|
+
logging.info('[%s] use session ids' % loss_name)
|
|
180
|
+
group_equal = tf.equal(
|
|
181
|
+
tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0))
|
|
182
|
+
pairwise_mask = tf.logical_and(pairwise_mask, group_equal)
|
|
183
|
+
|
|
184
|
+
pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
|
|
185
|
+
num_pair = tf.size(pairwise_logits)
|
|
186
|
+
tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair)
|
|
187
|
+
|
|
188
|
+
# The following is the same as log(1 + exp(-pairwise_logits)).
|
|
189
|
+
losses = tf.nn.relu(-pairwise_logits) + tf.math.log1p(
|
|
190
|
+
tf.exp(-tf.abs(pairwise_logits)))
|
|
191
|
+
|
|
192
|
+
if tf.is_numeric_tensor(weights):
|
|
193
|
+
logging.info('[%s] use sample weight' % loss_name)
|
|
194
|
+
weights = tf.expand_dims(tf.cast(weights, tf.float32), -1)
|
|
195
|
+
batch_size, _ = get_shape_list(weights, 2)
|
|
196
|
+
pairwise_weights = tf.tile(weights, tf.stack([1, batch_size]))
|
|
197
|
+
pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask)
|
|
198
|
+
else:
|
|
199
|
+
pairwise_weights = weights
|
|
200
|
+
|
|
201
|
+
if ohem_ratio == 1.0:
|
|
202
|
+
return compute_weighted_loss(losses, pairwise_weights)
|
|
203
|
+
|
|
204
|
+
losses = compute_weighted_loss(
|
|
205
|
+
losses, pairwise_weights, reduction=tf.losses.Reduction.NONE)
|
|
206
|
+
k = tf.to_float(tf.size(losses)) * tf.convert_to_tensor(ohem_ratio)
|
|
207
|
+
k = tf.to_int32(tf.math.rint(k))
|
|
208
|
+
topk = tf.nn.top_k(losses, k)
|
|
209
|
+
losses = tf.boolean_mask(topk.values, topk.values > 0)
|
|
210
|
+
return tf.reduce_mean(losses)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def pairwise_hinge_loss(labels,
|
|
214
|
+
logits,
|
|
215
|
+
session_ids=None,
|
|
216
|
+
temperature=1.0,
|
|
217
|
+
margin=1.0,
|
|
218
|
+
weights=1.0,
|
|
219
|
+
ohem_ratio=1.0,
|
|
220
|
+
label_is_logits=True,
|
|
221
|
+
use_label_margin=True,
|
|
222
|
+
use_exponent=False,
|
|
223
|
+
name=''):
|
|
224
|
+
r"""Computes pairwise hinge loss between `labels` and `logits`.
|
|
225
|
+
|
|
226
|
+
Definition:
|
|
227
|
+
$$
|
|
228
|
+
\mathcal{L}(\{y\}, \{s\}) =
|
|
229
|
+
\sum_i \sum_j I[y_i > y_j] \max(0, 1 - (s_i - s_j))
|
|
230
|
+
$$
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
labels: A `Tensor` of the same shape as `logits` representing graded
|
|
234
|
+
relevance.
|
|
235
|
+
logits: A `Tensor` with shape [batch_size].
|
|
236
|
+
session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
|
|
237
|
+
temperature: (Optional) The temperature to use for scaling the logits.
|
|
238
|
+
margin: the margin between positive and negative logits
|
|
239
|
+
weights: A scalar, a `Tensor` with shape [batch_size] for each sample
|
|
240
|
+
ohem_ratio: the percent of hard examples to be mined
|
|
241
|
+
label_is_logits: Whether `labels` is expected to be a logits tensor.
|
|
242
|
+
use_label_margin: whether to use the diff `label[i]-label[j]` as margin
|
|
243
|
+
use_exponent: whether to use exponential difference
|
|
244
|
+
name: the name of loss
|
|
245
|
+
"""
|
|
246
|
+
loss_name = name if name else 'pairwise_hinge_loss'
|
|
247
|
+
assert 0 < ohem_ratio <= 1.0, loss_name + ' ohem_ratio must be in (0, 1]'
|
|
248
|
+
logging.info(
|
|
249
|
+
'[{}] margin: {}, ohem_ratio: {}, temperature: {}, use_exponent: {}, label_is_logits: {}, use_label_margin: {}'
|
|
250
|
+
.format(loss_name, margin, ohem_ratio, temperature, use_exponent,
|
|
251
|
+
label_is_logits, use_label_margin))
|
|
252
|
+
|
|
253
|
+
if temperature != 1.0:
|
|
254
|
+
logits /= temperature
|
|
255
|
+
if label_is_logits:
|
|
256
|
+
labels /= temperature
|
|
257
|
+
if use_exponent:
|
|
258
|
+
labels = tf.nn.sigmoid(labels)
|
|
259
|
+
logits = tf.nn.sigmoid(labels)
|
|
260
|
+
|
|
261
|
+
pairwise_logits = tf.math.subtract(
|
|
262
|
+
tf.expand_dims(logits, -1), tf.expand_dims(logits, 0))
|
|
263
|
+
pairwise_labels = tf.math.subtract(
|
|
264
|
+
tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
|
|
265
|
+
|
|
266
|
+
pairwise_mask = tf.greater(pairwise_labels, 0)
|
|
267
|
+
if session_ids is not None:
|
|
268
|
+
logging.info('[%s] use session ids' % loss_name)
|
|
269
|
+
group_equal = tf.equal(
|
|
270
|
+
tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0))
|
|
271
|
+
pairwise_mask = tf.logical_and(pairwise_mask, group_equal)
|
|
272
|
+
|
|
273
|
+
pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
|
|
274
|
+
pairwise_labels = tf.boolean_mask(pairwise_labels, pairwise_mask)
|
|
275
|
+
num_pair = tf.size(pairwise_logits)
|
|
276
|
+
tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair)
|
|
277
|
+
|
|
278
|
+
if use_label_margin:
|
|
279
|
+
diff = pairwise_labels - pairwise_logits
|
|
280
|
+
else:
|
|
281
|
+
diff = margin - pairwise_logits
|
|
282
|
+
if use_exponent:
|
|
283
|
+
threshold = 88.0 # the max value of float32 is 3.4028235e+38
|
|
284
|
+
safe_diff = tf.clip_by_value(diff, -threshold, threshold)
|
|
285
|
+
losses = tf.nn.relu(tf.exp(safe_diff) - 1.0)
|
|
286
|
+
else:
|
|
287
|
+
losses = tf.nn.relu(diff)
|
|
288
|
+
|
|
289
|
+
if tf.is_numeric_tensor(weights):
|
|
290
|
+
logging.info('[%s] use sample weight' % loss_name)
|
|
291
|
+
weights = tf.expand_dims(tf.cast(weights, tf.float32), -1)
|
|
292
|
+
batch_size, _ = get_shape_list(weights, 2)
|
|
293
|
+
pairwise_weights = tf.tile(weights, tf.stack([1, batch_size]))
|
|
294
|
+
pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask)
|
|
295
|
+
else:
|
|
296
|
+
pairwise_weights = weights
|
|
297
|
+
|
|
298
|
+
if ohem_ratio == 1.0:
|
|
299
|
+
return compute_weighted_loss(losses, pairwise_weights)
|
|
300
|
+
|
|
301
|
+
losses = compute_weighted_loss(
|
|
302
|
+
losses, pairwise_weights, reduction=tf.losses.Reduction.NONE)
|
|
303
|
+
k = tf.to_float(tf.size(losses)) * tf.convert_to_tensor(ohem_ratio)
|
|
304
|
+
k = tf.to_int32(tf.math.rint(k))
|
|
305
|
+
topk = tf.nn.top_k(losses, k)
|
|
306
|
+
losses = tf.boolean_mask(topk.values, topk.values > 0)
|
|
307
|
+
return tf.reduce_mean(losses)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# coding=utf-8
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
|
|
5
|
+
if tf.__version__ >= '2.0':
|
|
6
|
+
tf = tf.compat.v1
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def support_vector_guided_softmax_loss(pos_score,
|
|
10
|
+
neg_scores,
|
|
11
|
+
margin=0,
|
|
12
|
+
t=1,
|
|
13
|
+
smooth=1.0,
|
|
14
|
+
threshold=0,
|
|
15
|
+
weights=1.0):
|
|
16
|
+
"""Refer paper: Support Vector Guided Softmax Loss for Face Recognition (https://128.84.21.199/abs/1812.11317)."""
|
|
17
|
+
new_pos_score = pos_score - margin
|
|
18
|
+
cond = tf.greater_equal(new_pos_score - neg_scores, threshold)
|
|
19
|
+
mask = tf.where(cond, tf.zeros_like(cond, tf.float32),
|
|
20
|
+
tf.ones_like(cond, tf.float32)) # I_k
|
|
21
|
+
new_neg_scores = mask * (neg_scores * t + t - 1) + (1 - mask) * neg_scores
|
|
22
|
+
logits = tf.concat([new_pos_score, new_neg_scores], axis=1)
|
|
23
|
+
if 1.0 != smooth:
|
|
24
|
+
logits *= smooth
|
|
25
|
+
|
|
26
|
+
loss = tf.losses.sparse_softmax_cross_entropy(
|
|
27
|
+
tf.zeros_like(pos_score, dtype=tf.int32), logits, weights=weights)
|
|
28
|
+
# set rank loss to zero if a batch has no positive sample.
|
|
29
|
+
loss = tf.where(tf.is_nan(loss), tf.zeros_like(loss), loss)
|
|
30
|
+
return loss
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def softmax_loss_with_negative_mining(user_emb,
|
|
34
|
+
item_emb,
|
|
35
|
+
labels,
|
|
36
|
+
num_negative_samples=4,
|
|
37
|
+
embed_normed=False,
|
|
38
|
+
weights=1.0,
|
|
39
|
+
gamma=1.0,
|
|
40
|
+
margin=0,
|
|
41
|
+
t=1,
|
|
42
|
+
seed=None):
|
|
43
|
+
"""Compute the softmax loss based on the cosine distance explained below.
|
|
44
|
+
|
|
45
|
+
Given mini batches for `user_emb` and `item_emb`, this function computes for each element in `user_emb`
|
|
46
|
+
the cosine distance between it and the corresponding `item_emb`,
|
|
47
|
+
and additionally the cosine distance between `user_emb` and some other elements of `item_emb`
|
|
48
|
+
(referred to a negative samples).
|
|
49
|
+
The negative samples are formed on the fly by shifting the right side (`item_emb`).
|
|
50
|
+
Then the softmax loss will be computed based on these cosine distance.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
user_emb: A `Tensor` with shape [batch_size, embedding_size]. The embedding of user.
|
|
54
|
+
item_emb: A `Tensor` with shape [batch_size, embedding_size]. The embedding of item.
|
|
55
|
+
labels: a `Tensor` with shape [batch_size]. e.g. click or not click in the session. It's values must be 0 or 1.
|
|
56
|
+
num_negative_samples: the num of negative samples, should be in range [1, batch_size).
|
|
57
|
+
embed_normed: bool, whether input embeddings l2 normalized
|
|
58
|
+
weights: `weights` acts as a coefficient for the loss. If a scalar is provided,
|
|
59
|
+
then the loss is simply scaled by the given value. If `weights` is a
|
|
60
|
+
tensor of shape `[batch_size]`, then the loss weights apply to each corresponding sample.
|
|
61
|
+
gamma: smooth coefficient of softmax
|
|
62
|
+
margin: the margin between positive pair and negative pair
|
|
63
|
+
t: coefficient of support vector guided softmax loss
|
|
64
|
+
seed: A Python integer. Used to create a random seed for the distribution.
|
|
65
|
+
See `tf.set_random_seed`
|
|
66
|
+
for behavior.
|
|
67
|
+
|
|
68
|
+
Return:
|
|
69
|
+
support vector guided softmax loss of positive labels
|
|
70
|
+
"""
|
|
71
|
+
assert 0 < num_negative_samples, '`num_negative_samples` should be greater than 0'
|
|
72
|
+
|
|
73
|
+
batch_size = tf.shape(item_emb)[0]
|
|
74
|
+
is_valid = tf.assert_less(
|
|
75
|
+
num_negative_samples,
|
|
76
|
+
batch_size,
|
|
77
|
+
message='`num_negative_samples` should be less than batch_size')
|
|
78
|
+
with tf.control_dependencies([is_valid]):
|
|
79
|
+
if not embed_normed:
|
|
80
|
+
user_emb = tf.nn.l2_normalize(user_emb, axis=-1)
|
|
81
|
+
item_emb = tf.nn.l2_normalize(item_emb, axis=-1)
|
|
82
|
+
|
|
83
|
+
vectors = [item_emb]
|
|
84
|
+
for i in range(num_negative_samples):
|
|
85
|
+
shift = tf.random_uniform([], 1, batch_size, dtype=tf.int32, seed=seed)
|
|
86
|
+
neg_item_emb = tf.roll(item_emb, shift, axis=0)
|
|
87
|
+
vectors.append(neg_item_emb)
|
|
88
|
+
# all_embeddings's shape: (batch_size, num_negative_samples + 1, vec_dim)
|
|
89
|
+
all_embeddings = tf.stack(vectors, axis=1)
|
|
90
|
+
|
|
91
|
+
mask = tf.greater(labels, 0)
|
|
92
|
+
mask_user_emb = tf.boolean_mask(user_emb, mask)
|
|
93
|
+
mask_item_emb = tf.boolean_mask(all_embeddings, mask)
|
|
94
|
+
if isinstance(weights, tf.Tensor):
|
|
95
|
+
weights = tf.boolean_mask(weights, mask)
|
|
96
|
+
|
|
97
|
+
# sim_scores's shape: (num_of_pos_label_in_batch_size, num_negative_samples + 1)
|
|
98
|
+
sim_scores = tf.keras.backend.batch_dot(
|
|
99
|
+
mask_user_emb, mask_item_emb, axes=(1, 2))
|
|
100
|
+
pos_score = tf.slice(sim_scores, [0, 0], [-1, 1])
|
|
101
|
+
neg_scores = tf.slice(sim_scores, [0, 1], [-1, -1])
|
|
102
|
+
|
|
103
|
+
loss = support_vector_guided_softmax_loss(
|
|
104
|
+
pos_score,
|
|
105
|
+
neg_scores,
|
|
106
|
+
margin=margin,
|
|
107
|
+
t=t,
|
|
108
|
+
smooth=gamma,
|
|
109
|
+
weights=weights)
|
|
110
|
+
return loss
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
"""Zero-inflated lognormal loss for lifetime value prediction."""
|
|
4
|
+
import tensorflow as tf
|
|
5
|
+
import tensorflow_probability as tfp
|
|
6
|
+
|
|
7
|
+
tfd = tfp.distributions
|
|
8
|
+
|
|
9
|
+
if tf.__version__ >= '2.0':
|
|
10
|
+
tf = tf.compat.v1
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def zero_inflated_lognormal_pred(logits):
|
|
14
|
+
"""Calculates predicted mean of zero inflated lognormal logits.
|
|
15
|
+
|
|
16
|
+
Arguments:
|
|
17
|
+
logits: [batch_size, 3] tensor of logits.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
positive_probs: [batch_size, 1] tensor of positive probability.
|
|
21
|
+
preds: [batch_size, 1] tensor of predicted mean.
|
|
22
|
+
"""
|
|
23
|
+
logits = tf.convert_to_tensor(logits, dtype=tf.float32)
|
|
24
|
+
positive_probs = tf.keras.backend.sigmoid(logits[..., :1])
|
|
25
|
+
loc = logits[..., 1:2]
|
|
26
|
+
scale = tf.keras.backend.softplus(logits[..., 2:])
|
|
27
|
+
preds = (
|
|
28
|
+
positive_probs *
|
|
29
|
+
tf.keras.backend.exp(loc + 0.5 * tf.keras.backend.square(scale)))
|
|
30
|
+
return positive_probs, preds
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def zero_inflated_lognormal_loss(labels, logits, name=''):
|
|
34
|
+
"""Computes the zero inflated lognormal loss.
|
|
35
|
+
|
|
36
|
+
Usage with tf.keras API:
|
|
37
|
+
|
|
38
|
+
```python
|
|
39
|
+
model = tf.keras.Model(inputs, outputs)
|
|
40
|
+
model.compile('sgd', loss=zero_inflated_lognormal)
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
Arguments:
|
|
44
|
+
labels: True targets, tensor of shape [batch_size, 1].
|
|
45
|
+
logits: Logits of output layer, tensor of shape [batch_size, 3].
|
|
46
|
+
name: the name of loss
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Zero inflated lognormal loss value.
|
|
50
|
+
"""
|
|
51
|
+
loss_name = name if name else 'ziln_loss'
|
|
52
|
+
labels = tf.cast(labels, dtype=tf.float32)
|
|
53
|
+
if labels.shape.ndims == 1:
|
|
54
|
+
labels = tf.expand_dims(labels, 1) # [B, 1]
|
|
55
|
+
positive = tf.cast(labels > 0, tf.float32)
|
|
56
|
+
|
|
57
|
+
logits = tf.convert_to_tensor(logits, dtype=tf.float32)
|
|
58
|
+
logits.shape.assert_is_compatible_with(
|
|
59
|
+
tf.TensorShape(labels.shape[:-1].as_list() + [3]))
|
|
60
|
+
|
|
61
|
+
positive_logits = logits[..., :1]
|
|
62
|
+
classification_loss = tf.keras.backend.binary_crossentropy(
|
|
63
|
+
positive, positive_logits, from_logits=True)
|
|
64
|
+
classification_loss = tf.keras.backend.mean(classification_loss)
|
|
65
|
+
tf.summary.scalar('loss/%s_classify' % loss_name, classification_loss)
|
|
66
|
+
|
|
67
|
+
loc = logits[..., 1:2]
|
|
68
|
+
scale = tf.math.maximum(
|
|
69
|
+
tf.keras.backend.softplus(logits[..., 2:]),
|
|
70
|
+
tf.math.sqrt(tf.keras.backend.epsilon()))
|
|
71
|
+
safe_labels = positive * labels + (
|
|
72
|
+
1 - positive) * tf.keras.backend.ones_like(labels)
|
|
73
|
+
regression_loss = -tf.keras.backend.mean(
|
|
74
|
+
positive * tfd.LogNormal(loc=loc, scale=scale).log_prob(safe_labels))
|
|
75
|
+
tf.summary.scalar('loss/%s_regression' % loss_name, regression_loss)
|
|
76
|
+
return classification_loss + regression_loss
|