easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
+
import copy
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
from tensorflow.core.framework import graph_pb2
|
|
9
|
+
from tensorflow.python.framework import importer
|
|
10
|
+
from tensorflow.python.framework import ops
|
|
11
|
+
from tensorflow.python.framework.dtypes import _TYPE_TO_STRING
|
|
12
|
+
from tensorflow.python.ops.resource_variable_ops import _from_proto_fn
|
|
13
|
+
from tensorflow.python.saved_model import signature_constants
|
|
14
|
+
from tensorflow.python.tools import saved_model_utils
|
|
15
|
+
from tensorflow.python.training import saver as tf_saver
|
|
16
|
+
|
|
17
|
+
from easy_rec.python.utils import io_util
|
|
18
|
+
|
|
19
|
+
if tf.__version__ >= '2.0':
|
|
20
|
+
tf = tf.compat.v1
|
|
21
|
+
from tensorflow.python.saved_model.path_helpers import get_variables_path
|
|
22
|
+
else:
|
|
23
|
+
from tensorflow.python.saved_model.utils_impl import get_variables_path
|
|
24
|
+
|
|
25
|
+
FLAGS = tf.app.flags.FLAGS
|
|
26
|
+
tf.app.flags.DEFINE_string('model_dir', '', '')
|
|
27
|
+
tf.app.flags.DEFINE_string('user_model_dir', '', '')
|
|
28
|
+
tf.app.flags.DEFINE_string('item_model_dir', '', '')
|
|
29
|
+
tf.app.flags.DEFINE_string('user_fg_json_path', '', '')
|
|
30
|
+
tf.app.flags.DEFINE_string('item_fg_json_path', '', '')
|
|
31
|
+
|
|
32
|
+
logging.basicConfig(
|
|
33
|
+
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def search_pb(directory):
|
|
37
|
+
dir_list = []
|
|
38
|
+
for root, dirs, files in tf.gfile.Walk(directory):
|
|
39
|
+
for f in files:
|
|
40
|
+
_, ext = os.path.splitext(f)
|
|
41
|
+
if ext == '.pb':
|
|
42
|
+
dir_list.append(root)
|
|
43
|
+
if len(dir_list) == 0:
|
|
44
|
+
raise ValueError('savedmodel is not found in directory %s' % directory)
|
|
45
|
+
elif len(dir_list) > 1:
|
|
46
|
+
raise ValueError('multiple saved model found in directory %s' % directory)
|
|
47
|
+
|
|
48
|
+
return dir_list[0]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _node_name(name):
|
|
52
|
+
if name.startswith('^'):
|
|
53
|
+
return name[1:]
|
|
54
|
+
else:
|
|
55
|
+
return name.split(':')[0]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def extract_sub_graph(graph_def, dest_nodes, variable_protos):
|
|
59
|
+
"""Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
graph_def: graph_pb2.GraphDef
|
|
63
|
+
dest_nodes: a list includes output node names
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
out: the GraphDef of the sub-graph.
|
|
67
|
+
variables_to_keep: variables to be kept for saver.
|
|
68
|
+
"""
|
|
69
|
+
if not isinstance(graph_def, graph_pb2.GraphDef):
|
|
70
|
+
raise TypeError('graph_def must be a graph_pb2.GraphDef proto.')
|
|
71
|
+
|
|
72
|
+
edges = {}
|
|
73
|
+
name_to_node_map = {}
|
|
74
|
+
node_seq = {}
|
|
75
|
+
seq = 0
|
|
76
|
+
nodes_to_keep = set()
|
|
77
|
+
variables_to_keep = set()
|
|
78
|
+
|
|
79
|
+
for node in graph_def.node:
|
|
80
|
+
n = _node_name(node.name)
|
|
81
|
+
name_to_node_map[n] = node
|
|
82
|
+
edges[n] = [_node_name(item) for item in node.input]
|
|
83
|
+
node_seq[n] = seq
|
|
84
|
+
seq += 1
|
|
85
|
+
for d in dest_nodes:
|
|
86
|
+
assert d in name_to_node_map, "'%s' is not in graph" % d
|
|
87
|
+
|
|
88
|
+
next_to_visit = dest_nodes[:]
|
|
89
|
+
while next_to_visit:
|
|
90
|
+
n = next_to_visit[0]
|
|
91
|
+
|
|
92
|
+
if n in variable_protos:
|
|
93
|
+
proto = variable_protos[n]
|
|
94
|
+
next_to_visit.append(_node_name(proto.initial_value_name))
|
|
95
|
+
next_to_visit.append(_node_name(proto.initializer_name))
|
|
96
|
+
next_to_visit.append(_node_name(proto.snapshot_name))
|
|
97
|
+
variables_to_keep.add(proto.variable_name)
|
|
98
|
+
|
|
99
|
+
del next_to_visit[0]
|
|
100
|
+
if n in nodes_to_keep:
|
|
101
|
+
continue
|
|
102
|
+
# make sure n is in edges
|
|
103
|
+
if n in edges:
|
|
104
|
+
nodes_to_keep.add(n)
|
|
105
|
+
next_to_visit += edges[n]
|
|
106
|
+
nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda n: node_seq[n])
|
|
107
|
+
|
|
108
|
+
out = graph_pb2.GraphDef()
|
|
109
|
+
for n in nodes_to_keep_list:
|
|
110
|
+
out.node.extend([copy.deepcopy(name_to_node_map[n])])
|
|
111
|
+
out.library.CopyFrom(graph_def.library)
|
|
112
|
+
out.versions.CopyFrom(graph_def.versions)
|
|
113
|
+
|
|
114
|
+
return out, variables_to_keep
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def load_meta_graph_def(model_dir):
|
|
118
|
+
"""Load meta graph def in saved model.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
model_dir: saved model directory.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
meta_graph_def: a MetaGraphDef.
|
|
125
|
+
variable_protos: a dict of VariableDef.
|
|
126
|
+
input_tensor_names: signature inputs in saved model.
|
|
127
|
+
output_tensor_names: signature outputs in saved model.
|
|
128
|
+
"""
|
|
129
|
+
input_tensor_names = {}
|
|
130
|
+
output_tensor_names = {}
|
|
131
|
+
variable_protos = {}
|
|
132
|
+
|
|
133
|
+
meta_graph_def = saved_model_utils.get_meta_graph_def(
|
|
134
|
+
model_dir, tf.saved_model.tag_constants.SERVING)
|
|
135
|
+
signatures = meta_graph_def.signature_def
|
|
136
|
+
collections = meta_graph_def.collection_def
|
|
137
|
+
|
|
138
|
+
# parse collection_def in SavedModel
|
|
139
|
+
for key, col_def in collections.items():
|
|
140
|
+
if key in ops.GraphKeys._VARIABLE_COLLECTIONS:
|
|
141
|
+
tf.logging.info('[Collection] %s:' % key)
|
|
142
|
+
for value in col_def.bytes_list.value:
|
|
143
|
+
proto_type = ops.get_collection_proto_type(key)
|
|
144
|
+
proto = proto_type()
|
|
145
|
+
proto.ParseFromString(value)
|
|
146
|
+
tf.logging.info('%s' % proto.variable_name)
|
|
147
|
+
variable_node_name = _node_name(proto.variable_name)
|
|
148
|
+
if variable_node_name not in variable_protos:
|
|
149
|
+
variable_protos[variable_node_name] = proto
|
|
150
|
+
|
|
151
|
+
# parse signature info for SavedModel
|
|
152
|
+
for sig_name in signatures:
|
|
153
|
+
if signatures[
|
|
154
|
+
sig_name].method_name == tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
|
|
155
|
+
tf.logging.info('[Signature] inputs:')
|
|
156
|
+
for input_name in signatures[sig_name].inputs:
|
|
157
|
+
input_tensor_shape = []
|
|
158
|
+
input_tensor = signatures[sig_name].inputs[input_name]
|
|
159
|
+
for dim in input_tensor.tensor_shape.dim:
|
|
160
|
+
input_tensor_shape.append(int(dim.size))
|
|
161
|
+
tf.logging.info('"%s": %s; %s' %
|
|
162
|
+
(input_name, _TYPE_TO_STRING[input_tensor.dtype],
|
|
163
|
+
input_tensor_shape))
|
|
164
|
+
input_tensor_names[input_name] = input_tensor.name
|
|
165
|
+
tf.logging.info('[Signature] outputs:')
|
|
166
|
+
for output_name in signatures[sig_name].outputs:
|
|
167
|
+
output_tensor_shape = []
|
|
168
|
+
output_tensor = signatures[sig_name].outputs[output_name]
|
|
169
|
+
for dim in output_tensor.tensor_shape.dim:
|
|
170
|
+
output_tensor_shape.append(int(dim.size))
|
|
171
|
+
tf.logging.info('"%s": %s; %s' %
|
|
172
|
+
(output_name, _TYPE_TO_STRING[output_tensor.dtype],
|
|
173
|
+
output_tensor_shape))
|
|
174
|
+
output_tensor_names[output_name] = output_tensor.name
|
|
175
|
+
|
|
176
|
+
return meta_graph_def, variable_protos, input_tensor_names, output_tensor_names
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def export(model_dir, meta_graph_def, variable_protos, input_tensor_names,
|
|
180
|
+
output_tensor_names, part_name, part_dir):
|
|
181
|
+
"""Export subpart saved model.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
model_dir: saved model directory.
|
|
185
|
+
meta_graph_def: a MetaGraphDef.
|
|
186
|
+
variable_protos: a dict of VariableDef.
|
|
187
|
+
input_tensor_names: signature inputs in saved model.
|
|
188
|
+
output_tensor_names: signature outputs in saved model.
|
|
189
|
+
part_name: subpart model name, user or item.
|
|
190
|
+
part_dir: subpart model export directory.
|
|
191
|
+
"""
|
|
192
|
+
output_tensor_names = {
|
|
193
|
+
x: output_tensor_names[x]
|
|
194
|
+
for x in output_tensor_names.keys()
|
|
195
|
+
if part_name in x
|
|
196
|
+
}
|
|
197
|
+
output_node_names = [
|
|
198
|
+
_node_name(output_tensor_names[x]) for x in output_tensor_names.keys()
|
|
199
|
+
]
|
|
200
|
+
|
|
201
|
+
inference_graph, variables_to_keep = extract_sub_graph(
|
|
202
|
+
meta_graph_def.graph_def, output_node_names, variable_protos)
|
|
203
|
+
|
|
204
|
+
tf.reset_default_graph()
|
|
205
|
+
with tf.Session() as sess:
|
|
206
|
+
with sess.graph.as_default():
|
|
207
|
+
graph = ops.get_default_graph()
|
|
208
|
+
importer.import_graph_def(inference_graph, name='')
|
|
209
|
+
for name in variables_to_keep:
|
|
210
|
+
variable = _from_proto_fn(variable_protos[name.split(':')[0]])
|
|
211
|
+
graph.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, variable)
|
|
212
|
+
saver = tf_saver.Saver()
|
|
213
|
+
saver.restore(sess, get_variables_path(model_dir))
|
|
214
|
+
|
|
215
|
+
builder = tf.saved_model.builder.SavedModelBuilder(part_dir)
|
|
216
|
+
signature_inputs = {}
|
|
217
|
+
for input_name in input_tensor_names:
|
|
218
|
+
try:
|
|
219
|
+
tensor_info = tf.saved_model.utils.build_tensor_info(
|
|
220
|
+
graph.get_tensor_by_name(input_tensor_names[input_name]))
|
|
221
|
+
signature_inputs[input_name] = tensor_info
|
|
222
|
+
except Exception:
|
|
223
|
+
print('ignore input: %s' % input_name)
|
|
224
|
+
|
|
225
|
+
signature_outputs = {}
|
|
226
|
+
for output_name in output_tensor_names:
|
|
227
|
+
tensor_info = tf.saved_model.utils.build_tensor_info(
|
|
228
|
+
graph.get_tensor_by_name(output_tensor_names[output_name]))
|
|
229
|
+
signature_outputs[output_name] = tensor_info
|
|
230
|
+
|
|
231
|
+
prediction_signature = (
|
|
232
|
+
tf.saved_model.signature_def_utils.build_signature_def(
|
|
233
|
+
inputs=signature_inputs,
|
|
234
|
+
outputs=signature_outputs,
|
|
235
|
+
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
|
|
236
|
+
))
|
|
237
|
+
|
|
238
|
+
builder.add_meta_graph_and_variables(
|
|
239
|
+
sess, [tf.saved_model.tag_constants.SERVING],
|
|
240
|
+
signature_def_map={
|
|
241
|
+
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
|
242
|
+
prediction_signature,
|
|
243
|
+
})
|
|
244
|
+
builder.save()
|
|
245
|
+
config_path = os.path.join(model_dir, 'assets/pipeline.config')
|
|
246
|
+
assert tf.gfile.Exists(config_path)
|
|
247
|
+
dst_path = os.path.join(part_dir, 'assets')
|
|
248
|
+
dst_config_path = os.path.join(dst_path, 'pipeline.config')
|
|
249
|
+
tf.gfile.MkDir(dst_path)
|
|
250
|
+
tf.gfile.Copy(config_path, dst_config_path)
|
|
251
|
+
if part_name == 'user' and FLAGS.user_fg_json_path:
|
|
252
|
+
dst_fg_path = os.path.join(dst_path, 'fg.json')
|
|
253
|
+
tf.gfile.Copy(FLAGS.user_fg_json_path, dst_fg_path)
|
|
254
|
+
if part_name == 'item' and FLAGS.item_fg_json_path:
|
|
255
|
+
dst_fg_path = os.path.join(dst_path, 'fg.json')
|
|
256
|
+
tf.gfile.Copy(FLAGS.item_fg_json_path, dst_fg_path)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def main(argv):
|
|
260
|
+
model_dir = search_pb(FLAGS.model_dir)
|
|
261
|
+
tf.logging.info('Loading meta graph...')
|
|
262
|
+
meta_graph_def, variable_protos, input_tensor_names, output_tensor_names = load_meta_graph_def(
|
|
263
|
+
model_dir)
|
|
264
|
+
tf.logging.info('Exporting user part model...')
|
|
265
|
+
export(
|
|
266
|
+
model_dir,
|
|
267
|
+
meta_graph_def,
|
|
268
|
+
variable_protos,
|
|
269
|
+
input_tensor_names,
|
|
270
|
+
output_tensor_names,
|
|
271
|
+
part_name='user',
|
|
272
|
+
part_dir=FLAGS.user_model_dir)
|
|
273
|
+
tf.logging.info('Exporting item part model...')
|
|
274
|
+
export(
|
|
275
|
+
model_dir,
|
|
276
|
+
meta_graph_def,
|
|
277
|
+
variable_protos,
|
|
278
|
+
input_tensor_names,
|
|
279
|
+
output_tensor_names,
|
|
280
|
+
part_name='item',
|
|
281
|
+
part_dir=FLAGS.item_model_dir)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
if __name__ == '__main__':
|
|
285
|
+
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
|
|
286
|
+
tf.app.run()
|
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
+
import copy
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
from tensorflow.core.framework import graph_pb2
|
|
9
|
+
from tensorflow.python.framework import importer
|
|
10
|
+
from tensorflow.python.framework import ops
|
|
11
|
+
from tensorflow.python.framework.dtypes import _TYPE_TO_STRING
|
|
12
|
+
from tensorflow.python.saved_model import signature_constants
|
|
13
|
+
from tensorflow.python.saved_model.utils_impl import get_variables_path
|
|
14
|
+
from tensorflow.python.tools import saved_model_utils
|
|
15
|
+
from tensorflow.python.training import saver as tf_saver
|
|
16
|
+
|
|
17
|
+
from easy_rec.python.utils import io_util
|
|
18
|
+
|
|
19
|
+
FLAGS = tf.app.flags.FLAGS
|
|
20
|
+
tf.app.flags.DEFINE_string('model_dir', '', '')
|
|
21
|
+
tf.app.flags.DEFINE_string('trigger_model_dir', '', '')
|
|
22
|
+
tf.app.flags.DEFINE_string('sim_model_dir', '', '')
|
|
23
|
+
|
|
24
|
+
logging.basicConfig(
|
|
25
|
+
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def search_pb(directory):
|
|
29
|
+
dir_list = []
|
|
30
|
+
for root, dirs, files in tf.gfile.Walk(directory):
|
|
31
|
+
for f in files:
|
|
32
|
+
_, ext = os.path.splitext(f)
|
|
33
|
+
if ext == '.pb':
|
|
34
|
+
dir_list.append(root)
|
|
35
|
+
if len(dir_list) == 0:
|
|
36
|
+
raise ValueError('savedmodel is not found in directory %s' % directory)
|
|
37
|
+
elif len(dir_list) > 1:
|
|
38
|
+
raise ValueError('multiple saved model found in directory %s' % directory)
|
|
39
|
+
|
|
40
|
+
return dir_list[0]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _node_name(name):
|
|
44
|
+
if name.startswith('^'):
|
|
45
|
+
return name[1:]
|
|
46
|
+
else:
|
|
47
|
+
return name.split(':')[0]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def extract_sub_graph(graph_def, dest_nodes, variable_protos):
|
|
51
|
+
"""Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
graph_def: graph_pb2.GraphDef
|
|
55
|
+
dest_nodes: a list includes output node names
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
out: the GraphDef of the sub-graph.
|
|
59
|
+
variables_to_keep: variables to be kept for saver.
|
|
60
|
+
"""
|
|
61
|
+
if not isinstance(graph_def, graph_pb2.GraphDef):
|
|
62
|
+
raise TypeError('graph_def must be a graph_pb2.GraphDef proto.')
|
|
63
|
+
|
|
64
|
+
edges = {}
|
|
65
|
+
name_to_node_map = {}
|
|
66
|
+
node_seq = {}
|
|
67
|
+
seq = 0
|
|
68
|
+
nodes_to_keep = set()
|
|
69
|
+
variables_to_keep = set()
|
|
70
|
+
|
|
71
|
+
for node in graph_def.node:
|
|
72
|
+
n = _node_name(node.name)
|
|
73
|
+
name_to_node_map[n] = node
|
|
74
|
+
edges[n] = [_node_name(item) for item in node.input]
|
|
75
|
+
node_seq[n] = seq
|
|
76
|
+
seq += 1
|
|
77
|
+
for d in dest_nodes:
|
|
78
|
+
assert d in name_to_node_map, "'%s' is not in graph" % d
|
|
79
|
+
|
|
80
|
+
next_to_visit = dest_nodes[:]
|
|
81
|
+
while next_to_visit:
|
|
82
|
+
n = next_to_visit[0]
|
|
83
|
+
|
|
84
|
+
if n in variable_protos:
|
|
85
|
+
proto = variable_protos[n]
|
|
86
|
+
next_to_visit.append(_node_name(proto.initial_value_name))
|
|
87
|
+
next_to_visit.append(_node_name(proto.initializer_name))
|
|
88
|
+
next_to_visit.append(_node_name(proto.snapshot_name))
|
|
89
|
+
variables_to_keep.add(proto.variable_name)
|
|
90
|
+
|
|
91
|
+
del next_to_visit[0]
|
|
92
|
+
if n in nodes_to_keep:
|
|
93
|
+
continue
|
|
94
|
+
# make sure n is in edges
|
|
95
|
+
if n in edges:
|
|
96
|
+
nodes_to_keep.add(n)
|
|
97
|
+
next_to_visit += edges[n]
|
|
98
|
+
nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda n: node_seq[n])
|
|
99
|
+
|
|
100
|
+
out = graph_pb2.GraphDef()
|
|
101
|
+
for n in nodes_to_keep_list:
|
|
102
|
+
out.node.extend([copy.deepcopy(name_to_node_map[n])])
|
|
103
|
+
out.library.CopyFrom(graph_def.library)
|
|
104
|
+
out.versions.CopyFrom(graph_def.versions)
|
|
105
|
+
|
|
106
|
+
return out, variables_to_keep
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def load_meta_graph_def(model_dir):
|
|
110
|
+
"""Load meta graph def in saved model.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
model_dir: saved model directory.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
meta_graph_def: a MetaGraphDef.
|
|
117
|
+
variable_protos: a dict of VariableDef.
|
|
118
|
+
input_tensor_names: signature inputs in saved model.
|
|
119
|
+
output_tensor_names: signature outputs in saved model.
|
|
120
|
+
"""
|
|
121
|
+
input_tensor_names = {}
|
|
122
|
+
output_tensor_names = {}
|
|
123
|
+
variable_protos = {}
|
|
124
|
+
|
|
125
|
+
meta_graph_def = saved_model_utils.get_meta_graph_def(
|
|
126
|
+
model_dir, tf.saved_model.tag_constants.SERVING)
|
|
127
|
+
signatures = meta_graph_def.signature_def
|
|
128
|
+
collections = meta_graph_def.collection_def
|
|
129
|
+
|
|
130
|
+
# parse collection_def in SavedModel
|
|
131
|
+
for key, col_def in collections.items():
|
|
132
|
+
if key in ops.GraphKeys._VARIABLE_COLLECTIONS:
|
|
133
|
+
tf.logging.info('[Collection] %s:' % key)
|
|
134
|
+
for value in col_def.bytes_list.value:
|
|
135
|
+
proto_type = ops.get_collection_proto_type(key)
|
|
136
|
+
proto = proto_type()
|
|
137
|
+
proto.ParseFromString(value)
|
|
138
|
+
tf.logging.info('%s' % proto.variable_name)
|
|
139
|
+
variable_node_name = _node_name(proto.variable_name)
|
|
140
|
+
if variable_node_name not in variable_protos:
|
|
141
|
+
variable_protos[variable_node_name] = proto
|
|
142
|
+
|
|
143
|
+
# parse signature info for SavedModel
|
|
144
|
+
for sig_name in signatures:
|
|
145
|
+
if signatures[
|
|
146
|
+
sig_name].method_name == tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
|
|
147
|
+
tf.logging.info('[Signature] inputs:')
|
|
148
|
+
for input_name in signatures[sig_name].inputs:
|
|
149
|
+
input_tensor_shape = []
|
|
150
|
+
input_tensor = signatures[sig_name].inputs[input_name]
|
|
151
|
+
for dim in input_tensor.tensor_shape.dim:
|
|
152
|
+
input_tensor_shape.append(int(dim.size))
|
|
153
|
+
tf.logging.info('"%s": %s; %s' %
|
|
154
|
+
(input_name, _TYPE_TO_STRING[input_tensor.dtype],
|
|
155
|
+
input_tensor_shape))
|
|
156
|
+
input_tensor_names[input_name] = input_tensor.name
|
|
157
|
+
tf.logging.info('[Signature] outputs:')
|
|
158
|
+
for output_name in signatures[sig_name].outputs:
|
|
159
|
+
output_tensor_shape = []
|
|
160
|
+
output_tensor = signatures[sig_name].outputs[output_name]
|
|
161
|
+
for dim in output_tensor.tensor_shape.dim:
|
|
162
|
+
output_tensor_shape.append(int(dim.size))
|
|
163
|
+
tf.logging.info('"%s": %s; %s' %
|
|
164
|
+
(output_name, _TYPE_TO_STRING[output_tensor.dtype],
|
|
165
|
+
output_tensor_shape))
|
|
166
|
+
output_tensor_names[output_name] = output_tensor.name
|
|
167
|
+
|
|
168
|
+
return meta_graph_def, variable_protos, input_tensor_names, output_tensor_names
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def export(model_dir, meta_graph_def, variable_protos, input_tensor_names,
|
|
172
|
+
output_tensor_names, part_name, part_dir):
|
|
173
|
+
"""Export subpart saved model.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
model_dir: saved model directory.
|
|
177
|
+
meta_graph_def: a MetaGraphDef.
|
|
178
|
+
variable_protos: a dict of VariableDef.
|
|
179
|
+
input_tensor_names: signature inputs in saved model.
|
|
180
|
+
output_tensor_names: signature outputs in saved model.
|
|
181
|
+
part_name: subpart model name, user or item.
|
|
182
|
+
part_dir: subpart model export directory.
|
|
183
|
+
"""
|
|
184
|
+
output_tensor_names = {
|
|
185
|
+
x: output_tensor_names[x]
|
|
186
|
+
for x in output_tensor_names.keys()
|
|
187
|
+
if part_name in x
|
|
188
|
+
}
|
|
189
|
+
output_node_names = [
|
|
190
|
+
_node_name(output_tensor_names[x]) for x in output_tensor_names.keys()
|
|
191
|
+
]
|
|
192
|
+
|
|
193
|
+
inference_graph, variables_to_keep = extract_sub_graph(
|
|
194
|
+
meta_graph_def.graph_def, output_node_names, variable_protos)
|
|
195
|
+
|
|
196
|
+
tf.reset_default_graph()
|
|
197
|
+
with tf.Session() as sess:
|
|
198
|
+
with sess.graph.as_default():
|
|
199
|
+
graph = ops.get_default_graph()
|
|
200
|
+
importer.import_graph_def(inference_graph, name='')
|
|
201
|
+
for name in variables_to_keep:
|
|
202
|
+
variable = graph.get_tensor_by_name(name)
|
|
203
|
+
graph.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, variable)
|
|
204
|
+
saver = tf_saver.Saver()
|
|
205
|
+
saver.restore(sess, get_variables_path(model_dir))
|
|
206
|
+
|
|
207
|
+
builder = tf.saved_model.builder.SavedModelBuilder(part_dir)
|
|
208
|
+
signature_inputs = {}
|
|
209
|
+
for input_name in input_tensor_names:
|
|
210
|
+
try:
|
|
211
|
+
tensor_info = tf.saved_model.utils.build_tensor_info(
|
|
212
|
+
graph.get_tensor_by_name(input_tensor_names[input_name]))
|
|
213
|
+
signature_inputs[input_name] = tensor_info
|
|
214
|
+
except Exception:
|
|
215
|
+
print('ignore input: %s' % input_name)
|
|
216
|
+
|
|
217
|
+
signature_outputs = {}
|
|
218
|
+
for output_name in output_tensor_names:
|
|
219
|
+
tensor_info = tf.saved_model.utils.build_tensor_info(
|
|
220
|
+
graph.get_tensor_by_name(output_tensor_names[output_name]))
|
|
221
|
+
signature_outputs[output_name] = tensor_info
|
|
222
|
+
|
|
223
|
+
prediction_signature = (
|
|
224
|
+
tf.saved_model.signature_def_utils.build_signature_def(
|
|
225
|
+
inputs=signature_inputs,
|
|
226
|
+
outputs=signature_outputs,
|
|
227
|
+
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
|
|
228
|
+
))
|
|
229
|
+
|
|
230
|
+
builder.add_meta_graph_and_variables(
|
|
231
|
+
sess, [tf.saved_model.tag_constants.SERVING],
|
|
232
|
+
signature_def_map={
|
|
233
|
+
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
|
234
|
+
prediction_signature,
|
|
235
|
+
})
|
|
236
|
+
builder.save()
|
|
237
|
+
config_path = os.path.join(model_dir, 'assets/pipeline.config')
|
|
238
|
+
assert tf.gfile.Exists(config_path)
|
|
239
|
+
dst_path = os.path.join(part_dir, 'assets')
|
|
240
|
+
dst_config_path = os.path.join(dst_path, 'pipeline.config')
|
|
241
|
+
tf.gfile.MkDir(dst_path)
|
|
242
|
+
tf.gfile.Copy(config_path, dst_config_path)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def main(argv):
|
|
246
|
+
model_dir = search_pb(FLAGS.model_dir)
|
|
247
|
+
tf.logging.info('Loading meta graph...')
|
|
248
|
+
meta_graph_def, variable_protos, input_tensor_names, output_tensor_names = load_meta_graph_def(
|
|
249
|
+
model_dir)
|
|
250
|
+
tf.logging.info('Exporting trigger part model...')
|
|
251
|
+
export(
|
|
252
|
+
model_dir,
|
|
253
|
+
meta_graph_def,
|
|
254
|
+
variable_protos,
|
|
255
|
+
input_tensor_names,
|
|
256
|
+
output_tensor_names,
|
|
257
|
+
part_name='trigger_out',
|
|
258
|
+
part_dir=FLAGS.trigger_model_dir)
|
|
259
|
+
tf.logging.info('Exporting sim part model...')
|
|
260
|
+
export(
|
|
261
|
+
model_dir,
|
|
262
|
+
meta_graph_def,
|
|
263
|
+
variable_protos,
|
|
264
|
+
input_tensor_names,
|
|
265
|
+
output_tensor_names,
|
|
266
|
+
part_name='sim_out',
|
|
267
|
+
part_dir=FLAGS.sim_model_dir)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
if __name__ == '__main__':
|
|
271
|
+
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
|
|
272
|
+
tf.app.run()
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import argparse
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import tensorflow as tf
|
|
10
|
+
|
|
11
|
+
import easy_rec
|
|
12
|
+
from easy_rec.python.inference.predictor import Predictor
|
|
13
|
+
|
|
14
|
+
logging.basicConfig(
|
|
15
|
+
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
|
|
16
|
+
level=logging.INFO)
|
|
17
|
+
|
|
18
|
+
lookup_op_path = os.path.join(easy_rec.ops_dir, 'libkv_lookup.so')
|
|
19
|
+
lookup_op = tf.load_op_library(lookup_op_path)
|
|
20
|
+
|
|
21
|
+
if __name__ == '__main__':
|
|
22
|
+
"""Test saved model, an example:
|
|
23
|
+
|
|
24
|
+
python -m easy_rec.python.tools.test_saved_model
|
|
25
|
+
--saved_model_dir after_edit_save
|
|
26
|
+
--input_path data/test/rtp/xys_cxr_fg_sample_test2_with_lbl.txt
|
|
27
|
+
--with_lbl
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
parser = argparse.ArgumentParser()
|
|
31
|
+
parser.add_argument(
|
|
32
|
+
'--saved_model_dir', type=str, default=None, help='saved model dir')
|
|
33
|
+
parser.add_argument('--input_path', type=str, default=None, help='output dir')
|
|
34
|
+
parser.add_argument('--save_path', type=str, default=None, help='save path')
|
|
35
|
+
parser.add_argument('--separator', type=str, default=',', help='separator')
|
|
36
|
+
parser.add_argument(
|
|
37
|
+
'--cmp_res_path', type=str, default=None, help='compare result path')
|
|
38
|
+
parser.add_argument(
|
|
39
|
+
'--cmp_key', type=str, default='probs', help='compare key')
|
|
40
|
+
parser.add_argument('--tol', type=float, default=1e-5, help='tolerance')
|
|
41
|
+
parser.add_argument(
|
|
42
|
+
'--with_lbl',
|
|
43
|
+
action='store_true',
|
|
44
|
+
default=False,
|
|
45
|
+
help='whether the test data has label field')
|
|
46
|
+
args = parser.parse_args()
|
|
47
|
+
|
|
48
|
+
logging.info('saved_model_dir: %s' % args.saved_model_dir)
|
|
49
|
+
logging.info('test_data_path: %s' % args.input_path)
|
|
50
|
+
logging.info('test_data has lbl: %s' % args.with_lbl)
|
|
51
|
+
|
|
52
|
+
predictor = Predictor(args.saved_model_dir)
|
|
53
|
+
with open(args.input_path, 'r') as fin:
|
|
54
|
+
feature_vals = []
|
|
55
|
+
for line_str in fin:
|
|
56
|
+
line_str = line_str.strip()
|
|
57
|
+
line_toks = line_str.split(args.separator)
|
|
58
|
+
if args.with_lbl:
|
|
59
|
+
line_toks = line_toks[1:]
|
|
60
|
+
feature_vals.append(args.separator.join(line_toks))
|
|
61
|
+
output = predictor.predict(feature_vals, batch_size=4096)
|
|
62
|
+
|
|
63
|
+
if args.save_path:
|
|
64
|
+
with open(args.save_path, 'w') as fout:
|
|
65
|
+
for one in output:
|
|
66
|
+
fout.write(str(one) + '\n')
|
|
67
|
+
|
|
68
|
+
if args.cmp_res_path:
|
|
69
|
+
logging.info('compare result path: ' + args.cmp_res_path)
|
|
70
|
+
logging.info('compare key: ' + args.cmp_key)
|
|
71
|
+
logging.info('tolerance: ' + str(args.tol))
|
|
72
|
+
with open(args.cmp_res_path, 'r') as fin:
|
|
73
|
+
for line_id, line_str in enumerate(fin):
|
|
74
|
+
line_str = line_str.strip()
|
|
75
|
+
line_pred = json.loads(line_str)
|
|
76
|
+
assert np.abs(
|
|
77
|
+
line_pred[args.cmp_key] -
|
|
78
|
+
output[line_id][args.cmp_key]) < args.tol, 'line[%d]: %.8f' % (
|
|
79
|
+
line_id,
|
|
80
|
+
np.abs(line_pred[args.cmp_key] - output[line_id][args.cmp_key]))
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import argparse
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
from google.protobuf import text_format
|
|
7
|
+
from tensorflow.core.protobuf import saved_model_pb2
|
|
8
|
+
from tensorflow.python.platform.gfile import GFile
|
|
9
|
+
|
|
10
|
+
logging.basicConfig(
|
|
11
|
+
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
|
|
12
|
+
level=logging.INFO)
|
|
13
|
+
|
|
14
|
+
if __name__ == '__main__':
|
|
15
|
+
parser = argparse.ArgumentParser()
|
|
16
|
+
parser.add_argument(
|
|
17
|
+
'--input', type=str, default=None, help='saved model path')
|
|
18
|
+
parser.add_argument(
|
|
19
|
+
'--output', type=str, default=None, help='saved model save path')
|
|
20
|
+
args = parser.parse_args()
|
|
21
|
+
|
|
22
|
+
assert args.input is not None and args.output is not None
|
|
23
|
+
|
|
24
|
+
logging.info('saved_model_path: %s' % args.input)
|
|
25
|
+
|
|
26
|
+
saved_model = saved_model_pb2.SavedModel()
|
|
27
|
+
if args.input.endswith('.pb'):
|
|
28
|
+
with GFile(args.input, 'rb') as fin:
|
|
29
|
+
saved_model.ParseFromString(fin.read())
|
|
30
|
+
else:
|
|
31
|
+
with GFile(args.input, 'r') as fin:
|
|
32
|
+
text_format.Merge(fin.read(), saved_model)
|
|
33
|
+
|
|
34
|
+
if args.output.endswith('.pbtxt'):
|
|
35
|
+
with GFile(args.output, 'w') as fout:
|
|
36
|
+
fout.write(text_format.MessageToString(saved_model, as_utf8=True))
|
|
37
|
+
else:
|
|
38
|
+
with GFile(args.output, 'wb') as fout:
|
|
39
|
+
fout.write(saved_model.SerializeToString())
|