easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
# Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# =============================================================================
|
|
15
|
+
# """Evaluation of Top k hitrate."""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
from __future__ import division
|
|
18
|
+
from __future__ import print_function
|
|
19
|
+
|
|
20
|
+
import json
|
|
21
|
+
import logging
|
|
22
|
+
import os
|
|
23
|
+
import sys
|
|
24
|
+
|
|
25
|
+
import graphlearn as gl
|
|
26
|
+
import tensorflow as tf
|
|
27
|
+
|
|
28
|
+
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
|
|
29
|
+
from easy_rec.python.utils import config_util
|
|
30
|
+
from easy_rec.python.utils import io_util
|
|
31
|
+
from easy_rec.python.utils.config_util import process_multi_file_input_path
|
|
32
|
+
from easy_rec.python.utils.hit_rate_utils import compute_hitrate_batch
|
|
33
|
+
from easy_rec.python.utils.hit_rate_utils import load_graph
|
|
34
|
+
from easy_rec.python.utils.hit_rate_utils import reduce_hitrate
|
|
35
|
+
from easy_rec.python.utils.hive_utils import HiveUtils
|
|
36
|
+
|
|
37
|
+
if tf.__version__ >= '2.0':
|
|
38
|
+
tf = tf.compat.v1
|
|
39
|
+
|
|
40
|
+
from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_train_worker_num_on_ds # NOQA
|
|
41
|
+
|
|
42
|
+
logging.basicConfig(
|
|
43
|
+
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
|
|
44
|
+
level=logging.INFO)
|
|
45
|
+
|
|
46
|
+
tf.app.flags.DEFINE_string('item_emb_table', '', 'item embedding table name')
|
|
47
|
+
tf.app.flags.DEFINE_string('gt_table', '', 'ground truth table name')
|
|
48
|
+
tf.app.flags.DEFINE_string('hitrate_details_result', '',
|
|
49
|
+
'hitrate detail file path')
|
|
50
|
+
tf.app.flags.DEFINE_string('total_hitrate_result', '',
|
|
51
|
+
'total hitrate result file path')
|
|
52
|
+
|
|
53
|
+
tf.app.flags.DEFINE_string('pipeline_config_path', '', 'pipeline config path')
|
|
54
|
+
tf.app.flags.DEFINE_integer('batch_size', 512, 'batch size')
|
|
55
|
+
tf.app.flags.DEFINE_integer('emb_dim', 128, 'embedding dimension')
|
|
56
|
+
tf.app.flags.DEFINE_string('recall_type', 'i2i', 'i2i or u2i')
|
|
57
|
+
tf.app.flags.DEFINE_integer('top_k', '5', 'top_k hitrate.')
|
|
58
|
+
tf.app.flags.DEFINE_integer('knn_metric', '0', '0(l2) or 1(ip).')
|
|
59
|
+
tf.app.flags.DEFINE_bool('knn_strict', False, 'use exact search.')
|
|
60
|
+
tf.app.flags.DEFINE_integer('timeout', '60', 'timeout')
|
|
61
|
+
tf.app.flags.DEFINE_integer('num_interests', 1, 'max number of interests')
|
|
62
|
+
tf.app.flags.DEFINE_string('gt_table_field_sep', '\t', 'gt_table_field_sep')
|
|
63
|
+
tf.app.flags.DEFINE_string('item_emb_table_field_sep', '\t',
|
|
64
|
+
'item_emb_table_field_sep')
|
|
65
|
+
tf.app.flags.DEFINE_bool('is_on_ds', False, help='is on ds')
|
|
66
|
+
|
|
67
|
+
FLAGS = tf.app.flags.FLAGS
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def compute_hitrate(g, gt_all, hitrate_writer, gt_table=None):
|
|
71
|
+
"""Compute hitrate of each worker.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
g: a GL Graph instance.
|
|
75
|
+
gt_reader: reader of input trigger_items_table.
|
|
76
|
+
hitrate_writer: writer of hitrate table.
|
|
77
|
+
gt_table: ground truth table.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
total_hits: total hits of this worker.
|
|
81
|
+
total_gt_count: total count of ground truth items of this worker.
|
|
82
|
+
"""
|
|
83
|
+
total_hits = 0.0
|
|
84
|
+
total_gt_count = 0.0
|
|
85
|
+
|
|
86
|
+
for gt_record in gt_all:
|
|
87
|
+
gt_record = list(gt_record)
|
|
88
|
+
hits, gt_count, src_ids, recall_ids, recall_distances, hitrates, bad_cases, bad_dists = \
|
|
89
|
+
compute_hitrate_batch(g, gt_record, FLAGS.emb_dim, FLAGS.num_interests, FLAGS.top_k)
|
|
90
|
+
total_hits += hits
|
|
91
|
+
total_gt_count += gt_count
|
|
92
|
+
|
|
93
|
+
src_ids = [str(ids) for ids in src_ids]
|
|
94
|
+
hitrates = [str(hitrate) for hitrate in hitrates]
|
|
95
|
+
topk_recalls = [','.join(str(x) for x in ids) for ids in recall_ids]
|
|
96
|
+
topk_dists = [
|
|
97
|
+
','.join('|'.join(str(x)
|
|
98
|
+
for x in dist)
|
|
99
|
+
for dist in dists)
|
|
100
|
+
for dists in recall_distances
|
|
101
|
+
]
|
|
102
|
+
bad_cases = [','.join(str(x) for x in bad_case) for bad_case in bad_cases]
|
|
103
|
+
bad_dists = [','.join(str(x) for x in dist) for dist in bad_dists]
|
|
104
|
+
|
|
105
|
+
hitrate_writer.write('\n'.join([
|
|
106
|
+
'\t'.join(line) for line in zip(src_ids, topk_recalls, topk_dists,
|
|
107
|
+
hitrates, bad_cases, bad_dists)
|
|
108
|
+
]))
|
|
109
|
+
print('total_hits: ', total_hits)
|
|
110
|
+
print('total_gt_count: ', total_gt_count)
|
|
111
|
+
return total_hits, total_gt_count
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def gt_hdfs(gt_table, batch_size, gt_file_sep):
|
|
115
|
+
|
|
116
|
+
if '*' in gt_table or ',' in gt_table:
|
|
117
|
+
file_paths = tf.gfile.Glob(gt_table.split(','))
|
|
118
|
+
elif tf.gfile.IsDirectory(gt_table):
|
|
119
|
+
file_paths = tf.gfile.Glob(os.path.join(gt_table, '*'))
|
|
120
|
+
else:
|
|
121
|
+
file_paths = tf.gfile.Glob(gt_table)
|
|
122
|
+
|
|
123
|
+
batch_list, i = [], 0
|
|
124
|
+
for file_path in file_paths:
|
|
125
|
+
with tf.gfile.GFile(file_path, 'r') as fin:
|
|
126
|
+
for gt in fin:
|
|
127
|
+
i += 1
|
|
128
|
+
gt_list = gt.strip().split(gt_file_sep)
|
|
129
|
+
# make id , emb_num to int
|
|
130
|
+
gt_list[0], gt_list[3] = int(gt_list[0]), int(gt_list[3])
|
|
131
|
+
batch_list.append(tuple(i for i in gt_list))
|
|
132
|
+
if i >= batch_size:
|
|
133
|
+
yield batch_list
|
|
134
|
+
batch_list, i = [], 0
|
|
135
|
+
if i != 0:
|
|
136
|
+
yield batch_list
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def main():
|
|
140
|
+
tf_config = json.loads(os.environ['TF_CONFIG'])
|
|
141
|
+
worker_count = len(tf_config['cluster']['worker'])
|
|
142
|
+
task_index = tf_config['task']['index']
|
|
143
|
+
job_name = tf_config['task']['type']
|
|
144
|
+
|
|
145
|
+
hitrate_details_result = FLAGS.hitrate_details_result
|
|
146
|
+
total_hitrate_result = FLAGS.total_hitrate_result
|
|
147
|
+
i_emb_table = FLAGS.item_emb_table
|
|
148
|
+
gt_table = FLAGS.gt_table
|
|
149
|
+
|
|
150
|
+
pipeline_config = config_util.get_configs_from_pipeline_file(
|
|
151
|
+
FLAGS.pipeline_config_path)
|
|
152
|
+
logging.info('i_emb_table %s', i_emb_table)
|
|
153
|
+
|
|
154
|
+
input_type = pipeline_config.data_config.input_type
|
|
155
|
+
input_type_name = DatasetConfig.InputType.Name(input_type)
|
|
156
|
+
if input_type_name == 'CSVInput':
|
|
157
|
+
i_emb_table = process_multi_file_input_path(i_emb_table)
|
|
158
|
+
else:
|
|
159
|
+
hive_utils = HiveUtils(
|
|
160
|
+
data_config=pipeline_config.data_config,
|
|
161
|
+
hive_config=pipeline_config.hive_train_input)
|
|
162
|
+
i_emb_table = hive_utils.get_table_location(i_emb_table)
|
|
163
|
+
|
|
164
|
+
g = load_graph(i_emb_table, FLAGS.emb_dim, FLAGS.knn_metric, FLAGS.timeout,
|
|
165
|
+
FLAGS.knn_strict)
|
|
166
|
+
gl.set_tracker_mode(0)
|
|
167
|
+
gl.set_field_delimiter(FLAGS.item_emb_table_field_sep)
|
|
168
|
+
|
|
169
|
+
cluster = tf.train.ClusterSpec({
|
|
170
|
+
'ps': tf_config['cluster']['ps'],
|
|
171
|
+
'worker': tf_config['cluster']['worker']
|
|
172
|
+
})
|
|
173
|
+
server = tf.train.Server(cluster, job_name=job_name, task_index=task_index)
|
|
174
|
+
|
|
175
|
+
if job_name == 'ps':
|
|
176
|
+
server.join()
|
|
177
|
+
else:
|
|
178
|
+
worker_hosts = [
|
|
179
|
+
str(host.split(':')[0]) + ':888' + str(i)
|
|
180
|
+
for i, host in enumerate(tf_config['cluster']['worker'])
|
|
181
|
+
]
|
|
182
|
+
worker_hosts = ','.join(worker_hosts)
|
|
183
|
+
g.init(task_index=task_index, task_count=worker_count, hosts=worker_hosts)
|
|
184
|
+
# Your model, use g to do some operation, such as sampling
|
|
185
|
+
|
|
186
|
+
if input_type_name == 'CSVInput':
|
|
187
|
+
gt_all = gt_hdfs(gt_table, FLAGS.batch_size, FLAGS.gt_table_field_sep)
|
|
188
|
+
else:
|
|
189
|
+
gt_reader = HiveUtils(
|
|
190
|
+
data_config=pipeline_config.data_config,
|
|
191
|
+
hive_config=pipeline_config.hive_train_input,
|
|
192
|
+
selected_cols='*')
|
|
193
|
+
gt_all = gt_reader.hive_read_lines(gt_table, FLAGS.batch_size)
|
|
194
|
+
if not tf.gfile.IsDirectory(hitrate_details_result):
|
|
195
|
+
tf.gfile.MakeDirs(hitrate_details_result)
|
|
196
|
+
hitrate_details_result = os.path.join(hitrate_details_result,
|
|
197
|
+
'part-%s' % task_index)
|
|
198
|
+
details_writer = tf.gfile.GFile(hitrate_details_result, 'w')
|
|
199
|
+
print('Start compute hitrate...')
|
|
200
|
+
total_hits, total_gt_count = compute_hitrate(g, gt_all, details_writer,
|
|
201
|
+
gt_table)
|
|
202
|
+
var_total_hitrate, var_worker_count = reduce_hitrate(
|
|
203
|
+
cluster, total_hits, total_gt_count, task_index)
|
|
204
|
+
|
|
205
|
+
with tf.train.MonitoredTrainingSession(
|
|
206
|
+
master=server.target, is_chief=(task_index == 0)) as sess:
|
|
207
|
+
outs = sess.run([var_total_hitrate, var_worker_count])
|
|
208
|
+
|
|
209
|
+
# write after all workers have completed the calculation of hitrate.
|
|
210
|
+
print('outs: ', outs)
|
|
211
|
+
if outs[1] == worker_count:
|
|
212
|
+
logging.info(outs)
|
|
213
|
+
with tf.gfile.GFile(total_hitrate_result, 'w') as total_writer:
|
|
214
|
+
total_writer.write(str(outs[0]))
|
|
215
|
+
|
|
216
|
+
details_writer.close()
|
|
217
|
+
g.close()
|
|
218
|
+
print('Compute hitrate done.')
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
if __name__ == '__main__':
|
|
222
|
+
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
|
|
223
|
+
main()
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
# Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# =============================================================================
|
|
15
|
+
"""Evaluation of Top k hitrate."""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
from __future__ import division
|
|
18
|
+
from __future__ import print_function
|
|
19
|
+
|
|
20
|
+
import sys
|
|
21
|
+
|
|
22
|
+
import tensorflow as tf
|
|
23
|
+
|
|
24
|
+
from easy_rec.python.utils import io_util
|
|
25
|
+
from easy_rec.python.utils.hit_rate_utils import compute_hitrate_batch
|
|
26
|
+
from easy_rec.python.utils.hit_rate_utils import load_graph
|
|
27
|
+
from easy_rec.python.utils.hit_rate_utils import reduce_hitrate
|
|
28
|
+
|
|
29
|
+
flags = tf.app.flags
|
|
30
|
+
FLAGS = flags.FLAGS
|
|
31
|
+
flags.DEFINE_integer('task_index', None, 'Task index')
|
|
32
|
+
flags.DEFINE_integer('task_count', None, 'Task count')
|
|
33
|
+
flags.DEFINE_string('job_name', None, 'worker or ps or aligraph')
|
|
34
|
+
flags.DEFINE_string('ps_hosts', '', 'ps hosts')
|
|
35
|
+
flags.DEFINE_string('worker_hosts', '', 'worker hosts')
|
|
36
|
+
flags.DEFINE_string('tables', '', 'input odps tables name')
|
|
37
|
+
flags.DEFINE_string('outputs', '', 'ouput odps tables name')
|
|
38
|
+
flags.DEFINE_integer('batch_size', 512, 'batch size')
|
|
39
|
+
flags.DEFINE_integer('emb_dim', 128, 'embedding dimension')
|
|
40
|
+
flags.DEFINE_string('recall_type', 'i2i', 'i2i or u2i')
|
|
41
|
+
flags.DEFINE_integer('top_k', '5', 'top_k hitrate.')
|
|
42
|
+
flags.DEFINE_integer('knn_metric', '0', '0(l2) or 1(ip).')
|
|
43
|
+
flags.DEFINE_bool('knn_strict', False, 'use exact search.')
|
|
44
|
+
flags.DEFINE_integer('timeout', '60', 'timeout')
|
|
45
|
+
flags.DEFINE_integer('num_interests', 1, 'max number of interests')
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def compute_hitrate(g, gt_reader, hitrate_writer):
|
|
49
|
+
"""Compute hitrate of each worker.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
g: a GL Graph instance.
|
|
53
|
+
gt_reader: odps reader of input trigger_items_table.
|
|
54
|
+
hitrate_writer: odps writer of hitrate table.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
total_hits: total hits of this worker.
|
|
58
|
+
total_gt_count: total count of ground truth items of this worker.
|
|
59
|
+
"""
|
|
60
|
+
total_hits = 0.0
|
|
61
|
+
total_gt_count = 0.0
|
|
62
|
+
while True:
|
|
63
|
+
try:
|
|
64
|
+
gt_record = gt_reader.read(FLAGS.batch_size)
|
|
65
|
+
hits, gt_count, src_ids, recall_ids, recall_distances, hitrates, bad_cases, bad_dists = \
|
|
66
|
+
compute_hitrate_batch(g, gt_record, FLAGS.emb_dim, FLAGS.num_interests, FLAGS.top_k)
|
|
67
|
+
total_hits += hits
|
|
68
|
+
total_gt_count += gt_count
|
|
69
|
+
topk_recalls = [','.join(str(x) for x in ids) for ids in recall_ids]
|
|
70
|
+
topk_dists = [
|
|
71
|
+
','.join(str(x) for x in dists) for dists in recall_distances
|
|
72
|
+
]
|
|
73
|
+
bad_cases = [','.join(str(x) for x in case) for case in bad_cases]
|
|
74
|
+
bad_dists = [','.join(str(x) for x in dist) for dist in bad_dists]
|
|
75
|
+
|
|
76
|
+
hitrate_writer.write(
|
|
77
|
+
list(
|
|
78
|
+
zip(src_ids, topk_recalls, topk_dists, hitrates, bad_cases,
|
|
79
|
+
bad_dists)),
|
|
80
|
+
indices=[0, 1, 2, 3, 4, 5])
|
|
81
|
+
except tf.python_io.OutOfRangeException:
|
|
82
|
+
break
|
|
83
|
+
return total_hits, total_gt_count
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def main():
|
|
87
|
+
worker_count = len(FLAGS.worker_hosts.split(','))
|
|
88
|
+
input_tables = FLAGS.tables.split(',')
|
|
89
|
+
if FLAGS.recall_type == 'u2i':
|
|
90
|
+
i_emb_table, gt_table = input_tables
|
|
91
|
+
g = load_graph(i_emb_table, FLAGS.emb_dim, FLAGS.knn_metric, FLAGS.timeout,
|
|
92
|
+
FLAGS.knn_strict)
|
|
93
|
+
else:
|
|
94
|
+
i_emb_table, gt_table = input_tables[-2], input_tables[-1]
|
|
95
|
+
g = load_graph(i_emb_table, FLAGS.emb_dim, FLAGS.knn_metric, FLAGS.timeout,
|
|
96
|
+
FLAGS.knn_strict)
|
|
97
|
+
hitrate_details_table, total_hitrate_table = FLAGS.outputs.split(',')
|
|
98
|
+
|
|
99
|
+
cluster = tf.train.ClusterSpec({
|
|
100
|
+
'ps': FLAGS.ps_hosts.split(','),
|
|
101
|
+
'worker': FLAGS.worker_hosts.split(',')
|
|
102
|
+
})
|
|
103
|
+
server = tf.train.Server(
|
|
104
|
+
cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
|
|
105
|
+
if FLAGS.job_name == 'ps':
|
|
106
|
+
server.join()
|
|
107
|
+
else:
|
|
108
|
+
g.init(task_index=FLAGS.task_index, task_count=worker_count)
|
|
109
|
+
gt_reader = tf.python_io.TableReader(
|
|
110
|
+
gt_table,
|
|
111
|
+
slice_id=FLAGS.task_index,
|
|
112
|
+
slice_count=worker_count,
|
|
113
|
+
capacity=2048)
|
|
114
|
+
details_writer = tf.python_io.TableWriter(
|
|
115
|
+
hitrate_details_table, slice_id=FLAGS.task_index)
|
|
116
|
+
print('Start compute hitrate...')
|
|
117
|
+
total_hits, total_gt_count = compute_hitrate(g, gt_reader, details_writer)
|
|
118
|
+
var_total_hitrate, var_worker_count = reduce_hitrate(
|
|
119
|
+
cluster, total_hits, total_gt_count, FLAGS.task_index)
|
|
120
|
+
|
|
121
|
+
with tf.train.MonitoredTrainingSession(
|
|
122
|
+
master=server.target, is_chief=(FLAGS.task_index == 0)) as sess:
|
|
123
|
+
outs = sess.run([var_total_hitrate, var_worker_count])
|
|
124
|
+
|
|
125
|
+
# write after all workers have completed the calculation of hitrate.
|
|
126
|
+
if outs[1] == worker_count:
|
|
127
|
+
with tf.python_io.TableWriter(total_hitrate_table) as total_writer:
|
|
128
|
+
total_writer.write([outs[0]], indices=[0])
|
|
129
|
+
|
|
130
|
+
gt_reader.close()
|
|
131
|
+
details_writer.close()
|
|
132
|
+
g.close()
|
|
133
|
+
print('Compute hitrate done.')
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
if __name__ == '__main__':
|
|
137
|
+
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
|
|
138
|
+
main()
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
import tensorflow as tf
|
|
9
|
+
|
|
10
|
+
from easy_rec.python.input.input import Input
|
|
11
|
+
from easy_rec.python.utils import config_util
|
|
12
|
+
from easy_rec.python.utils import fg_util
|
|
13
|
+
from easy_rec.python.utils import io_util
|
|
14
|
+
from easy_rec.python.utils.check_utils import check_env_and_input_path
|
|
15
|
+
from easy_rec.python.utils.check_utils import check_sequence
|
|
16
|
+
|
|
17
|
+
if tf.__version__ >= '2.0':
|
|
18
|
+
tf = tf.compat.v1
|
|
19
|
+
|
|
20
|
+
logging.basicConfig(
|
|
21
|
+
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
|
|
22
|
+
level=logging.INFO)
|
|
23
|
+
tf.app.flags.DEFINE_string('pipeline_config_path', None,
|
|
24
|
+
'Path to pipeline config '
|
|
25
|
+
'file.')
|
|
26
|
+
tf.app.flags.DEFINE_multi_string(
|
|
27
|
+
'data_input_path', None, help='data input path')
|
|
28
|
+
|
|
29
|
+
FLAGS = tf.app.flags.FLAGS
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _get_input_fn(data_config,
|
|
33
|
+
feature_configs,
|
|
34
|
+
data_path=None,
|
|
35
|
+
export_config=None):
|
|
36
|
+
"""Build estimator input function.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
data_config: dataset config
|
|
40
|
+
feature_configs: FeatureConfig
|
|
41
|
+
data_path: input_data_path
|
|
42
|
+
export_config: configuration for exporting models,
|
|
43
|
+
only used to build input_fn when exporting models
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
subclass of Input
|
|
47
|
+
"""
|
|
48
|
+
input_class_map = {y: x for x, y in data_config.InputType.items()}
|
|
49
|
+
input_cls_name = input_class_map[data_config.input_type]
|
|
50
|
+
|
|
51
|
+
input_class = Input.create_class(input_cls_name)
|
|
52
|
+
if 'TF_CONFIG' in os.environ:
|
|
53
|
+
tf_config = json.loads(os.environ['TF_CONFIG'])
|
|
54
|
+
worker_num = len(tf_config['cluster']['worker'])
|
|
55
|
+
task_index = tf_config['task']['index']
|
|
56
|
+
else:
|
|
57
|
+
worker_num = 1
|
|
58
|
+
task_index = 0
|
|
59
|
+
|
|
60
|
+
input_obj = input_class(
|
|
61
|
+
data_config,
|
|
62
|
+
feature_configs,
|
|
63
|
+
data_path,
|
|
64
|
+
task_index=task_index,
|
|
65
|
+
task_num=worker_num,
|
|
66
|
+
check_mode=True)
|
|
67
|
+
input_fn = input_obj.create_input(export_config)
|
|
68
|
+
return input_fn
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def loda_pipeline_config(pipeline_config_path):
|
|
72
|
+
pipeline_config = config_util.get_configs_from_pipeline_file(
|
|
73
|
+
pipeline_config_path, False)
|
|
74
|
+
if pipeline_config.fg_json_path:
|
|
75
|
+
fg_util.load_fg_json_to_config(pipeline_config)
|
|
76
|
+
config_util.auto_expand_share_feature_configs(pipeline_config)
|
|
77
|
+
return pipeline_config
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def run_check(pipeline_config, input_path):
|
|
81
|
+
logging.info('data_input_path: %s' % input_path)
|
|
82
|
+
check_env_and_input_path(pipeline_config, input_path)
|
|
83
|
+
feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
|
|
84
|
+
eval_input_fn = _get_input_fn(pipeline_config.data_config, feature_configs,
|
|
85
|
+
input_path)
|
|
86
|
+
eval_spec = tf.estimator.EvalSpec(
|
|
87
|
+
name='val',
|
|
88
|
+
input_fn=eval_input_fn,
|
|
89
|
+
steps=None,
|
|
90
|
+
throttle_secs=10,
|
|
91
|
+
exporters=[])
|
|
92
|
+
input_iter = eval_spec.input_fn(
|
|
93
|
+
mode=tf.estimator.ModeKeys.EVAL).make_one_shot_iterator()
|
|
94
|
+
with tf.Session() as sess:
|
|
95
|
+
try:
|
|
96
|
+
while (True):
|
|
97
|
+
input_feas, input_lbls = input_iter.get_next()
|
|
98
|
+
features = sess.run(input_feas)
|
|
99
|
+
check_sequence(pipeline_config, features)
|
|
100
|
+
except tf.errors.OutOfRangeError:
|
|
101
|
+
logging.info('pre-check finish...')
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def main(argv):
|
|
105
|
+
assert FLAGS.pipeline_config_path, 'pipeline_config_path should not be empty when checking!'
|
|
106
|
+
pipeline_config = loda_pipeline_config(FLAGS.pipeline_config_path)
|
|
107
|
+
|
|
108
|
+
if FLAGS.data_input_path:
|
|
109
|
+
input_path = ','.join(FLAGS.data_input_path)
|
|
110
|
+
else:
|
|
111
|
+
assert pipeline_config.train_input_path or pipeline_config.eval_input_path, \
|
|
112
|
+
'input_path should not be empty when checking!'
|
|
113
|
+
input_path = pipeline_config.train_input_path + ',' + pipeline_config.eval_input_path
|
|
114
|
+
|
|
115
|
+
run_check(pipeline_config, input_path)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
if __name__ == '__main__':
|
|
119
|
+
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
|
|
120
|
+
tf.app.run()
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import argparse
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
import easy_rec
|
|
12
|
+
from easy_rec.python.inference.predictor import Predictor
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import tensorflow as tf
|
|
16
|
+
tf.load_op_library(os.path.join(easy_rec.ops_dir, 'libembed_op.so'))
|
|
17
|
+
except Exception as ex:
|
|
18
|
+
logging.warning('exception: %s' % str(ex))
|
|
19
|
+
|
|
20
|
+
logging.basicConfig(
|
|
21
|
+
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
|
|
22
|
+
|
|
23
|
+
if __name__ == '__main__':
|
|
24
|
+
parser = argparse.ArgumentParser()
|
|
25
|
+
parser.add_argument(
|
|
26
|
+
'--saved_model_dir', type=str, default=None, help='saved model directory')
|
|
27
|
+
parser.add_argument(
|
|
28
|
+
'--input_path', type=str, default=None, help='input feature path')
|
|
29
|
+
parser.add_argument('--save_path', type=str, default=None, help='save path')
|
|
30
|
+
parser.add_argument(
|
|
31
|
+
'--cmp_res_path', type=str, default=None, help='compare result path')
|
|
32
|
+
parser.add_argument(
|
|
33
|
+
'--cmp_key', type=str, default='probs', help='compare key')
|
|
34
|
+
parser.add_argument(
|
|
35
|
+
'--rtp_fea_id',
|
|
36
|
+
type=int,
|
|
37
|
+
default=-1,
|
|
38
|
+
help='rtp feature column index, default to the last column')
|
|
39
|
+
parser.add_argument('--tol', type=float, default=1e-5, help='tolerance')
|
|
40
|
+
parser.add_argument(
|
|
41
|
+
'--label_id',
|
|
42
|
+
nargs='*',
|
|
43
|
+
type=int,
|
|
44
|
+
help='the label column, which is to be excluded')
|
|
45
|
+
parser.add_argument(
|
|
46
|
+
'--separator',
|
|
47
|
+
type=str,
|
|
48
|
+
default='',
|
|
49
|
+
help='separator between features, default to \\u0002')
|
|
50
|
+
parser.add_argument(
|
|
51
|
+
'--rtp_separator',
|
|
52
|
+
type=str,
|
|
53
|
+
default='',
|
|
54
|
+
help='separator, default to \\u0001')
|
|
55
|
+
args = parser.parse_args()
|
|
56
|
+
|
|
57
|
+
if not args.saved_model_dir:
|
|
58
|
+
logging.error('saved_model_dir is not set')
|
|
59
|
+
sys.exit(1)
|
|
60
|
+
|
|
61
|
+
if not args.input_path:
|
|
62
|
+
logging.error('input_path is not set')
|
|
63
|
+
sys.exit(1)
|
|
64
|
+
|
|
65
|
+
if args.label_id is None:
|
|
66
|
+
args.label_id = []
|
|
67
|
+
|
|
68
|
+
logging.info('input_path: ' + args.input_path)
|
|
69
|
+
logging.info('save_path: ' + args.save_path)
|
|
70
|
+
logging.info('separator: ' + args.separator)
|
|
71
|
+
|
|
72
|
+
predictor = Predictor(args.saved_model_dir)
|
|
73
|
+
if len(predictor.input_names) == 1:
|
|
74
|
+
assert len(
|
|
75
|
+
args.label_id
|
|
76
|
+
) == 0, 'label_id should not be set if rtp feature format is used.'
|
|
77
|
+
|
|
78
|
+
with open(args.input_path, 'r') as fin:
|
|
79
|
+
batch_input = []
|
|
80
|
+
for line_str in fin:
|
|
81
|
+
line_str = line_str.strip()
|
|
82
|
+
line_tok = line_str.split(args.rtp_separator)
|
|
83
|
+
feature = line_tok[args.rtp_fea_id]
|
|
84
|
+
feature = [
|
|
85
|
+
x for fid, x in enumerate(feature.split(args.separator))
|
|
86
|
+
if fid not in args.label_id
|
|
87
|
+
]
|
|
88
|
+
if 'features' in predictor.input_names:
|
|
89
|
+
feature = args.separator.join(feature)
|
|
90
|
+
batch_input.append(feature)
|
|
91
|
+
output = predictor.predict(batch_input)
|
|
92
|
+
|
|
93
|
+
if args.save_path:
|
|
94
|
+
fout = open(args.save_path, 'w')
|
|
95
|
+
for one in output:
|
|
96
|
+
fout.write(str(one) + '\n')
|
|
97
|
+
fout.close()
|
|
98
|
+
|
|
99
|
+
if args.cmp_res_path:
|
|
100
|
+
logging.info('compare result path: ' + args.cmp_res_path)
|
|
101
|
+
logging.info('compare key: ' + args.cmp_key)
|
|
102
|
+
logging.info('tolerance: ' + str(args.tol))
|
|
103
|
+
with open(args.cmp_res_path, 'r') as fin:
|
|
104
|
+
for line_id, line_str in enumerate(fin):
|
|
105
|
+
line_str = line_str.strip()
|
|
106
|
+
line_pred = json.loads(line_str)
|
|
107
|
+
assert np.abs(
|
|
108
|
+
line_pred[args.cmp_key] -
|
|
109
|
+
output[line_id][args.cmp_key]) < args.tol, 'line[%d]: %.8f' % (
|
|
110
|
+
line_id,
|
|
111
|
+
np.abs(line_pred[args.cmp_key] - output[line_id][args.cmp_key]))
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import argparse
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
from kafka import KafkaConsumer
|
|
9
|
+
from kafka.structs import TopicPartition
|
|
10
|
+
|
|
11
|
+
logging.basicConfig(
|
|
12
|
+
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
|
|
13
|
+
|
|
14
|
+
if __name__ == '__main__':
|
|
15
|
+
parser = argparse.ArgumentParser()
|
|
16
|
+
parser.add_argument('--servers', type=str, default='localhost:9092')
|
|
17
|
+
parser.add_argument('--topic', type=str, default=None)
|
|
18
|
+
parser.add_argument('--group', type=str, default='consumer')
|
|
19
|
+
parser.add_argument('--partitions', type=str, default=None)
|
|
20
|
+
parser.add_argument('--timeout', type=float, default=float('inf'))
|
|
21
|
+
parser.add_argument('--save_dir', type=str, default=None)
|
|
22
|
+
args = parser.parse_args()
|
|
23
|
+
|
|
24
|
+
if args.topic is None:
|
|
25
|
+
logging.error('--topic is not set')
|
|
26
|
+
sys.exit(1)
|
|
27
|
+
|
|
28
|
+
servers = args.servers.split(',')
|
|
29
|
+
consumer = KafkaConsumer(
|
|
30
|
+
group_id=args.group,
|
|
31
|
+
bootstrap_servers=servers,
|
|
32
|
+
consumer_timeout_ms=args.timeout * 1000)
|
|
33
|
+
|
|
34
|
+
if args.partitions is not None:
|
|
35
|
+
partitions = [int(x) for x in args.partitions.split(',')]
|
|
36
|
+
else:
|
|
37
|
+
partitions = consumer.partitions_for_topic(args.topic)
|
|
38
|
+
logging.info('partitions: %s' % partitions)
|
|
39
|
+
|
|
40
|
+
topics = [
|
|
41
|
+
TopicPartition(topic=args.topic, partition=part_id)
|
|
42
|
+
for part_id in partitions
|
|
43
|
+
]
|
|
44
|
+
consumer.assign(topics)
|
|
45
|
+
consumer.seek_to_beginning()
|
|
46
|
+
|
|
47
|
+
record_id = 0
|
|
48
|
+
for x in consumer:
|
|
49
|
+
logging.info('%d: key=%s\toffset=%d\ttimestamp=%d\tlen=%d' %
|
|
50
|
+
(record_id, x.key, x.offset, x.timestamp, len(x.value)))
|
|
51
|
+
if args.save_dir is not None:
|
|
52
|
+
save_path = os.path.join(args.save_dir, x.key)
|
|
53
|
+
with open(save_path, 'wb') as fout:
|
|
54
|
+
fout.write(x.value)
|
|
55
|
+
record_id += 1
|