easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import argparse
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
import tensorflow as tf
|
|
9
|
+
from tensorflow.core.protobuf import saved_model_pb2
|
|
10
|
+
from tensorflow.python.lib.io.file_io import file_exists
|
|
11
|
+
from tensorflow.python.lib.io.file_io import recursive_create_dir
|
|
12
|
+
from tensorflow.python.platform.gfile import GFile
|
|
13
|
+
|
|
14
|
+
import easy_rec
|
|
15
|
+
from easy_rec.python.utils.meta_graph_editor import MetaGraphEditor
|
|
16
|
+
|
|
17
|
+
logging.basicConfig(
|
|
18
|
+
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
|
|
19
|
+
level=logging.INFO)
|
|
20
|
+
|
|
21
|
+
if __name__ == '__main__':
|
|
22
|
+
"""Replace the default embedding_lookup ops with self defined embedding lookup ops.
|
|
23
|
+
|
|
24
|
+
The data are now stored in redis, for lookup, it is to retrieve the
|
|
25
|
+
embedding vectors by {version}_{embed_name}_{embed_id}.
|
|
26
|
+
Example:
|
|
27
|
+
python -m easy_rec.python.tools.edit_lookup_graph
|
|
28
|
+
--saved_model_dir rtp_large_embedding_export/1604304644
|
|
29
|
+
--output_dir ./after_edit_save
|
|
30
|
+
--test_data_path data/test/rtp/xys_cxr_fg_sample_test2_with_lbl.txt
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
parser = argparse.ArgumentParser()
|
|
34
|
+
parser.add_argument(
|
|
35
|
+
'--saved_model_dir', type=str, default=None, help='saved model dir')
|
|
36
|
+
parser.add_argument('--output_dir', type=str, default=None, help='output dir')
|
|
37
|
+
parser.add_argument(
|
|
38
|
+
'--redis_url', type=str, default='127.0.0.1:6379', help='redis url')
|
|
39
|
+
parser.add_argument(
|
|
40
|
+
'--redis_passwd', type=str, default='', help='redis password')
|
|
41
|
+
parser.add_argument('--time_out', type=int, default=1500, help='timeout')
|
|
42
|
+
parser.add_argument(
|
|
43
|
+
'--test_data_path', type=str, default='', help='test data path')
|
|
44
|
+
parser.add_argument('--verbose', action='store_true', default=False)
|
|
45
|
+
|
|
46
|
+
args = parser.parse_args()
|
|
47
|
+
logging.info('saved_model_dir: %s' % args.saved_model_dir)
|
|
48
|
+
|
|
49
|
+
if not os.path.exists(os.path.join(args.saved_model_dir, 'saved_model.pb')):
|
|
50
|
+
logging.error('saved_model.pb does not exist in %s' % args.saved_model_dir)
|
|
51
|
+
sys.exit(1)
|
|
52
|
+
|
|
53
|
+
logging.info('output_dir: %s' % args.output_dir)
|
|
54
|
+
logging.info('redis_url: %s' % args.redis_url)
|
|
55
|
+
lookup_lib_path = os.path.join(easy_rec.ops_dir, 'libkv_lookup.so')
|
|
56
|
+
logging.info('lookup_lib_path: %s' % lookup_lib_path)
|
|
57
|
+
|
|
58
|
+
if not file_exists(args.output_dir):
|
|
59
|
+
recursive_create_dir(args.output_dir)
|
|
60
|
+
|
|
61
|
+
meta_graph_editor = MetaGraphEditor(
|
|
62
|
+
lookup_lib_path,
|
|
63
|
+
args.saved_model_dir,
|
|
64
|
+
args.redis_url,
|
|
65
|
+
args.redis_passwd,
|
|
66
|
+
args.time_out,
|
|
67
|
+
meta_graph_def=None,
|
|
68
|
+
debug_dir=args.output_dir if args.verbose else '')
|
|
69
|
+
meta_graph_editor.edit_graph()
|
|
70
|
+
|
|
71
|
+
meta_graph_version = meta_graph_editor.meta_graph_version
|
|
72
|
+
if meta_graph_version == '':
|
|
73
|
+
export_ts = [
|
|
74
|
+
x for x in args.saved_model_dir.split('/') if x != '' and x is not None
|
|
75
|
+
]
|
|
76
|
+
meta_graph_version = export_ts[-1]
|
|
77
|
+
|
|
78
|
+
# import edit graph
|
|
79
|
+
tf.reset_default_graph()
|
|
80
|
+
saver = tf.train.import_meta_graph(meta_graph_editor._meta_graph_def)
|
|
81
|
+
|
|
82
|
+
embed_name_to_id_file = os.path.join(args.output_dir, 'embed_name_to_ids.txt')
|
|
83
|
+
with GFile(embed_name_to_id_file, 'w') as fout:
|
|
84
|
+
for tmp_norm_name in meta_graph_editor._embed_name_to_ids:
|
|
85
|
+
fout.write(
|
|
86
|
+
'%s\t%s\n' %
|
|
87
|
+
(tmp_norm_name, meta_graph_editor._embed_name_to_ids[tmp_norm_name]))
|
|
88
|
+
tf.add_to_collection(
|
|
89
|
+
tf.GraphKeys.ASSET_FILEPATHS,
|
|
90
|
+
tf.constant(
|
|
91
|
+
embed_name_to_id_file, dtype=tf.string, name='embed_name_to_ids.txt'))
|
|
92
|
+
|
|
93
|
+
graph = tf.get_default_graph()
|
|
94
|
+
inputs = meta_graph_editor.signature_def.inputs
|
|
95
|
+
inputs_map = {}
|
|
96
|
+
for name, tensor in inputs.items():
|
|
97
|
+
logging.info('model inputs: %s => %s' % (name, tensor.name))
|
|
98
|
+
inputs_map[name] = graph.get_tensor_by_name(tensor.name)
|
|
99
|
+
|
|
100
|
+
outputs = meta_graph_editor.signature_def.outputs
|
|
101
|
+
outputs_map = {}
|
|
102
|
+
for name, tensor in outputs.items():
|
|
103
|
+
logging.info('model outputs: %s => %s' % (name, tensor.name))
|
|
104
|
+
outputs_map[name] = graph.get_tensor_by_name(tensor.name)
|
|
105
|
+
with tf.Session() as sess:
|
|
106
|
+
saver.restore(sess, args.saved_model_dir + '/variables/variables')
|
|
107
|
+
output_dir = os.path.join(args.output_dir, meta_graph_version)
|
|
108
|
+
tf.saved_model.simple_save(
|
|
109
|
+
sess, output_dir, inputs=inputs_map, outputs=outputs_map)
|
|
110
|
+
# the meta_graph_version could not be passed via existing interfaces
|
|
111
|
+
# so we could only write it by the raw methods
|
|
112
|
+
saved_model = saved_model_pb2.SavedModel()
|
|
113
|
+
with GFile(os.path.join(output_dir, 'saved_model.pb'), 'rb') as fin:
|
|
114
|
+
saved_model.ParseFromString(fin.read())
|
|
115
|
+
|
|
116
|
+
saved_model.meta_graphs[
|
|
117
|
+
0].meta_info_def.meta_graph_version = meta_graph_editor.meta_graph_version
|
|
118
|
+
with GFile(os.path.join(output_dir, 'saved_model.pb'), 'wb') as fout:
|
|
119
|
+
fout.write(saved_model.SerializeToString())
|
|
120
|
+
|
|
121
|
+
logging.info('save output to %s' % output_dir)
|
|
122
|
+
if args.test_data_path:
|
|
123
|
+
with GFile(args.test_data_path, 'r') as fin:
|
|
124
|
+
feature_vals = []
|
|
125
|
+
for line_str in fin:
|
|
126
|
+
line_str = line_str.strip()
|
|
127
|
+
line_toks = line_str.split('')
|
|
128
|
+
line_toks = line_toks[1:]
|
|
129
|
+
feature_vals.append(''.join(line_toks))
|
|
130
|
+
if len(feature_vals) >= 32:
|
|
131
|
+
break
|
|
132
|
+
out_vals = sess.run(
|
|
133
|
+
outputs_map, feed_dict={inputs_map['features']: feature_vals})
|
|
134
|
+
logging.info('test_data probs:' + str(out_vals))
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
from __future__ import print_function
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
|
|
9
|
+
import faiss
|
|
10
|
+
import numpy as np
|
|
11
|
+
import tensorflow as tf
|
|
12
|
+
|
|
13
|
+
from easy_rec.python.utils import io_util
|
|
14
|
+
|
|
15
|
+
logging.basicConfig(
|
|
16
|
+
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
|
|
17
|
+
|
|
18
|
+
tf.app.flags.DEFINE_string('tables', '', 'tables passed by pai command')
|
|
19
|
+
tf.app.flags.DEFINE_integer('batch_size', 1024, 'batch size')
|
|
20
|
+
tf.app.flags.DEFINE_integer('embedding_dim', 32, 'embedding dimension')
|
|
21
|
+
tf.app.flags.DEFINE_string('index_output_dir', '', 'index output directory')
|
|
22
|
+
tf.app.flags.DEFINE_string('index_type', 'IVFFlat', 'index type')
|
|
23
|
+
tf.app.flags.DEFINE_integer('ivf_nlist', 1000, 'nlist')
|
|
24
|
+
tf.app.flags.DEFINE_integer('hnsw_M', 32, 'hnsw M')
|
|
25
|
+
tf.app.flags.DEFINE_integer('hnsw_efConstruction', 200, 'hnsw efConstruction')
|
|
26
|
+
tf.app.flags.DEFINE_integer('debug', 0, 'debug index')
|
|
27
|
+
|
|
28
|
+
FLAGS = tf.app.flags.FLAGS
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def main(argv):
|
|
32
|
+
reader = tf.python_io.TableReader(
|
|
33
|
+
FLAGS.tables, slice_id=0, slice_count=1, capacity=FLAGS.batch_size * 2)
|
|
34
|
+
i = 0
|
|
35
|
+
id_map_f = tf.gfile.GFile(
|
|
36
|
+
os.path.join(FLAGS.index_output_dir, 'id_mapping'), 'w')
|
|
37
|
+
embeddings = []
|
|
38
|
+
while True:
|
|
39
|
+
try:
|
|
40
|
+
records = reader.read(FLAGS.batch_size)
|
|
41
|
+
for j, record in enumerate(records):
|
|
42
|
+
if isinstance(record[0], bytes):
|
|
43
|
+
eid = record[0].decode('utf-8')
|
|
44
|
+
id_map_f.write('%s\n' % eid)
|
|
45
|
+
|
|
46
|
+
embeddings.extend(
|
|
47
|
+
[list(map(float, record[1].split(b','))) for record in records])
|
|
48
|
+
i += 1
|
|
49
|
+
if i % 100 == 0:
|
|
50
|
+
logging.info('read %d embeddings.' % (i * FLAGS.batch_size))
|
|
51
|
+
except tf.python_io.OutOfRangeException:
|
|
52
|
+
break
|
|
53
|
+
reader.close()
|
|
54
|
+
id_map_f.close()
|
|
55
|
+
|
|
56
|
+
logging.info('Building faiss index..')
|
|
57
|
+
if FLAGS.index_type == 'IVFFlat':
|
|
58
|
+
quantizer = faiss.IndexFlatIP(FLAGS.embedding_dim)
|
|
59
|
+
index = faiss.IndexIVFFlat(quantizer, FLAGS.embedding_dim, FLAGS.ivf_nlist,
|
|
60
|
+
faiss.METRIC_INNER_PRODUCT)
|
|
61
|
+
elif FLAGS.index_type == 'HNSWFlat':
|
|
62
|
+
index = faiss.IndexHNSWFlat(FLAGS.embedding_dim, FLAGS.hnsw_M,
|
|
63
|
+
faiss.METRIC_INNER_PRODUCT)
|
|
64
|
+
index.hnsw.efConstruction = FLAGS.hnsw_efConstruction
|
|
65
|
+
else:
|
|
66
|
+
raise NotImplementedError
|
|
67
|
+
|
|
68
|
+
embeddings = np.array(embeddings)
|
|
69
|
+
if FLAGS.index_type == 'IVFFlat':
|
|
70
|
+
logging.info('train embeddings...')
|
|
71
|
+
index.train(embeddings)
|
|
72
|
+
|
|
73
|
+
logging.info('build embeddings...')
|
|
74
|
+
index.add(embeddings)
|
|
75
|
+
faiss.write_index(index, 'faiss_index')
|
|
76
|
+
|
|
77
|
+
with tf.gfile.GFile(
|
|
78
|
+
os.path.join(FLAGS.index_output_dir, 'faiss_index'), 'wb') as f_out:
|
|
79
|
+
with open('faiss_index', 'rb') as f_in:
|
|
80
|
+
f_out.write(f_in.read())
|
|
81
|
+
|
|
82
|
+
if FLAGS.debug != 0:
|
|
83
|
+
# IVFFlat
|
|
84
|
+
for ivf_nlist in [100, 500, 1000, 2000]:
|
|
85
|
+
quantizer = faiss.IndexFlatIP(FLAGS.embedding_dim)
|
|
86
|
+
index = faiss.IndexIVFFlat(quantizer, FLAGS.embedding_dim, ivf_nlist,
|
|
87
|
+
faiss.METRIC_INNER_PRODUCT)
|
|
88
|
+
index.train(embeddings)
|
|
89
|
+
index.add(embeddings)
|
|
90
|
+
index_name = 'faiss_index_ivfflat_nlist%d' % ivf_nlist
|
|
91
|
+
faiss.write_index(index, index_name)
|
|
92
|
+
with tf.gfile.GFile(
|
|
93
|
+
os.path.join(FLAGS.index_output_dir, index_name), 'wb') as f_out:
|
|
94
|
+
with open(index_name, 'rb') as f_in:
|
|
95
|
+
f_out.write(f_in.read())
|
|
96
|
+
|
|
97
|
+
# HNSWFlat
|
|
98
|
+
for hnsw_M in [16, 32, 64, 128]:
|
|
99
|
+
for hnsw_efConstruction in [64, 128, 256, 512, 1024, 2048, 4096, 8196]:
|
|
100
|
+
if hnsw_efConstruction < hnsw_M * 2:
|
|
101
|
+
continue
|
|
102
|
+
index = faiss.IndexHNSWFlat(FLAGS.embedding_dim, hnsw_M,
|
|
103
|
+
faiss.METRIC_INNER_PRODUCT)
|
|
104
|
+
index.hnsw.efConstruction = hnsw_efConstruction
|
|
105
|
+
index.add(embeddings)
|
|
106
|
+
index_name = 'faiss_index_hnsw_M%d_ef%d' % (hnsw_M, hnsw_efConstruction)
|
|
107
|
+
faiss.write_index(index, index_name)
|
|
108
|
+
with tf.gfile.GFile(
|
|
109
|
+
os.path.join(FLAGS.index_output_dir, index_name), 'wb') as f_out:
|
|
110
|
+
with open(index_name, 'rb') as f_in:
|
|
111
|
+
f_out.write(f_in.read())
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
if __name__ == '__main__':
|
|
115
|
+
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
|
|
116
|
+
tf.app.run()
|
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
from __future__ import division
|
|
2
|
+
from __future__ import print_function
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
from collections import OrderedDict
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import tensorflow as tf
|
|
12
|
+
from tensorflow.python.framework.meta_graph import read_meta_graph_file
|
|
13
|
+
|
|
14
|
+
from easy_rec.python.utils import config_util
|
|
15
|
+
from easy_rec.python.utils import io_util
|
|
16
|
+
|
|
17
|
+
if tf.__version__ >= '2.0':
|
|
18
|
+
tf = tf.compat.v1
|
|
19
|
+
|
|
20
|
+
import matplotlib # NOQA
|
|
21
|
+
matplotlib.use('Agg') # NOQA
|
|
22
|
+
import matplotlib.pyplot as plt # NOQA
|
|
23
|
+
|
|
24
|
+
tf.app.flags.DEFINE_string('model_type', 'variational_dropout',
|
|
25
|
+
'feature selection model type')
|
|
26
|
+
tf.app.flags.DEFINE_string('config_path', '',
|
|
27
|
+
'feature selection model config path')
|
|
28
|
+
tf.app.flags.DEFINE_string('checkpoint_path', None,
|
|
29
|
+
'feature selection model checkpoint path')
|
|
30
|
+
tf.app.flags.DEFINE_string('output_dir', '',
|
|
31
|
+
'feature selection result directory')
|
|
32
|
+
tf.app.flags.DEFINE_integer(
|
|
33
|
+
'topk', 100, 'select topk importance features for each feature group')
|
|
34
|
+
tf.app.flags.DEFINE_string('fg_path', '', 'fg config path')
|
|
35
|
+
tf.app.flags.DEFINE_bool('visualize', False,
|
|
36
|
+
'visualization feature selection result or not')
|
|
37
|
+
FLAGS = tf.app.flags.FLAGS
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class VariationalDropoutFS:
|
|
41
|
+
|
|
42
|
+
def __init__(self,
|
|
43
|
+
config_path,
|
|
44
|
+
output_dir,
|
|
45
|
+
topk,
|
|
46
|
+
checkpoint_path=None,
|
|
47
|
+
fg_path=None,
|
|
48
|
+
visualize=False):
|
|
49
|
+
self._config_path = config_path
|
|
50
|
+
self._output_dir = output_dir
|
|
51
|
+
self._topk = topk
|
|
52
|
+
if not tf.gfile.Exists(self._output_dir):
|
|
53
|
+
tf.gfile.MakeDirs(self._output_dir)
|
|
54
|
+
self._checkpoint_path = checkpoint_path
|
|
55
|
+
self._fg_path = fg_path
|
|
56
|
+
self._visualize = visualize
|
|
57
|
+
|
|
58
|
+
def process(self):
|
|
59
|
+
tf.logging.info('Loading logit_p of VariationalDropout layer ...')
|
|
60
|
+
feature_dim_dropout_p_map, embedding_wise_variational_dropout = self._feature_dim_dropout_ratio(
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
feature_importance_map = {}
|
|
64
|
+
for group_name, feature_dim_dropout_p in feature_dim_dropout_p_map.items():
|
|
65
|
+
tf.logging.info('Calculating %s feature importance ...' % group_name)
|
|
66
|
+
feature_importance = self._get_feature_importance(
|
|
67
|
+
feature_dim_dropout_p, embedding_wise_variational_dropout)
|
|
68
|
+
feature_importance_map[group_name] = feature_importance
|
|
69
|
+
|
|
70
|
+
tf.logging.info('Dump %s feature importance to csv ...' % group_name)
|
|
71
|
+
self._dump_to_csv(feature_importance, group_name)
|
|
72
|
+
|
|
73
|
+
if self._visualize:
|
|
74
|
+
tf.logging.info('Visualizing %s feature importance ...' % group_name)
|
|
75
|
+
if embedding_wise_variational_dropout:
|
|
76
|
+
self._visualize_embedding_dim_importance(feature_dim_dropout_p)
|
|
77
|
+
self._visualize_feature_importance(feature_importance, group_name)
|
|
78
|
+
|
|
79
|
+
tf.logging.info('Processing model config ...')
|
|
80
|
+
self._process_config(feature_importance_map)
|
|
81
|
+
|
|
82
|
+
def _feature_dim_dropout_ratio(self):
|
|
83
|
+
"""Get dropout ratio of embedding-wise or feature-wise."""
|
|
84
|
+
config = config_util.get_configs_from_pipeline_file(self._config_path)
|
|
85
|
+
assert config.model_config.HasField(
|
|
86
|
+
'variational_dropout'), 'variational_dropout must be in model_config'
|
|
87
|
+
|
|
88
|
+
embedding_wise_variational_dropout = config.model_config.variational_dropout.embedding_wise_variational_dropout
|
|
89
|
+
|
|
90
|
+
if self._checkpoint_path is None or len(self._checkpoint_path) == 0:
|
|
91
|
+
checkpoint_path = tf.train.latest_checkpoint(config.model_dir)
|
|
92
|
+
else:
|
|
93
|
+
checkpoint_path = self._checkpoint_path
|
|
94
|
+
|
|
95
|
+
meta_graph_def = read_meta_graph_file(checkpoint_path + '.meta')
|
|
96
|
+
features_dimension_map = dict()
|
|
97
|
+
for col_def in meta_graph_def.collection_def[
|
|
98
|
+
'variational_dropout'].bytes_list.value:
|
|
99
|
+
name, features_dimension = json.loads(col_def)
|
|
100
|
+
name = 'all' if name == '' else name
|
|
101
|
+
features_dimension_map[name] = OrderedDict(features_dimension)
|
|
102
|
+
|
|
103
|
+
tf.logging.info('Reading checkpoint from %s ...' % checkpoint_path)
|
|
104
|
+
reader = tf.train.NewCheckpointReader(checkpoint_path)
|
|
105
|
+
|
|
106
|
+
feature_dim_dropout_p_map = {}
|
|
107
|
+
for feature_group in config.model_config.feature_groups:
|
|
108
|
+
group_name = feature_group.group_name
|
|
109
|
+
|
|
110
|
+
logit_p_name = 'logit_p' if group_name == 'all' else 'logit_p_%s' % group_name
|
|
111
|
+
try:
|
|
112
|
+
logit_p = reader.get_tensor(logit_p_name)
|
|
113
|
+
except Exception:
|
|
114
|
+
print('get `logit_p` failed, try to get `backbone/logit_p`')
|
|
115
|
+
logit_p = reader.get_tensor('backbone/' + logit_p_name)
|
|
116
|
+
feature_dims_importance = tf.sigmoid(logit_p)
|
|
117
|
+
with tf.Session() as sess:
|
|
118
|
+
feature_dims_importance = feature_dims_importance.eval(session=sess)
|
|
119
|
+
|
|
120
|
+
feature_dim_dropout_p = {}
|
|
121
|
+
if embedding_wise_variational_dropout:
|
|
122
|
+
index_end = 0
|
|
123
|
+
for feature_name, feature_dim in features_dimension_map[
|
|
124
|
+
group_name].items():
|
|
125
|
+
index_start = index_end
|
|
126
|
+
index_end = index_start + feature_dim
|
|
127
|
+
feature_dim_dropout_p[feature_name] = feature_dims_importance[
|
|
128
|
+
index_start:index_end]
|
|
129
|
+
else:
|
|
130
|
+
index = 0
|
|
131
|
+
for feature_name in features_dimension_map[group_name].keys():
|
|
132
|
+
feature_dim_dropout_p[feature_name] = feature_dims_importance[index]
|
|
133
|
+
index += 1
|
|
134
|
+
|
|
135
|
+
feature_dim_dropout_p_map[group_name] = feature_dim_dropout_p
|
|
136
|
+
return feature_dim_dropout_p_map, embedding_wise_variational_dropout
|
|
137
|
+
|
|
138
|
+
def _get_feature_importance(self, feature_dim_dropout_p,
|
|
139
|
+
embedding_wise_variational_dropout):
|
|
140
|
+
"""Calculate feature importance."""
|
|
141
|
+
if embedding_wise_variational_dropout:
|
|
142
|
+
feature_importance = {}
|
|
143
|
+
for item in feature_dim_dropout_p.items():
|
|
144
|
+
dropout_rate_mean = np.mean(item[1])
|
|
145
|
+
feature_importance[item[0]] = dropout_rate_mean
|
|
146
|
+
feature_importance = OrderedDict(
|
|
147
|
+
sorted(feature_importance.items(), key=lambda e: e[1]))
|
|
148
|
+
else:
|
|
149
|
+
feature_importance = OrderedDict(
|
|
150
|
+
sorted(feature_dim_dropout_p.items(), key=lambda e: e[1]))
|
|
151
|
+
return feature_importance
|
|
152
|
+
|
|
153
|
+
def _process_config(self, feature_importance_map):
|
|
154
|
+
"""Process model config and fg config with feature selection."""
|
|
155
|
+
excluded_features = set()
|
|
156
|
+
for group_name, feature_importance in feature_importance_map.items():
|
|
157
|
+
for i, (feature_name, _) in enumerate(feature_importance.items()):
|
|
158
|
+
if i >= self._topk:
|
|
159
|
+
excluded_features.add(feature_name)
|
|
160
|
+
|
|
161
|
+
config = config_util.get_configs_from_pipeline_file(self._config_path)
|
|
162
|
+
# keep sequence features and side-infos
|
|
163
|
+
sequence_features = set()
|
|
164
|
+
for feature_group in config.model_config.feature_groups:
|
|
165
|
+
for sequence_feature in feature_group.sequence_features:
|
|
166
|
+
for seq_att_map in sequence_feature.seq_att_map:
|
|
167
|
+
for key in seq_att_map.key:
|
|
168
|
+
sequence_features.add(key)
|
|
169
|
+
for hist_seq in seq_att_map.hist_seq:
|
|
170
|
+
sequence_features.add(hist_seq)
|
|
171
|
+
# compat with din
|
|
172
|
+
for sequence_feature in config.model_config.seq_att_groups:
|
|
173
|
+
for seq_att_map in sequence_feature.seq_att_map:
|
|
174
|
+
for key in seq_att_map.key:
|
|
175
|
+
sequence_features.add(key)
|
|
176
|
+
for hist_seq in seq_att_map.hist_seq:
|
|
177
|
+
sequence_features.add(hist_seq)
|
|
178
|
+
excluded_features = excluded_features - sequence_features
|
|
179
|
+
|
|
180
|
+
feature_configs = []
|
|
181
|
+
for feature_config in config_util.get_compatible_feature_configs(config):
|
|
182
|
+
feature_name = feature_config.feature_name if feature_config.HasField('feature_name') \
|
|
183
|
+
else feature_config.input_names[0]
|
|
184
|
+
if feature_name not in excluded_features:
|
|
185
|
+
feature_configs.append(feature_config)
|
|
186
|
+
|
|
187
|
+
if config.feature_configs:
|
|
188
|
+
config.ClearField('feature_configs')
|
|
189
|
+
config.feature_configs.extend(feature_configs)
|
|
190
|
+
else:
|
|
191
|
+
config.feature_config.ClearField('features')
|
|
192
|
+
config.feature_config.features.extend(feature_configs)
|
|
193
|
+
|
|
194
|
+
for feature_group in config.model_config.feature_groups:
|
|
195
|
+
feature_names = []
|
|
196
|
+
for feature_name in feature_group.feature_names:
|
|
197
|
+
if feature_name not in excluded_features:
|
|
198
|
+
feature_names.append(feature_name)
|
|
199
|
+
feature_group.ClearField('feature_names')
|
|
200
|
+
feature_group.feature_names.extend(feature_names)
|
|
201
|
+
config_util.save_message(
|
|
202
|
+
config,
|
|
203
|
+
os.path.join(self._output_dir, os.path.basename(self._config_path)))
|
|
204
|
+
|
|
205
|
+
if self._fg_path is not None and len(self._fg_path) > 0:
|
|
206
|
+
with tf.gfile.Open(self._fg_path) as f:
|
|
207
|
+
fg_json = json.load(f, object_pairs_hook=OrderedDict)
|
|
208
|
+
features = []
|
|
209
|
+
for feature in fg_json['features']:
|
|
210
|
+
if 'feature_name' in feature:
|
|
211
|
+
if feature['feature_name'] not in excluded_features:
|
|
212
|
+
features.append(feature)
|
|
213
|
+
else:
|
|
214
|
+
features.append(feature)
|
|
215
|
+
fg_json['features'] = features
|
|
216
|
+
with tf.gfile.Open(
|
|
217
|
+
os.path.join(self._output_dir, os.path.basename(self._fg_path)),
|
|
218
|
+
'w') as f:
|
|
219
|
+
json.dump(fg_json, f, indent=4)
|
|
220
|
+
|
|
221
|
+
def _dump_to_csv(self, feature_importance, group_name):
|
|
222
|
+
"""Dump feature importance data to a csv file."""
|
|
223
|
+
with tf.gfile.Open(
|
|
224
|
+
os.path.join(self._output_dir,
|
|
225
|
+
'feature_dropout_ratio_%s.csv' % group_name), 'w') as f:
|
|
226
|
+
df = pd.DataFrame(
|
|
227
|
+
columns=['feature_name', 'mean_drop_p'],
|
|
228
|
+
data=[list(kv) for kv in feature_importance.items()])
|
|
229
|
+
df.to_csv(f, encoding='gbk')
|
|
230
|
+
|
|
231
|
+
def _visualize_embedding_dim_importance(self, feature_dim_dropout_p):
|
|
232
|
+
"""Visualize embedding-wise importance visualization for every feature."""
|
|
233
|
+
output_dir = os.path.join(self._output_dir, 'feature_dims_importance_pics')
|
|
234
|
+
if not tf.gfile.Exists(output_dir):
|
|
235
|
+
tf.gfile.MakeDirs(output_dir)
|
|
236
|
+
|
|
237
|
+
plt.rcdefaults()
|
|
238
|
+
for feature_name, feature_dropout_p in feature_dim_dropout_p.items():
|
|
239
|
+
embedding_len = len(feature_dropout_p)
|
|
240
|
+
embedding_dims = []
|
|
241
|
+
for i in range(embedding_len):
|
|
242
|
+
embedding_dims.append('dim_' + str(i + 1))
|
|
243
|
+
y_pos = np.arange(len(embedding_dims))
|
|
244
|
+
performance_list = []
|
|
245
|
+
for i in range(0, embedding_len):
|
|
246
|
+
performance_list.append(feature_dropout_p[i])
|
|
247
|
+
fig, ax = plt.subplots()
|
|
248
|
+
b = ax.barh(
|
|
249
|
+
y_pos,
|
|
250
|
+
performance_list,
|
|
251
|
+
align='center',
|
|
252
|
+
alpha=0.4,
|
|
253
|
+
label='dropout_rate',
|
|
254
|
+
lw=1)
|
|
255
|
+
for rect in b:
|
|
256
|
+
w = rect.get_width()
|
|
257
|
+
ax.text(
|
|
258
|
+
w,
|
|
259
|
+
rect.get_y() + rect.get_height() / 2,
|
|
260
|
+
'%.4f' % w,
|
|
261
|
+
ha='left',
|
|
262
|
+
va='center')
|
|
263
|
+
plt.yticks(y_pos, embedding_dims)
|
|
264
|
+
plt.xlabel(feature_name)
|
|
265
|
+
plt.title('Dropout ratio')
|
|
266
|
+
img_path = os.path.join(output_dir, feature_name + '.png')
|
|
267
|
+
with tf.gfile.GFile(img_path, 'wb') as f:
|
|
268
|
+
plt.savefig(f, format='png')
|
|
269
|
+
|
|
270
|
+
def _visualize_feature_importance(self, feature_importance, group_name):
|
|
271
|
+
"""Draw feature importance histogram."""
|
|
272
|
+
df = pd.DataFrame(
|
|
273
|
+
columns=['feature_name', 'mean_drop_p'],
|
|
274
|
+
data=[list(kv) for kv in feature_importance.items()])
|
|
275
|
+
df['color'] = ['red' if x < 0.5 else 'green' for x in df['mean_drop_p']]
|
|
276
|
+
df.sort_values('mean_drop_p', inplace=True, ascending=False)
|
|
277
|
+
df.reset_index(inplace=True)
|
|
278
|
+
# Draw plot
|
|
279
|
+
plt.figure(figsize=(90, 200), dpi=100)
|
|
280
|
+
plt.hlines(y=df.index, xmin=0, xmax=df.mean_drop_p)
|
|
281
|
+
for x, y, tex in zip(df.mean_drop_p, df.index, df.mean_drop_p):
|
|
282
|
+
plt.text(
|
|
283
|
+
x,
|
|
284
|
+
y,
|
|
285
|
+
round(tex, 2),
|
|
286
|
+
horizontalalignment='right' if x < 0 else 'left',
|
|
287
|
+
verticalalignment='center',
|
|
288
|
+
fontdict={
|
|
289
|
+
'color': 'red' if x < 0 else 'green',
|
|
290
|
+
'size': 14
|
|
291
|
+
})
|
|
292
|
+
# Decorations
|
|
293
|
+
plt.yticks(df.index, df.feature_name, fontsize=20)
|
|
294
|
+
plt.title('Dropout Ratio', fontdict={'size': 30})
|
|
295
|
+
plt.grid(linestyle='--', alpha=0.5)
|
|
296
|
+
plt.xlim(0, 1)
|
|
297
|
+
with tf.gfile.GFile(
|
|
298
|
+
os.path.join(self._output_dir,
|
|
299
|
+
'feature_dropout_pic_%s.png' % group_name), 'wb') as f:
|
|
300
|
+
plt.savefig(f, format='png')
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
if __name__ == '__main__':
|
|
304
|
+
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
|
|
305
|
+
if FLAGS.model_type == 'variational_dropout':
|
|
306
|
+
fs = VariationalDropoutFS(
|
|
307
|
+
FLAGS.config_path,
|
|
308
|
+
FLAGS.output_dir,
|
|
309
|
+
FLAGS.topk,
|
|
310
|
+
checkpoint_path=FLAGS.checkpoint_path,
|
|
311
|
+
fg_path=FLAGS.fg_path,
|
|
312
|
+
visualize=FLAGS.visualize)
|
|
313
|
+
fs.process()
|
|
314
|
+
else:
|
|
315
|
+
raise ValueError('Unknown feature selection model type %s' %
|
|
316
|
+
FLAGS.model_type)
|