easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,739 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
from __future__ import print_function
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import re
|
|
9
|
+
import time
|
|
10
|
+
from collections import OrderedDict
|
|
11
|
+
|
|
12
|
+
import tensorflow as tf
|
|
13
|
+
from tensorflow.python.client import session as tf_session
|
|
14
|
+
from tensorflow.python.eager import context
|
|
15
|
+
from tensorflow.python.framework import ops
|
|
16
|
+
from tensorflow.python.ops import variables
|
|
17
|
+
from tensorflow.python.platform import gfile
|
|
18
|
+
from tensorflow.python.saved_model import signature_constants
|
|
19
|
+
from tensorflow.python.training import basic_session_run_hooks
|
|
20
|
+
from tensorflow.python.training import saver
|
|
21
|
+
|
|
22
|
+
from easy_rec.python.builders import optimizer_builder
|
|
23
|
+
from easy_rec.python.compat import optimizers
|
|
24
|
+
from easy_rec.python.compat import sync_replicas_optimizer
|
|
25
|
+
from easy_rec.python.compat.early_stopping import custom_early_stop_hook
|
|
26
|
+
from easy_rec.python.compat.early_stopping import deadline_stop_hook
|
|
27
|
+
from easy_rec.python.compat.early_stopping import find_early_stop_var
|
|
28
|
+
from easy_rec.python.compat.early_stopping import oss_stop_hook
|
|
29
|
+
from easy_rec.python.compat.early_stopping import stop_if_no_decrease_hook
|
|
30
|
+
from easy_rec.python.compat.early_stopping import stop_if_no_increase_hook
|
|
31
|
+
from easy_rec.python.compat.ops import GraphKeys
|
|
32
|
+
from easy_rec.python.input.input import Input
|
|
33
|
+
from easy_rec.python.layers.utils import _tensor_to_tensorinfo
|
|
34
|
+
from easy_rec.python.protos.pipeline_pb2 import EasyRecConfig
|
|
35
|
+
from easy_rec.python.protos.train_pb2 import DistributionStrategy
|
|
36
|
+
from easy_rec.python.utils import constant
|
|
37
|
+
from easy_rec.python.utils import embedding_utils
|
|
38
|
+
from easy_rec.python.utils import estimator_utils
|
|
39
|
+
from easy_rec.python.utils import hvd_utils
|
|
40
|
+
from easy_rec.python.utils import pai_util
|
|
41
|
+
from easy_rec.python.utils.multi_optimizer import MultiOptimizer
|
|
42
|
+
|
|
43
|
+
from easy_rec.python.compat.embedding_parallel_saver import EmbeddingParallelSaver # NOQA
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
import horovod.tensorflow as hvd
|
|
47
|
+
except Exception:
|
|
48
|
+
hvd = None
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
from sparse_operation_kit import experiment as sok
|
|
52
|
+
from easy_rec.python.compat import sok_optimizer
|
|
53
|
+
except Exception:
|
|
54
|
+
sok = None
|
|
55
|
+
|
|
56
|
+
if tf.__version__ >= '2.0':
|
|
57
|
+
tf = tf.compat.v1
|
|
58
|
+
|
|
59
|
+
tf.estimator.Estimator._assert_members_are_not_overridden = lambda x: x
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class EasyRecEstimator(tf.estimator.Estimator):
|
|
63
|
+
|
|
64
|
+
def __init__(self, pipeline_config, model_cls, run_config, params):
|
|
65
|
+
self._pipeline_config = pipeline_config
|
|
66
|
+
self._model_cls = model_cls
|
|
67
|
+
assert isinstance(self._pipeline_config, EasyRecConfig)
|
|
68
|
+
|
|
69
|
+
super(EasyRecEstimator, self).__init__(
|
|
70
|
+
model_fn=self._model_fn,
|
|
71
|
+
model_dir=pipeline_config.model_dir,
|
|
72
|
+
config=run_config,
|
|
73
|
+
params=params)
|
|
74
|
+
|
|
75
|
+
def evaluate(self,
|
|
76
|
+
input_fn,
|
|
77
|
+
steps=None,
|
|
78
|
+
hooks=None,
|
|
79
|
+
checkpoint_path=None,
|
|
80
|
+
name=None):
|
|
81
|
+
# support for datahub/kafka offset restore
|
|
82
|
+
input_fn.input_creator.restore(checkpoint_path)
|
|
83
|
+
return super(EasyRecEstimator, self).evaluate(input_fn, steps, hooks,
|
|
84
|
+
checkpoint_path, name)
|
|
85
|
+
|
|
86
|
+
def train(self,
|
|
87
|
+
input_fn,
|
|
88
|
+
hooks=None,
|
|
89
|
+
steps=None,
|
|
90
|
+
max_steps=None,
|
|
91
|
+
saving_listeners=None):
|
|
92
|
+
# support for datahub/kafka offset restore
|
|
93
|
+
checkpoint_path = estimator_utils.latest_checkpoint(self.model_dir)
|
|
94
|
+
if checkpoint_path is not None:
|
|
95
|
+
input_fn.input_creator.restore(checkpoint_path)
|
|
96
|
+
elif self.train_config.HasField('fine_tune_checkpoint'):
|
|
97
|
+
fine_tune_ckpt = self.train_config.fine_tune_checkpoint
|
|
98
|
+
if fine_tune_ckpt.endswith('/') or gfile.IsDirectory(fine_tune_ckpt +
|
|
99
|
+
'/'):
|
|
100
|
+
fine_tune_ckpt = estimator_utils.latest_checkpoint(fine_tune_ckpt)
|
|
101
|
+
print(
|
|
102
|
+
'fine_tune_checkpoint[%s] is directory, will use the latest checkpoint: %s'
|
|
103
|
+
% (self.train_config.fine_tune_checkpoint, fine_tune_ckpt))
|
|
104
|
+
self.train_config.fine_tune_checkpoint = fine_tune_ckpt
|
|
105
|
+
input_fn.input_creator.restore(fine_tune_ckpt)
|
|
106
|
+
return super(EasyRecEstimator, self).train(input_fn, hooks, steps,
|
|
107
|
+
max_steps, saving_listeners)
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def feature_configs(self):
|
|
111
|
+
if len(self._pipeline_config.feature_configs) > 0:
|
|
112
|
+
return self._pipeline_config.feature_configs
|
|
113
|
+
elif self._pipeline_config.feature_config and len(
|
|
114
|
+
self._pipeline_config.feature_config.features) > 0:
|
|
115
|
+
return self._pipeline_config.feature_config.features
|
|
116
|
+
else:
|
|
117
|
+
assert False, 'One of feature_configs and feature_config.features must be configured.'
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def model_config(self):
|
|
121
|
+
return self._pipeline_config.model_config
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def eval_config(self):
|
|
125
|
+
return self._pipeline_config.eval_config
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def train_config(self):
|
|
129
|
+
return self._pipeline_config.train_config
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def incr_save_config(self):
|
|
133
|
+
return self.train_config.incr_save_config if self.train_config.HasField(
|
|
134
|
+
'incr_save_config') else None
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def export_config(self):
|
|
138
|
+
return self._pipeline_config.export_config
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def embedding_parallel(self):
|
|
142
|
+
return self.train_config.train_distribute in (
|
|
143
|
+
DistributionStrategy.SokStrategy,
|
|
144
|
+
DistributionStrategy.EmbeddingParallelStrategy)
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def saver_cls(self):
|
|
148
|
+
# when embedding parallel is used, will use the extended
|
|
149
|
+
# saver class (EmbeddingParallelSaver) to save sharded embedding
|
|
150
|
+
tmp_saver_cls = saver.Saver
|
|
151
|
+
if self.embedding_parallel:
|
|
152
|
+
tmp_saver_cls = EmbeddingParallelSaver
|
|
153
|
+
return tmp_saver_cls
|
|
154
|
+
|
|
155
|
+
def _train_model_fn(self, features, labels, run_config):
|
|
156
|
+
tf.keras.backend.set_learning_phase(1)
|
|
157
|
+
model = self._model_cls(
|
|
158
|
+
self.model_config,
|
|
159
|
+
self.feature_configs,
|
|
160
|
+
features,
|
|
161
|
+
labels,
|
|
162
|
+
is_training=True)
|
|
163
|
+
predict_dict = model.build_predict_graph()
|
|
164
|
+
loss_dict = model.build_loss_graph()
|
|
165
|
+
|
|
166
|
+
regularization_losses = tf.get_collection(
|
|
167
|
+
tf.GraphKeys.REGULARIZATION_LOSSES)
|
|
168
|
+
if regularization_losses:
|
|
169
|
+
regularization_losses = [
|
|
170
|
+
reg_loss.get() if hasattr(reg_loss, 'get') else reg_loss
|
|
171
|
+
for reg_loss in regularization_losses
|
|
172
|
+
]
|
|
173
|
+
regularization_losses = tf.add_n(
|
|
174
|
+
regularization_losses, name='regularization_loss')
|
|
175
|
+
loss_dict['regularization_loss'] = regularization_losses
|
|
176
|
+
|
|
177
|
+
variational_dropout_loss = tf.get_collection('variational_dropout_loss')
|
|
178
|
+
if variational_dropout_loss:
|
|
179
|
+
variational_dropout_loss = tf.add_n(
|
|
180
|
+
variational_dropout_loss, name='variational_dropout_loss')
|
|
181
|
+
loss_dict['variational_dropout_loss'] = variational_dropout_loss
|
|
182
|
+
|
|
183
|
+
loss = tf.add_n(list(loss_dict.values()))
|
|
184
|
+
loss_dict['total_loss'] = loss
|
|
185
|
+
for key in loss_dict:
|
|
186
|
+
tf.summary.scalar(key, loss_dict[key], family='loss')
|
|
187
|
+
|
|
188
|
+
if Input.DATA_OFFSET in features:
|
|
189
|
+
task_index, task_num = estimator_utils.get_task_index_and_num()
|
|
190
|
+
data_offset_var = tf.get_variable(
|
|
191
|
+
name=Input.DATA_OFFSET,
|
|
192
|
+
dtype=tf.string,
|
|
193
|
+
shape=[task_num],
|
|
194
|
+
collections=[tf.GraphKeys.GLOBAL_VARIABLES, Input.DATA_OFFSET],
|
|
195
|
+
trainable=False)
|
|
196
|
+
update_offset = tf.assign(data_offset_var[task_index],
|
|
197
|
+
features[Input.DATA_OFFSET])
|
|
198
|
+
ops.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_offset)
|
|
199
|
+
else:
|
|
200
|
+
data_offset_var = None
|
|
201
|
+
|
|
202
|
+
# update op, usually used for batch-norm
|
|
203
|
+
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
|
204
|
+
if update_ops:
|
|
205
|
+
# register for increment update, such as batchnorm moving_mean and moving_variance
|
|
206
|
+
global_vars = {x.name: x for x in tf.global_variables()}
|
|
207
|
+
for x in update_ops:
|
|
208
|
+
if isinstance(x, ops.Operation) and x.inputs[0].name in global_vars:
|
|
209
|
+
ops.add_to_collection(constant.DENSE_UPDATE_VARIABLES,
|
|
210
|
+
global_vars[x.inputs[0].name])
|
|
211
|
+
update_op = tf.group(*update_ops, name='update_barrier')
|
|
212
|
+
with tf.control_dependencies([update_op]):
|
|
213
|
+
loss = tf.identity(loss, name='total_loss')
|
|
214
|
+
|
|
215
|
+
# build optimizer
|
|
216
|
+
if len(self.train_config.optimizer_config) == 1:
|
|
217
|
+
optimizer_config = self.train_config.optimizer_config[0]
|
|
218
|
+
optimizer, learning_rate = optimizer_builder.build(optimizer_config)
|
|
219
|
+
tf.summary.scalar('learning_rate', learning_rate[0])
|
|
220
|
+
else:
|
|
221
|
+
optimizer_config = self.train_config.optimizer_config
|
|
222
|
+
all_opts = []
|
|
223
|
+
for opti_id, tmp_config in enumerate(optimizer_config):
|
|
224
|
+
with tf.name_scope('optimizer_%d' % opti_id):
|
|
225
|
+
opt, learning_rate = optimizer_builder.build(tmp_config)
|
|
226
|
+
tf.summary.scalar('learning_rate', learning_rate[0])
|
|
227
|
+
all_opts.append(opt)
|
|
228
|
+
grouped_vars = model.get_grouped_vars(len(all_opts))
|
|
229
|
+
assert len(grouped_vars) == len(optimizer_config), \
|
|
230
|
+
'the number of var group(%d) != the number of optimizers(%d)' \
|
|
231
|
+
% (len(grouped_vars), len(optimizer_config))
|
|
232
|
+
optimizer = MultiOptimizer(all_opts, grouped_vars)
|
|
233
|
+
|
|
234
|
+
if self.train_config.train_distribute == DistributionStrategy.SokStrategy:
|
|
235
|
+
optimizer = sok_optimizer.OptimizerWrapper(optimizer)
|
|
236
|
+
|
|
237
|
+
hooks = []
|
|
238
|
+
if estimator_utils.has_hvd():
|
|
239
|
+
assert not self.train_config.sync_replicas, \
|
|
240
|
+
'sync_replicas should not be set when using horovod'
|
|
241
|
+
bcast_hook = hvd_utils.BroadcastGlobalVariablesHook(0)
|
|
242
|
+
hooks.append(bcast_hook)
|
|
243
|
+
|
|
244
|
+
# for distributed and synced training
|
|
245
|
+
if self.train_config.sync_replicas and run_config.num_worker_replicas > 1:
|
|
246
|
+
logging.info('sync_replicas: num_worker_replias = %d' %
|
|
247
|
+
run_config.num_worker_replicas)
|
|
248
|
+
if pai_util.is_on_pai():
|
|
249
|
+
optimizer = tf.train.SyncReplicasOptimizer(
|
|
250
|
+
optimizer,
|
|
251
|
+
replicas_to_aggregate=run_config.num_worker_replicas,
|
|
252
|
+
total_num_replicas=run_config.num_worker_replicas,
|
|
253
|
+
sparse_accumulator_type=self.train_config.sparse_accumulator_type)
|
|
254
|
+
else:
|
|
255
|
+
optimizer = sync_replicas_optimizer.SyncReplicasOptimizer(
|
|
256
|
+
optimizer,
|
|
257
|
+
replicas_to_aggregate=run_config.num_worker_replicas,
|
|
258
|
+
total_num_replicas=run_config.num_worker_replicas)
|
|
259
|
+
hooks.append(
|
|
260
|
+
optimizer.make_session_run_hook(run_config.is_chief, num_tokens=0))
|
|
261
|
+
|
|
262
|
+
# add barrier for no strategy case
|
|
263
|
+
if run_config.num_worker_replicas > 1 and \
|
|
264
|
+
self.train_config.train_distribute == DistributionStrategy.NoStrategy:
|
|
265
|
+
hooks.append(
|
|
266
|
+
estimator_utils.ExitBarrierHook(run_config.num_worker_replicas,
|
|
267
|
+
run_config.is_chief, self.model_dir))
|
|
268
|
+
|
|
269
|
+
if self.export_config.enable_early_stop:
|
|
270
|
+
eval_dir = os.path.join(self._model_dir, 'eval_val')
|
|
271
|
+
logging.info('will use early stop, eval_events_dir=%s' % eval_dir)
|
|
272
|
+
if self.export_config.HasField('early_stop_func'):
|
|
273
|
+
hooks.append(
|
|
274
|
+
custom_early_stop_hook(
|
|
275
|
+
self,
|
|
276
|
+
eval_dir=eval_dir,
|
|
277
|
+
custom_stop_func=self.export_config.early_stop_func,
|
|
278
|
+
custom_stop_func_params=self.export_config.early_stop_params))
|
|
279
|
+
elif self.export_config.metric_bigger:
|
|
280
|
+
hooks.append(
|
|
281
|
+
stop_if_no_increase_hook(
|
|
282
|
+
self,
|
|
283
|
+
self.export_config.best_exporter_metric,
|
|
284
|
+
self.export_config.max_check_steps,
|
|
285
|
+
eval_dir=eval_dir))
|
|
286
|
+
else:
|
|
287
|
+
hooks.append(
|
|
288
|
+
stop_if_no_decrease_hook(
|
|
289
|
+
self,
|
|
290
|
+
self.export_config.best_exporter_metric,
|
|
291
|
+
self.export_config.max_check_steps,
|
|
292
|
+
eval_dir=eval_dir))
|
|
293
|
+
|
|
294
|
+
if self.train_config.enable_oss_stop_signal:
|
|
295
|
+
hooks.append(oss_stop_hook(self))
|
|
296
|
+
|
|
297
|
+
if self.train_config.HasField('dead_line'):
|
|
298
|
+
hooks.append(deadline_stop_hook(self, self.train_config.dead_line))
|
|
299
|
+
|
|
300
|
+
summaries = ['global_gradient_norm']
|
|
301
|
+
if self.train_config.summary_model_vars:
|
|
302
|
+
summaries.extend(['gradient_norm', 'gradients'])
|
|
303
|
+
|
|
304
|
+
gradient_clipping_by_norm = self.train_config.gradient_clipping_by_norm
|
|
305
|
+
if gradient_clipping_by_norm <= 0:
|
|
306
|
+
gradient_clipping_by_norm = None
|
|
307
|
+
|
|
308
|
+
gradient_multipliers = None
|
|
309
|
+
if self.train_config.optimizer_config[0].HasField(
|
|
310
|
+
'embedding_learning_rate_multiplier'):
|
|
311
|
+
gradient_multipliers = {
|
|
312
|
+
var: self.train_config.optimizer_config[0]
|
|
313
|
+
.embedding_learning_rate_multiplier
|
|
314
|
+
for var in tf.trainable_variables()
|
|
315
|
+
if 'embedding_weights:' in var.name or
|
|
316
|
+
'/embedding_weights/part_' in var.name
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
# optimize loss
|
|
320
|
+
# colocate_gradients_with_ops=True means to compute gradients
|
|
321
|
+
# on the same device on which op is processes in forward process
|
|
322
|
+
all_train_vars = []
|
|
323
|
+
if len(self.train_config.freeze_gradient) > 0:
|
|
324
|
+
for one_var in tf.trainable_variables():
|
|
325
|
+
is_freeze = False
|
|
326
|
+
for x in self.train_config.freeze_gradient:
|
|
327
|
+
if re.search(x, one_var.name) is not None:
|
|
328
|
+
logging.info('will freeze gradients of %s' % one_var.name)
|
|
329
|
+
is_freeze = True
|
|
330
|
+
break
|
|
331
|
+
if not is_freeze:
|
|
332
|
+
all_train_vars.append(one_var)
|
|
333
|
+
else:
|
|
334
|
+
all_train_vars = tf.trainable_variables()
|
|
335
|
+
|
|
336
|
+
if self.embedding_parallel:
|
|
337
|
+
logging.info('embedding_parallel is enabled')
|
|
338
|
+
|
|
339
|
+
train_op = optimizers.optimize_loss(
|
|
340
|
+
loss=loss,
|
|
341
|
+
global_step=tf.train.get_global_step(),
|
|
342
|
+
learning_rate=None,
|
|
343
|
+
clip_gradients=gradient_clipping_by_norm,
|
|
344
|
+
optimizer=optimizer,
|
|
345
|
+
gradient_multipliers=gradient_multipliers,
|
|
346
|
+
variables=all_train_vars,
|
|
347
|
+
summaries=summaries,
|
|
348
|
+
colocate_gradients_with_ops=True,
|
|
349
|
+
not_apply_grad_after_first_step=run_config.is_chief and
|
|
350
|
+
self._pipeline_config.data_config.chief_redundant,
|
|
351
|
+
name='', # Preventing scope prefix on all variables.
|
|
352
|
+
incr_save=(self.incr_save_config is not None),
|
|
353
|
+
embedding_parallel=self.embedding_parallel)
|
|
354
|
+
|
|
355
|
+
# online evaluation
|
|
356
|
+
metric_update_op_dict = None
|
|
357
|
+
if self.eval_config.eval_online:
|
|
358
|
+
metric_update_op_dict = {}
|
|
359
|
+
metric_dict = model.build_metric_graph(self.eval_config)
|
|
360
|
+
for k, v in metric_dict.items():
|
|
361
|
+
metric_update_op_dict['%s/batch' % k] = v[1]
|
|
362
|
+
if isinstance(v[1], tf.Tensor):
|
|
363
|
+
tf.summary.scalar('%s/batch' % k, v[1])
|
|
364
|
+
train_op = tf.group([train_op] + list(metric_update_op_dict.values()))
|
|
365
|
+
if estimator_utils.is_chief():
|
|
366
|
+
hooks.append(
|
|
367
|
+
estimator_utils.OnlineEvaluationHook(
|
|
368
|
+
metric_dict=metric_dict, output_dir=self.model_dir))
|
|
369
|
+
|
|
370
|
+
if self.train_config.HasField('fine_tune_checkpoint'):
|
|
371
|
+
fine_tune_ckpt = self.train_config.fine_tune_checkpoint
|
|
372
|
+
logging.warning('will restore from %s' % fine_tune_ckpt)
|
|
373
|
+
fine_tune_ckpt_var_map = self.train_config.fine_tune_ckpt_var_map
|
|
374
|
+
force_restore = self.train_config.force_restore_shape_compatible
|
|
375
|
+
restore_hook = model.restore(
|
|
376
|
+
fine_tune_ckpt,
|
|
377
|
+
include_global_step=False,
|
|
378
|
+
ckpt_var_map_path=fine_tune_ckpt_var_map,
|
|
379
|
+
force_restore_shape_compatible=force_restore)
|
|
380
|
+
if restore_hook is not None:
|
|
381
|
+
hooks.append(restore_hook)
|
|
382
|
+
|
|
383
|
+
# logging
|
|
384
|
+
logging_dict = OrderedDict()
|
|
385
|
+
logging_dict['step'] = tf.train.get_global_step()
|
|
386
|
+
logging_dict['lr'] = learning_rate[0]
|
|
387
|
+
logging_dict.update(loss_dict)
|
|
388
|
+
if metric_update_op_dict is not None:
|
|
389
|
+
logging_dict.update(metric_update_op_dict)
|
|
390
|
+
|
|
391
|
+
log_step_count_steps = self.train_config.log_step_count_steps
|
|
392
|
+
logging_hook = basic_session_run_hooks.LoggingTensorHook(
|
|
393
|
+
logging_dict,
|
|
394
|
+
every_n_iter=log_step_count_steps,
|
|
395
|
+
formatter=estimator_utils.tensor_log_format_func)
|
|
396
|
+
hooks.append(logging_hook)
|
|
397
|
+
|
|
398
|
+
if self.train_config.train_distribute in [
|
|
399
|
+
DistributionStrategy.CollectiveAllReduceStrategy,
|
|
400
|
+
DistributionStrategy.MirroredStrategy,
|
|
401
|
+
DistributionStrategy.MultiWorkerMirroredStrategy
|
|
402
|
+
]:
|
|
403
|
+
# for multi worker strategy, we could not replace the
|
|
404
|
+
# inner CheckpointSaverHook, so just use it.
|
|
405
|
+
scaffold = tf.train.Scaffold()
|
|
406
|
+
else:
|
|
407
|
+
var_list = (
|
|
408
|
+
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) +
|
|
409
|
+
tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS))
|
|
410
|
+
|
|
411
|
+
# exclude data_offset_var
|
|
412
|
+
var_list = [x for x in var_list if x != data_offset_var]
|
|
413
|
+
# early_stop flag will not be saved in checkpoint
|
|
414
|
+
# and could not be restored from checkpoint
|
|
415
|
+
early_stop_var = find_early_stop_var(var_list)
|
|
416
|
+
var_list = [x for x in var_list if x != early_stop_var]
|
|
417
|
+
|
|
418
|
+
initialize_var_list = [
|
|
419
|
+
x for x in var_list if 'WorkQueue' not in str(type(x))
|
|
420
|
+
]
|
|
421
|
+
|
|
422
|
+
# incompatiable shape restore will not be saved in checkpoint
|
|
423
|
+
# but must be able to restore from checkpoint
|
|
424
|
+
incompatiable_shape_restore = tf.get_collection('T_E_M_P_RESTROE')
|
|
425
|
+
|
|
426
|
+
local_init_ops = [tf.train.Scaffold.default_local_init_op()]
|
|
427
|
+
if data_offset_var is not None and estimator_utils.is_chief():
|
|
428
|
+
local_init_ops.append(tf.initializers.variables([data_offset_var]))
|
|
429
|
+
if early_stop_var is not None and estimator_utils.is_chief():
|
|
430
|
+
local_init_ops.append(tf.initializers.variables([early_stop_var]))
|
|
431
|
+
if len(incompatiable_shape_restore) > 0:
|
|
432
|
+
local_init_ops.append(
|
|
433
|
+
tf.initializers.variables(incompatiable_shape_restore))
|
|
434
|
+
|
|
435
|
+
scaffold = tf.train.Scaffold(
|
|
436
|
+
saver=self.saver_cls(
|
|
437
|
+
var_list=var_list,
|
|
438
|
+
sharded=True,
|
|
439
|
+
max_to_keep=self.train_config.keep_checkpoint_max,
|
|
440
|
+
save_relative_paths=True),
|
|
441
|
+
local_init_op=tf.group(local_init_ops),
|
|
442
|
+
ready_for_local_init_op=tf.report_uninitialized_variables(
|
|
443
|
+
var_list=initialize_var_list))
|
|
444
|
+
# saver hook
|
|
445
|
+
saver_hook = estimator_utils.CheckpointSaverHook(
|
|
446
|
+
checkpoint_dir=self.model_dir,
|
|
447
|
+
save_secs=self._config.save_checkpoints_secs,
|
|
448
|
+
save_steps=self._config.save_checkpoints_steps,
|
|
449
|
+
scaffold=scaffold,
|
|
450
|
+
write_graph=self.train_config.write_graph,
|
|
451
|
+
data_offset_var=data_offset_var,
|
|
452
|
+
increment_save_config=self.incr_save_config)
|
|
453
|
+
if estimator_utils.is_chief() or self.embedding_parallel:
|
|
454
|
+
hooks.append(saver_hook)
|
|
455
|
+
if estimator_utils.is_chief():
|
|
456
|
+
hooks.append(
|
|
457
|
+
basic_session_run_hooks.StepCounterHook(
|
|
458
|
+
every_n_steps=log_step_count_steps, output_dir=self.model_dir))
|
|
459
|
+
|
|
460
|
+
# profiling hook
|
|
461
|
+
if self.train_config.is_profiling and estimator_utils.is_chief():
|
|
462
|
+
profile_hook = tf.train.ProfilerHook(
|
|
463
|
+
save_steps=log_step_count_steps, output_dir=self.model_dir)
|
|
464
|
+
hooks.append(profile_hook)
|
|
465
|
+
|
|
466
|
+
return tf.estimator.EstimatorSpec(
|
|
467
|
+
mode=tf.estimator.ModeKeys.TRAIN,
|
|
468
|
+
loss=loss,
|
|
469
|
+
predictions=predict_dict,
|
|
470
|
+
train_op=train_op,
|
|
471
|
+
scaffold=scaffold,
|
|
472
|
+
training_hooks=hooks)
|
|
473
|
+
|
|
474
|
+
def _eval_model_fn(self, features, labels, run_config):
|
|
475
|
+
tf.keras.backend.set_learning_phase(0)
|
|
476
|
+
start = time.time()
|
|
477
|
+
model = self._model_cls(
|
|
478
|
+
self.model_config,
|
|
479
|
+
self.feature_configs,
|
|
480
|
+
features,
|
|
481
|
+
labels,
|
|
482
|
+
is_training=False)
|
|
483
|
+
predict_dict = model.build_predict_graph()
|
|
484
|
+
loss_dict = model.build_loss_graph()
|
|
485
|
+
loss = tf.add_n(list(loss_dict.values()))
|
|
486
|
+
loss_dict['total_loss'] = loss
|
|
487
|
+
|
|
488
|
+
metric_dict = model.build_metric_graph(self.eval_config)
|
|
489
|
+
for loss_key in loss_dict.keys():
|
|
490
|
+
loss_tensor = loss_dict[loss_key]
|
|
491
|
+
# add key-prefix to make loss metric key in the same family of train loss
|
|
492
|
+
metric_dict['loss/loss/' + loss_key] = tf.metrics.mean(loss_tensor)
|
|
493
|
+
tf.logging.info('metric_dict keys: %s' % metric_dict.keys())
|
|
494
|
+
|
|
495
|
+
var_list = (
|
|
496
|
+
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) +
|
|
497
|
+
ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))
|
|
498
|
+
|
|
499
|
+
metric_variables = ops.get_collection(ops.GraphKeys.METRIC_VARIABLES)
|
|
500
|
+
model_ready_for_local_init_op = tf.variables_initializer(metric_variables)
|
|
501
|
+
scaffold = tf.train.Scaffold(
|
|
502
|
+
saver=self.saver_cls(
|
|
503
|
+
var_list=var_list, sharded=True, save_relative_paths=True),
|
|
504
|
+
ready_for_local_init_op=model_ready_for_local_init_op)
|
|
505
|
+
end = time.time()
|
|
506
|
+
tf.logging.info('eval graph construct finished. Time %.3fs' % (end - start))
|
|
507
|
+
return tf.estimator.EstimatorSpec(
|
|
508
|
+
mode=tf.estimator.ModeKeys.EVAL,
|
|
509
|
+
loss=loss,
|
|
510
|
+
scaffold=scaffold,
|
|
511
|
+
predictions=predict_dict,
|
|
512
|
+
eval_metric_ops=metric_dict)
|
|
513
|
+
|
|
514
|
+
def _distribute_eval_model_fn(self, features, labels, run_config):
|
|
515
|
+
tf.keras.backend.set_learning_phase(0)
|
|
516
|
+
start = time.time()
|
|
517
|
+
model = self._model_cls(
|
|
518
|
+
self.model_config,
|
|
519
|
+
self.feature_configs,
|
|
520
|
+
features,
|
|
521
|
+
labels,
|
|
522
|
+
is_training=False)
|
|
523
|
+
predict_dict = model.build_predict_graph()
|
|
524
|
+
loss_dict = model.build_loss_graph()
|
|
525
|
+
loss = tf.add_n(list(loss_dict.values()))
|
|
526
|
+
loss_dict['total_loss'] = loss
|
|
527
|
+
metric_dict = model.build_metric_graph(self.eval_config)
|
|
528
|
+
for loss_key in loss_dict.keys():
|
|
529
|
+
loss_tensor = loss_dict[loss_key]
|
|
530
|
+
# add key-prefix to make loss metric key in the same family of train loss
|
|
531
|
+
metric_dict['loss/loss/' + loss_key] = tf.metrics.mean(loss_tensor)
|
|
532
|
+
tf.logging.info('metric_dict keys: %s' % metric_dict.keys())
|
|
533
|
+
|
|
534
|
+
end = time.time()
|
|
535
|
+
tf.logging.info('eval graph construct finished. Time %.3fs' % (end - start))
|
|
536
|
+
metric_name_list = []
|
|
537
|
+
for metric_i in self.eval_config.metrics_set:
|
|
538
|
+
metric_name_list.append(metric_i.WhichOneof('metric'))
|
|
539
|
+
all_var_list = []
|
|
540
|
+
metric_var_list = []
|
|
541
|
+
for var in variables._all_saveable_objects():
|
|
542
|
+
var_name = var.name
|
|
543
|
+
flag = True
|
|
544
|
+
for metric_i in metric_name_list:
|
|
545
|
+
if metric_i in var_name:
|
|
546
|
+
flag = False
|
|
547
|
+
break
|
|
548
|
+
if flag:
|
|
549
|
+
all_var_list.append(var)
|
|
550
|
+
else:
|
|
551
|
+
metric_var_list.append(var)
|
|
552
|
+
global_variables = tf.global_variables()
|
|
553
|
+
metric_variables = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES)
|
|
554
|
+
model_ready_for_local_init_op = tf.variables_initializer(metric_variables)
|
|
555
|
+
remain_variables = list(
|
|
556
|
+
set(global_variables).difference(set(metric_variables)))
|
|
557
|
+
cur_saver = tf.train.Saver(var_list=remain_variables, sharded=True)
|
|
558
|
+
scaffold = tf.train.Scaffold(
|
|
559
|
+
saver=cur_saver, ready_for_local_init_op=model_ready_for_local_init_op)
|
|
560
|
+
return tf.estimator.EstimatorSpec(
|
|
561
|
+
mode=tf.estimator.ModeKeys.EVAL,
|
|
562
|
+
loss=loss,
|
|
563
|
+
predictions=predict_dict,
|
|
564
|
+
eval_metric_ops=metric_dict,
|
|
565
|
+
scaffold=scaffold)
|
|
566
|
+
|
|
567
|
+
def _export_model_fn(self, features, labels, run_config, params):
|
|
568
|
+
tf.keras.backend.set_learning_phase(0)
|
|
569
|
+
model = self._model_cls(
|
|
570
|
+
self.model_config,
|
|
571
|
+
self.feature_configs,
|
|
572
|
+
features,
|
|
573
|
+
labels=None,
|
|
574
|
+
is_training=False)
|
|
575
|
+
model.build_predict_graph()
|
|
576
|
+
|
|
577
|
+
export_config = self._pipeline_config.export_config
|
|
578
|
+
outputs = {}
|
|
579
|
+
logging.info('building default outputs')
|
|
580
|
+
outputs.update(model.build_output_dict())
|
|
581
|
+
if export_config.export_features:
|
|
582
|
+
logging.info('building output features')
|
|
583
|
+
outputs.update(model.build_feature_output_dict())
|
|
584
|
+
if export_config.export_rtp_outputs:
|
|
585
|
+
logging.info('building RTP outputs')
|
|
586
|
+
outputs.update(model.build_rtp_output_dict())
|
|
587
|
+
|
|
588
|
+
for out in outputs:
|
|
589
|
+
tf.logging.info(
|
|
590
|
+
'output %s shape: %s type: %s' %
|
|
591
|
+
(out, outputs[out].get_shape().as_list(), outputs[out].dtype))
|
|
592
|
+
export_outputs = {
|
|
593
|
+
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
|
594
|
+
tf.estimator.export.PredictOutput(outputs)
|
|
595
|
+
}
|
|
596
|
+
|
|
597
|
+
# save train pipeline.config for debug purpose
|
|
598
|
+
pipeline_path = os.path.join(self._model_dir, 'pipeline.config')
|
|
599
|
+
if gfile.Exists(pipeline_path):
|
|
600
|
+
ops.add_to_collection(
|
|
601
|
+
tf.GraphKeys.ASSET_FILEPATHS,
|
|
602
|
+
tf.constant(pipeline_path, dtype=tf.string, name='pipeline.config'))
|
|
603
|
+
else:
|
|
604
|
+
print('train pipeline_path(%s) does not exist' % pipeline_path)
|
|
605
|
+
|
|
606
|
+
# restore DENSE_UPDATE_VARIABLES collection
|
|
607
|
+
dense_train_var_path = os.path.join(self.model_dir,
|
|
608
|
+
constant.DENSE_UPDATE_VARIABLES)
|
|
609
|
+
if gfile.Exists(dense_train_var_path):
|
|
610
|
+
with gfile.GFile(dense_train_var_path, 'r') as fin:
|
|
611
|
+
var_name_to_id_map = json.load(fin)
|
|
612
|
+
var_name_id_lst = [
|
|
613
|
+
(x, var_name_to_id_map[x]) for x in var_name_to_id_map
|
|
614
|
+
]
|
|
615
|
+
var_name_id_lst.sort(key=lambda x: x[1])
|
|
616
|
+
all_vars = {x.op.name: x for x in tf.global_variables()}
|
|
617
|
+
for var_name, var_id in var_name_id_lst:
|
|
618
|
+
assert var_name in all_vars, 'dense_train_var[%s] is not found' % var_name
|
|
619
|
+
ops.add_to_collection(constant.DENSE_UPDATE_VARIABLES,
|
|
620
|
+
all_vars[var_name])
|
|
621
|
+
|
|
622
|
+
# add more asset files
|
|
623
|
+
if len(export_config.asset_files) > 0:
|
|
624
|
+
for asset_file in export_config.asset_files:
|
|
625
|
+
if asset_file.startswith('!'):
|
|
626
|
+
asset_file = asset_file[1:]
|
|
627
|
+
_, asset_name = os.path.split(asset_file)
|
|
628
|
+
ops.add_to_collection(
|
|
629
|
+
ops.GraphKeys.ASSET_FILEPATHS,
|
|
630
|
+
tf.constant(asset_file, dtype=tf.string, name=asset_name))
|
|
631
|
+
elif 'asset_files' in params:
|
|
632
|
+
for asset_name in params['asset_files']:
|
|
633
|
+
asset_file = params['asset_files'][asset_name]
|
|
634
|
+
ops.add_to_collection(
|
|
635
|
+
tf.GraphKeys.ASSET_FILEPATHS,
|
|
636
|
+
tf.constant(asset_file, dtype=tf.string, name=asset_name))
|
|
637
|
+
|
|
638
|
+
if self._pipeline_config.HasField('fg_json_path'):
|
|
639
|
+
fg_path = self._pipeline_config.fg_json_path
|
|
640
|
+
if fg_path[0] == '!':
|
|
641
|
+
fg_path = fg_path[1:]
|
|
642
|
+
ops.add_to_collection(
|
|
643
|
+
tf.GraphKeys.ASSET_FILEPATHS,
|
|
644
|
+
tf.constant(fg_path, dtype=tf.string, name='fg.json'))
|
|
645
|
+
|
|
646
|
+
var_list = (
|
|
647
|
+
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) +
|
|
648
|
+
ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))
|
|
649
|
+
|
|
650
|
+
scaffold = tf.train.Scaffold(
|
|
651
|
+
saver=self.saver_cls(
|
|
652
|
+
var_list=var_list, sharded=True, save_relative_paths=True))
|
|
653
|
+
|
|
654
|
+
return tf.estimator.EstimatorSpec(
|
|
655
|
+
mode=tf.estimator.ModeKeys.PREDICT,
|
|
656
|
+
loss=None,
|
|
657
|
+
scaffold=scaffold,
|
|
658
|
+
predictions=outputs,
|
|
659
|
+
export_outputs=export_outputs)
|
|
660
|
+
|
|
661
|
+
def _model_fn(self, features, labels, mode, config, params):
|
|
662
|
+
os.environ['tf.estimator.mode'] = mode
|
|
663
|
+
os.environ['tf.estimator.ModeKeys.TRAIN'] = tf.estimator.ModeKeys.TRAIN
|
|
664
|
+
if self._pipeline_config.feature_config.embedding_on_cpu:
|
|
665
|
+
os.environ['place_embedding_on_cpu'] = 'True'
|
|
666
|
+
if self._pipeline_config.fg_json_path:
|
|
667
|
+
EasyRecEstimator._write_rtp_fg_config_to_col(
|
|
668
|
+
fg_config_path=self._pipeline_config.fg_json_path)
|
|
669
|
+
EasyRecEstimator._write_rtp_inputs_to_col(features)
|
|
670
|
+
|
|
671
|
+
if self.embedding_parallel:
|
|
672
|
+
embedding_utils.set_embedding_parallel()
|
|
673
|
+
|
|
674
|
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
|
675
|
+
return self._train_model_fn(features, labels, config)
|
|
676
|
+
elif mode == tf.estimator.ModeKeys.EVAL:
|
|
677
|
+
return self._eval_model_fn(features, labels, config)
|
|
678
|
+
elif mode == tf.estimator.ModeKeys.PREDICT:
|
|
679
|
+
return self._export_model_fn(features, labels, config, params)
|
|
680
|
+
|
|
681
|
+
@staticmethod
|
|
682
|
+
def _write_rtp_fg_config_to_col(fg_config=None, fg_config_path=None):
|
|
683
|
+
"""Write RTP config to RTP-specified graph collections.
|
|
684
|
+
|
|
685
|
+
Args:
|
|
686
|
+
fg_config: JSON-dict RTP config. If set, fg_config_path will be ignored.
|
|
687
|
+
fg_config_path: path to the RTP config file.
|
|
688
|
+
"""
|
|
689
|
+
if fg_config is None:
|
|
690
|
+
if fg_config_path.startswith('!'):
|
|
691
|
+
fg_config_path = fg_config_path[1:]
|
|
692
|
+
with gfile.GFile(fg_config_path, 'r') as f:
|
|
693
|
+
fg_config = json.load(f)
|
|
694
|
+
col = ops.get_collection_ref(GraphKeys.RANK_SERVICE_FG_CONF)
|
|
695
|
+
if len(col) == 0:
|
|
696
|
+
col.append(json.dumps(fg_config))
|
|
697
|
+
else:
|
|
698
|
+
col[0] = json.dumps(fg_config)
|
|
699
|
+
|
|
700
|
+
@staticmethod
|
|
701
|
+
def _write_rtp_inputs_to_col(features):
|
|
702
|
+
"""Write input nodes information to RTP-specified graph collections.
|
|
703
|
+
|
|
704
|
+
Args:
|
|
705
|
+
features: the feature dictionary used as model input.
|
|
706
|
+
"""
|
|
707
|
+
feature_info_map = dict()
|
|
708
|
+
for feature_name, feature_value in features.items():
|
|
709
|
+
feature_info = _tensor_to_tensorinfo(feature_value)
|
|
710
|
+
feature_info_map[feature_name] = feature_info
|
|
711
|
+
col = ops.get_collection_ref(GraphKeys.RANK_SERVICE_FEATURE_NODE)
|
|
712
|
+
if len(col) == 0:
|
|
713
|
+
col.append(json.dumps(feature_info_map))
|
|
714
|
+
else:
|
|
715
|
+
col[0] = json.dumps(feature_info_map)
|
|
716
|
+
|
|
717
|
+
def export_checkpoint(self,
|
|
718
|
+
export_path=None,
|
|
719
|
+
serving_input_receiver_fn=None,
|
|
720
|
+
checkpoint_path=None,
|
|
721
|
+
mode=tf.estimator.ModeKeys.PREDICT):
|
|
722
|
+
with context.graph_mode():
|
|
723
|
+
if not checkpoint_path:
|
|
724
|
+
# Locate the latest checkpoint
|
|
725
|
+
checkpoint_path = estimator_utils.latest_checkpoint(self._model_dir)
|
|
726
|
+
if not checkpoint_path:
|
|
727
|
+
raise ValueError("Couldn't find trained model at %s." % self._model_dir)
|
|
728
|
+
with ops.Graph().as_default():
|
|
729
|
+
input_receiver = serving_input_receiver_fn()
|
|
730
|
+
estimator_spec = self._call_model_fn(
|
|
731
|
+
features=input_receiver.features,
|
|
732
|
+
labels=getattr(input_receiver, 'labels', None),
|
|
733
|
+
mode=mode,
|
|
734
|
+
config=self.config)
|
|
735
|
+
with tf_session.Session(config=self._session_config) as session:
|
|
736
|
+
graph_saver = estimator_spec.scaffold.saver or saver.Saver(
|
|
737
|
+
sharded=True)
|
|
738
|
+
graph_saver.restore(session, checkpoint_path)
|
|
739
|
+
graph_saver.save(session, export_path)
|