easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,1036 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
from __future__ import absolute_import
|
|
4
|
+
from __future__ import division
|
|
5
|
+
from __future__ import print_function
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
import sys
|
|
12
|
+
import time
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import six
|
|
16
|
+
import tensorflow as tf
|
|
17
|
+
from tensorflow.core.framework.summary_pb2 import Summary
|
|
18
|
+
from tensorflow.python.client import device_lib
|
|
19
|
+
from tensorflow.python.framework import errors_impl
|
|
20
|
+
from tensorflow.python.framework import meta_graph
|
|
21
|
+
from tensorflow.python.framework import ops
|
|
22
|
+
from tensorflow.python.ops import array_ops
|
|
23
|
+
from tensorflow.python.platform import gfile
|
|
24
|
+
from tensorflow.python.training import basic_session_run_hooks
|
|
25
|
+
from tensorflow.python.training import session_run_hook
|
|
26
|
+
from tensorflow.python.training.summary_io import SummaryWriterCache
|
|
27
|
+
|
|
28
|
+
from easy_rec.python.ops.incr_record import get_sparse_indices
|
|
29
|
+
from easy_rec.python.ops.incr_record import kv_resource_incr_gather
|
|
30
|
+
from easy_rec.python.utils import constant
|
|
31
|
+
from easy_rec.python.utils import embedding_utils
|
|
32
|
+
from easy_rec.python.utils import shape_utils
|
|
33
|
+
|
|
34
|
+
from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer # NOQA
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
import horovod.tensorflow as hvd
|
|
38
|
+
except Exception:
|
|
39
|
+
hvd = None
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
from sparse_operation_kit import experiment as sok
|
|
43
|
+
except Exception:
|
|
44
|
+
sok = None
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
from kafka import KafkaProducer, KafkaAdminClient
|
|
48
|
+
from kafka.admin import NewTopic
|
|
49
|
+
except ImportError as ex:
|
|
50
|
+
logging.warning('kafka-python is not installed: %s' % str(ex))
|
|
51
|
+
|
|
52
|
+
if tf.__version__ >= '2.0':
|
|
53
|
+
tf = tf.compat.v1
|
|
54
|
+
SessionRunHook = session_run_hook.SessionRunHook
|
|
55
|
+
CheckpointSaverHook = basic_session_run_hooks.CheckpointSaverHook
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def tensor_log_format_func(tensor_dict):
|
|
59
|
+
prefix = ''
|
|
60
|
+
if 'step' in tensor_dict:
|
|
61
|
+
prefix = 'global step %s: ' % tensor_dict['step']
|
|
62
|
+
stats = []
|
|
63
|
+
for k in tensor_dict:
|
|
64
|
+
if k == 'step':
|
|
65
|
+
continue
|
|
66
|
+
tensor_value = tensor_dict[k]
|
|
67
|
+
stats.append('%s = %s' % (k, tensor_value))
|
|
68
|
+
return prefix + ', '.join(stats)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class ExitBarrierHook(SessionRunHook):
|
|
72
|
+
"""ExitBarrier to make sure master and workers exit at the same time.
|
|
73
|
+
|
|
74
|
+
After training finish, master has to do evaluation and model export, so master exits a little late
|
|
75
|
+
than workers.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(self, num_worker, is_chief, model_dir):
|
|
79
|
+
self._num_worker = num_worker
|
|
80
|
+
self._is_chief = is_chief
|
|
81
|
+
self._queue = None
|
|
82
|
+
self._signal_que = None
|
|
83
|
+
self._que_size = None
|
|
84
|
+
self._queue = None
|
|
85
|
+
self._enque = None
|
|
86
|
+
self._deque = None
|
|
87
|
+
self._model_dir = model_dir
|
|
88
|
+
self._send = None
|
|
89
|
+
self._recv = None
|
|
90
|
+
|
|
91
|
+
def begin(self):
|
|
92
|
+
"""Count the number of workers and masters, and setup barrier queue."""
|
|
93
|
+
tf.logging.info('number workers(including master) = %d' % self._num_worker)
|
|
94
|
+
with tf.device(
|
|
95
|
+
tf.DeviceSpec(job='ps', task=0, device_type='CPU', device_index=0)):
|
|
96
|
+
self._queue = tf.FIFOQueue(
|
|
97
|
+
capacity=self._num_worker,
|
|
98
|
+
dtypes=[tf.float32],
|
|
99
|
+
shapes=[()],
|
|
100
|
+
name='exit_counter',
|
|
101
|
+
shared_name='exit_counter')
|
|
102
|
+
self._signal_que = tf.FIFOQueue(
|
|
103
|
+
capacity=self._num_worker,
|
|
104
|
+
dtypes=[tf.string],
|
|
105
|
+
shapes=[()],
|
|
106
|
+
name='exit_counter_signal',
|
|
107
|
+
shared_name='exit_counter_signal')
|
|
108
|
+
self._enque = self._queue.enqueue(1.0)
|
|
109
|
+
self._que_size = self._queue.size()
|
|
110
|
+
self._deque = self._queue.dequeue()
|
|
111
|
+
if self._is_chief:
|
|
112
|
+
self._flag_file = os.path.join(self._model_dir,
|
|
113
|
+
'atexit_sync_' + str(int(time.time())))
|
|
114
|
+
self._send = self._signal_que.enqueue([self._flag_file])
|
|
115
|
+
else:
|
|
116
|
+
self._recv = self._signal_que.dequeue()
|
|
117
|
+
self._flag_file = None
|
|
118
|
+
|
|
119
|
+
def after_create_session(self, session, coord):
|
|
120
|
+
"""Clean up the queue after create session.
|
|
121
|
+
|
|
122
|
+
Sometimes ps is not exit, the last run enqueued elements will remain in the queue
|
|
123
|
+
"""
|
|
124
|
+
if self._is_chief:
|
|
125
|
+
# clear the queue
|
|
126
|
+
que_size = session.run(self._que_size)
|
|
127
|
+
while que_size > 0:
|
|
128
|
+
session.run(self._deque)
|
|
129
|
+
que_size = session.run(self._que_size)
|
|
130
|
+
logging.info('exit counter cleared: %d' % que_size)
|
|
131
|
+
|
|
132
|
+
def end(self, session):
|
|
133
|
+
"""Ensure when all workers and master enqueue an element, then exit."""
|
|
134
|
+
session.run(self._enque)
|
|
135
|
+
que_size = session.run(self._que_size)
|
|
136
|
+
while que_size < self._num_worker:
|
|
137
|
+
que_size = session.run(self._que_size)
|
|
138
|
+
time.sleep(5)
|
|
139
|
+
tf.logging.info(
|
|
140
|
+
'waiting for other worker to exit, finished %d, total %d' %
|
|
141
|
+
(que_size, self._num_worker))
|
|
142
|
+
# prepare on_exit synchronize base on self._flag_file
|
|
143
|
+
if self._is_chief:
|
|
144
|
+
for i in range(self._num_worker - 1):
|
|
145
|
+
session.run(self._send)
|
|
146
|
+
else:
|
|
147
|
+
self._flag_file = session.run(self._recv)
|
|
148
|
+
|
|
149
|
+
def _check_flag_file(is_chief, flag_file):
|
|
150
|
+
logging.info('_check_flag_file: is_chief = %d flag_file=%s' %
|
|
151
|
+
(is_chief, flag_file))
|
|
152
|
+
if is_chief:
|
|
153
|
+
with gfile.GFile(flag_file, 'w') as fout:
|
|
154
|
+
fout.write('atexit time: %d' % int(time.time()))
|
|
155
|
+
else:
|
|
156
|
+
while not gfile.Exists(flag_file):
|
|
157
|
+
time.sleep(1)
|
|
158
|
+
|
|
159
|
+
from atexit import register
|
|
160
|
+
register(
|
|
161
|
+
_check_flag_file, is_chief=self._is_chief, flag_file=self._flag_file)
|
|
162
|
+
logging.info('ExitBarrier passed')
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class EvaluateExitBarrierHook(SessionRunHook):
|
|
166
|
+
"""ExitBarrier to make sure master and workers exit at the same time.
|
|
167
|
+
|
|
168
|
+
After training finish, master has to do evaluation and model export, so master exits a little late
|
|
169
|
+
than workers.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __init__(self, num_worker, is_chief, model_dir, metric_ops=None):
|
|
173
|
+
self._num_worker = num_worker
|
|
174
|
+
self._is_chief = is_chief
|
|
175
|
+
self._queue = None
|
|
176
|
+
self._signal_que = None
|
|
177
|
+
self._que_size = None
|
|
178
|
+
self._queue = None
|
|
179
|
+
self._enque = None
|
|
180
|
+
self._deque = None
|
|
181
|
+
self._model_dir = model_dir
|
|
182
|
+
self._send = None
|
|
183
|
+
self._recv = None
|
|
184
|
+
self.metric_ops = metric_ops
|
|
185
|
+
self.eval_result = None
|
|
186
|
+
|
|
187
|
+
def begin(self):
|
|
188
|
+
"""Count the number of workers and masters, and setup barrier queue."""
|
|
189
|
+
tf.logging.info('number workers(including master) = %d' % self._num_worker)
|
|
190
|
+
with tf.device(
|
|
191
|
+
tf.DeviceSpec(job='ps', task=0, device_type='CPU', device_index=0)):
|
|
192
|
+
self._queue = tf.FIFOQueue(
|
|
193
|
+
capacity=self._num_worker,
|
|
194
|
+
dtypes=[tf.float32],
|
|
195
|
+
shapes=[()],
|
|
196
|
+
name='exit_counter',
|
|
197
|
+
shared_name='exit_counter')
|
|
198
|
+
self._signal_que = tf.FIFOQueue(
|
|
199
|
+
capacity=self._num_worker,
|
|
200
|
+
dtypes=[tf.string],
|
|
201
|
+
shapes=[()],
|
|
202
|
+
name='exit_counter_signal',
|
|
203
|
+
shared_name='exit_counter_signal')
|
|
204
|
+
self._enque = self._queue.enqueue(1.0)
|
|
205
|
+
self._que_size = self._queue.size()
|
|
206
|
+
self._deque = self._queue.dequeue()
|
|
207
|
+
if self._is_chief:
|
|
208
|
+
self._flag_file = os.path.join(self._model_dir,
|
|
209
|
+
'atexit_sync_' + str(int(time.time())))
|
|
210
|
+
self._send = self._signal_que.enqueue([self._flag_file])
|
|
211
|
+
else:
|
|
212
|
+
self._recv = self._signal_que.dequeue()
|
|
213
|
+
self._flag_file = None
|
|
214
|
+
|
|
215
|
+
def after_create_session(self, session, coord):
|
|
216
|
+
"""Clean up the queue after create session.
|
|
217
|
+
|
|
218
|
+
Sometimes ps is not exit, the last run enqueued elements will remain in the queue
|
|
219
|
+
"""
|
|
220
|
+
if self._is_chief:
|
|
221
|
+
# clear the queue
|
|
222
|
+
que_size = session.run(self._que_size)
|
|
223
|
+
while que_size > 0:
|
|
224
|
+
session.run(self._deque)
|
|
225
|
+
que_size = session.run(self._que_size)
|
|
226
|
+
logging.info('exit counter cleared: %d' % que_size)
|
|
227
|
+
|
|
228
|
+
def end(self, session):
|
|
229
|
+
"""Ensure when all workers and master enqueue an element, then exit."""
|
|
230
|
+
session.run(self._enque)
|
|
231
|
+
que_size = session.run(self._que_size)
|
|
232
|
+
while que_size < self._num_worker:
|
|
233
|
+
que_size = session.run(self._que_size)
|
|
234
|
+
time.sleep(5)
|
|
235
|
+
tf.logging.info(
|
|
236
|
+
'waiting for other worker to exit, finished %d, total %d' %
|
|
237
|
+
(que_size, self._num_worker))
|
|
238
|
+
# prepare on_exit synchronize base on self._flag_file
|
|
239
|
+
if self._is_chief:
|
|
240
|
+
self.eval_result = session.run(self.metric_ops)
|
|
241
|
+
for i in range(self._num_worker - 1):
|
|
242
|
+
session.run(self._send)
|
|
243
|
+
else:
|
|
244
|
+
self._flag_file = session.run(self._recv)
|
|
245
|
+
|
|
246
|
+
def _check_flag_file(is_chief, flag_file):
|
|
247
|
+
logging.info('_check_flag_file: is_chief = %d flag_file=%s' %
|
|
248
|
+
(is_chief, flag_file))
|
|
249
|
+
if is_chief:
|
|
250
|
+
with gfile.GFile(flag_file, 'w') as fout:
|
|
251
|
+
fout.write('atexit time: %d' % int(time.time()))
|
|
252
|
+
else:
|
|
253
|
+
while not gfile.Exists(flag_file):
|
|
254
|
+
time.sleep(1)
|
|
255
|
+
|
|
256
|
+
from atexit import register
|
|
257
|
+
register(
|
|
258
|
+
_check_flag_file, is_chief=self._is_chief, flag_file=self._flag_file)
|
|
259
|
+
session.run(self.metric_ops)
|
|
260
|
+
|
|
261
|
+
logging.info('ExitBarrier passed')
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
class ProgressHook(SessionRunHook):
|
|
265
|
+
|
|
266
|
+
def __init__(self, num_steps, filename, is_chief):
|
|
267
|
+
"""Initializes a `ProgressHook`.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
num_steps: total train steps
|
|
271
|
+
filename: progress file name
|
|
272
|
+
is_chief: is chief worker or not
|
|
273
|
+
"""
|
|
274
|
+
self._num_steps = num_steps
|
|
275
|
+
self._is_chief = is_chief
|
|
276
|
+
if self._is_chief:
|
|
277
|
+
self._progress_file = gfile.GFile(filename, 'w')
|
|
278
|
+
self._progress_file.write('0.00\n')
|
|
279
|
+
self._progress_interval = 0.01 # 1%
|
|
280
|
+
self._last_progress_cnt = 0
|
|
281
|
+
|
|
282
|
+
def before_run(self, run_context):
|
|
283
|
+
if self._is_chief:
|
|
284
|
+
return tf.train.SessionRunArgs([tf.train.get_global_step()])
|
|
285
|
+
|
|
286
|
+
def after_run(
|
|
287
|
+
self,
|
|
288
|
+
run_context, # pylint: disable=unused-argument
|
|
289
|
+
run_values):
|
|
290
|
+
if self._is_chief:
|
|
291
|
+
global_step = run_values.results[0]
|
|
292
|
+
curr_progress = global_step / self._num_steps
|
|
293
|
+
curr_progress_cnt = int(curr_progress / self._progress_interval)
|
|
294
|
+
if curr_progress_cnt >= self._last_progress_cnt + 1:
|
|
295
|
+
self._progress_file.write('%.2f\n' % curr_progress)
|
|
296
|
+
self._progress_file.flush()
|
|
297
|
+
self._last_progress_cnt = curr_progress_cnt
|
|
298
|
+
logging.info('Training Progress: %.2f' % curr_progress)
|
|
299
|
+
|
|
300
|
+
def end(self, session):
|
|
301
|
+
if self._is_chief:
|
|
302
|
+
if self._last_progress_cnt < 1 / self._progress_interval:
|
|
303
|
+
self._progress_file.write('1.00\n')
|
|
304
|
+
self._progress_file.close()
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class CheckpointSaverHook(CheckpointSaverHook):
|
|
308
|
+
"""Saves checkpoints every N steps or seconds."""
|
|
309
|
+
|
|
310
|
+
def __init__(self,
|
|
311
|
+
checkpoint_dir,
|
|
312
|
+
save_secs=None,
|
|
313
|
+
save_steps=None,
|
|
314
|
+
saver=None,
|
|
315
|
+
checkpoint_basename='model.ckpt',
|
|
316
|
+
scaffold=None,
|
|
317
|
+
listeners=None,
|
|
318
|
+
write_graph=True,
|
|
319
|
+
data_offset_var=None,
|
|
320
|
+
increment_save_config=None):
|
|
321
|
+
"""Initializes a `CheckpointSaverHook`.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
checkpoint_dir: `str`, base directory for the checkpoint files.
|
|
325
|
+
save_secs: `int`, save every N secs.
|
|
326
|
+
save_steps: `int`, save every N steps.
|
|
327
|
+
saver: `Saver` object, used for saving.
|
|
328
|
+
checkpoint_basename: `str`, base name for the checkpoint files.
|
|
329
|
+
scaffold: `Scaffold`, use to get saver object.
|
|
330
|
+
listeners: List of `CheckpointSaverListener` subclass instances.
|
|
331
|
+
Used for callbacks that run immediately before or after this hook saves
|
|
332
|
+
the checkpoint.
|
|
333
|
+
write_graph: whether to save graph.pbtxt.
|
|
334
|
+
data_offset_var: data offset variable.
|
|
335
|
+
increment_save_config: parameters for saving increment checkpoints.
|
|
336
|
+
|
|
337
|
+
Raises:
|
|
338
|
+
ValueError: One of `save_steps` or `save_secs` should be set.
|
|
339
|
+
ValueError: At most one of saver or scaffold should be set.
|
|
340
|
+
"""
|
|
341
|
+
super(CheckpointSaverHook, self).__init__(
|
|
342
|
+
checkpoint_dir,
|
|
343
|
+
save_secs=save_secs,
|
|
344
|
+
save_steps=save_steps,
|
|
345
|
+
saver=saver,
|
|
346
|
+
checkpoint_basename=checkpoint_basename,
|
|
347
|
+
scaffold=scaffold,
|
|
348
|
+
listeners=listeners)
|
|
349
|
+
self._cuda_profile_start = 0
|
|
350
|
+
self._cuda_profile_stop = 0
|
|
351
|
+
self._steps_per_run = 1
|
|
352
|
+
self._write_graph = write_graph
|
|
353
|
+
self._data_offset_var = data_offset_var
|
|
354
|
+
|
|
355
|
+
self._task_idx, self._task_num = get_task_index_and_num()
|
|
356
|
+
|
|
357
|
+
if increment_save_config is not None:
|
|
358
|
+
self._kafka_timeout_ms = os.environ.get('KAFKA_TIMEOUT', 600) * 1000
|
|
359
|
+
logging.info('KAFKA_TIMEOUT: %dms' % self._kafka_timeout_ms)
|
|
360
|
+
self._kafka_max_req_size = os.environ.get('KAFKA_MAX_REQ_SIZE',
|
|
361
|
+
1024 * 1024 * 64)
|
|
362
|
+
logging.info('KAFKA_MAX_REQ_SIZE: %d' % self._kafka_max_req_size)
|
|
363
|
+
self._kafka_max_msg_size = os.environ.get('KAFKA_MAX_MSG_SIZE',
|
|
364
|
+
1024 * 1024 * 1024)
|
|
365
|
+
logging.info('KAFKA_MAX_MSG_SIZE: %d' % self._kafka_max_msg_size)
|
|
366
|
+
|
|
367
|
+
self._dense_name_to_ids = embedding_utils.get_dense_name_to_ids()
|
|
368
|
+
self._sparse_name_to_ids = embedding_utils.get_sparse_name_to_ids()
|
|
369
|
+
|
|
370
|
+
with gfile.GFile(
|
|
371
|
+
os.path.join(checkpoint_dir, constant.DENSE_UPDATE_VARIABLES),
|
|
372
|
+
'w') as fout:
|
|
373
|
+
json.dump(self._dense_name_to_ids, fout, indent=2)
|
|
374
|
+
|
|
375
|
+
save_secs = increment_save_config.dense_save_secs
|
|
376
|
+
save_steps = increment_save_config.dense_save_steps
|
|
377
|
+
self._dense_timer = SecondOrStepTimer(
|
|
378
|
+
every_secs=save_secs if save_secs > 0 else None,
|
|
379
|
+
every_steps=save_steps if save_steps > 0 else None)
|
|
380
|
+
save_secs = increment_save_config.sparse_save_secs
|
|
381
|
+
save_steps = increment_save_config.sparse_save_steps
|
|
382
|
+
self._sparse_timer = SecondOrStepTimer(
|
|
383
|
+
every_secs=save_secs if save_secs > 0 else None,
|
|
384
|
+
every_steps=save_steps if save_steps > 0 else None)
|
|
385
|
+
|
|
386
|
+
self._dense_timer.update_last_triggered_step(0)
|
|
387
|
+
self._sparse_timer.update_last_triggered_step(0)
|
|
388
|
+
|
|
389
|
+
self._sparse_indices = []
|
|
390
|
+
self._sparse_values = []
|
|
391
|
+
sparse_train_vars = ops.get_collection(constant.SPARSE_UPDATE_VARIABLES)
|
|
392
|
+
for sparse_var, indice_dtype in sparse_train_vars:
|
|
393
|
+
with ops.control_dependencies([tf.train.get_global_step()]):
|
|
394
|
+
with ops.colocate_with(sparse_var):
|
|
395
|
+
sparse_indice = get_sparse_indices(
|
|
396
|
+
var_name=sparse_var.op.name, ktype=indice_dtype)
|
|
397
|
+
# sparse_indice = sparse_indice.global_indices
|
|
398
|
+
self._sparse_indices.append(sparse_indice)
|
|
399
|
+
if 'EmbeddingVariable' in str(type(sparse_var)):
|
|
400
|
+
self._sparse_values.append(
|
|
401
|
+
kv_resource_incr_gather(
|
|
402
|
+
sparse_var._handle, sparse_indice,
|
|
403
|
+
np.zeros(sparse_var.shape.as_list(), dtype=np.float32)))
|
|
404
|
+
# sparse_var.sparse_read(sparse_indice))
|
|
405
|
+
else:
|
|
406
|
+
self._sparse_values.append(
|
|
407
|
+
array_ops.gather(sparse_var, sparse_indice))
|
|
408
|
+
|
|
409
|
+
self._kafka_producer = None
|
|
410
|
+
self._incr_save_dir = None
|
|
411
|
+
if increment_save_config.HasField('kafka'):
|
|
412
|
+
self._topic = increment_save_config.kafka.topic
|
|
413
|
+
logging.info('increment save topic: %s' % self._topic)
|
|
414
|
+
|
|
415
|
+
admin_clt = KafkaAdminClient(
|
|
416
|
+
bootstrap_servers=increment_save_config.kafka.server,
|
|
417
|
+
request_timeout_ms=self._kafka_timeout_ms,
|
|
418
|
+
api_version_auto_timeout_ms=self._kafka_timeout_ms)
|
|
419
|
+
if self._topic not in admin_clt.list_topics():
|
|
420
|
+
admin_clt.create_topics(
|
|
421
|
+
new_topics=[
|
|
422
|
+
NewTopic(
|
|
423
|
+
name=self._topic,
|
|
424
|
+
num_partitions=1,
|
|
425
|
+
replication_factor=1,
|
|
426
|
+
topic_configs={
|
|
427
|
+
'max.message.bytes': self._kafka_max_msg_size
|
|
428
|
+
})
|
|
429
|
+
],
|
|
430
|
+
validate_only=False)
|
|
431
|
+
logging.info('create increment save topic: %s' % self._topic)
|
|
432
|
+
admin_clt.close()
|
|
433
|
+
|
|
434
|
+
servers = increment_save_config.kafka.server.split(',')
|
|
435
|
+
self._kafka_producer = KafkaProducer(
|
|
436
|
+
bootstrap_servers=servers,
|
|
437
|
+
max_request_size=self._kafka_max_req_size,
|
|
438
|
+
api_version_auto_timeout_ms=self._kafka_timeout_ms,
|
|
439
|
+
request_timeout_ms=self._kafka_timeout_ms)
|
|
440
|
+
elif increment_save_config.HasField('fs'):
|
|
441
|
+
fs = increment_save_config.fs
|
|
442
|
+
if fs.relative:
|
|
443
|
+
self._incr_save_dir = os.path.join(checkpoint_dir, fs.incr_save_dir)
|
|
444
|
+
else:
|
|
445
|
+
self._incr_save_dir = fs.incr_save_dir
|
|
446
|
+
if not self._incr_save_dir.endswith('/'):
|
|
447
|
+
self._incr_save_dir += '/'
|
|
448
|
+
if not gfile.IsDirectory(self._incr_save_dir):
|
|
449
|
+
gfile.MakeDirs(self._incr_save_dir)
|
|
450
|
+
elif increment_save_config.HasField('datahub'):
|
|
451
|
+
raise NotImplementedError('datahub increment saving is in development.')
|
|
452
|
+
else:
|
|
453
|
+
raise ValueError(
|
|
454
|
+
'incr_update not specified correctly, must be oneof: kafka,fs')
|
|
455
|
+
|
|
456
|
+
self._debug_save_update = increment_save_config.debug_save_update
|
|
457
|
+
else:
|
|
458
|
+
self._dense_timer = None
|
|
459
|
+
self._sparse_timer = None
|
|
460
|
+
|
|
461
|
+
def after_create_session(self, session, coord):
|
|
462
|
+
global_step = session.run(self._global_step_tensor)
|
|
463
|
+
if self._write_graph:
|
|
464
|
+
# We do write graph and saver_def at the first call of before_run.
|
|
465
|
+
# We cannot do this in begin, since we let other hooks to change graph and
|
|
466
|
+
# add variables at begin. Graph is finalized after all begin calls.
|
|
467
|
+
tf.train.write_graph(tf.get_default_graph().as_graph_def(add_shapes=True),
|
|
468
|
+
self._checkpoint_dir, 'graph.pbtxt')
|
|
469
|
+
saver_def = self._get_saver().saver_def if self._get_saver() else None
|
|
470
|
+
graph = tf.get_default_graph()
|
|
471
|
+
meta_graph_def = meta_graph.create_meta_graph_def(
|
|
472
|
+
graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
|
|
473
|
+
self._summary_writer.add_graph(graph)
|
|
474
|
+
self._summary_writer.add_meta_graph(meta_graph_def)
|
|
475
|
+
|
|
476
|
+
# save for step 0
|
|
477
|
+
self._save(session, global_step)
|
|
478
|
+
|
|
479
|
+
self._timer.update_last_triggered_step(global_step)
|
|
480
|
+
|
|
481
|
+
def before_run(self, run_context): # pylint: disable=unused-argument
|
|
482
|
+
return tf.train.SessionRunArgs(self._global_step_tensor)
|
|
483
|
+
|
|
484
|
+
def _send_dense(self, global_step, session):
|
|
485
|
+
dense_train_vars = ops.get_collection(constant.DENSE_UPDATE_VARIABLES)
|
|
486
|
+
dense_train_vals = session.run(dense_train_vars)
|
|
487
|
+
logging.info('global_step=%d, increment save dense variables' % global_step)
|
|
488
|
+
|
|
489
|
+
# build msg header
|
|
490
|
+
msg_num = len(dense_train_vals)
|
|
491
|
+
msg_ids = [self._dense_name_to_ids[x.op.name] for x in dense_train_vars]
|
|
492
|
+
# 0 mean dense update message
|
|
493
|
+
msg_header = [0, msg_num, global_step]
|
|
494
|
+
for msg_id, x in zip(msg_ids, dense_train_vals):
|
|
495
|
+
msg_header.append(msg_id)
|
|
496
|
+
msg_header.append(x.size)
|
|
497
|
+
|
|
498
|
+
# build msg body
|
|
499
|
+
bytes_buf = np.array(msg_header, dtype=np.int32).tobytes()
|
|
500
|
+
for x in dense_train_vals:
|
|
501
|
+
bytes_buf += x.tobytes()
|
|
502
|
+
|
|
503
|
+
if self._kafka_producer is not None:
|
|
504
|
+
msg_key = 'dense_update_%d' % global_step
|
|
505
|
+
send_res = self._kafka_producer.send(
|
|
506
|
+
self._topic, bytes_buf, key=msg_key.encode('utf-8'))
|
|
507
|
+
logging.info('kafka send dense: %d exception: %s' %
|
|
508
|
+
(global_step, send_res.exception))
|
|
509
|
+
|
|
510
|
+
if self._incr_save_dir is not None:
|
|
511
|
+
save_path = os.path.join(self._incr_save_dir,
|
|
512
|
+
'dense_update_%d' % global_step)
|
|
513
|
+
with gfile.GFile(save_path, 'wb') as fout:
|
|
514
|
+
fout.write(bytes_buf)
|
|
515
|
+
save_flag = save_path + '.done'
|
|
516
|
+
with gfile.GFile(save_flag, 'w') as fout:
|
|
517
|
+
fout.write('dense_update_%d' % global_step)
|
|
518
|
+
|
|
519
|
+
if self._debug_save_update and self._incr_save_dir is None:
|
|
520
|
+
base_dir, _ = os.path.split(self._save_path)
|
|
521
|
+
incr_save_dir = os.path.join(base_dir, 'incr_save/')
|
|
522
|
+
if not gfile.Exists(incr_save_dir):
|
|
523
|
+
gfile.MakeDirs(incr_save_dir)
|
|
524
|
+
save_path = os.path.join(incr_save_dir, 'dense_update_%d' % global_step)
|
|
525
|
+
with gfile.GFile(save_path, 'wb') as fout:
|
|
526
|
+
fout.write(bytes_buf)
|
|
527
|
+
|
|
528
|
+
logging.info(
|
|
529
|
+
'global_step=%d, increment update dense variables, msg_num=%d' %
|
|
530
|
+
(global_step, msg_num))
|
|
531
|
+
|
|
532
|
+
def _send_sparse(self, global_step, session):
|
|
533
|
+
sparse_train_vars = ops.get_collection(constant.SPARSE_UPDATE_VARIABLES)
|
|
534
|
+
sparse_res = session.run(self._sparse_indices + self._sparse_values)
|
|
535
|
+
msg_num = int(len(sparse_res) / 2)
|
|
536
|
+
|
|
537
|
+
sel_ids = [i for i in range(msg_num) if len(sparse_res[i]) > 0]
|
|
538
|
+
sparse_key_res = [sparse_res[i] for i in sel_ids]
|
|
539
|
+
sparse_val_res = [sparse_res[i + msg_num] for i in sel_ids]
|
|
540
|
+
sparse_train_vars = [sparse_train_vars[i][0] for i in sel_ids]
|
|
541
|
+
|
|
542
|
+
sel_embed_ids = [
|
|
543
|
+
self._sparse_name_to_ids[x.name] for x in sparse_train_vars
|
|
544
|
+
]
|
|
545
|
+
|
|
546
|
+
msg_num = len(sel_ids)
|
|
547
|
+
|
|
548
|
+
if msg_num == 0:
|
|
549
|
+
logging.warning('there are no sparse updates, will skip this send: %d' %
|
|
550
|
+
global_step)
|
|
551
|
+
return
|
|
552
|
+
|
|
553
|
+
# build msg header
|
|
554
|
+
# 1 means sparse update messages
|
|
555
|
+
msg_header = [1, msg_num, global_step]
|
|
556
|
+
for tmp_id, tmp_key in zip(sel_embed_ids, sparse_key_res):
|
|
557
|
+
msg_header.append(tmp_id)
|
|
558
|
+
msg_header.append(len(tmp_key))
|
|
559
|
+
bytes_buf = np.array(msg_header, dtype=np.int32).tobytes()
|
|
560
|
+
|
|
561
|
+
# build msg body
|
|
562
|
+
for tmp_id, tmp_key, tmp_val, tmp_var in zip(sel_embed_ids, sparse_key_res,
|
|
563
|
+
sparse_val_res,
|
|
564
|
+
sparse_train_vars):
|
|
565
|
+
# for non kv embedding variables, add partition offset to tmp_key
|
|
566
|
+
if 'EmbeddingVariable' not in str(type(tmp_var)):
|
|
567
|
+
if tmp_var._save_slice_info is not None:
|
|
568
|
+
tmp_key += tmp_var._save_slice_info.var_offset[0]
|
|
569
|
+
bytes_buf += tmp_key.tobytes()
|
|
570
|
+
bytes_buf += tmp_val.tobytes()
|
|
571
|
+
if self._kafka_producer is not None:
|
|
572
|
+
msg_key = 'sparse_update_%d' % global_step
|
|
573
|
+
send_res = self._kafka_producer.send(
|
|
574
|
+
self._topic, bytes_buf, key=msg_key.encode('utf-8'))
|
|
575
|
+
logging.info('kafka send sparse: %d %s' %
|
|
576
|
+
(global_step, send_res.exception))
|
|
577
|
+
|
|
578
|
+
if self._incr_save_dir is not None:
|
|
579
|
+
save_path = os.path.join(self._incr_save_dir,
|
|
580
|
+
'sparse_update_%d' % global_step)
|
|
581
|
+
with gfile.GFile(save_path, 'wb') as fout:
|
|
582
|
+
fout.write(bytes_buf)
|
|
583
|
+
save_flag = save_path + '.done'
|
|
584
|
+
with gfile.GFile(save_flag, 'w') as fout:
|
|
585
|
+
fout.write('sparse_update_%d' % global_step)
|
|
586
|
+
|
|
587
|
+
if self._debug_save_update and self._incr_save_dir is None:
|
|
588
|
+
base_dir, _ = os.path.split(self._save_path)
|
|
589
|
+
incr_save_dir = os.path.join(base_dir, 'incr_save/')
|
|
590
|
+
if not gfile.Exists(incr_save_dir):
|
|
591
|
+
gfile.MakeDirs(incr_save_dir)
|
|
592
|
+
save_path = os.path.join(incr_save_dir, 'sparse_update_%d' % global_step)
|
|
593
|
+
with gfile.GFile(save_path, 'wb') as fout:
|
|
594
|
+
fout.write(bytes_buf)
|
|
595
|
+
|
|
596
|
+
logging.info(
|
|
597
|
+
'global_step=%d, increment update sparse variables, msg_num=%d, msg_size=%d'
|
|
598
|
+
% (global_step, msg_num, len(bytes_buf)))
|
|
599
|
+
|
|
600
|
+
def after_run(self, run_context, run_values):
|
|
601
|
+
super(CheckpointSaverHook, self).after_run(run_context, run_values)
|
|
602
|
+
stale_global_step = run_values.results
|
|
603
|
+
global_step = -1
|
|
604
|
+
if self._dense_timer is not None and self._dense_timer.should_trigger_for_step(
|
|
605
|
+
stale_global_step + self._steps_per_run):
|
|
606
|
+
global_step = run_context.session.run(self._global_step_tensor)
|
|
607
|
+
self._dense_timer.update_last_triggered_step(global_step)
|
|
608
|
+
self._send_dense(global_step, run_context.session)
|
|
609
|
+
|
|
610
|
+
if self._sparse_timer is not None and self._sparse_timer.should_trigger_for_step(
|
|
611
|
+
stale_global_step + self._steps_per_run):
|
|
612
|
+
if global_step < 0:
|
|
613
|
+
global_step = run_context.session.run(self._global_step_tensor)
|
|
614
|
+
|
|
615
|
+
self._sparse_timer.update_last_triggered_step(global_step)
|
|
616
|
+
self._send_sparse(global_step, run_context.session)
|
|
617
|
+
|
|
618
|
+
def _save(self, session, step):
|
|
619
|
+
"""Saves the latest checkpoint, returns should_stop."""
|
|
620
|
+
logging.info('Saving checkpoints for %d into %s.', step, self._save_path)
|
|
621
|
+
|
|
622
|
+
for l in self._listeners: # noqa: E741
|
|
623
|
+
l.before_save(session, step)
|
|
624
|
+
|
|
625
|
+
if self._data_offset_var is not None:
|
|
626
|
+
save_data_offset = session.run(self._data_offset_var)
|
|
627
|
+
data_offset_json = {}
|
|
628
|
+
for x in save_data_offset:
|
|
629
|
+
if x:
|
|
630
|
+
data_offset_json.update(json.loads(x))
|
|
631
|
+
save_offset_path = os.path.join(self._checkpoint_dir,
|
|
632
|
+
'model.ckpt-%d.offset' % step)
|
|
633
|
+
with gfile.GFile(save_offset_path, 'w') as fout:
|
|
634
|
+
json.dump(data_offset_json, fout)
|
|
635
|
+
|
|
636
|
+
self._get_saver().save(
|
|
637
|
+
session,
|
|
638
|
+
self._save_path,
|
|
639
|
+
global_step=step,
|
|
640
|
+
write_meta_graph=self._write_graph)
|
|
641
|
+
|
|
642
|
+
self._summary_writer.add_session_log(
|
|
643
|
+
tf.SessionLog(
|
|
644
|
+
status=tf.SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
|
|
645
|
+
step)
|
|
646
|
+
|
|
647
|
+
should_stop = False
|
|
648
|
+
for l in self._listeners: # noqa: E741
|
|
649
|
+
if l.after_save(session, step):
|
|
650
|
+
logging.info(
|
|
651
|
+
'A CheckpointSaverListener requested that training be stopped. '
|
|
652
|
+
'listener: {}'.format(l))
|
|
653
|
+
should_stop = True
|
|
654
|
+
return should_stop
|
|
655
|
+
|
|
656
|
+
def end(self, session):
|
|
657
|
+
global_step = session.run(self._global_step_tensor)
|
|
658
|
+
super(CheckpointSaverHook, self).end(session)
|
|
659
|
+
if self._dense_timer is not None and \
|
|
660
|
+
global_step != self._dense_timer.last_triggered_step():
|
|
661
|
+
self._dense_timer.update_last_triggered_step(global_step)
|
|
662
|
+
self._send_dense(global_step, session)
|
|
663
|
+
if self._sparse_timer is not None and \
|
|
664
|
+
global_step != self._sparse_timer.last_triggered_step():
|
|
665
|
+
self._sparse_timer.update_last_triggered_step(global_step)
|
|
666
|
+
self._send_sparse(global_step, session)
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
class NumpyCheckpointRestoreHook(SessionRunHook):
|
|
670
|
+
"""Restore variable from numpy checkpoint."""
|
|
671
|
+
|
|
672
|
+
def __init__(self, ckpt_path, name2var_map):
|
|
673
|
+
"""Initializes a `NumpyCheckpointRestoreHook`.
|
|
674
|
+
|
|
675
|
+
Args:
|
|
676
|
+
ckpt_path: numpy checkpoint path to restore from
|
|
677
|
+
name2var_map: var name in numpy ckpt to variable map
|
|
678
|
+
"""
|
|
679
|
+
self._ckpt_path = ckpt_path
|
|
680
|
+
self._name2var_map = name2var_map
|
|
681
|
+
self._restore_op = None
|
|
682
|
+
|
|
683
|
+
def begin(self):
|
|
684
|
+
ckpt_data = np.load(self._ckpt_path)
|
|
685
|
+
vars_not_inited = {}
|
|
686
|
+
|
|
687
|
+
assign_ops = []
|
|
688
|
+
has_shape_unmatch = False
|
|
689
|
+
with tf.variable_scope('', reuse=True):
|
|
690
|
+
for var_name, var in six.iteritems(self._name2var_map):
|
|
691
|
+
var_shape = var.get_shape().as_list()
|
|
692
|
+
if var_name in ckpt_data.keys():
|
|
693
|
+
var_data = ckpt_data[var_name]
|
|
694
|
+
if list(var_data.shape) == var_shape:
|
|
695
|
+
assign_ops.append(var.assign(var_data))
|
|
696
|
+
else:
|
|
697
|
+
logging.error(
|
|
698
|
+
'variable [%s] shape not match %r vs %r' %
|
|
699
|
+
(var.name.split(':')[0], var_shape, list(var_data.shape)))
|
|
700
|
+
has_shape_unmatch = True
|
|
701
|
+
elif 'Momentum' not in var_name and 'global_step' not in var_name:
|
|
702
|
+
logging.error('variable [%s] not found in ckpt' % var_name)
|
|
703
|
+
vars_not_inited[var_name] = ','.join([str(s) for s in var_shape])
|
|
704
|
+
self._restore_op = tf.group(assign_ops)
|
|
705
|
+
|
|
706
|
+
with gfile.GFile(self._ckpt_path[:-4] + '_not_inited.txt', 'w') as f:
|
|
707
|
+
for var_name in sorted(vars_not_inited.keys()):
|
|
708
|
+
f.write('%s:%s\n' % (var_name, vars_not_inited[var_name]))
|
|
709
|
+
assert not has_shape_unmatch, 'exist variable shape not match, restore failed'
|
|
710
|
+
assert len(vars_not_inited.keys()) == 0, \
|
|
711
|
+
'exist variable shape not inited, restore failed'
|
|
712
|
+
|
|
713
|
+
def after_create_session(self, session, coord):
|
|
714
|
+
assert self._restore_op is not None
|
|
715
|
+
logging.info('running numpy checkpoint restore_op')
|
|
716
|
+
session.run(self._restore_op)
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
class IncompatibleShapeRestoreHook(SessionRunHook):
|
|
720
|
+
"""Restore variable with incompatible shapes."""
|
|
721
|
+
|
|
722
|
+
def __init__(self, incompatible_shape_var_map):
|
|
723
|
+
"""Initializes a `IncompatibleShapeRestoreHook`.
|
|
724
|
+
|
|
725
|
+
Args:
|
|
726
|
+
incompatible_shape_var_map: a variables mapping with incompatible shapes,
|
|
727
|
+
map from real variable to temp variable, real variable is the variable
|
|
728
|
+
used in model, temp variable is the variable restored from checkpoint.
|
|
729
|
+
"""
|
|
730
|
+
self._incompatible_shape_var_map = incompatible_shape_var_map
|
|
731
|
+
self._restore_op = None
|
|
732
|
+
|
|
733
|
+
def begin(self):
|
|
734
|
+
assign_ops = []
|
|
735
|
+
for var, var_tmp in six.iteritems(self._incompatible_shape_var_map):
|
|
736
|
+
assign_ops.append(
|
|
737
|
+
var.assign(
|
|
738
|
+
shape_utils.pad_or_clip_nd(var_tmp,
|
|
739
|
+
var.get_shape().as_list())))
|
|
740
|
+
logging.info(
|
|
741
|
+
'Assign variable[%s] from shape%s to shape%s' %
|
|
742
|
+
(var.name, var_tmp.get_shape().as_list(), var.get_shape().as_list()))
|
|
743
|
+
self._restore_op = tf.group(assign_ops)
|
|
744
|
+
|
|
745
|
+
def after_create_session(self, session, coord):
|
|
746
|
+
assert self._restore_op is not None
|
|
747
|
+
logging.info('running incompatible shape variable restore_op')
|
|
748
|
+
session.run(self._restore_op)
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
class MultipleCheckpointsRestoreHook(SessionRunHook):
|
|
752
|
+
"""Restore variable from numpy checkpoint."""
|
|
753
|
+
SEP = ';'
|
|
754
|
+
|
|
755
|
+
def __init__(self, ckpt_paths):
|
|
756
|
+
"""Initializes a `MultipleCheckpointsRestoreHook`.
|
|
757
|
+
|
|
758
|
+
Args:
|
|
759
|
+
ckpt_paths: multiple checkpoint path, seperated by ;
|
|
760
|
+
name2var_map: var name in numpy ckpt to variable map
|
|
761
|
+
"""
|
|
762
|
+
self._ckpt_path_list = ckpt_paths.split(self.SEP)
|
|
763
|
+
self._saver_list = []
|
|
764
|
+
|
|
765
|
+
def begin(self):
|
|
766
|
+
global_variables = tf.global_variables()
|
|
767
|
+
var_names = [re.sub(':[0-9]$', '', var.name) for var in global_variables]
|
|
768
|
+
restore_status = {var_name: False for var_name in var_names}
|
|
769
|
+
for ckpt_path in self._ckpt_path_list:
|
|
770
|
+
logging.info('read variable from %s' % ckpt_path)
|
|
771
|
+
ckpt_reader = tf.train.NewCheckpointReader(ckpt_path)
|
|
772
|
+
ckpt_var2shape_map = ckpt_reader.get_variable_to_shape_map()
|
|
773
|
+
# ckpt_var2shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
|
|
774
|
+
name2var = {}
|
|
775
|
+
for var in global_variables:
|
|
776
|
+
var_name = re.sub(':[0-9]$', '', var.name)
|
|
777
|
+
if var_name in ckpt_var2shape_map:
|
|
778
|
+
if restore_status[var_name]:
|
|
779
|
+
logging.warning(
|
|
780
|
+
'variable %s find in more than one checkpoint, skipped %s' %
|
|
781
|
+
(var_name, ckpt_path))
|
|
782
|
+
continue
|
|
783
|
+
name2var[var_name] = var
|
|
784
|
+
restore_status[var_name] = True
|
|
785
|
+
saver = tf.train.Saver(name2var)
|
|
786
|
+
self._saver_list.append(saver)
|
|
787
|
+
|
|
788
|
+
restore_check = True
|
|
789
|
+
for var_name, stat in six.iteritems(restore_status):
|
|
790
|
+
if not stat:
|
|
791
|
+
logging.error('var %s not find in checkpoints' % var_name)
|
|
792
|
+
restore_check = False
|
|
793
|
+
|
|
794
|
+
assert restore_check, 'failed to find all variables in checkpoints provided'
|
|
795
|
+
|
|
796
|
+
def after_create_session(self, session, coord):
|
|
797
|
+
logging.info('running multiple checkpoint restore hook')
|
|
798
|
+
for saver, ckpt_path in zip(self._saver_list, self._ckpt_path_list):
|
|
799
|
+
logging.info('restore checkpoint from %s' % ckpt_path)
|
|
800
|
+
saver.restore(session, ckpt_path)
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
class OnlineEvaluationHook(SessionRunHook):
|
|
804
|
+
|
|
805
|
+
def __init__(self, metric_dict, output_dir):
|
|
806
|
+
self._metric_dict = metric_dict
|
|
807
|
+
self._output_dir = output_dir
|
|
808
|
+
self._summary_writer = SummaryWriterCache.get(self._output_dir)
|
|
809
|
+
|
|
810
|
+
def end(self, session):
|
|
811
|
+
metric_tensor_dict = {k: v[0] for k, v in self._metric_dict.items()}
|
|
812
|
+
metric_value_dict = session.run(metric_tensor_dict)
|
|
813
|
+
tf.logging.info('Eval metric: %s' % metric_value_dict)
|
|
814
|
+
|
|
815
|
+
global_step_tensor = tf.train.get_or_create_global_step()
|
|
816
|
+
global_step = session.run(global_step_tensor)
|
|
817
|
+
|
|
818
|
+
summary = Summary()
|
|
819
|
+
for k, v in metric_value_dict.items():
|
|
820
|
+
summary.value.add(tag=k, simple_value=v)
|
|
821
|
+
self._summary_writer.add_summary(summary, global_step=global_step)
|
|
822
|
+
self._summary_writer.flush()
|
|
823
|
+
|
|
824
|
+
eval_result_file = os.path.join(self._output_dir,
|
|
825
|
+
'online_eval_result.txt-%s' % global_step)
|
|
826
|
+
logging.info('Saving online eval result to file %s' % eval_result_file)
|
|
827
|
+
with gfile.GFile(eval_result_file, 'w') as ofile:
|
|
828
|
+
result_to_write = {}
|
|
829
|
+
for key in sorted(metric_value_dict):
|
|
830
|
+
# convert numpy float to python float
|
|
831
|
+
result_to_write[key] = metric_value_dict[key].item()
|
|
832
|
+
ofile.write(json.dumps(result_to_write, indent=2))
|
|
833
|
+
|
|
834
|
+
|
|
835
|
+
def parse_tf_config():
|
|
836
|
+
tf_config_str = os.environ.get('TF_CONFIG', '')
|
|
837
|
+
if 'TF_CONFIG' in os.environ:
|
|
838
|
+
tf_config = json.loads(tf_config_str)
|
|
839
|
+
cluster = tf_config['cluster']
|
|
840
|
+
task = tf_config['task']
|
|
841
|
+
task_type = task['type']
|
|
842
|
+
task_index = task['index']
|
|
843
|
+
else:
|
|
844
|
+
cluster = {}
|
|
845
|
+
task_type = 'master'
|
|
846
|
+
task_index = 0
|
|
847
|
+
return cluster, task_type, task_index
|
|
848
|
+
|
|
849
|
+
|
|
850
|
+
def get_task_index_and_num():
|
|
851
|
+
if hvd is not None and 'HOROVOD_RANK' in os.environ:
|
|
852
|
+
return hvd.rank(), hvd.size()
|
|
853
|
+
cluster, task_type, task_index = parse_tf_config()
|
|
854
|
+
if 'worker' not in cluster:
|
|
855
|
+
return 0, 1
|
|
856
|
+
if task_type == 'evaluator':
|
|
857
|
+
return 0, 1
|
|
858
|
+
|
|
859
|
+
task_num = len(cluster['worker'])
|
|
860
|
+
if 'chief' in cluster or 'master' in cluster:
|
|
861
|
+
task_num += 1
|
|
862
|
+
if task_type not in ['chief', 'master']:
|
|
863
|
+
task_index += 1
|
|
864
|
+
return task_index, task_num
|
|
865
|
+
|
|
866
|
+
|
|
867
|
+
def get_ckpt_version(ckpt_path):
|
|
868
|
+
"""Get checkpoint version from ckpt_path.
|
|
869
|
+
|
|
870
|
+
Args:
|
|
871
|
+
ckpt_path: such as xx/model.ckpt-2000 or xx/model.ckpt-2000.meta
|
|
872
|
+
|
|
873
|
+
Return:
|
|
874
|
+
ckpt_version: such as 2000
|
|
875
|
+
"""
|
|
876
|
+
_, ckpt_name = os.path.split(ckpt_path)
|
|
877
|
+
ckpt_name, ext = os.path.splitext(ckpt_name)
|
|
878
|
+
if ext.startswith('.ckpt-'):
|
|
879
|
+
ckpt_name = ext
|
|
880
|
+
toks = ckpt_name.split('-')
|
|
881
|
+
return int(toks[-1])
|
|
882
|
+
|
|
883
|
+
|
|
884
|
+
def get_latest_checkpoint_from_checkpoint_path(checkpoint_path,
|
|
885
|
+
ignore_ckpt_error):
|
|
886
|
+
ckpt_path = None
|
|
887
|
+
if checkpoint_path.endswith('/') or gfile.IsDirectory(checkpoint_path + '/'):
|
|
888
|
+
checkpoint_dir = checkpoint_path
|
|
889
|
+
if not checkpoint_dir.endswith('/'):
|
|
890
|
+
checkpoint_dir = checkpoint_dir + '/'
|
|
891
|
+
if gfile.Exists(checkpoint_dir):
|
|
892
|
+
ckpt_path = latest_checkpoint(checkpoint_dir)
|
|
893
|
+
if ckpt_path:
|
|
894
|
+
logging.info(
|
|
895
|
+
'fine_tune_checkpoint is directory, will use the latest checkpoint: %s'
|
|
896
|
+
% ckpt_path)
|
|
897
|
+
else:
|
|
898
|
+
assert ignore_ckpt_error, 'fine_tune_checkpoint(%s) is not exists.' % checkpoint_path
|
|
899
|
+
else:
|
|
900
|
+
assert ignore_ckpt_error, 'fine_tune_checkpoint(%s) is not exists.' % checkpoint_path
|
|
901
|
+
elif gfile.Exists(checkpoint_path + '.index'):
|
|
902
|
+
ckpt_path = checkpoint_path
|
|
903
|
+
logging.info('update fine_tune_checkpoint to %s' % checkpoint_path)
|
|
904
|
+
else:
|
|
905
|
+
assert ignore_ckpt_error, 'fine_tune_checkpoint(%s) is not exists.' % checkpoint_path
|
|
906
|
+
return ckpt_path
|
|
907
|
+
|
|
908
|
+
|
|
909
|
+
def latest_checkpoint(model_dir):
|
|
910
|
+
"""Find lastest checkpoint under a directory.
|
|
911
|
+
|
|
912
|
+
Args:
|
|
913
|
+
model_dir: model directory
|
|
914
|
+
|
|
915
|
+
Return:
|
|
916
|
+
model_path: xx/model.ckpt-2000
|
|
917
|
+
"""
|
|
918
|
+
try:
|
|
919
|
+
ckpt_metas = gfile.Glob(os.path.join(model_dir, 'model.ckpt-*.index'))
|
|
920
|
+
|
|
921
|
+
if len(ckpt_metas) == 0:
|
|
922
|
+
return None
|
|
923
|
+
|
|
924
|
+
if len(ckpt_metas) > 1:
|
|
925
|
+
ckpt_metas.sort(key=lambda x: get_ckpt_version(x))
|
|
926
|
+
ckpt_path = os.path.splitext(ckpt_metas[-1])[0]
|
|
927
|
+
return ckpt_path
|
|
928
|
+
except errors_impl.NotFoundError:
|
|
929
|
+
return None
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
def get_trained_steps(model_dir):
|
|
933
|
+
ckpt_path = latest_checkpoint(model_dir)
|
|
934
|
+
if ckpt_path is not None:
|
|
935
|
+
return int(ckpt_path.split('-')[-1])
|
|
936
|
+
else:
|
|
937
|
+
return 0
|
|
938
|
+
|
|
939
|
+
|
|
940
|
+
def master_to_chief():
|
|
941
|
+
if 'TF_CONFIG' in os.environ:
|
|
942
|
+
tf_config = json.loads(os.environ['TF_CONFIG'])
|
|
943
|
+
# change chief to master
|
|
944
|
+
if 'master' in tf_config['cluster']:
|
|
945
|
+
tf_config['cluster']['chief'] = tf_config['cluster']['master']
|
|
946
|
+
del tf_config['cluster']['chief']
|
|
947
|
+
if tf_config['task']['type'] == 'master':
|
|
948
|
+
tf_config['task']['type'] = 'chief'
|
|
949
|
+
os.environ['TF_CONFIG'] = json.dumps(tf_config)
|
|
950
|
+
return tf_config
|
|
951
|
+
else:
|
|
952
|
+
return None
|
|
953
|
+
|
|
954
|
+
|
|
955
|
+
def chief_to_master():
|
|
956
|
+
if 'TF_CONFIG' in os.environ:
|
|
957
|
+
tf_config = json.loads(os.environ['TF_CONFIG'])
|
|
958
|
+
# change chief to master
|
|
959
|
+
if 'chief' in tf_config['cluster']:
|
|
960
|
+
tf_config['cluster']['master'] = tf_config['cluster']['chief']
|
|
961
|
+
del tf_config['cluster']['chief']
|
|
962
|
+
if tf_config['task']['type'] == 'chief':
|
|
963
|
+
tf_config['task']['type'] = 'master'
|
|
964
|
+
os.environ['TF_CONFIG'] = json.dumps(tf_config)
|
|
965
|
+
return tf_config
|
|
966
|
+
else:
|
|
967
|
+
return None
|
|
968
|
+
|
|
969
|
+
|
|
970
|
+
def is_ps():
|
|
971
|
+
if 'TF_CONFIG' in os.environ:
|
|
972
|
+
tf_config = json.loads(os.environ['TF_CONFIG'])
|
|
973
|
+
if 'task' in tf_config:
|
|
974
|
+
return tf_config['task']['type'] == 'ps'
|
|
975
|
+
return False
|
|
976
|
+
|
|
977
|
+
|
|
978
|
+
def is_chief():
|
|
979
|
+
if has_hvd():
|
|
980
|
+
return hvd.rank() == 0
|
|
981
|
+
|
|
982
|
+
if 'TF_CONFIG' in os.environ:
|
|
983
|
+
tf_config = json.loads(os.environ['TF_CONFIG'])
|
|
984
|
+
if 'task' in tf_config:
|
|
985
|
+
return tf_config['task']['type'] in ['chief', 'master']
|
|
986
|
+
return True
|
|
987
|
+
|
|
988
|
+
|
|
989
|
+
def is_master():
|
|
990
|
+
if 'TF_CONFIG' in os.environ:
|
|
991
|
+
tf_config = json.loads(os.environ['TF_CONFIG'])
|
|
992
|
+
if 'task' in tf_config:
|
|
993
|
+
return tf_config['task']['type'] == 'master'
|
|
994
|
+
return True
|
|
995
|
+
|
|
996
|
+
|
|
997
|
+
def is_evaluator():
|
|
998
|
+
if 'TF_CONFIG' in os.environ:
|
|
999
|
+
tf_config = json.loads(os.environ['TF_CONFIG'])
|
|
1000
|
+
if 'task' in tf_config:
|
|
1001
|
+
return tf_config['task']['type'] == 'evaluator'
|
|
1002
|
+
return False
|
|
1003
|
+
|
|
1004
|
+
|
|
1005
|
+
def has_hvd():
|
|
1006
|
+
return hvd is not None and 'HOROVOD_RANK' in os.environ
|
|
1007
|
+
|
|
1008
|
+
|
|
1009
|
+
def has_sok():
|
|
1010
|
+
return sok is not None and 'ENABLE_SOK' in os.environ
|
|
1011
|
+
|
|
1012
|
+
|
|
1013
|
+
def init_hvd():
|
|
1014
|
+
if hvd is None:
|
|
1015
|
+
logging.error(
|
|
1016
|
+
'horovod is not installed: HOROVOD_WITH_TENSORFLOW=1 pip install horovod'
|
|
1017
|
+
)
|
|
1018
|
+
sys.exit(1)
|
|
1019
|
+
|
|
1020
|
+
hvd.init()
|
|
1021
|
+
os.environ['HOROVOD_RANK'] = str(hvd.rank())
|
|
1022
|
+
|
|
1023
|
+
|
|
1024
|
+
def init_sok():
|
|
1025
|
+
try:
|
|
1026
|
+
sok.init()
|
|
1027
|
+
os.environ['ENABLE_SOK'] = '1'
|
|
1028
|
+
return True
|
|
1029
|
+
except Exception:
|
|
1030
|
+
logging.warning('sok is not installed')
|
|
1031
|
+
return False
|
|
1032
|
+
|
|
1033
|
+
|
|
1034
|
+
def get_available_gpus():
|
|
1035
|
+
local_device_protos = device_lib.list_local_devices()
|
|
1036
|
+
return [x.name for x in local_device_protos if x.device_type == 'GPU']
|