easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections import OrderedDict
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from copy import copy
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DAG(object):
|
|
9
|
+
"""Directed acyclic graph implementation."""
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
"""Construct a new DAG with no nodes or edges."""
|
|
13
|
+
self.reset_graph()
|
|
14
|
+
|
|
15
|
+
def add_node(self, node_name, graph=None):
|
|
16
|
+
"""Add a node if it does not exist yet, or error out."""
|
|
17
|
+
if not graph:
|
|
18
|
+
graph = self.graph
|
|
19
|
+
if node_name in graph:
|
|
20
|
+
raise KeyError('node %s already exists' % node_name)
|
|
21
|
+
graph[node_name] = set()
|
|
22
|
+
|
|
23
|
+
def add_node_if_not_exists(self, node_name, graph=None):
|
|
24
|
+
try:
|
|
25
|
+
self.add_node(node_name, graph=graph)
|
|
26
|
+
except KeyError:
|
|
27
|
+
logging.info('node %s already exist' % node_name)
|
|
28
|
+
|
|
29
|
+
def delete_node(self, node_name, graph=None):
|
|
30
|
+
"""Deletes this node and all edges referencing it."""
|
|
31
|
+
if not graph:
|
|
32
|
+
graph = self.graph
|
|
33
|
+
if node_name not in graph:
|
|
34
|
+
raise KeyError('node %s does not exist' % node_name)
|
|
35
|
+
graph.pop(node_name)
|
|
36
|
+
|
|
37
|
+
for node, edges in graph.items():
|
|
38
|
+
if node_name in edges:
|
|
39
|
+
edges.remove(node_name)
|
|
40
|
+
|
|
41
|
+
def delete_node_if_exists(self, node_name, graph=None):
|
|
42
|
+
try:
|
|
43
|
+
self.delete_node(node_name, graph=graph)
|
|
44
|
+
except KeyError:
|
|
45
|
+
logging.info('node %s does not exist' % node_name)
|
|
46
|
+
|
|
47
|
+
def add_edge(self, ind_node, dep_node, graph=None):
|
|
48
|
+
"""Add an edge (dependency) between the specified nodes."""
|
|
49
|
+
if not graph:
|
|
50
|
+
graph = self.graph
|
|
51
|
+
if ind_node not in graph or dep_node not in graph:
|
|
52
|
+
raise KeyError('one or more nodes do not exist in graph')
|
|
53
|
+
test_graph = deepcopy(graph)
|
|
54
|
+
test_graph[ind_node].add(dep_node)
|
|
55
|
+
is_valid, message = self.validate(test_graph)
|
|
56
|
+
if is_valid:
|
|
57
|
+
graph[ind_node].add(dep_node)
|
|
58
|
+
else:
|
|
59
|
+
raise Exception('invalid DAG')
|
|
60
|
+
|
|
61
|
+
def delete_edge(self, ind_node, dep_node, graph=None):
|
|
62
|
+
"""Delete an edge from the graph."""
|
|
63
|
+
if not graph:
|
|
64
|
+
graph = self.graph
|
|
65
|
+
if dep_node not in graph.get(ind_node, []):
|
|
66
|
+
raise KeyError('this edge does not exist in graph')
|
|
67
|
+
graph[ind_node].remove(dep_node)
|
|
68
|
+
|
|
69
|
+
def rename_edges(self, old_task_name, new_task_name, graph=None):
|
|
70
|
+
"""Change references to a task in existing edges."""
|
|
71
|
+
if not graph:
|
|
72
|
+
graph = self.graph
|
|
73
|
+
for node, edges in graph.items():
|
|
74
|
+
|
|
75
|
+
if node == old_task_name:
|
|
76
|
+
graph[new_task_name] = copy(edges)
|
|
77
|
+
del graph[old_task_name]
|
|
78
|
+
|
|
79
|
+
else:
|
|
80
|
+
if old_task_name in edges:
|
|
81
|
+
edges.remove(old_task_name)
|
|
82
|
+
edges.add(new_task_name)
|
|
83
|
+
|
|
84
|
+
def predecessors(self, node, graph=None):
|
|
85
|
+
"""Returns a list of all predecessors of the given node."""
|
|
86
|
+
if graph is None:
|
|
87
|
+
graph = self.graph
|
|
88
|
+
return [key for key in graph if node in graph[key]]
|
|
89
|
+
|
|
90
|
+
def downstream(self, node, graph=None):
|
|
91
|
+
"""Returns a list of all nodes this node has edges towards."""
|
|
92
|
+
if graph is None:
|
|
93
|
+
graph = self.graph
|
|
94
|
+
if node not in graph:
|
|
95
|
+
raise KeyError('node %s is not in graph' % node)
|
|
96
|
+
return list(graph[node])
|
|
97
|
+
|
|
98
|
+
def all_downstreams(self, node, graph=None):
|
|
99
|
+
"""Returns a list of all nodes ultimately downstream of the given node in the dependency graph.
|
|
100
|
+
|
|
101
|
+
in topological order.
|
|
102
|
+
"""
|
|
103
|
+
if graph is None:
|
|
104
|
+
graph = self.graph
|
|
105
|
+
nodes = [node]
|
|
106
|
+
nodes_seen = set()
|
|
107
|
+
i = 0
|
|
108
|
+
while i < len(nodes):
|
|
109
|
+
downstreams = self.downstream(nodes[i], graph)
|
|
110
|
+
for downstream_node in downstreams:
|
|
111
|
+
if downstream_node not in nodes_seen:
|
|
112
|
+
nodes_seen.add(downstream_node)
|
|
113
|
+
nodes.append(downstream_node)
|
|
114
|
+
i += 1
|
|
115
|
+
return list(
|
|
116
|
+
filter(lambda node: node in nodes_seen,
|
|
117
|
+
self.topological_sort(graph=graph)))
|
|
118
|
+
|
|
119
|
+
def all_leaves(self, graph=None):
|
|
120
|
+
"""Return a list of all leaves (nodes with no downstreams)."""
|
|
121
|
+
if graph is None:
|
|
122
|
+
graph = self.graph
|
|
123
|
+
return [key for key in graph if not graph[key]]
|
|
124
|
+
|
|
125
|
+
def from_dict(self, graph_dict):
|
|
126
|
+
"""Reset the graph and build it from the passed dictionary.
|
|
127
|
+
|
|
128
|
+
The dictionary takes the form of {node_name: [directed edges]}
|
|
129
|
+
"""
|
|
130
|
+
self.reset_graph()
|
|
131
|
+
for new_node in graph_dict.keys():
|
|
132
|
+
self.add_node(new_node)
|
|
133
|
+
for ind_node, dep_nodes in graph_dict.items():
|
|
134
|
+
if not isinstance(dep_nodes, list):
|
|
135
|
+
raise TypeError('dict values must be lists')
|
|
136
|
+
for dep_node in dep_nodes:
|
|
137
|
+
self.add_edge(ind_node, dep_node)
|
|
138
|
+
|
|
139
|
+
def reset_graph(self):
|
|
140
|
+
"""Restore the graph to an empty state."""
|
|
141
|
+
self.graph = OrderedDict()
|
|
142
|
+
|
|
143
|
+
def independent_nodes(self, graph=None):
|
|
144
|
+
"""Returns a list of all nodes in the graph with no dependencies."""
|
|
145
|
+
if graph is None:
|
|
146
|
+
graph = self.graph
|
|
147
|
+
|
|
148
|
+
dependent_nodes = set(
|
|
149
|
+
node for dependents in graph.values() for node in dependents)
|
|
150
|
+
return [node for node in graph.keys() if node not in dependent_nodes]
|
|
151
|
+
|
|
152
|
+
def validate(self, graph=None):
|
|
153
|
+
"""Returns (Boolean, message) of whether DAG is valid."""
|
|
154
|
+
graph = graph if graph is not None else self.graph
|
|
155
|
+
if len(self.independent_nodes(graph)) == 0:
|
|
156
|
+
return False, 'no independent nodes detected'
|
|
157
|
+
try:
|
|
158
|
+
self.topological_sort(graph)
|
|
159
|
+
except ValueError:
|
|
160
|
+
return False, 'failed topological sort'
|
|
161
|
+
return True, 'valid'
|
|
162
|
+
|
|
163
|
+
def topological_sort(self, graph=None):
|
|
164
|
+
"""Returns a topological ordering of the DAG.
|
|
165
|
+
|
|
166
|
+
Raises an error if this is not possible (graph is not valid).
|
|
167
|
+
"""
|
|
168
|
+
if graph is None:
|
|
169
|
+
graph = self.graph
|
|
170
|
+
result = []
|
|
171
|
+
in_degree = defaultdict(lambda: 0)
|
|
172
|
+
|
|
173
|
+
for u in graph:
|
|
174
|
+
for v in graph[u]:
|
|
175
|
+
in_degree[v] += 1
|
|
176
|
+
ready = [node for node in graph if not in_degree[node]]
|
|
177
|
+
|
|
178
|
+
while ready:
|
|
179
|
+
u = ready.pop()
|
|
180
|
+
result.append(u)
|
|
181
|
+
for v in graph[u]:
|
|
182
|
+
in_degree[v] -= 1
|
|
183
|
+
if in_degree[v] == 0:
|
|
184
|
+
ready.append(v)
|
|
185
|
+
|
|
186
|
+
if len(result) == len(graph):
|
|
187
|
+
return result
|
|
188
|
+
else:
|
|
189
|
+
raise ValueError('graph is not acyclic')
|
|
190
|
+
|
|
191
|
+
def size(self):
|
|
192
|
+
return len(self.graph)
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
from __future__ import print_function
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
import tensorflow as tf
|
|
10
|
+
|
|
11
|
+
from easy_rec.python.protos.train_pb2 import DistributionStrategy
|
|
12
|
+
from easy_rec.python.utils import estimator_utils
|
|
13
|
+
from easy_rec.python.utils.estimator_utils import chief_to_master
|
|
14
|
+
from easy_rec.python.utils.estimator_utils import master_to_chief
|
|
15
|
+
|
|
16
|
+
DistributionStrategyMap = {
|
|
17
|
+
'': DistributionStrategy.NoStrategy,
|
|
18
|
+
'ps': DistributionStrategy.PSStrategy,
|
|
19
|
+
'ess': DistributionStrategy.ExascaleStrategy,
|
|
20
|
+
'mirrored': DistributionStrategy.MirroredStrategy,
|
|
21
|
+
'collective': DistributionStrategy.CollectiveAllReduceStrategy
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def set_distribution_config(pipeline_config, num_worker, num_gpus_per_worker,
|
|
26
|
+
distribute_strategy):
|
|
27
|
+
if distribute_strategy in [
|
|
28
|
+
DistributionStrategy.PSStrategy, DistributionStrategy.MirroredStrategy,
|
|
29
|
+
DistributionStrategy.CollectiveAllReduceStrategy,
|
|
30
|
+
DistributionStrategy.ExascaleStrategy
|
|
31
|
+
]:
|
|
32
|
+
pipeline_config.train_config.sync_replicas = False
|
|
33
|
+
pipeline_config.train_config.train_distribute = distribute_strategy
|
|
34
|
+
pipeline_config.train_config.num_gpus_per_worker = num_gpus_per_worker
|
|
35
|
+
print('Dump pipeline_config.train_config:')
|
|
36
|
+
print(pipeline_config.train_config)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def set_tf_config_and_get_train_worker_num(
|
|
40
|
+
ps_hosts,
|
|
41
|
+
worker_hosts,
|
|
42
|
+
task_index,
|
|
43
|
+
job_name,
|
|
44
|
+
distribute_strategy=DistributionStrategy.NoStrategy,
|
|
45
|
+
eval_method='none'):
|
|
46
|
+
logging.info(
|
|
47
|
+
'set_tf_config_and_get_train_worker_num: distribute_strategy = %d' %
|
|
48
|
+
distribute_strategy)
|
|
49
|
+
worker_hosts = worker_hosts.split(',')
|
|
50
|
+
ps_hosts = ps_hosts.split(',') if ps_hosts else []
|
|
51
|
+
|
|
52
|
+
total_worker_num = len(worker_hosts)
|
|
53
|
+
train_worker_num = total_worker_num
|
|
54
|
+
|
|
55
|
+
print('Original TF_CONFIG=%s' % os.environ.get('TF_CONFIG', ''))
|
|
56
|
+
print('worker_hosts=%s ps_hosts=%s task_index=%d job_name=%s' %
|
|
57
|
+
(','.join(worker_hosts), ','.join(ps_hosts), task_index, job_name))
|
|
58
|
+
print('eval_method=%s' % eval_method)
|
|
59
|
+
|
|
60
|
+
if distribute_strategy == DistributionStrategy.MirroredStrategy:
|
|
61
|
+
assert total_worker_num == 1, 'mirrored distribute strategy only need 1 worker'
|
|
62
|
+
elif distribute_strategy in [
|
|
63
|
+
DistributionStrategy.NoStrategy, DistributionStrategy.PSStrategy,
|
|
64
|
+
DistributionStrategy.CollectiveAllReduceStrategy,
|
|
65
|
+
DistributionStrategy.ExascaleStrategy
|
|
66
|
+
]:
|
|
67
|
+
cluster, task_type, task_index_ = estimator_utils.parse_tf_config()
|
|
68
|
+
train_worker_num = 0
|
|
69
|
+
if eval_method == 'separate':
|
|
70
|
+
if 'evaluator' in cluster:
|
|
71
|
+
# 'evaluator' in cluster indicates user use new-style cluster content
|
|
72
|
+
if 'chief' in cluster:
|
|
73
|
+
train_worker_num += len(cluster['chief'])
|
|
74
|
+
elif 'master' in cluster:
|
|
75
|
+
train_worker_num += len(cluster['master'])
|
|
76
|
+
if 'worker' in cluster:
|
|
77
|
+
train_worker_num += len(cluster['worker'])
|
|
78
|
+
# drop evaluator to avoid hang
|
|
79
|
+
if distribute_strategy == DistributionStrategy.NoStrategy:
|
|
80
|
+
del cluster['evaluator']
|
|
81
|
+
tf_config = {
|
|
82
|
+
'cluster': cluster,
|
|
83
|
+
'task': {
|
|
84
|
+
'type': task_type,
|
|
85
|
+
'index': task_index_
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
os.environ['TF_CONFIG'] = json.dumps(tf_config)
|
|
89
|
+
else:
|
|
90
|
+
# backward compatibility, if user does not assign one evaluator in
|
|
91
|
+
# -Dcluster, we use first worker for chief, second for evaluation
|
|
92
|
+
train_worker_num = total_worker_num - 1
|
|
93
|
+
assert train_worker_num > 0, 'in distribution mode worker num must be greater than 1, ' \
|
|
94
|
+
'the second worker will be used as evaluator'
|
|
95
|
+
if len(worker_hosts) > 1:
|
|
96
|
+
cluster = {'chief': [worker_hosts[0]], 'worker': worker_hosts[2:]}
|
|
97
|
+
if distribute_strategy != DistributionStrategy.NoStrategy:
|
|
98
|
+
cluster['evaluator'] = [worker_hosts[1]]
|
|
99
|
+
if len(ps_hosts) > 0:
|
|
100
|
+
cluster['ps'] = ps_hosts
|
|
101
|
+
if job_name == 'ps':
|
|
102
|
+
os.environ['TF_CONFIG'] = json.dumps({
|
|
103
|
+
'cluster': cluster,
|
|
104
|
+
'task': {
|
|
105
|
+
'type': job_name,
|
|
106
|
+
'index': task_index
|
|
107
|
+
}
|
|
108
|
+
})
|
|
109
|
+
elif job_name == 'worker':
|
|
110
|
+
if task_index == 0:
|
|
111
|
+
os.environ['TF_CONFIG'] = json.dumps({
|
|
112
|
+
'cluster': cluster,
|
|
113
|
+
'task': {
|
|
114
|
+
'type': 'chief',
|
|
115
|
+
'index': 0
|
|
116
|
+
}
|
|
117
|
+
})
|
|
118
|
+
elif task_index == 1:
|
|
119
|
+
os.environ['TF_CONFIG'] = json.dumps({
|
|
120
|
+
'cluster': cluster,
|
|
121
|
+
'task': {
|
|
122
|
+
'type': 'evaluator',
|
|
123
|
+
'index': 0
|
|
124
|
+
}
|
|
125
|
+
})
|
|
126
|
+
else:
|
|
127
|
+
os.environ['TF_CONFIG'] = json.dumps({
|
|
128
|
+
'cluster': cluster,
|
|
129
|
+
'task': {
|
|
130
|
+
'type': job_name,
|
|
131
|
+
'index': task_index - 2
|
|
132
|
+
}
|
|
133
|
+
})
|
|
134
|
+
else:
|
|
135
|
+
if 'evaluator' in cluster:
|
|
136
|
+
evaluator = cluster['evaluator']
|
|
137
|
+
del cluster['evaluator']
|
|
138
|
+
# 'evaluator' in cluster indicates user use new-style cluster content
|
|
139
|
+
train_worker_num += 1
|
|
140
|
+
if 'chief' in cluster:
|
|
141
|
+
train_worker_num += len(cluster['chief'])
|
|
142
|
+
elif 'master' in cluster:
|
|
143
|
+
train_worker_num += len(cluster['master'])
|
|
144
|
+
if 'worker' in cluster:
|
|
145
|
+
train_worker_num += len(cluster['worker'])
|
|
146
|
+
cluster['worker'].append(evaluator[0])
|
|
147
|
+
else:
|
|
148
|
+
cluster['worker'] = [evaluator[0]]
|
|
149
|
+
if task_type == 'evaluator':
|
|
150
|
+
tf_config = {
|
|
151
|
+
'cluster': cluster,
|
|
152
|
+
'task': {
|
|
153
|
+
'type': 'worker',
|
|
154
|
+
'index': train_worker_num - 2
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
else:
|
|
158
|
+
tf_config = {
|
|
159
|
+
'cluster': cluster,
|
|
160
|
+
'task': {
|
|
161
|
+
'type': task_type,
|
|
162
|
+
'index': task_index_
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
os.environ['TF_CONFIG'] = json.dumps(tf_config)
|
|
166
|
+
else:
|
|
167
|
+
cluster = {'chief': [worker_hosts[0]], 'worker': worker_hosts[1:]}
|
|
168
|
+
train_worker_num = len(worker_hosts)
|
|
169
|
+
if len(ps_hosts) > 0:
|
|
170
|
+
cluster['ps'] = ps_hosts
|
|
171
|
+
if job_name == 'ps':
|
|
172
|
+
os.environ['TF_CONFIG'] = json.dumps({
|
|
173
|
+
'cluster': cluster,
|
|
174
|
+
'task': {
|
|
175
|
+
'type': job_name,
|
|
176
|
+
'index': task_index
|
|
177
|
+
}
|
|
178
|
+
})
|
|
179
|
+
else:
|
|
180
|
+
if task_index == 0:
|
|
181
|
+
os.environ['TF_CONFIG'] = json.dumps({
|
|
182
|
+
'cluster': cluster,
|
|
183
|
+
'task': {
|
|
184
|
+
'type': 'chief',
|
|
185
|
+
'index': 0
|
|
186
|
+
}
|
|
187
|
+
})
|
|
188
|
+
else:
|
|
189
|
+
os.environ['TF_CONFIG'] = json.dumps({
|
|
190
|
+
'cluster': cluster,
|
|
191
|
+
'task': {
|
|
192
|
+
'type': 'worker',
|
|
193
|
+
'index': task_index - 1
|
|
194
|
+
}
|
|
195
|
+
})
|
|
196
|
+
if eval_method == 'none':
|
|
197
|
+
# change master to chief, will not evaluate
|
|
198
|
+
master_to_chief()
|
|
199
|
+
elif eval_method == 'master':
|
|
200
|
+
# change chief to master, will evaluate on master
|
|
201
|
+
chief_to_master()
|
|
202
|
+
else:
|
|
203
|
+
assert distribute_strategy == '', 'invalid distribute_strategy %s'\
|
|
204
|
+
% distribute_strategy
|
|
205
|
+
cluster, task_type, task_index = estimator_utils.parse_tf_config()
|
|
206
|
+
print('Final TF_CONFIG = %s' % os.environ.get('TF_CONFIG', ''))
|
|
207
|
+
tf.logging.info('TF_CONFIG %s' % os.environ.get('TF_CONFIG', ''))
|
|
208
|
+
tf.logging.info('distribute_stategy %s, train_worker_num: %d' %
|
|
209
|
+
(distribute_strategy, train_worker_num))
|
|
210
|
+
|
|
211
|
+
# remove pai chief-worker waiting strategy
|
|
212
|
+
# which is conflicted with worker waiting strategy in easyrec
|
|
213
|
+
if 'TF_WRITE_WORKER_STATUS_FILE' in os.environ:
|
|
214
|
+
del os.environ['TF_WRITE_WORKER_STATUS_FILE']
|
|
215
|
+
return train_worker_num
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def set_tf_config_and_get_train_worker_num_on_ds():
|
|
219
|
+
if 'TF_CONFIG' not in os.environ:
|
|
220
|
+
return
|
|
221
|
+
tf_config = json.loads(os.environ['TF_CONFIG'])
|
|
222
|
+
if 'cluster' in tf_config and 'ps' in tf_config['cluster'] and (
|
|
223
|
+
'evaluator' not in tf_config['cluster']):
|
|
224
|
+
easyrec_tf_config = dict()
|
|
225
|
+
easyrec_tf_config['cluster'] = {}
|
|
226
|
+
easyrec_tf_config['task'] = {}
|
|
227
|
+
easyrec_tf_config['cluster']['ps'] = tf_config['cluster']['ps']
|
|
228
|
+
easyrec_tf_config['cluster']['chief'] = [tf_config['cluster']['worker'][0]]
|
|
229
|
+
easyrec_tf_config['cluster']['worker'] = tf_config['cluster']['worker'][2:]
|
|
230
|
+
|
|
231
|
+
if tf_config['task']['type'] == 'worker' and tf_config['task']['index'] == 0:
|
|
232
|
+
easyrec_tf_config['task']['type'] = 'chief'
|
|
233
|
+
easyrec_tf_config['task']['index'] = 0
|
|
234
|
+
elif tf_config['task']['type'] == 'worker' and tf_config['task'][
|
|
235
|
+
'index'] == 1:
|
|
236
|
+
easyrec_tf_config['task']['type'] = 'evaluator'
|
|
237
|
+
easyrec_tf_config['task']['index'] = 0
|
|
238
|
+
elif tf_config['task']['type'] == 'worker':
|
|
239
|
+
easyrec_tf_config['task']['type'] = tf_config['task']['type']
|
|
240
|
+
easyrec_tf_config['task']['index'] = tf_config['task']['index'] - 2
|
|
241
|
+
else:
|
|
242
|
+
easyrec_tf_config['task']['type'] = tf_config['task']['type']
|
|
243
|
+
easyrec_tf_config['task']['index'] = tf_config['task']['index']
|
|
244
|
+
os.environ['TF_CONFIG'] = json.dumps(easyrec_tf_config)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def set_tf_config_and_get_distribute_eval_worker_num_on_ds():
|
|
248
|
+
assert 'TF_CONFIG' in os.environ, "'TF_CONFIG' must in os.environ"
|
|
249
|
+
tf_config = json.loads(os.environ['TF_CONFIG'])
|
|
250
|
+
if 'cluster' in tf_config and 'ps' in tf_config['cluster'] and (
|
|
251
|
+
'evaluator' not in tf_config['cluster']):
|
|
252
|
+
easyrec_tf_config = dict()
|
|
253
|
+
easyrec_tf_config['cluster'] = {}
|
|
254
|
+
easyrec_tf_config['task'] = {}
|
|
255
|
+
easyrec_tf_config['cluster']['ps'] = tf_config['cluster']['ps']
|
|
256
|
+
easyrec_tf_config['cluster']['chief'] = [tf_config['cluster']['worker'][0]]
|
|
257
|
+
easyrec_tf_config['cluster']['worker'] = tf_config['cluster']['worker'][1:]
|
|
258
|
+
|
|
259
|
+
if tf_config['task']['type'] == 'worker' and tf_config['task']['index'] == 0:
|
|
260
|
+
easyrec_tf_config['task']['type'] = 'chief'
|
|
261
|
+
easyrec_tf_config['task']['index'] = 0
|
|
262
|
+
elif tf_config['task']['type'] == 'worker':
|
|
263
|
+
easyrec_tf_config['task']['type'] = tf_config['task']['type']
|
|
264
|
+
easyrec_tf_config['task']['index'] = tf_config['task']['index'] - 1
|
|
265
|
+
else:
|
|
266
|
+
easyrec_tf_config['task']['type'] = tf_config['task']['type']
|
|
267
|
+
easyrec_tf_config['task']['index'] = tf_config['task']['index']
|
|
268
|
+
os.environ['TF_CONFIG'] = json.dumps(easyrec_tf_config)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import subprocess
|
|
6
|
+
import traceback
|
|
7
|
+
|
|
8
|
+
from tensorflow.python.platform import gfile
|
|
9
|
+
|
|
10
|
+
from easy_rec.python.utils import estimator_utils
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def is_on_ds():
|
|
14
|
+
# IS_ON_PAI is set in train_eval
|
|
15
|
+
# which is the entry on DataScience platform
|
|
16
|
+
return 'IS_ON_DS' in os.environ
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def set_on_ds():
|
|
20
|
+
logging.info('set on ds environment variable: IS_ON_DS')
|
|
21
|
+
os.environ['IS_ON_DS'] = '1'
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def cache_ckpt(pipeline_config):
|
|
25
|
+
fine_tune_ckpt_path = pipeline_config.train_config.fine_tune_checkpoint
|
|
26
|
+
if not fine_tune_ckpt_path.startswith('hdfs://'):
|
|
27
|
+
# there is no need to cache if remote directories are mounted
|
|
28
|
+
return
|
|
29
|
+
|
|
30
|
+
if estimator_utils.is_ps() or estimator_utils.is_chief(
|
|
31
|
+
) or estimator_utils.is_master():
|
|
32
|
+
tmpdir = os.path.dirname(fine_tune_ckpt_path.replace('hdfs://', ''))
|
|
33
|
+
tmpdir = os.path.join('/tmp/experiments', tmpdir)
|
|
34
|
+
logging.info('will cache fine_tune_ckpt to local dir: %s' % tmpdir)
|
|
35
|
+
if gfile.IsDirectory(tmpdir):
|
|
36
|
+
gfile.DeleteRecursively(tmpdir)
|
|
37
|
+
gfile.MakeDirs(tmpdir)
|
|
38
|
+
src_files = gfile.Glob(fine_tune_ckpt_path + '*')
|
|
39
|
+
src_files.sort()
|
|
40
|
+
data_files = [x for x in src_files if '.data-' in x]
|
|
41
|
+
meta_files = [x for x in src_files if '.data-' not in x]
|
|
42
|
+
if estimator_utils.is_ps():
|
|
43
|
+
_, _, ps_id = estimator_utils.parse_tf_config()
|
|
44
|
+
ps_id = (ps_id % len(data_files))
|
|
45
|
+
data_files = data_files[ps_id:] + data_files[:ps_id]
|
|
46
|
+
src_files = meta_files + data_files
|
|
47
|
+
else:
|
|
48
|
+
src_files = meta_files
|
|
49
|
+
for src_path in src_files:
|
|
50
|
+
_, file_name = os.path.split(src_path)
|
|
51
|
+
dst_path = os.path.join(tmpdir, os.path.basename(src_path))
|
|
52
|
+
logging.info('will copy %s to local path %s' % (src_path, dst_path))
|
|
53
|
+
try:
|
|
54
|
+
output = subprocess.check_output(
|
|
55
|
+
'hadoop fs -get %s %s' % (src_path, dst_path), shell=True)
|
|
56
|
+
logging.info('copy succeed: %s' % output)
|
|
57
|
+
except Exception:
|
|
58
|
+
logging.warning('exception: %s' % traceback.format_exc())
|
|
59
|
+
ckpt_filename = os.path.basename(fine_tune_ckpt_path)
|
|
60
|
+
fine_tune_ckpt_path = os.path.join(tmpdir, ckpt_filename)
|
|
61
|
+
pipeline_config.train_config.fine_tune_checkpoint = fine_tune_ckpt_path
|
|
62
|
+
logging.info('will restore from %s' % fine_tune_ckpt_path)
|
|
63
|
+
else:
|
|
64
|
+
# workers do not have to create the restore graph
|
|
65
|
+
pipeline_config.train_config.ClearField('fine_tune_checkpoint')
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
from tensorflow.python.framework import ops
|
|
7
|
+
|
|
8
|
+
from easy_rec.python.utils import constant
|
|
9
|
+
from easy_rec.python.utils import proto_util
|
|
10
|
+
|
|
11
|
+
if tf.__version__ >= '2.0':
|
|
12
|
+
tf = tf.compat.v1
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_norm_name_to_ids():
|
|
16
|
+
"""Get normalize embedding name(including kv variables) to ids.
|
|
17
|
+
|
|
18
|
+
Return:
|
|
19
|
+
normalized names to ids mapping.
|
|
20
|
+
"""
|
|
21
|
+
norm_name_to_ids = {}
|
|
22
|
+
for x in ops.get_collection(constant.SPARSE_UPDATE_VARIABLES):
|
|
23
|
+
norm_name, part_id = proto_util.get_norm_embed_name(x[0].name)
|
|
24
|
+
norm_name_to_ids[norm_name] = 1
|
|
25
|
+
|
|
26
|
+
for tid, t in enumerate(norm_name_to_ids.keys()):
|
|
27
|
+
norm_name_to_ids[t] = str(tid)
|
|
28
|
+
return norm_name_to_ids
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_sparse_name_to_ids():
|
|
32
|
+
"""Get embedding variable(including kv variables) name to ids mapping.
|
|
33
|
+
|
|
34
|
+
Return:
|
|
35
|
+
variable names to ids mappping.
|
|
36
|
+
"""
|
|
37
|
+
norm_name_to_ids = get_norm_name_to_ids()
|
|
38
|
+
name_to_ids = {}
|
|
39
|
+
for x in ops.get_collection(constant.SPARSE_UPDATE_VARIABLES):
|
|
40
|
+
norm_name, _ = proto_util.get_norm_embed_name(x[0].name)
|
|
41
|
+
name_to_ids[x[0].name] = norm_name_to_ids[norm_name]
|
|
42
|
+
return name_to_ids
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_dense_name_to_ids():
|
|
46
|
+
dense_train_vars = ops.get_collection(constant.DENSE_UPDATE_VARIABLES)
|
|
47
|
+
norm_name_to_ids = {}
|
|
48
|
+
for tid, x in enumerate(dense_train_vars):
|
|
49
|
+
norm_name_to_ids[x.op.name] = tid
|
|
50
|
+
return norm_name_to_ids
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
embedding_parallel = False
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def set_embedding_parallel():
|
|
57
|
+
global embedding_parallel
|
|
58
|
+
embedding_parallel = True
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def is_embedding_parallel():
|
|
62
|
+
global embedding_parallel
|
|
63
|
+
return embedding_parallel
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def sort_col_by_name():
|
|
67
|
+
return constant.SORT_COL_BY_NAME in os.environ
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def embedding_on_cpu():
|
|
71
|
+
place_on_cpu = os.getenv(constant.EmbeddingOnCPU)
|
|
72
|
+
place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
|
|
73
|
+
return place_on_cpu
|