easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
easy_rec/python/main.py
ADDED
|
@@ -0,0 +1,878 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
|
|
4
|
+
from __future__ import absolute_import
|
|
5
|
+
from __future__ import division
|
|
6
|
+
from __future__ import print_function
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
import math
|
|
11
|
+
import os
|
|
12
|
+
import time
|
|
13
|
+
|
|
14
|
+
import six
|
|
15
|
+
import tensorflow as tf
|
|
16
|
+
from tensorflow.core.protobuf import saved_model_pb2
|
|
17
|
+
|
|
18
|
+
import easy_rec
|
|
19
|
+
from easy_rec.python.builders import strategy_builder
|
|
20
|
+
from easy_rec.python.compat import estimator_train
|
|
21
|
+
from easy_rec.python.compat import exporter
|
|
22
|
+
from easy_rec.python.input.input import Input
|
|
23
|
+
from easy_rec.python.model.easy_rec_estimator import EasyRecEstimator
|
|
24
|
+
from easy_rec.python.model.easy_rec_model import EasyRecModel
|
|
25
|
+
from easy_rec.python.protos.train_pb2 import DistributionStrategy
|
|
26
|
+
from easy_rec.python.utils import config_util
|
|
27
|
+
from easy_rec.python.utils import constant
|
|
28
|
+
from easy_rec.python.utils import estimator_utils
|
|
29
|
+
from easy_rec.python.utils import fg_util
|
|
30
|
+
from easy_rec.python.utils import load_class
|
|
31
|
+
from easy_rec.python.utils.config_util import get_eval_input_path
|
|
32
|
+
from easy_rec.python.utils.config_util import get_model_dir_path
|
|
33
|
+
from easy_rec.python.utils.config_util import get_train_input_path
|
|
34
|
+
from easy_rec.python.utils.config_util import set_eval_input_path
|
|
35
|
+
from easy_rec.python.utils.export_big_model import export_big_model
|
|
36
|
+
from easy_rec.python.utils.export_big_model import export_big_model_to_oss
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
import horovod.tensorflow as hvd
|
|
40
|
+
except Exception:
|
|
41
|
+
hvd = None
|
|
42
|
+
|
|
43
|
+
if tf.__version__ >= '2.0':
|
|
44
|
+
from tensorflow.core.protobuf import config_pb2
|
|
45
|
+
|
|
46
|
+
ConfigProto = config_pb2.ConfigProto
|
|
47
|
+
GPUOptions = config_pb2.GPUOptions
|
|
48
|
+
|
|
49
|
+
tf = tf.compat.v1
|
|
50
|
+
else:
|
|
51
|
+
GPUOptions = tf.GPUOptions
|
|
52
|
+
ConfigProto = tf.ConfigProto
|
|
53
|
+
|
|
54
|
+
load_class.auto_import()
|
|
55
|
+
|
|
56
|
+
# when version of tensorflow > 1.8 strip_default_attrs set true will cause
|
|
57
|
+
# saved_model inference core, such as:
|
|
58
|
+
# [libprotobuf FATAL external/protobuf_archive/src/google/protobuf/map.h:1058]
|
|
59
|
+
# CHECK failed: it != end(): key not found: new_axis_mask
|
|
60
|
+
# so temporarily modify strip_default_attrs of _SavedModelExporter in
|
|
61
|
+
# tf.estimator.exporter to false by default
|
|
62
|
+
FinalExporter = exporter.FinalExporter
|
|
63
|
+
LatestExporter = exporter.LatestExporter
|
|
64
|
+
BestExporter = exporter.BestExporter
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _get_input_fn(data_config,
|
|
68
|
+
feature_configs,
|
|
69
|
+
data_path=None,
|
|
70
|
+
export_config=None,
|
|
71
|
+
check_mode=False,
|
|
72
|
+
**kwargs):
|
|
73
|
+
"""Build estimator input function.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
data_config: dataset config
|
|
77
|
+
feature_configs: FeatureConfig
|
|
78
|
+
data_path: input_data_path
|
|
79
|
+
export_config: configuration for exporting models,
|
|
80
|
+
only used to build input_fn when exporting models
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
subclass of Input
|
|
84
|
+
"""
|
|
85
|
+
input_class_map = {y: x for x, y in data_config.InputType.items()}
|
|
86
|
+
input_cls_name = input_class_map[data_config.input_type]
|
|
87
|
+
input_class = Input.create_class(input_cls_name)
|
|
88
|
+
|
|
89
|
+
task_id, task_num = estimator_utils.get_task_index_and_num()
|
|
90
|
+
input_obj = input_class(
|
|
91
|
+
data_config,
|
|
92
|
+
feature_configs,
|
|
93
|
+
data_path,
|
|
94
|
+
task_index=task_id,
|
|
95
|
+
task_num=task_num,
|
|
96
|
+
check_mode=check_mode,
|
|
97
|
+
**kwargs)
|
|
98
|
+
input_fn = input_obj.create_input(export_config)
|
|
99
|
+
return input_fn
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _create_estimator(pipeline_config, distribution=None, params={}):
|
|
103
|
+
model_config = pipeline_config.model_config
|
|
104
|
+
train_config = pipeline_config.train_config
|
|
105
|
+
gpu_options = GPUOptions(allow_growth=True) # False)
|
|
106
|
+
|
|
107
|
+
logging.info(
|
|
108
|
+
'train_config.train_distribute=%s[value=%d]' %
|
|
109
|
+
(DistributionStrategy.Name(pipeline_config.train_config.train_distribute),
|
|
110
|
+
pipeline_config.train_config.train_distribute))
|
|
111
|
+
|
|
112
|
+
# set gpu options only under hvd scenes
|
|
113
|
+
if hvd is not None and pipeline_config.train_config.train_distribute in [
|
|
114
|
+
DistributionStrategy.EmbeddingParallelStrategy,
|
|
115
|
+
DistributionStrategy.SokStrategy, DistributionStrategy.HorovodStrategy
|
|
116
|
+
]:
|
|
117
|
+
local_rnk = hvd.local_rank()
|
|
118
|
+
gpus = tf.config.experimental.list_physical_devices('GPU')
|
|
119
|
+
logging.info('local_rnk=%d num_gpus=%d' % (local_rnk, len(gpus)))
|
|
120
|
+
if len(gpus) > 0:
|
|
121
|
+
tf.config.experimental.set_visible_devices(gpus[local_rnk], 'GPU')
|
|
122
|
+
gpu_options.visible_device_list = str(local_rnk)
|
|
123
|
+
|
|
124
|
+
session_config = ConfigProto(
|
|
125
|
+
gpu_options=gpu_options,
|
|
126
|
+
allow_soft_placement=True,
|
|
127
|
+
log_device_placement=params.get('log_device_placement', False),
|
|
128
|
+
inter_op_parallelism_threads=train_config.inter_op_parallelism_threads,
|
|
129
|
+
intra_op_parallelism_threads=train_config.intra_op_parallelism_threads)
|
|
130
|
+
|
|
131
|
+
if constant.NO_ARITHMETRIC_OPTI in os.environ:
|
|
132
|
+
logging.info('arithmetic_optimization is closed to improve performance')
|
|
133
|
+
session_config.graph_options.rewrite_options.arithmetic_optimization = \
|
|
134
|
+
session_config.graph_options.rewrite_options.OFF
|
|
135
|
+
|
|
136
|
+
session_config.device_filters.append('/job:ps')
|
|
137
|
+
model_cls = EasyRecModel.create_class(model_config.model_class)
|
|
138
|
+
|
|
139
|
+
save_checkpoints_steps = None
|
|
140
|
+
save_checkpoints_secs = None
|
|
141
|
+
if train_config.HasField('save_checkpoints_steps'):
|
|
142
|
+
save_checkpoints_steps = train_config.save_checkpoints_steps
|
|
143
|
+
if train_config.HasField('save_checkpoints_secs'):
|
|
144
|
+
save_checkpoints_secs = train_config.save_checkpoints_secs
|
|
145
|
+
# if both `save_checkpoints_steps` and `save_checkpoints_secs` are not set,
|
|
146
|
+
# use the default value of save_checkpoints_steps
|
|
147
|
+
if save_checkpoints_steps is None and save_checkpoints_secs is None:
|
|
148
|
+
save_checkpoints_steps = train_config.save_checkpoints_steps
|
|
149
|
+
|
|
150
|
+
run_config = tf.estimator.RunConfig(
|
|
151
|
+
model_dir=pipeline_config.model_dir,
|
|
152
|
+
log_step_count_steps=None, # train_config.log_step_count_steps,
|
|
153
|
+
save_summary_steps=train_config.save_summary_steps,
|
|
154
|
+
save_checkpoints_steps=save_checkpoints_steps,
|
|
155
|
+
save_checkpoints_secs=save_checkpoints_secs,
|
|
156
|
+
keep_checkpoint_max=train_config.keep_checkpoint_max,
|
|
157
|
+
train_distribute=distribution,
|
|
158
|
+
eval_distribute=distribution,
|
|
159
|
+
session_config=session_config)
|
|
160
|
+
|
|
161
|
+
estimator = EasyRecEstimator(
|
|
162
|
+
pipeline_config, model_cls, run_config=run_config, params=params)
|
|
163
|
+
return estimator, run_config
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _create_eval_export_spec(pipeline_config, eval_data, check_mode=False):
|
|
167
|
+
data_config = pipeline_config.data_config
|
|
168
|
+
# feature_configs = pipeline_config.feature_configs
|
|
169
|
+
feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
|
|
170
|
+
eval_config = pipeline_config.eval_config
|
|
171
|
+
export_config = pipeline_config.export_config
|
|
172
|
+
if eval_config.num_examples > 0:
|
|
173
|
+
eval_steps = int(
|
|
174
|
+
math.ceil(float(eval_config.num_examples) / data_config.batch_size))
|
|
175
|
+
logging.info('eval_steps = %d' % eval_steps)
|
|
176
|
+
else:
|
|
177
|
+
eval_steps = None
|
|
178
|
+
input_fn_kwargs = {'pipeline_config': pipeline_config}
|
|
179
|
+
if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
|
|
180
|
+
input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
|
|
181
|
+
# create eval input
|
|
182
|
+
export_input_fn = _get_input_fn(
|
|
183
|
+
data_config,
|
|
184
|
+
feature_configs,
|
|
185
|
+
None,
|
|
186
|
+
export_config,
|
|
187
|
+
check_mode=check_mode,
|
|
188
|
+
**input_fn_kwargs)
|
|
189
|
+
if export_config.exporter_type == 'final':
|
|
190
|
+
exporters = [
|
|
191
|
+
FinalExporter(name='final', serving_input_receiver_fn=export_input_fn)
|
|
192
|
+
]
|
|
193
|
+
elif export_config.exporter_type == 'latest':
|
|
194
|
+
exporters = [
|
|
195
|
+
LatestExporter(
|
|
196
|
+
name='latest',
|
|
197
|
+
serving_input_receiver_fn=export_input_fn,
|
|
198
|
+
exports_to_keep=export_config.exports_to_keep)
|
|
199
|
+
]
|
|
200
|
+
elif export_config.exporter_type == 'best':
|
|
201
|
+
logging.info(
|
|
202
|
+
'will use BestExporter, metric is %s, the bigger the better: %d' %
|
|
203
|
+
(export_config.best_exporter_metric, export_config.metric_bigger))
|
|
204
|
+
|
|
205
|
+
def _metric_cmp_fn(best_eval_result, current_eval_result):
|
|
206
|
+
logging.info('metric: best = %s current = %s' %
|
|
207
|
+
(str(best_eval_result), str(current_eval_result)))
|
|
208
|
+
if export_config.metric_bigger:
|
|
209
|
+
return (best_eval_result[export_config.best_exporter_metric] <
|
|
210
|
+
current_eval_result[export_config.best_exporter_metric])
|
|
211
|
+
else:
|
|
212
|
+
return (best_eval_result[export_config.best_exporter_metric] >
|
|
213
|
+
current_eval_result[export_config.best_exporter_metric])
|
|
214
|
+
|
|
215
|
+
exporters = [
|
|
216
|
+
BestExporter(
|
|
217
|
+
name='best',
|
|
218
|
+
serving_input_receiver_fn=export_input_fn,
|
|
219
|
+
compare_fn=_metric_cmp_fn,
|
|
220
|
+
exports_to_keep=export_config.exports_to_keep)
|
|
221
|
+
]
|
|
222
|
+
elif export_config.exporter_type == 'none':
|
|
223
|
+
exporters = []
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError('Unknown exporter type %s' % export_config.exporter_type)
|
|
226
|
+
|
|
227
|
+
# set throttle_secs to a small number, so that we can control evaluation
|
|
228
|
+
# interval steps by checkpoint saving steps
|
|
229
|
+
eval_input_fn = _get_input_fn(data_config, feature_configs, eval_data,
|
|
230
|
+
**input_fn_kwargs)
|
|
231
|
+
eval_spec = tf.estimator.EvalSpec(
|
|
232
|
+
name='val',
|
|
233
|
+
input_fn=eval_input_fn,
|
|
234
|
+
steps=eval_steps,
|
|
235
|
+
throttle_secs=10,
|
|
236
|
+
exporters=exporters)
|
|
237
|
+
return eval_spec
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _check_model_dir(model_dir, continue_train):
|
|
241
|
+
if not continue_train:
|
|
242
|
+
if not tf.gfile.IsDirectory(model_dir):
|
|
243
|
+
tf.gfile.MakeDirs(model_dir)
|
|
244
|
+
else:
|
|
245
|
+
assert len(tf.gfile.Glob(model_dir + '/model.ckpt-*.meta')) == 0, \
|
|
246
|
+
'model_dir[=%s] already exists and not empty(if you ' \
|
|
247
|
+
'want to continue train on current model_dir please ' \
|
|
248
|
+
'delete dir %s or specify --continue_train[internal use only])' % (
|
|
249
|
+
model_dir, model_dir)
|
|
250
|
+
else:
|
|
251
|
+
if not tf.gfile.IsDirectory(model_dir):
|
|
252
|
+
logging.info('%s does not exists, create it automatically' % model_dir)
|
|
253
|
+
tf.gfile.MakeDirs(model_dir)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _get_ckpt_path(pipeline_config, checkpoint_path):
|
|
257
|
+
if checkpoint_path != '' and checkpoint_path is not None:
|
|
258
|
+
if tf.gfile.IsDirectory(checkpoint_path):
|
|
259
|
+
ckpt_path = estimator_utils.latest_checkpoint(checkpoint_path)
|
|
260
|
+
else:
|
|
261
|
+
ckpt_path = checkpoint_path
|
|
262
|
+
elif tf.gfile.IsDirectory(pipeline_config.model_dir):
|
|
263
|
+
ckpt_path = estimator_utils.latest_checkpoint(pipeline_config.model_dir)
|
|
264
|
+
logging.info('checkpoint_path is not specified, '
|
|
265
|
+
'will use latest checkpoint %s from %s' %
|
|
266
|
+
(ckpt_path, pipeline_config.model_dir))
|
|
267
|
+
else:
|
|
268
|
+
assert False, 'pipeline_config.model_dir(%s) does not exist' \
|
|
269
|
+
% pipeline_config.model_dir
|
|
270
|
+
return ckpt_path
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def train_and_evaluate(pipeline_config_path, continue_train=False):
|
|
274
|
+
"""Train and evaluate a EasyRec model defined in pipeline_config_path.
|
|
275
|
+
|
|
276
|
+
Build an EasyRecEstimator, and then train and evaluate the estimator.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
pipeline_config_path: a path to EasyRecConfig object, specifies
|
|
280
|
+
train_config: model_config, data_config and eval_config
|
|
281
|
+
continue_train: whether to restart train from an existing
|
|
282
|
+
checkpoint
|
|
283
|
+
Returns:
|
|
284
|
+
None, the model will be saved into pipeline_config.model_dir
|
|
285
|
+
"""
|
|
286
|
+
assert tf.gfile.Exists(
|
|
287
|
+
pipeline_config_path), 'pipeline_config_path not exists'
|
|
288
|
+
pipeline_config = config_util.get_configs_from_pipeline_file(
|
|
289
|
+
pipeline_config_path)
|
|
290
|
+
|
|
291
|
+
_train_and_evaluate_impl(pipeline_config, continue_train)
|
|
292
|
+
|
|
293
|
+
return pipeline_config
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _train_and_evaluate_impl(pipeline_config,
|
|
297
|
+
continue_train=False,
|
|
298
|
+
check_mode=False,
|
|
299
|
+
fit_on_eval=False,
|
|
300
|
+
fit_on_eval_steps=None):
|
|
301
|
+
train_config = pipeline_config.train_config
|
|
302
|
+
data_config = pipeline_config.data_config
|
|
303
|
+
feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
|
|
304
|
+
|
|
305
|
+
if train_config.train_distribute != DistributionStrategy.NoStrategy\
|
|
306
|
+
and train_config.sync_replicas:
|
|
307
|
+
logging.warning(
|
|
308
|
+
'will set sync_replicas to False, because train_distribute[%s] != NoStrategy'
|
|
309
|
+
% pipeline_config.train_config.train_distribute)
|
|
310
|
+
pipeline_config.train_config.sync_replicas = False
|
|
311
|
+
|
|
312
|
+
train_data = get_train_input_path(pipeline_config)
|
|
313
|
+
eval_data = get_eval_input_path(pipeline_config)
|
|
314
|
+
|
|
315
|
+
distribution = strategy_builder.build(train_config)
|
|
316
|
+
params = {}
|
|
317
|
+
if train_config.is_profiling:
|
|
318
|
+
params['log_device_placement'] = True
|
|
319
|
+
estimator, run_config = _create_estimator(
|
|
320
|
+
pipeline_config, distribution=distribution, params=params)
|
|
321
|
+
|
|
322
|
+
version_file = os.path.join(pipeline_config.model_dir, 'version')
|
|
323
|
+
if estimator_utils.is_chief():
|
|
324
|
+
_check_model_dir(pipeline_config.model_dir, continue_train)
|
|
325
|
+
config_util.save_pipeline_config(pipeline_config, pipeline_config.model_dir)
|
|
326
|
+
with tf.gfile.GFile(version_file, 'w') as f:
|
|
327
|
+
f.write(easy_rec.__version__ + '\n')
|
|
328
|
+
|
|
329
|
+
train_steps = None
|
|
330
|
+
if train_config.HasField('num_steps') and train_config.num_steps > 0:
|
|
331
|
+
train_steps = train_config.num_steps
|
|
332
|
+
assert train_steps is not None or data_config.num_epochs > 0, (
|
|
333
|
+
'either num_steps and num_epochs must be set to an integer > 0.')
|
|
334
|
+
|
|
335
|
+
if train_steps and data_config.num_epochs:
|
|
336
|
+
logging.info('Both num_steps and num_epochs are set.')
|
|
337
|
+
is_sync = train_config.sync_replicas
|
|
338
|
+
batch_size = data_config.batch_size
|
|
339
|
+
epoch_str = 'sample_num * %d / %d' % (data_config.num_epochs, batch_size)
|
|
340
|
+
if is_sync:
|
|
341
|
+
_, worker_num = estimator_utils.get_task_index_and_num()
|
|
342
|
+
epoch_str += ' / ' + str(worker_num)
|
|
343
|
+
logging.info('Will train min(%d, %s) steps...' % (train_steps, epoch_str))
|
|
344
|
+
|
|
345
|
+
input_fn_kwargs = {'pipeline_config': pipeline_config}
|
|
346
|
+
if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
|
|
347
|
+
input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
|
|
348
|
+
|
|
349
|
+
# create train input
|
|
350
|
+
train_input_fn = _get_input_fn(
|
|
351
|
+
data_config,
|
|
352
|
+
feature_configs,
|
|
353
|
+
train_data,
|
|
354
|
+
check_mode=check_mode,
|
|
355
|
+
**input_fn_kwargs)
|
|
356
|
+
# Currently only a single Eval Spec is allowed.
|
|
357
|
+
train_spec = tf.estimator.TrainSpec(
|
|
358
|
+
input_fn=train_input_fn, max_steps=train_steps)
|
|
359
|
+
|
|
360
|
+
embedding_parallel = train_config.train_distribute in (
|
|
361
|
+
DistributionStrategy.SokStrategy,
|
|
362
|
+
DistributionStrategy.EmbeddingParallelStrategy)
|
|
363
|
+
|
|
364
|
+
if embedding_parallel:
|
|
365
|
+
estimator.train(
|
|
366
|
+
input_fn=train_input_fn,
|
|
367
|
+
max_steps=train_spec.max_steps,
|
|
368
|
+
hooks=list(train_spec.hooks),
|
|
369
|
+
saving_listeners=train_spec.saving_listeners)
|
|
370
|
+
train_input_fn.input_creator.stop()
|
|
371
|
+
else:
|
|
372
|
+
# create eval spec
|
|
373
|
+
eval_spec = _create_eval_export_spec(
|
|
374
|
+
pipeline_config, eval_data, check_mode=check_mode)
|
|
375
|
+
estimator_train.train_and_evaluate(estimator, train_spec, eval_spec)
|
|
376
|
+
logging.info('Train and evaluate finish')
|
|
377
|
+
if fit_on_eval and (not estimator_utils.is_evaluator()):
|
|
378
|
+
tf.reset_default_graph()
|
|
379
|
+
logging.info('Start continue training on eval data')
|
|
380
|
+
eval_input_fn = _get_input_fn(data_config, feature_configs, eval_data,
|
|
381
|
+
**input_fn_kwargs)
|
|
382
|
+
if fit_on_eval_steps is not None:
|
|
383
|
+
# wait estimator train done to get the correct train_steps
|
|
384
|
+
while not estimator_train.estimator_train_done(estimator):
|
|
385
|
+
time.sleep(1)
|
|
386
|
+
train_steps = estimator_utils.get_trained_steps(estimator.model_dir)
|
|
387
|
+
logging.info('\ttrain_steps=%d fit_on_eval_steps=%d' %
|
|
388
|
+
(train_steps, fit_on_eval_steps))
|
|
389
|
+
fit_on_eval_steps += train_steps
|
|
390
|
+
# Do not use estimator_train.train_and_evaluate as it starts tf.Server,
|
|
391
|
+
# which is redundant and reports port not available error.
|
|
392
|
+
estimator.train(
|
|
393
|
+
input_fn=eval_input_fn,
|
|
394
|
+
max_steps=fit_on_eval_steps,
|
|
395
|
+
hooks=list(train_spec.hooks),
|
|
396
|
+
saving_listeners=train_spec.saving_listeners if hasattr(
|
|
397
|
+
train_spec, 'saving_listeners') else None)
|
|
398
|
+
logging.info('Finished training on eval data')
|
|
399
|
+
# return estimator for custom training using estimator.train
|
|
400
|
+
return estimator
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def evaluate(pipeline_config,
|
|
404
|
+
eval_checkpoint_path='',
|
|
405
|
+
eval_data_path=None,
|
|
406
|
+
eval_result_filename='eval_result.txt'):
|
|
407
|
+
"""Evaluate a EasyRec model defined in pipeline_config_path.
|
|
408
|
+
|
|
409
|
+
Evaluate the model defined in pipeline_config_path on the eval data,
|
|
410
|
+
the metrics will be displayed on tensorboard and saved into eval_result.txt.
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
pipeline_config: either EasyRecConfig path or its instance
|
|
414
|
+
eval_checkpoint_path: if specified, will use this model instead of
|
|
415
|
+
model specified by model_dir in pipeline_config_path
|
|
416
|
+
eval_data_path: eval data path, default use eval data in pipeline_config
|
|
417
|
+
could be a path or a list of paths
|
|
418
|
+
eval_result_filename: evaluation result metrics save path.
|
|
419
|
+
|
|
420
|
+
Returns:
|
|
421
|
+
A dict of evaluation metrics: the metrics are specified in
|
|
422
|
+
pipeline_config_path
|
|
423
|
+
global_step: the global step for which this evaluation was performed.
|
|
424
|
+
|
|
425
|
+
Raises:
|
|
426
|
+
AssertionError, if:
|
|
427
|
+
* pipeline_config_path does not exist
|
|
428
|
+
"""
|
|
429
|
+
pipeline_config = config_util.get_configs_from_pipeline_file(pipeline_config)
|
|
430
|
+
if pipeline_config.fg_json_path:
|
|
431
|
+
fg_util.load_fg_json_to_config(pipeline_config)
|
|
432
|
+
if eval_data_path is not None:
|
|
433
|
+
logging.info('Evaluating on data: %s' % eval_data_path)
|
|
434
|
+
set_eval_input_path(pipeline_config, eval_data_path)
|
|
435
|
+
|
|
436
|
+
train_config = pipeline_config.train_config
|
|
437
|
+
eval_data = get_eval_input_path(pipeline_config)
|
|
438
|
+
|
|
439
|
+
server_target = None
|
|
440
|
+
if 'TF_CONFIG' in os.environ:
|
|
441
|
+
tf_config = estimator_utils.chief_to_master()
|
|
442
|
+
from tensorflow.python.training import server_lib
|
|
443
|
+
if tf_config['task']['type'] == 'ps':
|
|
444
|
+
cluster = tf.train.ClusterSpec(tf_config['cluster'])
|
|
445
|
+
server = server_lib.Server(
|
|
446
|
+
cluster, job_name='ps', task_index=tf_config['task']['index'])
|
|
447
|
+
server.join()
|
|
448
|
+
elif tf_config['task']['type'] == 'master':
|
|
449
|
+
if 'ps' in tf_config['cluster']:
|
|
450
|
+
cluster = tf.train.ClusterSpec(tf_config['cluster'])
|
|
451
|
+
server = server_lib.Server(cluster, job_name='master', task_index=0)
|
|
452
|
+
server_target = server.target
|
|
453
|
+
print('server_target = %s' % server_target)
|
|
454
|
+
|
|
455
|
+
distribution = strategy_builder.build(train_config)
|
|
456
|
+
estimator, run_config = _create_estimator(pipeline_config, distribution)
|
|
457
|
+
eval_spec = _create_eval_export_spec(pipeline_config, eval_data)
|
|
458
|
+
ckpt_path = _get_ckpt_path(pipeline_config, eval_checkpoint_path)
|
|
459
|
+
|
|
460
|
+
if server_target:
|
|
461
|
+
# evaluate with parameter server
|
|
462
|
+
input_iter = eval_spec.input_fn(
|
|
463
|
+
mode=tf.estimator.ModeKeys.EVAL).make_one_shot_iterator()
|
|
464
|
+
input_feas, input_lbls = input_iter.get_next()
|
|
465
|
+
from tensorflow.python.training.device_setter import replica_device_setter
|
|
466
|
+
from tensorflow.python.framework.ops import device
|
|
467
|
+
from tensorflow.python.training.monitored_session import MonitoredSession
|
|
468
|
+
from tensorflow.python.training.monitored_session import ChiefSessionCreator
|
|
469
|
+
with device(
|
|
470
|
+
replica_device_setter(
|
|
471
|
+
worker_device='/job:master/task:0', cluster=cluster)):
|
|
472
|
+
estimator_spec = estimator._eval_model_fn(input_feas, input_lbls,
|
|
473
|
+
run_config)
|
|
474
|
+
|
|
475
|
+
session_config = ConfigProto(
|
|
476
|
+
allow_soft_placement=True, log_device_placement=True)
|
|
477
|
+
chief_sess_creator = ChiefSessionCreator(
|
|
478
|
+
master=server_target,
|
|
479
|
+
checkpoint_filename_with_path=ckpt_path,
|
|
480
|
+
config=session_config)
|
|
481
|
+
eval_metric_ops = estimator_spec.eval_metric_ops
|
|
482
|
+
update_ops = [eval_metric_ops[x][1] for x in eval_metric_ops.keys()]
|
|
483
|
+
metric_ops = {x: eval_metric_ops[x][0] for x in eval_metric_ops.keys()}
|
|
484
|
+
update_op = tf.group(update_ops)
|
|
485
|
+
with MonitoredSession(
|
|
486
|
+
session_creator=chief_sess_creator,
|
|
487
|
+
hooks=None,
|
|
488
|
+
stop_grace_period_secs=120) as sess:
|
|
489
|
+
while True:
|
|
490
|
+
try:
|
|
491
|
+
sess.run(update_op)
|
|
492
|
+
except tf.errors.OutOfRangeError:
|
|
493
|
+
break
|
|
494
|
+
eval_result = sess.run(metric_ops)
|
|
495
|
+
else:
|
|
496
|
+
# this way does not work, wait to be debugged
|
|
497
|
+
# the variables are not placed to parameter server
|
|
498
|
+
# with tf.device(
|
|
499
|
+
# replica_device_setter(
|
|
500
|
+
# worker_device='/job:master/task:0', cluster=cluster)):
|
|
501
|
+
eval_result = estimator.evaluate(
|
|
502
|
+
eval_spec.input_fn, eval_spec.steps, checkpoint_path=ckpt_path)
|
|
503
|
+
eval_spec.input_fn.input_creator.stop()
|
|
504
|
+
logging.info('Evaluate finish')
|
|
505
|
+
|
|
506
|
+
print('eval_result = ', eval_result)
|
|
507
|
+
logging.info('eval_result = {0}'.format(eval_result))
|
|
508
|
+
# write eval result to file
|
|
509
|
+
model_dir = pipeline_config.model_dir
|
|
510
|
+
eval_result_file = os.path.join(model_dir, eval_result_filename)
|
|
511
|
+
logging.info('save eval result to file %s' % eval_result_file)
|
|
512
|
+
with tf.gfile.GFile(eval_result_file, 'w') as ofile:
|
|
513
|
+
result_to_write = {}
|
|
514
|
+
for key in sorted(eval_result):
|
|
515
|
+
# skip logging binary data
|
|
516
|
+
if isinstance(eval_result[key], six.binary_type):
|
|
517
|
+
continue
|
|
518
|
+
# convert numpy float to python float
|
|
519
|
+
result_to_write[key] = eval_result[key].item()
|
|
520
|
+
ofile.write(json.dumps(result_to_write, indent=2))
|
|
521
|
+
return eval_result
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
def distribute_evaluate(pipeline_config,
|
|
525
|
+
eval_checkpoint_path='',
|
|
526
|
+
eval_data_path=None,
|
|
527
|
+
eval_result_filename='distribute_eval_result.txt'):
|
|
528
|
+
"""Evaluate a EasyRec model defined in pipeline_config_path.
|
|
529
|
+
|
|
530
|
+
Evaluate the model defined in pipeline_config_path on the eval data,
|
|
531
|
+
the metrics will be displayed on tensorboard and saved into eval_result.txt.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
pipeline_config: either EasyRecConfig path or its instance
|
|
535
|
+
eval_checkpoint_path: if specified, will use this model instead of
|
|
536
|
+
model specified by model_dir in pipeline_config_path
|
|
537
|
+
eval_data_path: eval data path, default use eval data in pipeline_config
|
|
538
|
+
could be a path or a list of paths
|
|
539
|
+
eval_result_filename: evaluation result metrics save path.
|
|
540
|
+
|
|
541
|
+
Returns:
|
|
542
|
+
A dict of evaluation metrics: the metrics are specified in
|
|
543
|
+
pipeline_config_path
|
|
544
|
+
global_step: the global step for which this evaluation was performed.
|
|
545
|
+
|
|
546
|
+
Raises:
|
|
547
|
+
AssertionError, if:
|
|
548
|
+
* pipeline_config_path does not exist
|
|
549
|
+
"""
|
|
550
|
+
pipeline_config = config_util.get_configs_from_pipeline_file(pipeline_config)
|
|
551
|
+
if eval_data_path is not None:
|
|
552
|
+
logging.info('Evaluating on data: %s' % eval_data_path)
|
|
553
|
+
set_eval_input_path(pipeline_config, eval_data_path)
|
|
554
|
+
train_config = pipeline_config.train_config
|
|
555
|
+
eval_data = get_eval_input_path(pipeline_config)
|
|
556
|
+
data_config = pipeline_config.data_config
|
|
557
|
+
if data_config.HasField('sampler'):
|
|
558
|
+
logging.warning(
|
|
559
|
+
'It is not accuracy to use eval with negative sampler, recommand to use hitrate.py!'
|
|
560
|
+
)
|
|
561
|
+
eval_result = {}
|
|
562
|
+
return eval_result
|
|
563
|
+
model_dir = get_model_dir_path(pipeline_config)
|
|
564
|
+
eval_tmp_results_dir = os.path.join(model_dir, 'distribute_eval_tmp_results')
|
|
565
|
+
if not tf.gfile.IsDirectory(eval_tmp_results_dir):
|
|
566
|
+
logging.info('create eval tmp results dir {}'.format(eval_tmp_results_dir))
|
|
567
|
+
tf.gfile.MakeDirs(eval_tmp_results_dir)
|
|
568
|
+
assert tf.gfile.IsDirectory(
|
|
569
|
+
eval_tmp_results_dir), 'tmp results dir not create success.'
|
|
570
|
+
os.environ['eval_tmp_results_dir'] = eval_tmp_results_dir
|
|
571
|
+
|
|
572
|
+
server_target = None
|
|
573
|
+
cur_job_name = None
|
|
574
|
+
if 'TF_CONFIG' in os.environ:
|
|
575
|
+
tf_config = estimator_utils.chief_to_master()
|
|
576
|
+
|
|
577
|
+
from tensorflow.python.training import server_lib
|
|
578
|
+
if tf_config['task']['type'] == 'ps':
|
|
579
|
+
cluster = tf.train.ClusterSpec(tf_config['cluster'])
|
|
580
|
+
server = server_lib.Server(
|
|
581
|
+
cluster, job_name='ps', task_index=tf_config['task']['index'])
|
|
582
|
+
server.join()
|
|
583
|
+
elif tf_config['task']['type'] == 'master':
|
|
584
|
+
if 'ps' in tf_config['cluster']:
|
|
585
|
+
cur_job_name = tf_config['task']['type']
|
|
586
|
+
cur_task_index = tf_config['task']['index']
|
|
587
|
+
cluster = tf.train.ClusterSpec(tf_config['cluster'])
|
|
588
|
+
server = server_lib.Server(
|
|
589
|
+
cluster, job_name=cur_job_name, task_index=cur_task_index)
|
|
590
|
+
server_target = server.target
|
|
591
|
+
print('server_target = %s' % server_target)
|
|
592
|
+
elif tf_config['task']['type'] == 'worker':
|
|
593
|
+
if 'ps' in tf_config['cluster']:
|
|
594
|
+
cur_job_name = tf_config['task']['type']
|
|
595
|
+
cur_task_index = tf_config['task']['index']
|
|
596
|
+
cluster = tf.train.ClusterSpec(tf_config['cluster'])
|
|
597
|
+
server = server_lib.Server(
|
|
598
|
+
cluster, job_name=cur_job_name, task_index=cur_task_index)
|
|
599
|
+
server_target = server.target
|
|
600
|
+
print('server_target = %s' % server_target)
|
|
601
|
+
|
|
602
|
+
if server_target:
|
|
603
|
+
from tensorflow.python.training.device_setter import replica_device_setter
|
|
604
|
+
from tensorflow.python.framework.ops import device
|
|
605
|
+
from tensorflow.python.training.monitored_session import MonitoredSession
|
|
606
|
+
from tensorflow.python.training.monitored_session import ChiefSessionCreator
|
|
607
|
+
from tensorflow.python.training.monitored_session import WorkerSessionCreator
|
|
608
|
+
from easy_rec.python.utils.estimator_utils import EvaluateExitBarrierHook
|
|
609
|
+
cur_work_device = '/job:' + cur_job_name + '/task:' + str(cur_task_index)
|
|
610
|
+
cur_ps_num = len(tf_config['cluster']['ps'])
|
|
611
|
+
with device(
|
|
612
|
+
replica_device_setter(
|
|
613
|
+
ps_tasks=cur_ps_num, worker_device=cur_work_device,
|
|
614
|
+
cluster=cluster)):
|
|
615
|
+
distribution = strategy_builder.build(train_config)
|
|
616
|
+
estimator, run_config = _create_estimator(pipeline_config, distribution)
|
|
617
|
+
eval_spec = _create_eval_export_spec(pipeline_config, eval_data)
|
|
618
|
+
ckpt_path = _get_ckpt_path(pipeline_config, eval_checkpoint_path)
|
|
619
|
+
ckpt_dir = os.path.dirname(ckpt_path)
|
|
620
|
+
input_iter = eval_spec.input_fn(
|
|
621
|
+
mode=tf.estimator.ModeKeys.EVAL).make_one_shot_iterator()
|
|
622
|
+
input_feas, input_lbls = input_iter.get_next()
|
|
623
|
+
estimator_spec = estimator._distribute_eval_model_fn(
|
|
624
|
+
input_feas, input_lbls, run_config)
|
|
625
|
+
|
|
626
|
+
session_config = ConfigProto(
|
|
627
|
+
allow_soft_placement=True,
|
|
628
|
+
log_device_placement=True,
|
|
629
|
+
device_filters=['/job:ps',
|
|
630
|
+
'/job:worker/task:%d' % cur_task_index])
|
|
631
|
+
if cur_job_name == 'master':
|
|
632
|
+
metric_variables = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES)
|
|
633
|
+
model_ready_for_local_init_op = tf.variables_initializer(metric_variables)
|
|
634
|
+
global_variables = tf.global_variables()
|
|
635
|
+
remain_variables = list(
|
|
636
|
+
set(global_variables).difference(set(metric_variables)))
|
|
637
|
+
cur_saver = tf.train.Saver(var_list=remain_variables, sharded=True)
|
|
638
|
+
cur_scaffold = tf.train.Scaffold(
|
|
639
|
+
saver=cur_saver,
|
|
640
|
+
ready_for_local_init_op=model_ready_for_local_init_op)
|
|
641
|
+
cur_sess_creator = ChiefSessionCreator(
|
|
642
|
+
scaffold=cur_scaffold,
|
|
643
|
+
master=server_target,
|
|
644
|
+
checkpoint_filename_with_path=ckpt_path,
|
|
645
|
+
config=session_config)
|
|
646
|
+
else:
|
|
647
|
+
cur_sess_creator = WorkerSessionCreator(
|
|
648
|
+
master=server_target, config=session_config)
|
|
649
|
+
eval_metric_ops = estimator_spec.eval_metric_ops
|
|
650
|
+
update_ops = [eval_metric_ops[x][1] for x in eval_metric_ops.keys()]
|
|
651
|
+
metric_ops = {x: eval_metric_ops[x][0] for x in eval_metric_ops.keys()}
|
|
652
|
+
update_op = tf.group(update_ops)
|
|
653
|
+
cur_worker_num = len(tf_config['cluster']['worker']) + 1
|
|
654
|
+
if cur_job_name == 'master':
|
|
655
|
+
cur_stop_grace_period_sesc = 120
|
|
656
|
+
cur_hooks = EvaluateExitBarrierHook(cur_worker_num, True, ckpt_dir,
|
|
657
|
+
metric_ops)
|
|
658
|
+
else:
|
|
659
|
+
cur_stop_grace_period_sesc = 10
|
|
660
|
+
cur_hooks = EvaluateExitBarrierHook(cur_worker_num, False, ckpt_dir,
|
|
661
|
+
metric_ops)
|
|
662
|
+
with MonitoredSession(
|
|
663
|
+
session_creator=cur_sess_creator,
|
|
664
|
+
hooks=[cur_hooks],
|
|
665
|
+
stop_grace_period_secs=cur_stop_grace_period_sesc) as sess:
|
|
666
|
+
while True:
|
|
667
|
+
try:
|
|
668
|
+
sess.run(update_op)
|
|
669
|
+
except tf.errors.OutOfRangeError:
|
|
670
|
+
break
|
|
671
|
+
eval_result = cur_hooks.eval_result
|
|
672
|
+
|
|
673
|
+
logging.info('Evaluate finish')
|
|
674
|
+
|
|
675
|
+
# write eval result to file
|
|
676
|
+
model_dir = pipeline_config.model_dir
|
|
677
|
+
eval_result_file = os.path.join(model_dir, eval_result_filename)
|
|
678
|
+
logging.info('save eval result to file %s' % eval_result_file)
|
|
679
|
+
if cur_job_name == 'master':
|
|
680
|
+
print('eval_result = ', eval_result)
|
|
681
|
+
logging.info('eval_result = {0}'.format(eval_result))
|
|
682
|
+
with tf.gfile.GFile(eval_result_file, 'w') as ofile:
|
|
683
|
+
result_to_write = {'eval_method': 'distribute'}
|
|
684
|
+
for key in sorted(eval_result):
|
|
685
|
+
# skip logging binary data
|
|
686
|
+
if isinstance(eval_result[key], six.binary_type):
|
|
687
|
+
continue
|
|
688
|
+
# convert numpy float to python float
|
|
689
|
+
result_to_write[key] = eval_result[key].item()
|
|
690
|
+
|
|
691
|
+
ofile.write(json.dumps(result_to_write))
|
|
692
|
+
return eval_result
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
def predict(pipeline_config, checkpoint_path='', data_path=None):
|
|
696
|
+
"""Predict a EasyRec model defined in pipeline_config_path.
|
|
697
|
+
|
|
698
|
+
Predict the model defined in pipeline_config_path on the eval data.
|
|
699
|
+
|
|
700
|
+
Args:
|
|
701
|
+
pipeline_config: either EasyRecConfig path or its instance
|
|
702
|
+
checkpoint_path: if specified, will use this model instead of
|
|
703
|
+
model specified by model_dir in pipeline_config_path
|
|
704
|
+
data_path: data path, default use eval data in pipeline_config
|
|
705
|
+
could be a path or a list of paths
|
|
706
|
+
|
|
707
|
+
Returns:
|
|
708
|
+
A list of dict of predict results
|
|
709
|
+
|
|
710
|
+
Raises:
|
|
711
|
+
AssertionError, if:
|
|
712
|
+
* pipeline_config_path does not exist
|
|
713
|
+
"""
|
|
714
|
+
pipeline_config = config_util.get_configs_from_pipeline_file(pipeline_config)
|
|
715
|
+
if pipeline_config.fg_json_path:
|
|
716
|
+
fg_util.load_fg_json_to_config(pipeline_config)
|
|
717
|
+
if data_path is not None:
|
|
718
|
+
logging.info('Predict on data: %s' % data_path)
|
|
719
|
+
set_eval_input_path(pipeline_config, data_path)
|
|
720
|
+
train_config = pipeline_config.train_config
|
|
721
|
+
eval_data = get_eval_input_path(pipeline_config)
|
|
722
|
+
|
|
723
|
+
distribution = strategy_builder.build(train_config)
|
|
724
|
+
estimator, _ = _create_estimator(pipeline_config, distribution)
|
|
725
|
+
eval_spec = _create_eval_export_spec(pipeline_config, eval_data)
|
|
726
|
+
|
|
727
|
+
ckpt_path = _get_ckpt_path(pipeline_config, checkpoint_path)
|
|
728
|
+
|
|
729
|
+
pred_result = estimator.predict(eval_spec.input_fn, checkpoint_path=ckpt_path)
|
|
730
|
+
logging.info('Predict finish')
|
|
731
|
+
return pred_result
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
def export(export_dir,
|
|
735
|
+
pipeline_config,
|
|
736
|
+
checkpoint_path='',
|
|
737
|
+
asset_files=None,
|
|
738
|
+
verbose=False,
|
|
739
|
+
**extra_params):
|
|
740
|
+
"""Export model defined in pipeline_config_path.
|
|
741
|
+
|
|
742
|
+
Args:
|
|
743
|
+
export_dir: base directory where the model should be exported
|
|
744
|
+
pipeline_config: proto.EasyRecConfig instance or file path
|
|
745
|
+
specify proto.EasyRecConfig
|
|
746
|
+
checkpoint_path: if specified, will use this model instead of
|
|
747
|
+
model in model_dir in pipeline_config_path
|
|
748
|
+
asset_files: extra files to add to assets, comma separated;
|
|
749
|
+
if asset file variable in graph need to be renamed,
|
|
750
|
+
specify by new_file_name:file_path
|
|
751
|
+
version: if version is defined, then will skip writing embedding to redis,
|
|
752
|
+
assume that embedding is already write into redis
|
|
753
|
+
verbose: dumps debug information
|
|
754
|
+
extra_params: keys related to write embedding to redis/oss
|
|
755
|
+
redis_url, redis_passwd, redis_threads, redis_batch_size,
|
|
756
|
+
redis_timeout, redis_expire if export embedding to redis;
|
|
757
|
+
redis_embedding_version: if specified, will kill export to redis
|
|
758
|
+
--
|
|
759
|
+
oss_path, oss_endpoint, oss_ak, oss_sk, oss_timeout,
|
|
760
|
+
oss_expire, oss_write_kv, oss_embedding_version
|
|
761
|
+
|
|
762
|
+
Returns:
|
|
763
|
+
the directory where model is exported
|
|
764
|
+
|
|
765
|
+
Raises:
|
|
766
|
+
AssertionError, if:
|
|
767
|
+
* pipeline_config_path does not exist
|
|
768
|
+
"""
|
|
769
|
+
if not tf.gfile.Exists(export_dir):
|
|
770
|
+
tf.gfile.MakeDirs(export_dir)
|
|
771
|
+
|
|
772
|
+
pipeline_config = config_util.get_configs_from_pipeline_file(pipeline_config)
|
|
773
|
+
if pipeline_config.fg_json_path:
|
|
774
|
+
fg_util.load_fg_json_to_config(pipeline_config)
|
|
775
|
+
# feature_configs = pipeline_config.feature_configs
|
|
776
|
+
feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
|
|
777
|
+
# create estimator
|
|
778
|
+
params = {'log_device_placement': verbose}
|
|
779
|
+
if asset_files:
|
|
780
|
+
logging.info('will add asset files: %s' % asset_files)
|
|
781
|
+
asset_file_dict = {}
|
|
782
|
+
for asset_file in asset_files.split(','):
|
|
783
|
+
asset_file = asset_file.strip()
|
|
784
|
+
if ':' not in asset_file or asset_file.startswith(
|
|
785
|
+
'oss:') or asset_file.startswith('hdfs:'):
|
|
786
|
+
_, asset_name = os.path.split(asset_file)
|
|
787
|
+
else:
|
|
788
|
+
asset_name, asset_file = asset_file.split(':', 1)
|
|
789
|
+
asset_file_dict[asset_name] = asset_file
|
|
790
|
+
params['asset_files'] = asset_file_dict
|
|
791
|
+
estimator, _ = _create_estimator(pipeline_config, params=params)
|
|
792
|
+
# construct serving input fn
|
|
793
|
+
export_config = pipeline_config.export_config
|
|
794
|
+
data_config = pipeline_config.data_config
|
|
795
|
+
input_fn_kwargs = {'pipeline_config': pipeline_config}
|
|
796
|
+
if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
|
|
797
|
+
input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
|
|
798
|
+
serving_input_fn = _get_input_fn(data_config, feature_configs, None,
|
|
799
|
+
export_config, **input_fn_kwargs)
|
|
800
|
+
ckpt_path = _get_ckpt_path(pipeline_config, checkpoint_path)
|
|
801
|
+
if 'oss_path' in extra_params:
|
|
802
|
+
if pipeline_config.train_config.HasField('incr_save_config'):
|
|
803
|
+
incr_save_config = pipeline_config.train_config.incr_save_config
|
|
804
|
+
extra_params['incr_update'] = {}
|
|
805
|
+
incr_save_type = incr_save_config.WhichOneof('incr_update')
|
|
806
|
+
logging.info('incr_save_type=%s' % incr_save_type)
|
|
807
|
+
if incr_save_type:
|
|
808
|
+
extra_params['incr_update'][incr_save_type] = getattr(
|
|
809
|
+
incr_save_config, incr_save_type)
|
|
810
|
+
return export_big_model_to_oss(export_dir, pipeline_config, extra_params,
|
|
811
|
+
serving_input_fn, estimator, ckpt_path,
|
|
812
|
+
verbose)
|
|
813
|
+
|
|
814
|
+
if 'redis_url' in extra_params:
|
|
815
|
+
return export_big_model(export_dir, pipeline_config, extra_params,
|
|
816
|
+
serving_input_fn, estimator, ckpt_path, verbose)
|
|
817
|
+
|
|
818
|
+
final_export_dir = estimator.export_savedmodel(
|
|
819
|
+
export_dir_base=export_dir,
|
|
820
|
+
serving_input_receiver_fn=serving_input_fn,
|
|
821
|
+
checkpoint_path=ckpt_path,
|
|
822
|
+
strip_default_attrs=True)
|
|
823
|
+
|
|
824
|
+
# add export ts as version info
|
|
825
|
+
saved_model = saved_model_pb2.SavedModel()
|
|
826
|
+
if type(final_export_dir) not in [type(''), type(u'')]:
|
|
827
|
+
final_export_dir = final_export_dir.decode('utf-8')
|
|
828
|
+
export_ts = [
|
|
829
|
+
x for x in final_export_dir.split('/') if x != '' and x is not None
|
|
830
|
+
]
|
|
831
|
+
export_ts = export_ts[-1]
|
|
832
|
+
saved_pb_path = os.path.join(final_export_dir, 'saved_model.pb')
|
|
833
|
+
with tf.gfile.GFile(saved_pb_path, 'rb') as fin:
|
|
834
|
+
saved_model.ParseFromString(fin.read())
|
|
835
|
+
saved_model.meta_graphs[0].meta_info_def.meta_graph_version = export_ts
|
|
836
|
+
with tf.gfile.GFile(saved_pb_path, 'wb') as fout:
|
|
837
|
+
fout.write(saved_model.SerializeToString())
|
|
838
|
+
|
|
839
|
+
logging.info('model has been exported to %s successfully' % final_export_dir)
|
|
840
|
+
return final_export_dir
|
|
841
|
+
|
|
842
|
+
|
|
843
|
+
def export_checkpoint(pipeline_config=None,
|
|
844
|
+
export_path='',
|
|
845
|
+
checkpoint_path='',
|
|
846
|
+
asset_files=None,
|
|
847
|
+
verbose=False,
|
|
848
|
+
mode=tf.estimator.ModeKeys.PREDICT):
|
|
849
|
+
"""Export the EasyRec model as checkpoint."""
|
|
850
|
+
pipeline_config = config_util.get_configs_from_pipeline_file(pipeline_config)
|
|
851
|
+
if pipeline_config.fg_json_path:
|
|
852
|
+
fg_util.load_fg_json_to_config(pipeline_config)
|
|
853
|
+
feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
|
|
854
|
+
data_config = pipeline_config.data_config
|
|
855
|
+
|
|
856
|
+
input_fn_kwargs = {'pipeline_config': pipeline_config}
|
|
857
|
+
if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
|
|
858
|
+
input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
|
|
859
|
+
|
|
860
|
+
# create estimator
|
|
861
|
+
params = {'log_device_placement': verbose}
|
|
862
|
+
if asset_files:
|
|
863
|
+
logging.info('will add asset files: %s' % asset_files)
|
|
864
|
+
params['asset_files'] = asset_files
|
|
865
|
+
estimator, _ = _create_estimator(pipeline_config, params=params)
|
|
866
|
+
|
|
867
|
+
# construct serving input fn
|
|
868
|
+
export_config = pipeline_config.export_config
|
|
869
|
+
serving_input_fn = _get_input_fn(data_config, feature_configs, None,
|
|
870
|
+
export_config, **input_fn_kwargs)
|
|
871
|
+
ckpt_path = _get_ckpt_path(pipeline_config, checkpoint_path)
|
|
872
|
+
estimator.export_checkpoint(
|
|
873
|
+
export_path=export_path,
|
|
874
|
+
serving_input_receiver_fn=serving_input_fn,
|
|
875
|
+
checkpoint_path=ckpt_path,
|
|
876
|
+
mode=mode)
|
|
877
|
+
|
|
878
|
+
logging.info('model checkpoint has been exported successfully')
|