easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
"""Library of common learning rate schedules."""
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import tensorflow as tf
|
|
20
|
+
|
|
21
|
+
if tf.__version__ >= '2.0':
|
|
22
|
+
tf = tf.compat.v1
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def exponential_decay_with_burnin(global_step,
|
|
26
|
+
learning_rate_base,
|
|
27
|
+
learning_rate_decay_steps,
|
|
28
|
+
learning_rate_decay_factor,
|
|
29
|
+
burnin_learning_rate=0.0,
|
|
30
|
+
burnin_steps=0,
|
|
31
|
+
min_learning_rate=0.0,
|
|
32
|
+
staircase=True):
|
|
33
|
+
"""Exponential decay schedule with burn-in period.
|
|
34
|
+
|
|
35
|
+
In this schedule, learning rate is fixed at burnin_learning_rate
|
|
36
|
+
for a fixed period, before transitioning to a regular exponential
|
|
37
|
+
decay schedule.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
global_step: int tensor representing global step.
|
|
41
|
+
learning_rate_base: base learning rate.
|
|
42
|
+
learning_rate_decay_steps: steps to take between decaying the learning rate.
|
|
43
|
+
Note that this includes the number of burn-in steps.
|
|
44
|
+
learning_rate_decay_factor: multiplicative factor by which to decay
|
|
45
|
+
learning rate.
|
|
46
|
+
burnin_learning_rate: initial learning rate during burn-in period. If
|
|
47
|
+
0.0 (which is the default), then the burn-in learning rate is simply
|
|
48
|
+
set to learning_rate_base.
|
|
49
|
+
burnin_steps: number of steps to use burnin learning rate.
|
|
50
|
+
min_learning_rate: the minimum learning rate.
|
|
51
|
+
staircase: whether use staircase decay.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
a (scalar) float tensor representing learning rate
|
|
55
|
+
"""
|
|
56
|
+
if burnin_learning_rate == 0:
|
|
57
|
+
burnin_rate = learning_rate_base
|
|
58
|
+
else:
|
|
59
|
+
slope = (learning_rate_base - burnin_learning_rate) / burnin_steps
|
|
60
|
+
burnin_rate = slope * tf.cast(global_step,
|
|
61
|
+
tf.float32) + burnin_learning_rate
|
|
62
|
+
post_burnin_learning_rate = tf.train.exponential_decay(
|
|
63
|
+
learning_rate_base,
|
|
64
|
+
global_step - burnin_steps,
|
|
65
|
+
learning_rate_decay_steps,
|
|
66
|
+
learning_rate_decay_factor,
|
|
67
|
+
staircase=staircase)
|
|
68
|
+
return tf.maximum(
|
|
69
|
+
tf.where(
|
|
70
|
+
tf.less(tf.cast(global_step, tf.int32), tf.constant(burnin_steps)),
|
|
71
|
+
burnin_rate, post_burnin_learning_rate),
|
|
72
|
+
min_learning_rate,
|
|
73
|
+
name='learning_rate')
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def cosine_decay_with_warmup(global_step,
|
|
77
|
+
learning_rate_base,
|
|
78
|
+
total_steps,
|
|
79
|
+
warmup_learning_rate=0.0,
|
|
80
|
+
warmup_steps=0,
|
|
81
|
+
hold_base_rate_steps=0):
|
|
82
|
+
"""Cosine decay schedule with warm up period.
|
|
83
|
+
|
|
84
|
+
Cosine annealing learning rate as described in:
|
|
85
|
+
Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts.
|
|
86
|
+
ICLR 2017. https://arxiv.org/abs/1608.03983
|
|
87
|
+
In this schedule, the learning rate grows linearly from warmup_learning_rate
|
|
88
|
+
to learning_rate_base for warmup_steps, then transitions to a cosine decay
|
|
89
|
+
schedule.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
global_step: int64 (scalar) tensor representing global step.
|
|
93
|
+
learning_rate_base: base learning rate.
|
|
94
|
+
total_steps: total number of training steps.
|
|
95
|
+
warmup_learning_rate: initial learning rate for warm up.
|
|
96
|
+
warmup_steps: number of warmup steps.
|
|
97
|
+
hold_base_rate_steps: Optional number of steps to hold base learning rate
|
|
98
|
+
before decaying.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
a (scalar) float tensor representing learning rate.
|
|
102
|
+
|
|
103
|
+
Raises:
|
|
104
|
+
ValueError: if warmup_learning_rate is larger than learning_rate_base,
|
|
105
|
+
or if warmup_steps is larger than total_steps.
|
|
106
|
+
"""
|
|
107
|
+
if learning_rate_base < warmup_learning_rate:
|
|
108
|
+
raise ValueError('learning_rate_base must be larger '
|
|
109
|
+
'or equal to warmup_learning_rate.')
|
|
110
|
+
if total_steps < warmup_steps:
|
|
111
|
+
raise ValueError('total_steps must be larger or equal to ' 'warmup_steps.')
|
|
112
|
+
learning_rate = 0.5 * learning_rate_base * (1 + tf.cos(
|
|
113
|
+
np.pi *
|
|
114
|
+
(tf.cast(global_step, tf.float32) - warmup_steps - hold_base_rate_steps) /
|
|
115
|
+
float(total_steps - warmup_steps - hold_base_rate_steps)))
|
|
116
|
+
if hold_base_rate_steps > 0:
|
|
117
|
+
learning_rate = tf.where(global_step > warmup_steps + hold_base_rate_steps,
|
|
118
|
+
learning_rate, learning_rate_base)
|
|
119
|
+
if warmup_steps > 0:
|
|
120
|
+
slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
|
|
121
|
+
warmup_rate = slope * tf.cast(global_step,
|
|
122
|
+
tf.float32) + warmup_learning_rate
|
|
123
|
+
learning_rate = tf.where(global_step < warmup_steps, warmup_rate,
|
|
124
|
+
learning_rate)
|
|
125
|
+
return tf.where(
|
|
126
|
+
global_step > total_steps, 0.0, learning_rate, name='learning_rate')
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def manual_stepping(global_step, boundaries, rates, warmup=False):
|
|
130
|
+
"""Manually stepped learning rate schedule.
|
|
131
|
+
|
|
132
|
+
This function provides fine grained control over learning rates. One must
|
|
133
|
+
specify a sequence of learning rates as well as a set of integer steps
|
|
134
|
+
at which the current learning rate must transition to the next. For example,
|
|
135
|
+
if boundaries = [5, 10] and rates = [.1, .01, .001], then the learning
|
|
136
|
+
rate returned by this function is .1 for global_step=0,...,4, .01 for
|
|
137
|
+
global_step=5...9, and .001 for global_step=10 and onward.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
global_step: int64 (scalar) tensor representing global step.
|
|
141
|
+
boundaries: a list of global steps at which to switch learning
|
|
142
|
+
rates. This list is assumed to consist of increasing positive integers.
|
|
143
|
+
rates: a list of (float) learning rates corresponding to intervals between
|
|
144
|
+
the boundaries. The length of this list must be exactly
|
|
145
|
+
len(boundaries) + 1.
|
|
146
|
+
warmup: Whether to linearly interpolate learning rate for steps in
|
|
147
|
+
[0, boundaries[0]].
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
a (scalar) float tensor representing learning rate
|
|
151
|
+
Raises:
|
|
152
|
+
ValueError: if one of the following checks fails:
|
|
153
|
+
1. boundaries is a strictly increasing list of positive integers
|
|
154
|
+
2. len(rates) == len(boundaries) + 1
|
|
155
|
+
3. boundaries[0] != 0
|
|
156
|
+
"""
|
|
157
|
+
if any([b < 0 for b in boundaries]) or any(
|
|
158
|
+
[not isinstance(b, int) for b in boundaries]):
|
|
159
|
+
raise ValueError('boundaries must be a list of positive integers')
|
|
160
|
+
if any([bnext <= b for bnext, b in zip(boundaries[1:], boundaries[:-1])]):
|
|
161
|
+
raise ValueError('Entries in boundaries must be strictly increasing.')
|
|
162
|
+
if any([not isinstance(r, float) for r in rates]):
|
|
163
|
+
raise ValueError('Learning rates must be floats')
|
|
164
|
+
if len(rates) != len(boundaries) + 1:
|
|
165
|
+
raise ValueError('Number of provided learning rates must exceed '
|
|
166
|
+
'number of boundary points by exactly 1.')
|
|
167
|
+
|
|
168
|
+
if boundaries and boundaries[0] == 0:
|
|
169
|
+
raise ValueError('First step cannot be zero.')
|
|
170
|
+
|
|
171
|
+
if warmup and boundaries:
|
|
172
|
+
slope = (rates[1] - rates[0]) * 1.0 / boundaries[0]
|
|
173
|
+
warmup_steps = list(range(boundaries[0]))
|
|
174
|
+
warmup_rates = [rates[0] + slope * step for step in warmup_steps]
|
|
175
|
+
boundaries = warmup_steps + boundaries
|
|
176
|
+
rates = warmup_rates + rates[1:]
|
|
177
|
+
else:
|
|
178
|
+
boundaries = [0] + boundaries
|
|
179
|
+
num_boundaries = len(boundaries)
|
|
180
|
+
rate_index = tf.reduce_max(
|
|
181
|
+
tf.where(
|
|
182
|
+
tf.greater_equal(global_step, boundaries),
|
|
183
|
+
list(range(num_boundaries)), [0] * num_boundaries))
|
|
184
|
+
return tf.reduce_sum(
|
|
185
|
+
rates * tf.one_hot(rate_index, depth=num_boundaries),
|
|
186
|
+
name='learning_rate')
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def transformer_policy(global_step,
|
|
190
|
+
learning_rate,
|
|
191
|
+
d_model,
|
|
192
|
+
warmup_steps,
|
|
193
|
+
step_scaling_rate=1.0,
|
|
194
|
+
max_lr=None,
|
|
195
|
+
coefficient=1.0,
|
|
196
|
+
dtype=tf.float32):
|
|
197
|
+
"""Transformer's learning rate schedule.
|
|
198
|
+
|
|
199
|
+
Transformer's learning rate policy from
|
|
200
|
+
https://arxiv.org/pdf/1706.03762.pdf
|
|
201
|
+
with a hat (max_lr) (also called "noam" learning rate decay scheme).
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
global_step: global step TensorFlow tensor (ignored for this policy).
|
|
205
|
+
learning_rate (float): initial learning rate to use.
|
|
206
|
+
d_model (int): model dimensionality.
|
|
207
|
+
warmup_steps (int): number of warm-up steps.
|
|
208
|
+
step_scaling_rate (float): num step scale rate
|
|
209
|
+
max_lr (float): maximal learning rate, i.e. hat.
|
|
210
|
+
coefficient (float): optimizer adjustment.
|
|
211
|
+
Recommended 0.002 if using "Adam" else 1.0.
|
|
212
|
+
dtype: dtype for this policy.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
learning rate at step ``global_step``.
|
|
216
|
+
"""
|
|
217
|
+
step_num = tf.cast(global_step, dtype=dtype)
|
|
218
|
+
ws = tf.cast(warmup_steps, dtype=dtype)
|
|
219
|
+
step_num *= step_scaling_rate
|
|
220
|
+
ws *= step_scaling_rate
|
|
221
|
+
|
|
222
|
+
decay = coefficient * d_model**-0.5 * tf.minimum((step_num + 1) * ws**-1.5,
|
|
223
|
+
(step_num + 1)**-0.5)
|
|
224
|
+
|
|
225
|
+
new_lr = decay * learning_rate
|
|
226
|
+
if max_lr is not None:
|
|
227
|
+
return tf.minimum(max_lr, new_lr)
|
|
228
|
+
return new_lr
|
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import tensorflow as tf
|
|
10
|
+
from sklearn import metrics as sklearn_metrics
|
|
11
|
+
from tensorflow.python.ops import array_ops
|
|
12
|
+
from tensorflow.python.ops import math_ops
|
|
13
|
+
from tensorflow.python.ops import state_ops
|
|
14
|
+
from tensorflow.python.ops import variable_scope
|
|
15
|
+
|
|
16
|
+
from easy_rec.python.utils.estimator_utils import get_task_index_and_num
|
|
17
|
+
from easy_rec.python.utils.io_util import read_data_from_json_path
|
|
18
|
+
from easy_rec.python.utils.io_util import save_data_to_json_path
|
|
19
|
+
from easy_rec.python.utils.shape_utils import get_shape_list
|
|
20
|
+
|
|
21
|
+
if tf.__version__ >= '2.0':
|
|
22
|
+
tf = tf.compat.v1
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def max_f1(label, predictions):
|
|
26
|
+
"""Calculate the largest F1 metric under different thresholds.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
label: Ground truth (correct) target values.
|
|
30
|
+
predictions: Estimated targets as returned by a model.
|
|
31
|
+
"""
|
|
32
|
+
from easy_rec.python.core.easyrec_metrics import metrics_tf
|
|
33
|
+
num_thresholds = 200
|
|
34
|
+
kepsilon = 1e-7
|
|
35
|
+
thresholds = [
|
|
36
|
+
(i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
|
|
37
|
+
]
|
|
38
|
+
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
|
|
39
|
+
|
|
40
|
+
f1_scores = []
|
|
41
|
+
precision_update_ops = []
|
|
42
|
+
recall_update_ops = []
|
|
43
|
+
for threshold in thresholds:
|
|
44
|
+
pred = (predictions > threshold)
|
|
45
|
+
precision, precision_update_op = metrics_tf.precision(
|
|
46
|
+
labels=label, predictions=pred, name='precision_%s' % threshold)
|
|
47
|
+
recall, recall_update_op = metrics_tf.recall(
|
|
48
|
+
labels=label, predictions=pred, name='recall_%s' % threshold)
|
|
49
|
+
f1_score = (2 * precision * recall) / (precision + recall + 1e-12)
|
|
50
|
+
precision_update_ops.append(precision_update_op)
|
|
51
|
+
recall_update_ops.append(recall_update_op)
|
|
52
|
+
f1_scores.append(f1_score)
|
|
53
|
+
|
|
54
|
+
f1 = tf.math.reduce_max(tf.stack(f1_scores))
|
|
55
|
+
f1_update_op = tf.group(precision_update_ops + recall_update_ops)
|
|
56
|
+
return f1, f1_update_op
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _separated_auc_impl(labels, predictions, keys, reduction='mean'):
|
|
60
|
+
"""Computes the AUC group by the key separately.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
labels: A `Tensor` whose shape matches `predictions`. Will be cast to
|
|
64
|
+
`bool`.
|
|
65
|
+
predictions: A floating point `Tensor` of arbitrary shape and whose values
|
|
66
|
+
are in the range `[0, 1]`.
|
|
67
|
+
keys: keys to be group by, A int or string `Tensor` whose shape matches `predictions`.
|
|
68
|
+
reduction: reduction metric for auc of different keys
|
|
69
|
+
* "mean": simple mean of different keys
|
|
70
|
+
* "mean_by_sample_num": weighted mean with sample num of different keys
|
|
71
|
+
* "mean_by_positive_num": weighted mean with positive sample num of different keys
|
|
72
|
+
"""
|
|
73
|
+
assert reduction in ['mean', 'mean_by_sample_num', 'mean_by_positive_num'], \
|
|
74
|
+
'reduction method must in mean | mean_by_sample_num | mean_by_positive_num'
|
|
75
|
+
separated_label = defaultdict(list)
|
|
76
|
+
separated_prediction = defaultdict(list)
|
|
77
|
+
separated_weights = defaultdict(int)
|
|
78
|
+
|
|
79
|
+
def update_pyfunc(labels, predictions, keys):
|
|
80
|
+
for label, prediction, key in zip(labels, predictions, keys):
|
|
81
|
+
separated_label[key].append(label)
|
|
82
|
+
separated_prediction[key].append(prediction)
|
|
83
|
+
if reduction == 'mean':
|
|
84
|
+
separated_weights[key] = 1
|
|
85
|
+
elif reduction == 'mean_by_sample_num':
|
|
86
|
+
separated_weights[key] += 1
|
|
87
|
+
elif reduction == 'mean_by_positive_num':
|
|
88
|
+
separated_weights[key] += label
|
|
89
|
+
|
|
90
|
+
def value_pyfunc():
|
|
91
|
+
metrics = []
|
|
92
|
+
weights = []
|
|
93
|
+
for key in separated_label.keys():
|
|
94
|
+
per_label = np.asarray(separated_label[key]).reshape([-1])
|
|
95
|
+
per_prediction = np.asarray(separated_prediction[key]).reshape([-1])
|
|
96
|
+
if np.all(per_label == 1) or np.all(per_label == 0):
|
|
97
|
+
continue
|
|
98
|
+
metric = sklearn_metrics.roc_auc_score(per_label, per_prediction)
|
|
99
|
+
metrics.append(metric)
|
|
100
|
+
weights.append(separated_weights[key])
|
|
101
|
+
if len(metrics) > 0:
|
|
102
|
+
return np.average(metrics, weights=weights).astype(np.float32)
|
|
103
|
+
else:
|
|
104
|
+
return np.float32(0.0)
|
|
105
|
+
|
|
106
|
+
update_op = tf.py_func(update_pyfunc, [labels, predictions, keys], [])
|
|
107
|
+
value_op = tf.py_func(value_pyfunc, [], tf.float32)
|
|
108
|
+
return value_op, update_op
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def fast_auc(labels, predictions, name, num_thresholds=1e5):
|
|
112
|
+
num_thresholds = int(num_thresholds)
|
|
113
|
+
|
|
114
|
+
def value_pyfunc(pos_neg_arr, total_pos_neg):
|
|
115
|
+
partial_sum_pos = 0
|
|
116
|
+
auc = 0
|
|
117
|
+
total_neg = total_pos_neg[0]
|
|
118
|
+
total_pos = total_pos_neg[1]
|
|
119
|
+
for i in range(num_thresholds + 1):
|
|
120
|
+
partial_sum_pos += pos_neg_arr[1][i]
|
|
121
|
+
auc += (total_pos - partial_sum_pos) * pos_neg_arr[0][i] * 2
|
|
122
|
+
auc += pos_neg_arr[0][i] * pos_neg_arr[1][i]
|
|
123
|
+
auc = np.double(auc) / np.double(total_pos * total_neg * 2)
|
|
124
|
+
logging.info('fast_auc[%s]: total_pos=%d total_neg=%d total=%d' %
|
|
125
|
+
(name, total_pos, total_neg, total_pos + total_neg))
|
|
126
|
+
return np.float32(auc)
|
|
127
|
+
|
|
128
|
+
with variable_scope.variable_scope(name_or_scope=name), tf.name_scope(name):
|
|
129
|
+
neg_pos_var = variable_scope.get_variable(
|
|
130
|
+
name='neg_pos_cnt',
|
|
131
|
+
shape=[2, num_thresholds + 1],
|
|
132
|
+
trainable=False,
|
|
133
|
+
collections=[tf.GraphKeys.METRIC_VARIABLES],
|
|
134
|
+
initializer=tf.zeros_initializer(),
|
|
135
|
+
dtype=tf.int64)
|
|
136
|
+
total_var = variable_scope.get_variable(
|
|
137
|
+
name='total_cnt',
|
|
138
|
+
shape=[2],
|
|
139
|
+
trainable=False,
|
|
140
|
+
collections=[tf.GraphKeys.METRIC_VARIABLES],
|
|
141
|
+
initializer=tf.zeros_initializer(),
|
|
142
|
+
dtype=tf.int64)
|
|
143
|
+
pred_bins = math_ops.cast(predictions * num_thresholds, dtype=tf.int32)
|
|
144
|
+
labels = math_ops.cast(labels, dtype=tf.int32)
|
|
145
|
+
labels = array_ops.reshape(labels, [-1, 1])
|
|
146
|
+
pred_bins = array_ops.reshape(pred_bins, [-1, 1])
|
|
147
|
+
update_op0 = state_ops.scatter_nd_add(
|
|
148
|
+
neg_pos_var, tf.concat([labels, pred_bins], axis=1),
|
|
149
|
+
array_ops.ones(tf.shape(labels)[0], dtype=tf.int64))
|
|
150
|
+
total_pos = math_ops.reduce_sum(labels)
|
|
151
|
+
total_neg = array_ops.shape(labels)[0] - total_pos
|
|
152
|
+
total_add = math_ops.cast(tf.stack([total_neg, total_pos]), dtype=tf.int64)
|
|
153
|
+
update_op1 = state_ops.assign_add(total_var, total_add)
|
|
154
|
+
return tf.py_func(value_pyfunc, [neg_pos_var, total_var],
|
|
155
|
+
tf.float32), tf.group([update_op0, update_op1])
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _distribute_separated_auc_impl(labels,
|
|
159
|
+
predictions,
|
|
160
|
+
keys,
|
|
161
|
+
reduction='mean',
|
|
162
|
+
metric_name='sepatated_auc'):
|
|
163
|
+
"""Computes the AUC group by the key separately.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
labels: A `Tensor` whose shape matches `predictions`. Will be cast to
|
|
167
|
+
`bool`.
|
|
168
|
+
predictions: A floating point `Tensor` of arbitrary shape and whose values
|
|
169
|
+
are in the range `[0, 1]`.
|
|
170
|
+
keys: keys to be group by, A int or string `Tensor` whose shape matches `predictions`.
|
|
171
|
+
reduction: reduction metric for auc of different keys
|
|
172
|
+
metric_name: the name of compute metric
|
|
173
|
+
* "mean": simple mean of different keys
|
|
174
|
+
* "mean_by_sample_num": weighted mean with sample num of different keys
|
|
175
|
+
* "mean_by_positive_num": weighted mean with positive sample num of different keys
|
|
176
|
+
"""
|
|
177
|
+
assert reduction in ['mean', 'mean_by_sample_num', 'mean_by_positive_num'], \
|
|
178
|
+
'reduction method must in mean | mean_by_sample_num | mean_by_positive_num'
|
|
179
|
+
separated_label = defaultdict(list)
|
|
180
|
+
separated_prediction = defaultdict(list)
|
|
181
|
+
separated_weights = defaultdict(int)
|
|
182
|
+
tf_config = json.loads(os.environ['TF_CONFIG'])
|
|
183
|
+
cur_job_name = tf_config['task']['type']
|
|
184
|
+
cur_task_index, task_num = get_task_index_and_num()
|
|
185
|
+
cur_work_device = 'job_' + cur_job_name + '__' + 'task_' + str(cur_task_index)
|
|
186
|
+
eval_tmp_results_dir = os.environ['eval_tmp_results_dir']
|
|
187
|
+
assert tf.gfile.IsDirectory(
|
|
188
|
+
eval_tmp_results_dir), 'eval_tmp_results_dir not exists'
|
|
189
|
+
|
|
190
|
+
def update_pyfunc(labels, predictions, keys):
|
|
191
|
+
for label, prediction, key in zip(labels, predictions, keys):
|
|
192
|
+
key = str(key)
|
|
193
|
+
separated_label[key].append(label.item())
|
|
194
|
+
separated_prediction[key].append(prediction.item())
|
|
195
|
+
if reduction == 'mean':
|
|
196
|
+
separated_weights[key] = 1
|
|
197
|
+
elif reduction == 'mean_by_sample_num':
|
|
198
|
+
separated_weights[key] += 1
|
|
199
|
+
elif reduction == 'mean_by_positive_num':
|
|
200
|
+
separated_weights[key] += label.item()
|
|
201
|
+
for name, data in zip(
|
|
202
|
+
['separated_label', 'separated_prediction', 'separated_weights'],
|
|
203
|
+
[separated_label, separated_prediction, separated_weights]):
|
|
204
|
+
cur_json_name = metric_name + '__' + cur_work_device + '__' + name + '.json'
|
|
205
|
+
cur_json_path = os.path.join(eval_tmp_results_dir, cur_json_name)
|
|
206
|
+
save_data_to_json_path(cur_json_path, data)
|
|
207
|
+
|
|
208
|
+
def value_pyfunc():
|
|
209
|
+
for task_i in range(1, task_num):
|
|
210
|
+
work_device_i = 'job_worker__task_' + str(task_i)
|
|
211
|
+
for name in [
|
|
212
|
+
'separated_label', 'separated_prediction', 'separated_weights'
|
|
213
|
+
]:
|
|
214
|
+
json_name_i = metric_name + '__' + work_device_i + '__' + name + '.json'
|
|
215
|
+
json_path_i = os.path.join(eval_tmp_results_dir, json_name_i)
|
|
216
|
+
data_i = read_data_from_json_path(json_path_i)
|
|
217
|
+
if (name == 'separated_label'):
|
|
218
|
+
separated_label.update({
|
|
219
|
+
key: separated_label.get(key, []) + data_i.get(key, [])
|
|
220
|
+
for key in set(
|
|
221
|
+
list(separated_label.keys()) + list(data_i.keys()))
|
|
222
|
+
})
|
|
223
|
+
elif (name == 'separated_prediction'):
|
|
224
|
+
separated_prediction.update({
|
|
225
|
+
key: separated_prediction.get(key, []) + data_i.get(key, [])
|
|
226
|
+
for key in set(
|
|
227
|
+
list(separated_prediction.keys()) + list(data_i.keys()))
|
|
228
|
+
})
|
|
229
|
+
elif (name == 'separated_weights'):
|
|
230
|
+
if reduction == 'mean':
|
|
231
|
+
separated_weights.update(data_i)
|
|
232
|
+
else:
|
|
233
|
+
separated_weights.update({
|
|
234
|
+
key: separated_weights.get(key, 0) + data_i.get(key, 0)
|
|
235
|
+
for key in set(
|
|
236
|
+
list(separated_weights.keys()) + list(data_i.keys()))
|
|
237
|
+
})
|
|
238
|
+
else:
|
|
239
|
+
assert False, 'Not supported name {}'.format(name)
|
|
240
|
+
metrics = []
|
|
241
|
+
weights = []
|
|
242
|
+
for key in separated_label.keys():
|
|
243
|
+
per_label = np.asarray(separated_label[key]).reshape([-1])
|
|
244
|
+
per_prediction = np.asarray(separated_prediction[key]).reshape([-1])
|
|
245
|
+
if np.all(per_label == 1) or np.all(per_label == 0):
|
|
246
|
+
continue
|
|
247
|
+
metric = sklearn_metrics.roc_auc_score(per_label, per_prediction)
|
|
248
|
+
metrics.append(metric)
|
|
249
|
+
weights.append(separated_weights[key])
|
|
250
|
+
if len(metrics) > 0:
|
|
251
|
+
return np.average(metrics, weights=weights).astype(np.float32)
|
|
252
|
+
else:
|
|
253
|
+
return np.float32(0.0)
|
|
254
|
+
|
|
255
|
+
update_op = tf.py_func(update_pyfunc, [labels, predictions, keys], [])
|
|
256
|
+
value_op = tf.py_func(value_pyfunc, [], tf.float32)
|
|
257
|
+
return value_op, update_op
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def gauc(labels, predictions, uids, reduction='mean'):
|
|
261
|
+
"""Computes the AUC group by user separately.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
labels: A `Tensor` whose shape matches `predictions`. Will be cast to
|
|
265
|
+
`bool`.
|
|
266
|
+
predictions: A floating point `Tensor` of arbitrary shape and whose values
|
|
267
|
+
are in the range `[0, 1]`.
|
|
268
|
+
uids: user ids, A int or string `Tensor` whose shape matches `predictions`.
|
|
269
|
+
reduction: reduction method for auc of different users
|
|
270
|
+
* "mean": simple mean of different users
|
|
271
|
+
* "mean_by_sample_num": weighted mean with sample num of different users
|
|
272
|
+
* "mean_by_positive_num": weighted mean with positive sample num of different users
|
|
273
|
+
"""
|
|
274
|
+
if os.environ.get('distribute_eval') == 'True':
|
|
275
|
+
return _distribute_separated_auc_impl(
|
|
276
|
+
labels, predictions, uids, reduction, metric_name='gauc')
|
|
277
|
+
return _separated_auc_impl(labels, predictions, uids, reduction)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def session_auc(labels, predictions, session_ids, reduction='mean'):
|
|
281
|
+
"""Computes the AUC group by session separately.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
labels: A `Tensor` whose shape matches `predictions`. Will be cast to
|
|
285
|
+
`bool`.
|
|
286
|
+
predictions: A floating point `Tensor` of arbitrary shape and whose values
|
|
287
|
+
are in the range `[0, 1]`.
|
|
288
|
+
session_ids: session ids, A int or string `Tensor` whose shape matches `predictions`.
|
|
289
|
+
reduction: reduction method for auc of different sessions
|
|
290
|
+
* "mean": simple mean of different sessions
|
|
291
|
+
* "mean_by_sample_num": weighted mean with sample num of different sessions
|
|
292
|
+
* "mean_by_positive_num": weighted mean with positive sample num of different sessions
|
|
293
|
+
"""
|
|
294
|
+
if os.environ.get('distribute_eval') == 'True':
|
|
295
|
+
return _distribute_separated_auc_impl(
|
|
296
|
+
labels, predictions, session_ids, reduction, metric_name='session_auc')
|
|
297
|
+
return _separated_auc_impl(labels, predictions, session_ids, reduction)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def metric_learning_recall_at_k(k,
|
|
301
|
+
embeddings,
|
|
302
|
+
labels,
|
|
303
|
+
session_ids=None,
|
|
304
|
+
embed_normed=False):
|
|
305
|
+
"""Computes the recall_at_k metric for metric learning.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
k: a scalar of int, or a tuple of ints
|
|
309
|
+
embeddings: the output of last hidden layer, a tf.float32 `Tensor` with shape [batch_size, embedding_size]
|
|
310
|
+
labels: a `Tensor` with shape [batch_size]
|
|
311
|
+
session_ids: session ids, a `Tensor` with shape [batch_size]
|
|
312
|
+
embed_normed: indicator of whether the input embeddings are l2_normalized
|
|
313
|
+
"""
|
|
314
|
+
from easy_rec.python.core.easyrec_metrics import metrics_tf
|
|
315
|
+
# make sure embedding should be l2-normalized
|
|
316
|
+
if not embed_normed:
|
|
317
|
+
embeddings = tf.nn.l2_normalize(embeddings, axis=1)
|
|
318
|
+
embed_shape = get_shape_list(embeddings)
|
|
319
|
+
batch_size = embed_shape[0]
|
|
320
|
+
sim_mat = tf.matmul(embeddings, embeddings, transpose_b=True)
|
|
321
|
+
sim_mat = sim_mat - tf.eye(batch_size) * 2.0
|
|
322
|
+
indices_not_equal = tf.logical_not(tf.eye(batch_size, dtype=tf.bool))
|
|
323
|
+
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
|
324
|
+
labels_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
|
|
325
|
+
if session_ids is not None and session_ids is not labels:
|
|
326
|
+
sessions_equal = tf.equal(
|
|
327
|
+
tf.expand_dims(session_ids, 0), tf.expand_dims(session_ids, 1))
|
|
328
|
+
labels_equal = tf.logical_and(sessions_equal, labels_equal)
|
|
329
|
+
mask = tf.logical_and(indices_not_equal, labels_equal)
|
|
330
|
+
mask_pos = tf.where(
|
|
331
|
+
mask, sim_mat,
|
|
332
|
+
-array_ops.ones_like(sim_mat)) # shape: (batch_size, batch_size)
|
|
333
|
+
if isinstance(k, int):
|
|
334
|
+
_, pos_top_k_idx = tf.nn.top_k(mask_pos, k) # shape: (batch_size, k)
|
|
335
|
+
return metrics_tf.recall_at_k(
|
|
336
|
+
labels=tf.to_int64(pos_top_k_idx), predictions=sim_mat, k=k)
|
|
337
|
+
if any((isinstance(k, list), isinstance(k, tuple), isinstance(k, set))):
|
|
338
|
+
metrics = {}
|
|
339
|
+
for kk in k:
|
|
340
|
+
if kk < 1:
|
|
341
|
+
continue
|
|
342
|
+
_, pos_top_k_idx = tf.nn.top_k(mask_pos, kk)
|
|
343
|
+
metrics['recall@' + str(kk)] = metrics_tf.recall_at_k(
|
|
344
|
+
labels=tf.to_int64(pos_top_k_idx), predictions=sim_mat, k=kk)
|
|
345
|
+
return metrics
|
|
346
|
+
else:
|
|
347
|
+
raise ValueError('k should be a `int` or a list/tuple/set of int.')
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def metric_learning_average_precision_at_k(k,
|
|
351
|
+
embeddings,
|
|
352
|
+
labels,
|
|
353
|
+
session_ids=None,
|
|
354
|
+
embed_normed=False):
|
|
355
|
+
from easy_rec.python.core.easyrec_metrics import metrics_tf
|
|
356
|
+
# make sure embedding should be l2-normalized
|
|
357
|
+
if not embed_normed:
|
|
358
|
+
embeddings = tf.nn.l2_normalize(embeddings, axis=1)
|
|
359
|
+
embed_shape = get_shape_list(embeddings)
|
|
360
|
+
batch_size = embed_shape[0]
|
|
361
|
+
sim_mat = tf.matmul(embeddings, embeddings, transpose_b=True)
|
|
362
|
+
sim_mat = sim_mat - tf.eye(batch_size) * 2.0
|
|
363
|
+
mask = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
|
|
364
|
+
if session_ids is not None and session_ids is not labels:
|
|
365
|
+
sessions_equal = tf.equal(
|
|
366
|
+
tf.expand_dims(session_ids, 0), tf.expand_dims(session_ids, 1))
|
|
367
|
+
mask = tf.logical_and(sessions_equal, mask)
|
|
368
|
+
label_indices = _get_matrix_mask_indices(mask)
|
|
369
|
+
if isinstance(k, int):
|
|
370
|
+
return metrics_tf.average_precision_at_k(label_indices, sim_mat, k)
|
|
371
|
+
if any((isinstance(k, list), isinstance(k, tuple), isinstance(k, set))):
|
|
372
|
+
metrics = {}
|
|
373
|
+
for kk in k:
|
|
374
|
+
if kk < 1:
|
|
375
|
+
continue
|
|
376
|
+
metrics['MAP@' + str(kk)] = metrics_tf.average_precision_at_k(
|
|
377
|
+
label_indices, sim_mat, kk)
|
|
378
|
+
return metrics
|
|
379
|
+
else:
|
|
380
|
+
raise ValueError('k should be a `int` or a list/tuple/set of int.')
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def _get_matrix_mask_indices(matrix, num_rows=None):
|
|
384
|
+
if num_rows is None:
|
|
385
|
+
num_rows = get_shape_list(matrix)[0]
|
|
386
|
+
indices = tf.where(matrix)
|
|
387
|
+
num_indices = tf.shape(indices)[0]
|
|
388
|
+
elem_per_row = tf.bincount(
|
|
389
|
+
tf.cast(indices[:, 0], tf.int32), minlength=num_rows)
|
|
390
|
+
max_elem_per_row = tf.reduce_max(elem_per_row)
|
|
391
|
+
row_start = tf.concat([[0], tf.cumsum(elem_per_row[:-1])], axis=0)
|
|
392
|
+
r = tf.range(max_elem_per_row)
|
|
393
|
+
idx = tf.expand_dims(row_start, 1) + r
|
|
394
|
+
idx = tf.minimum(idx, num_indices - 1)
|
|
395
|
+
result = tf.gather(indices[:, 1], idx)
|
|
396
|
+
# replace invalid elements with -1
|
|
397
|
+
result = tf.where(
|
|
398
|
+
tf.expand_dims(elem_per_row, 1) > r, result, -array_ops.ones_like(result))
|
|
399
|
+
max_index_per_row = tf.reduce_max(result, axis=1, keepdims=True)
|
|
400
|
+
max_index_per_row = tf.tile(max_index_per_row, [1, max_elem_per_row])
|
|
401
|
+
result = tf.where(result >= 0, result, max_index_per_row)
|
|
402
|
+
return result
|