easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
# coding=utf-8
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
|
|
5
|
+
if tf.__version__ >= '2.0':
|
|
6
|
+
tf = tf.compat.v1
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def circle_loss(embeddings,
|
|
10
|
+
labels,
|
|
11
|
+
sessions=None,
|
|
12
|
+
margin=0.25,
|
|
13
|
+
gamma=32,
|
|
14
|
+
embed_normed=False):
|
|
15
|
+
"""Paper: Circle Loss: A Unified Perspective of Pair Similarity Optimization.
|
|
16
|
+
|
|
17
|
+
Link: http://arxiv.org/pdf/2002.10857.pdf
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
embeddings: A `Tensor` with shape [batch_size, embedding_size]. The embedding of each sample.
|
|
21
|
+
labels: a `Tensor` with shape [batch_size]. e.g. click or not click in the session.
|
|
22
|
+
sessions: a `Tensor` with shape [batch_size]. session ids of each sample.
|
|
23
|
+
margin: the margin between positive similarity and negative similarity
|
|
24
|
+
gamma: parameter of circle loss
|
|
25
|
+
embed_normed: bool, whether input embeddings l2 normalized
|
|
26
|
+
"""
|
|
27
|
+
norm_embeddings = embeddings if embed_normed else tf.nn.l2_normalize(
|
|
28
|
+
embeddings, axis=-1)
|
|
29
|
+
pair_wise_cosine_matrix = tf.matmul(
|
|
30
|
+
norm_embeddings, norm_embeddings, transpose_b=True)
|
|
31
|
+
|
|
32
|
+
positive_mask = get_anchor_positive_triplet_mask(labels, sessions)
|
|
33
|
+
negative_mask = 1 - positive_mask - tf.eye(tf.shape(labels)[0])
|
|
34
|
+
|
|
35
|
+
delta_p = 1 - margin
|
|
36
|
+
delta_n = margin
|
|
37
|
+
|
|
38
|
+
ap = tf.nn.relu(-tf.stop_gradient(pair_wise_cosine_matrix * positive_mask) +
|
|
39
|
+
1 + margin)
|
|
40
|
+
an = tf.nn.relu(
|
|
41
|
+
tf.stop_gradient(pair_wise_cosine_matrix * negative_mask) + margin)
|
|
42
|
+
|
|
43
|
+
logit_p = -ap * (pair_wise_cosine_matrix -
|
|
44
|
+
delta_p) * gamma * positive_mask - (1 - positive_mask) * 1e12
|
|
45
|
+
logit_n = an * (pair_wise_cosine_matrix -
|
|
46
|
+
delta_n) * gamma * negative_mask - (1 - negative_mask) * 1e12
|
|
47
|
+
|
|
48
|
+
joint_neg_loss = tf.reduce_logsumexp(logit_n, axis=-1)
|
|
49
|
+
joint_pos_loss = tf.reduce_logsumexp(logit_p, axis=-1)
|
|
50
|
+
loss = tf.nn.softplus(joint_neg_loss + joint_pos_loss)
|
|
51
|
+
return tf.reduce_mean(loss)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_anchor_positive_triplet_mask(labels, sessions=None):
|
|
55
|
+
"""Return a 2D mask where mask[a, p] is 1.0 iff a and p are distinct and have same session and label.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
labels: a `Tensor` with shape [batch_size]
|
|
59
|
+
sessions: a `Tensor` with shape [batch_size]
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
mask: tf.float32 `Tensor` with shape [batch_size, batch_size]
|
|
63
|
+
"""
|
|
64
|
+
# Check that i and j are distinct
|
|
65
|
+
indices_equal = tf.cast(tf.eye(tf.shape(labels)[0]), tf.bool)
|
|
66
|
+
indices_not_equal = tf.logical_not(indices_equal)
|
|
67
|
+
|
|
68
|
+
# Check if labels[i] == labels[j]
|
|
69
|
+
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
|
70
|
+
labels_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
|
|
71
|
+
|
|
72
|
+
# Check if labels[i] == labels[j]
|
|
73
|
+
if sessions is None or sessions is labels:
|
|
74
|
+
class_equal = labels_equal
|
|
75
|
+
else:
|
|
76
|
+
sessions_equal = tf.equal(
|
|
77
|
+
tf.expand_dims(sessions, 0), tf.expand_dims(sessions, 1))
|
|
78
|
+
class_equal = tf.logical_and(sessions_equal, labels_equal)
|
|
79
|
+
|
|
80
|
+
# Combine the three masks
|
|
81
|
+
mask = tf.logical_and(indices_not_equal, class_equal)
|
|
82
|
+
return tf.cast(mask, tf.float32)
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
|
|
5
|
+
if tf.__version__ >= '2.0':
|
|
6
|
+
tf = tf.compat.v1
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def l2_loss(x1, x2):
|
|
10
|
+
"""Compute euclidean distance of two embeddings."""
|
|
11
|
+
distance = tf.reduce_sum(tf.square(x1 - x2), axis=-1)
|
|
12
|
+
return tf.reduce_mean(distance)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def info_nce_loss(query, positive, temperature=0.1):
|
|
16
|
+
"""Calculates the InfoNCE loss for self-supervised learning.
|
|
17
|
+
|
|
18
|
+
This contrastive loss enforces the embeddings of similar (positive) samples to be close
|
|
19
|
+
and those of different (negative) samples to be distant.
|
|
20
|
+
A query embedding is compared with one positive key and with one or more negative keys.
|
|
21
|
+
|
|
22
|
+
References:
|
|
23
|
+
https://arxiv.org/abs/1807.03748v2
|
|
24
|
+
https://arxiv.org/abs/2010.05113
|
|
25
|
+
"""
|
|
26
|
+
# Check input dimensionality.
|
|
27
|
+
if query.shape.ndims != 2:
|
|
28
|
+
raise ValueError('<query> must have 2 dimensions.')
|
|
29
|
+
if positive.shape.ndims != 2:
|
|
30
|
+
raise ValueError('<positive> must have 2 dimensions.')
|
|
31
|
+
# Embedding vectors should have same number of components.
|
|
32
|
+
if query.shape[-1] != positive.shape[-1]:
|
|
33
|
+
raise ValueError(
|
|
34
|
+
'Vectors of <query> and <positive> should have the same number of components.'
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Negative keys are implicitly off-diagonal positive keys.
|
|
38
|
+
|
|
39
|
+
# Cosine between all combinations
|
|
40
|
+
logits = tf.matmul(query, positive, transpose_b=True)
|
|
41
|
+
logits /= temperature
|
|
42
|
+
|
|
43
|
+
# Positive keys are the entries on the diagonal
|
|
44
|
+
batch_size = tf.shape(query)[0]
|
|
45
|
+
labels = tf.range(batch_size)
|
|
46
|
+
|
|
47
|
+
return tf.losses.sparse_softmax_cross_entropy(labels, logits)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_mask_matrix(batch_size):
|
|
51
|
+
mat = tf.ones((batch_size, batch_size), dtype=tf.bool)
|
|
52
|
+
diag = tf.zeros([batch_size], dtype=tf.bool)
|
|
53
|
+
mask = tf.linalg.set_diag(mat, diag)
|
|
54
|
+
mask = tf.tile(mask, [2, 2])
|
|
55
|
+
return mask
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def nce_loss(z_i, z_j, temperature=1.0):
|
|
59
|
+
"""Contrastive nce loss for homogeneous embeddings.
|
|
60
|
+
|
|
61
|
+
Refer paper: Contrastive Learning for Sequential Recommendation
|
|
62
|
+
"""
|
|
63
|
+
batch_size = tf.shape(z_i)[0]
|
|
64
|
+
N = 2 * batch_size
|
|
65
|
+
z = tf.concat((z_i, z_j), axis=0)
|
|
66
|
+
sim = tf.matmul(z, tf.transpose(z)) / temperature
|
|
67
|
+
sim_i_j = tf.matrix_diag_part(
|
|
68
|
+
tf.slice(sim, [batch_size, 0], [batch_size, batch_size]))
|
|
69
|
+
sim_j_i = tf.matrix_diag_part(
|
|
70
|
+
tf.slice(sim, [0, batch_size], [batch_size, batch_size]))
|
|
71
|
+
positive_samples = tf.reshape(tf.concat((sim_i_j, sim_j_i), axis=0), (N, 1))
|
|
72
|
+
mask = get_mask_matrix(batch_size)
|
|
73
|
+
negative_samples = tf.reshape(tf.boolean_mask(sim, mask), (N, -1))
|
|
74
|
+
|
|
75
|
+
labels = tf.zeros(N, dtype=tf.int32)
|
|
76
|
+
logits = tf.concat((positive_samples, negative_samples), axis=1)
|
|
77
|
+
|
|
78
|
+
loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)
|
|
79
|
+
return loss
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# coding=utf-8
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
|
|
4
|
+
import tensorflow as tf
|
|
5
|
+
|
|
6
|
+
if tf.__version__ >= '2.0':
|
|
7
|
+
tf = tf.compat.v1
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def f1_reweight_sigmoid_cross_entropy(labels,
|
|
11
|
+
logits,
|
|
12
|
+
beta_square,
|
|
13
|
+
label_smoothing=0,
|
|
14
|
+
weights=None):
|
|
15
|
+
"""Refer paper: Adaptive Scaling for Sparse Detection in Information Extraction."""
|
|
16
|
+
probs = tf.nn.sigmoid(logits)
|
|
17
|
+
if len(logits.shape.as_list()) == 1:
|
|
18
|
+
logits = tf.expand_dims(logits, -1)
|
|
19
|
+
if len(labels.shape.as_list()) == 1:
|
|
20
|
+
labels = tf.expand_dims(labels, -1)
|
|
21
|
+
labels = tf.to_float(labels)
|
|
22
|
+
batch_size = tf.shape(labels)[0]
|
|
23
|
+
batch_size_float = tf.to_float(batch_size)
|
|
24
|
+
num_pos = tf.reduce_sum(labels, axis=0)
|
|
25
|
+
num_neg = batch_size_float - num_pos
|
|
26
|
+
tp = tf.reduce_sum(probs, axis=0)
|
|
27
|
+
tn = batch_size_float - tp
|
|
28
|
+
neg_weight = tp / (beta_square * num_pos + num_neg - tn + 1e-8)
|
|
29
|
+
neg_weight_tile = tf.tile(tf.expand_dims(neg_weight, 0), [batch_size, 1])
|
|
30
|
+
final_weights = tf.where(
|
|
31
|
+
tf.equal(labels, 1.0), tf.ones_like(labels), neg_weight_tile)
|
|
32
|
+
if weights is not None:
|
|
33
|
+
weights = tf.cast(weights, tf.float32)
|
|
34
|
+
if len(weights.shape.as_list()) == 1:
|
|
35
|
+
weights = tf.expand_dims(weights, -1)
|
|
36
|
+
final_weights *= weights
|
|
37
|
+
return tf.losses.sigmoid_cross_entropy(
|
|
38
|
+
labels, logits, final_weights, label_smoothing=label_smoothing)
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
|
|
7
|
+
if tf.__version__ >= '2.0':
|
|
8
|
+
tf = tf.compat.v1
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def sigmoid_focal_loss_with_logits(labels,
|
|
12
|
+
logits,
|
|
13
|
+
gamma=2.0,
|
|
14
|
+
alpha=None,
|
|
15
|
+
ohem_ratio=1.0,
|
|
16
|
+
sample_weights=None,
|
|
17
|
+
label_smoothing=0,
|
|
18
|
+
name=''):
|
|
19
|
+
"""Implements the focal loss function.
|
|
20
|
+
|
|
21
|
+
Focal loss was first introduced in the RetinaNet paper
|
|
22
|
+
(https://arxiv.org/pdf/1708.02002.pdf). Focal loss is extremely useful for
|
|
23
|
+
classification when you have highly imbalanced classes. It down-weights
|
|
24
|
+
well-classified examples and focuses on hard examples. The loss value is
|
|
25
|
+
much high for a sample which is misclassified by the classifier as compared
|
|
26
|
+
to the loss value corresponding to a well-classified example. One of the
|
|
27
|
+
best use-cases of focal loss is its usage in object detection where the
|
|
28
|
+
imbalance between the background class and other classes is extremely high.
|
|
29
|
+
|
|
30
|
+
Args
|
|
31
|
+
labels: `[batch_size]` target integer labels in `{0, 1}`.
|
|
32
|
+
logits: Float `[batch_size]` logits outputs of the network.
|
|
33
|
+
alpha: balancing factor.
|
|
34
|
+
gamma: modulating factor.
|
|
35
|
+
ohem_ratio: the percent of hard examples to be mined
|
|
36
|
+
sample_weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
|
37
|
+
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
|
38
|
+
be either `1`, or the same as the corresponding `losses` dimension).
|
|
39
|
+
label_smoothing: If greater than `0` then smooth the labels.
|
|
40
|
+
name: the name of loss
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Weighted loss float `Tensor`. If `reduction` is `NONE`,this has the
|
|
44
|
+
same shape as `y_true`; otherwise, it is scalar.
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
ValueError: If the shape of `sample_weight` is invalid or value of
|
|
48
|
+
`gamma` is less than zero
|
|
49
|
+
"""
|
|
50
|
+
loss_name = name if name else 'focal_loss'
|
|
51
|
+
assert 0 < ohem_ratio <= 1.0, loss_name + ' ohem_ratio must be in (0, 1]'
|
|
52
|
+
if gamma and gamma < 0:
|
|
53
|
+
raise ValueError('Value of gamma should be greater than or equal to zero')
|
|
54
|
+
logging.info(
|
|
55
|
+
'[{}] gamma: {}, alpha: {}, ohem_ratho: {}, label smoothing: {}'.format(
|
|
56
|
+
loss_name, gamma, alpha, ohem_ratio, label_smoothing))
|
|
57
|
+
|
|
58
|
+
y_true = tf.cast(labels, logits.dtype)
|
|
59
|
+
|
|
60
|
+
# convert the predictions into probabilities
|
|
61
|
+
y_pred = tf.nn.sigmoid(logits)
|
|
62
|
+
epsilon = 1e-7
|
|
63
|
+
y_pred = tf.clip_by_value(y_pred, epsilon, 1 - epsilon)
|
|
64
|
+
p_t = (y_true * y_pred) + ((1 - y_true) * (1 - y_pred))
|
|
65
|
+
weights = tf.pow((1 - p_t), gamma)
|
|
66
|
+
|
|
67
|
+
if alpha is not None:
|
|
68
|
+
alpha_factor = y_true * alpha + ((1 - alpha) * (1 - y_true))
|
|
69
|
+
weights *= alpha_factor
|
|
70
|
+
|
|
71
|
+
if sample_weights is not None:
|
|
72
|
+
if tf.is_numeric_tensor(sample_weights):
|
|
73
|
+
logging.info('[%s] use sample weight' % loss_name)
|
|
74
|
+
weights *= tf.cast(sample_weights, tf.float32)
|
|
75
|
+
elif sample_weights != 1.0:
|
|
76
|
+
logging.info('[%s] use sample weight: %f' % (loss_name, sample_weights))
|
|
77
|
+
weights *= sample_weights
|
|
78
|
+
|
|
79
|
+
if ohem_ratio == 1.0:
|
|
80
|
+
return tf.losses.sigmoid_cross_entropy(
|
|
81
|
+
y_true, logits, weights=weights, label_smoothing=label_smoothing)
|
|
82
|
+
|
|
83
|
+
losses = tf.losses.sigmoid_cross_entropy(
|
|
84
|
+
y_true,
|
|
85
|
+
logits,
|
|
86
|
+
weights=weights,
|
|
87
|
+
label_smoothing=label_smoothing,
|
|
88
|
+
reduction=tf.losses.Reduction.NONE)
|
|
89
|
+
k = tf.to_float(tf.size(losses)) * tf.convert_to_tensor(ohem_ratio)
|
|
90
|
+
k = tf.to_int32(tf.math.rint(k))
|
|
91
|
+
topk = tf.nn.top_k(losses, k)
|
|
92
|
+
losses = tf.boolean_mask(topk.values, topk.values > 0)
|
|
93
|
+
return tf.reduce_mean(losses)
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
|
|
8
|
+
if tf.__version__ >= '2.0':
|
|
9
|
+
tf = tf.compat.v1
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def jrc_loss(labels,
|
|
13
|
+
logits,
|
|
14
|
+
session_ids,
|
|
15
|
+
alpha=0.5,
|
|
16
|
+
loss_weight_strategy='fixed',
|
|
17
|
+
sample_weights=1.0,
|
|
18
|
+
same_label_loss=True,
|
|
19
|
+
name=''):
|
|
20
|
+
"""Joint Optimization of Ranking and Calibration with Contextualized Hybrid Model.
|
|
21
|
+
|
|
22
|
+
https://arxiv.org/abs/2208.06164
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
labels: a `Tensor` with shape [batch_size]. e.g. click or not click in the session.
|
|
26
|
+
logits: a `Tensor` with shape [batch_size, 2]. e.g. the value of last neuron before activation.
|
|
27
|
+
session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
|
|
28
|
+
alpha: the weight to balance ranking loss and calibration loss
|
|
29
|
+
loss_weight_strategy: str, the loss weight strategy to balancing between ce_loss and ge_loss
|
|
30
|
+
sample_weights: Coefficients for the loss. This must be scalar or broadcastable to
|
|
31
|
+
`labels` (i.e. same rank and each dimension is either 1 or the same).
|
|
32
|
+
same_label_loss: enable ge_loss for sample with same label in a session or not.
|
|
33
|
+
name: the name of loss
|
|
34
|
+
"""
|
|
35
|
+
loss_name = name if name else 'jrc_loss'
|
|
36
|
+
logging.info('[{}] alpha: {}, loss_weight_strategy: {}'.format(
|
|
37
|
+
loss_name, alpha, loss_weight_strategy))
|
|
38
|
+
|
|
39
|
+
ce_loss = tf.losses.sparse_softmax_cross_entropy(
|
|
40
|
+
labels, logits, weights=sample_weights)
|
|
41
|
+
|
|
42
|
+
labels = tf.expand_dims(labels, 1) # [B, 1]
|
|
43
|
+
labels = tf.concat([1 - labels, labels], axis=1) # [B, 2]
|
|
44
|
+
|
|
45
|
+
batch_size = tf.shape(logits)[0]
|
|
46
|
+
|
|
47
|
+
# Mask: shape [B, B], mask[i,j]=1 indicates the i-th sample
|
|
48
|
+
# and j-th sample are in the same context
|
|
49
|
+
mask = tf.equal(
|
|
50
|
+
tf.expand_dims(session_ids, 1), tf.expand_dims(session_ids, 0))
|
|
51
|
+
mask = tf.to_float(mask)
|
|
52
|
+
|
|
53
|
+
# Tile logits and label: [B, 2]->[B, B, 2]
|
|
54
|
+
logits = tf.tile(tf.expand_dims(logits, 1), [1, batch_size, 1])
|
|
55
|
+
y = tf.tile(tf.expand_dims(labels, 1), [1, batch_size, 1])
|
|
56
|
+
|
|
57
|
+
# Set logits that are not in the same context to -inf
|
|
58
|
+
mask3d = tf.expand_dims(mask, 2)
|
|
59
|
+
y = tf.to_float(y) * mask3d
|
|
60
|
+
logits = logits + (1 - mask3d) * -1e9
|
|
61
|
+
y_neg, y_pos = y[:, :, 0], y[:, :, 1]
|
|
62
|
+
l_neg, l_pos = logits[:, :, 0], logits[:, :, 1]
|
|
63
|
+
|
|
64
|
+
if tf.is_numeric_tensor(sample_weights):
|
|
65
|
+
logging.info('[%s] use sample weight' % loss_name)
|
|
66
|
+
weights = tf.expand_dims(tf.cast(sample_weights, tf.float32), 0)
|
|
67
|
+
pairwise_weights = tf.tile(weights, tf.stack([batch_size, 1]))
|
|
68
|
+
y_pos *= pairwise_weights
|
|
69
|
+
y_neg *= pairwise_weights
|
|
70
|
+
|
|
71
|
+
# Compute list-wise generative loss -log p(x|y, z)
|
|
72
|
+
if same_label_loss:
|
|
73
|
+
logging.info('[%s] enable same_label_loss' % loss_name)
|
|
74
|
+
loss_pos = -tf.reduce_sum(y_pos * tf.nn.log_softmax(l_pos, axis=0), axis=0)
|
|
75
|
+
loss_neg = -tf.reduce_sum(y_neg * tf.nn.log_softmax(l_neg, axis=0), axis=0)
|
|
76
|
+
ge_loss = tf.reduce_mean(
|
|
77
|
+
(loss_pos + loss_neg) / tf.reduce_sum(mask, axis=0))
|
|
78
|
+
else:
|
|
79
|
+
logging.info('[%s] disable same_label_loss' % loss_name)
|
|
80
|
+
diag = tf.one_hot(tf.range(batch_size), batch_size)
|
|
81
|
+
l_pos = l_pos + (1 - diag) * y_pos * -1e9
|
|
82
|
+
l_neg = l_neg + (1 - diag) * y_neg * -1e9
|
|
83
|
+
loss_pos = -tf.linalg.diag_part(y_pos * tf.nn.log_softmax(l_pos, axis=0))
|
|
84
|
+
loss_neg = -tf.linalg.diag_part(y_neg * tf.nn.log_softmax(l_neg, axis=0))
|
|
85
|
+
ge_loss = tf.reduce_mean(loss_pos + loss_neg)
|
|
86
|
+
|
|
87
|
+
tf.summary.scalar('loss/%s_ce' % loss_name, ce_loss)
|
|
88
|
+
tf.summary.scalar('loss/%s_ge' % loss_name, ge_loss)
|
|
89
|
+
|
|
90
|
+
# The final JRC model
|
|
91
|
+
if loss_weight_strategy == 'fixed':
|
|
92
|
+
loss = alpha * ce_loss + (1 - alpha) * ge_loss
|
|
93
|
+
elif loss_weight_strategy == 'random_uniform':
|
|
94
|
+
weight = tf.random_uniform([])
|
|
95
|
+
loss = weight * ce_loss + (1 - weight) * ge_loss
|
|
96
|
+
tf.summary.scalar('loss/%s_ce_weight' % loss_name, weight)
|
|
97
|
+
tf.summary.scalar('loss/%s_ge_weight' % loss_name, 1 - weight)
|
|
98
|
+
elif loss_weight_strategy == 'random_normal':
|
|
99
|
+
weights = tf.random_normal([2])
|
|
100
|
+
loss_weight = tf.nn.softmax(weights)
|
|
101
|
+
loss = loss_weight[0] * ce_loss + loss_weight[1] * ge_loss
|
|
102
|
+
tf.summary.scalar('loss/%s_ce_weight' % loss_name, loss_weight[0])
|
|
103
|
+
tf.summary.scalar('loss/%s_ge_weight' % loss_name, loss_weight[1])
|
|
104
|
+
elif loss_weight_strategy == 'random_bernoulli':
|
|
105
|
+
bern = tf.distributions.Bernoulli(probs=0.5, dtype=tf.float32)
|
|
106
|
+
weights = bern.sample(2)
|
|
107
|
+
loss_weight = tf.cond(
|
|
108
|
+
tf.equal(tf.reduce_sum(weights), 1), lambda: weights,
|
|
109
|
+
lambda: tf.convert_to_tensor([0.5, 0.5]))
|
|
110
|
+
loss = loss_weight[0] * ce_loss + loss_weight[1] * ge_loss
|
|
111
|
+
tf.summary.scalar('loss/%s_ce_weight' % loss_name, loss_weight[0])
|
|
112
|
+
tf.summary.scalar('loss/%s_ge_weight' % loss_name, loss_weight[1])
|
|
113
|
+
elif loss_weight_strategy == 'uncertainty':
|
|
114
|
+
uncertainty1 = tf.Variable(
|
|
115
|
+
0, name='%s_ranking_loss_weight' % loss_name, dtype=tf.float32)
|
|
116
|
+
tf.summary.scalar('loss/%s_ranking_uncertainty' % loss_name, uncertainty1)
|
|
117
|
+
uncertainty2 = tf.Variable(
|
|
118
|
+
0, name='%s_calibration_loss_weight' % loss_name, dtype=tf.float32)
|
|
119
|
+
tf.summary.scalar('loss/%s_calibration_uncertainty' % loss_name,
|
|
120
|
+
uncertainty2)
|
|
121
|
+
loss = tf.exp(-uncertainty1) * ce_loss + 0.5 * uncertainty1
|
|
122
|
+
loss += tf.exp(-uncertainty2) * ge_loss + 0.5 * uncertainty2
|
|
123
|
+
else:
|
|
124
|
+
raise ValueError('Unsupported loss weight strategy `%s` for jrc loss' %
|
|
125
|
+
loss_weight_strategy)
|
|
126
|
+
if np.isscalar(sample_weights) and sample_weights != 1.0:
|
|
127
|
+
return loss * sample_weights
|
|
128
|
+
return loss
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
|
|
7
|
+
from easy_rec.python.utils.load_class import load_by_path
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _list_wise_loss(x, labels, logits, session_ids, label_is_logits):
|
|
11
|
+
mask = tf.equal(x, session_ids)
|
|
12
|
+
logits = tf.boolean_mask(logits, mask)
|
|
13
|
+
labels = tf.boolean_mask(labels, mask)
|
|
14
|
+
y = tf.nn.softmax(labels) if label_is_logits else labels
|
|
15
|
+
y_hat = tf.nn.log_softmax(logits)
|
|
16
|
+
return -tf.reduce_sum(y * y_hat)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _list_prob_loss(x, labels, logits, session_ids):
|
|
20
|
+
mask = tf.equal(x, session_ids)
|
|
21
|
+
logits = tf.boolean_mask(logits, mask)
|
|
22
|
+
labels = tf.boolean_mask(labels, mask)
|
|
23
|
+
y = labels / tf.reduce_sum(labels)
|
|
24
|
+
y_hat = tf.nn.log_softmax(logits)
|
|
25
|
+
return -tf.reduce_sum(y * y_hat)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def listwise_rank_loss(labels,
|
|
29
|
+
logits,
|
|
30
|
+
session_ids,
|
|
31
|
+
transform_fn=None,
|
|
32
|
+
temperature=1.0,
|
|
33
|
+
label_is_logits=False,
|
|
34
|
+
scale_logits=False,
|
|
35
|
+
weights=1.0,
|
|
36
|
+
name='listwise_loss'):
|
|
37
|
+
r"""Computes listwise softmax cross entropy loss between `labels` and `logits`.
|
|
38
|
+
|
|
39
|
+
Definition:
|
|
40
|
+
$$
|
|
41
|
+
\mathcal{L}(\{y\}, \{s\}) =
|
|
42
|
+
\sum_i y_j \log( \frac{\exp(s_i)}{\sum_j exp(s_j)} )
|
|
43
|
+
$$
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
labels: A `Tensor` of the same shape as `logits` representing graded
|
|
47
|
+
relevance.
|
|
48
|
+
logits: A `Tensor` with shape [batch_size].
|
|
49
|
+
session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
|
|
50
|
+
transform_fn: an affine transformation function of labels
|
|
51
|
+
temperature: (Optional) The temperature to use for scaling the logits.
|
|
52
|
+
label_is_logits: Whether `labels` is expected to be a logits tensor.
|
|
53
|
+
By default, we consider that `labels` encodes a probability distribution.
|
|
54
|
+
scale_logits: Whether to scale the logits.
|
|
55
|
+
weights: sample weights
|
|
56
|
+
name: the name of loss
|
|
57
|
+
"""
|
|
58
|
+
loss_name = name if name else 'listwise_rank_loss'
|
|
59
|
+
logging.info('[{}] temperature: {}, scale logits: {}'.format(
|
|
60
|
+
loss_name, temperature, scale_logits))
|
|
61
|
+
labels = tf.to_float(labels)
|
|
62
|
+
if scale_logits:
|
|
63
|
+
with tf.variable_scope(loss_name):
|
|
64
|
+
w = tf.get_variable(
|
|
65
|
+
'scale_w',
|
|
66
|
+
dtype=tf.float32,
|
|
67
|
+
shape=(1,),
|
|
68
|
+
initializer=tf.ones_initializer())
|
|
69
|
+
b = tf.get_variable(
|
|
70
|
+
'scale_b',
|
|
71
|
+
dtype=tf.float32,
|
|
72
|
+
shape=(1,),
|
|
73
|
+
initializer=tf.zeros_initializer())
|
|
74
|
+
logits = logits * tf.abs(w) + b
|
|
75
|
+
if temperature != 1.0:
|
|
76
|
+
logits /= temperature
|
|
77
|
+
if label_is_logits:
|
|
78
|
+
labels /= temperature
|
|
79
|
+
if transform_fn is not None:
|
|
80
|
+
trans_fn = load_by_path(transform_fn)
|
|
81
|
+
labels = trans_fn(labels)
|
|
82
|
+
|
|
83
|
+
sessions, _ = tf.unique(tf.squeeze(session_ids))
|
|
84
|
+
tf.summary.scalar('loss/%s_num_of_group' % loss_name, tf.size(sessions))
|
|
85
|
+
losses = tf.map_fn(
|
|
86
|
+
lambda x: _list_wise_loss(x, labels, logits, session_ids, label_is_logits
|
|
87
|
+
),
|
|
88
|
+
sessions,
|
|
89
|
+
dtype=tf.float32)
|
|
90
|
+
if tf.is_numeric_tensor(weights):
|
|
91
|
+
logging.error('[%s] use unsupported sample weight' % loss_name)
|
|
92
|
+
return tf.reduce_mean(losses)
|
|
93
|
+
else:
|
|
94
|
+
return tf.reduce_mean(losses) * weights
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def listwise_distill_loss(labels,
|
|
98
|
+
logits,
|
|
99
|
+
session_ids,
|
|
100
|
+
transform_fn=None,
|
|
101
|
+
temperature=1.0,
|
|
102
|
+
label_clip_max_value=512,
|
|
103
|
+
scale_logits=False,
|
|
104
|
+
weights=1.0,
|
|
105
|
+
name='listwise_distill_loss'):
|
|
106
|
+
r"""Computes listwise softmax cross entropy loss between `labels` and `logits`.
|
|
107
|
+
|
|
108
|
+
Definition:
|
|
109
|
+
$$
|
|
110
|
+
\mathcal{L}(\{y\}, \{s\}) =
|
|
111
|
+
\sum_i y_j \log( \frac{\exp(s_i)}{\sum_j exp(s_j)} )
|
|
112
|
+
$$
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
labels: A `Tensor` of the same shape as `logits` representing the rank position of a base model.
|
|
116
|
+
logits: A `Tensor` with shape [batch_size].
|
|
117
|
+
session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
|
|
118
|
+
transform_fn: an transformation function of labels.
|
|
119
|
+
temperature: (Optional) The temperature to use for scaling the logits.
|
|
120
|
+
label_clip_max_value: clip the labels to this value.
|
|
121
|
+
scale_logits: Whether to scale the logits.
|
|
122
|
+
weights: sample weights
|
|
123
|
+
name: the name of loss
|
|
124
|
+
"""
|
|
125
|
+
loss_name = name if name else 'listwise_rank_loss'
|
|
126
|
+
logging.info('[{}] temperature: {}'.format(loss_name, temperature))
|
|
127
|
+
labels = tf.to_float(labels) # supposed to be positions of a teacher model
|
|
128
|
+
labels = tf.clip_by_value(labels, 1, label_clip_max_value)
|
|
129
|
+
if transform_fn is not None:
|
|
130
|
+
trans_fn = load_by_path(transform_fn)
|
|
131
|
+
labels = trans_fn(labels)
|
|
132
|
+
else:
|
|
133
|
+
labels = tf.log1p(label_clip_max_value) - tf.log(labels)
|
|
134
|
+
|
|
135
|
+
if scale_logits:
|
|
136
|
+
with tf.variable_scope(loss_name):
|
|
137
|
+
w = tf.get_variable(
|
|
138
|
+
'scale_w',
|
|
139
|
+
dtype=tf.float32,
|
|
140
|
+
shape=(1,),
|
|
141
|
+
initializer=tf.ones_initializer())
|
|
142
|
+
b = tf.get_variable(
|
|
143
|
+
'scale_b',
|
|
144
|
+
dtype=tf.float32,
|
|
145
|
+
shape=(1,),
|
|
146
|
+
initializer=tf.zeros_initializer())
|
|
147
|
+
logits = logits * tf.abs(w) + b
|
|
148
|
+
if temperature != 1.0:
|
|
149
|
+
logits /= temperature
|
|
150
|
+
|
|
151
|
+
sessions, _ = tf.unique(tf.squeeze(session_ids))
|
|
152
|
+
tf.summary.scalar('loss/%s_num_of_group' % loss_name, tf.size(sessions))
|
|
153
|
+
losses = tf.map_fn(
|
|
154
|
+
lambda x: _list_prob_loss(x, labels, logits, session_ids),
|
|
155
|
+
sessions,
|
|
156
|
+
dtype=tf.float32)
|
|
157
|
+
if tf.is_numeric_tensor(weights):
|
|
158
|
+
logging.error('[%s] use unsupported sample weight' % loss_name)
|
|
159
|
+
return tf.reduce_mean(losses)
|
|
160
|
+
else:
|
|
161
|
+
return tf.reduce_mean(losses) * weights
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
+
import tensorflow as tf
|
|
3
|
+
|
|
4
|
+
from easy_rec.python.loss.circle_loss import get_anchor_positive_triplet_mask
|
|
5
|
+
from easy_rec.python.utils.shape_utils import get_shape_list
|
|
6
|
+
|
|
7
|
+
if tf.__version__ >= '2.0':
|
|
8
|
+
tf = tf.compat.v1
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def ms_loss(embeddings,
|
|
12
|
+
labels,
|
|
13
|
+
session_ids=None,
|
|
14
|
+
alpha=2.0,
|
|
15
|
+
beta=50.0,
|
|
16
|
+
lamb=1.0,
|
|
17
|
+
eps=0.1,
|
|
18
|
+
ms_mining=False,
|
|
19
|
+
embed_normed=False):
|
|
20
|
+
"""Refer paper: Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning.
|
|
21
|
+
|
|
22
|
+
ref: http://openaccess.thecvf.com/content_CVPR_2019/papers/
|
|
23
|
+
Wang_Multi-Similarity_Loss_With_General_Pair_Weighting_for_Deep_Metric_Learning_CVPR_2019_paper.pdf
|
|
24
|
+
"""
|
|
25
|
+
# make sure embedding should be l2-normalized
|
|
26
|
+
if not embed_normed:
|
|
27
|
+
embeddings = tf.nn.l2_normalize(embeddings, axis=1)
|
|
28
|
+
labels = tf.reshape(labels, [-1, 1])
|
|
29
|
+
|
|
30
|
+
embed_shape = get_shape_list(embeddings)
|
|
31
|
+
batch_size = embed_shape[0]
|
|
32
|
+
|
|
33
|
+
mask_pos = get_anchor_positive_triplet_mask(labels, session_ids)
|
|
34
|
+
mask_neg = 1 - mask_pos - tf.eye(batch_size)
|
|
35
|
+
|
|
36
|
+
sim_mat = tf.matmul(
|
|
37
|
+
embeddings, embeddings, transpose_a=False, transpose_b=True)
|
|
38
|
+
sim_mat = tf.maximum(sim_mat, 0.0)
|
|
39
|
+
|
|
40
|
+
pos_mat = tf.multiply(sim_mat, mask_pos)
|
|
41
|
+
neg_mat = tf.multiply(sim_mat, mask_neg)
|
|
42
|
+
|
|
43
|
+
if ms_mining:
|
|
44
|
+
max_val = tf.reduce_max(neg_mat, axis=1, keepdims=True)
|
|
45
|
+
tmp_max_val = tf.reduce_max(pos_mat, axis=1, keepdims=True)
|
|
46
|
+
min_val = tf.reduce_min(
|
|
47
|
+
tf.multiply(sim_mat - tmp_max_val, mask_pos), axis=1,
|
|
48
|
+
keepdims=True) + tmp_max_val
|
|
49
|
+
|
|
50
|
+
max_val = tf.tile(max_val, [1, batch_size])
|
|
51
|
+
min_val = tf.tile(min_val, [1, batch_size])
|
|
52
|
+
|
|
53
|
+
mask_pos = tf.where(pos_mat < max_val + eps, mask_pos,
|
|
54
|
+
tf.zeros_like(mask_pos))
|
|
55
|
+
mask_neg = tf.where(neg_mat > min_val - eps, mask_neg,
|
|
56
|
+
tf.zeros_like(mask_neg))
|
|
57
|
+
|
|
58
|
+
pos_exp = tf.exp(-alpha * (pos_mat - lamb))
|
|
59
|
+
pos_exp = tf.where(mask_pos > 0.0, pos_exp, tf.zeros_like(pos_exp))
|
|
60
|
+
|
|
61
|
+
neg_exp = tf.exp(beta * (neg_mat - lamb))
|
|
62
|
+
neg_exp = tf.where(mask_neg > 0.0, neg_exp, tf.zeros_like(neg_exp))
|
|
63
|
+
|
|
64
|
+
pos_term = tf.log(1.0 + tf.reduce_sum(pos_exp, axis=1)) / alpha
|
|
65
|
+
neg_term = tf.log(1.0 + tf.reduce_sum(neg_exp, axis=1)) / beta
|
|
66
|
+
|
|
67
|
+
loss = tf.reduce_mean(pos_term + neg_term)
|
|
68
|
+
return loss
|