easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
import psutil
|
|
8
|
+
import tensorflow as tf
|
|
9
|
+
from tensorflow.python.summary import summary_iterator
|
|
10
|
+
|
|
11
|
+
if tf.__version__ >= '2.0':
|
|
12
|
+
gfile = tf.compat.v1.gfile
|
|
13
|
+
else:
|
|
14
|
+
gfile = tf.gfile
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_all_eval_result(event_file_pattern):
|
|
18
|
+
"""Get the best eval result from event files.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
event_files: Absolute pattern of event files.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
The best eval result.
|
|
25
|
+
"""
|
|
26
|
+
all_eval_result = []
|
|
27
|
+
for event_file in gfile.Glob(os.path.join(event_file_pattern)):
|
|
28
|
+
for event in summary_iterator.summary_iterator(event_file):
|
|
29
|
+
if event.HasField('summary'):
|
|
30
|
+
event_eval_result = {}
|
|
31
|
+
event_eval_result['global_step'] = event.step
|
|
32
|
+
for value in event.summary.value:
|
|
33
|
+
if value.HasField('simple_value'):
|
|
34
|
+
event_eval_result[value.tag] = value.simple_value
|
|
35
|
+
if len(event_eval_result) >= 2:
|
|
36
|
+
all_eval_result.append(event_eval_result)
|
|
37
|
+
return all_eval_result
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def save_eval_metrics(model_dir, metric_save_path, has_evaluator=True):
|
|
41
|
+
"""Save evaluation metrics.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
model_dir: train model directory
|
|
45
|
+
metric_save_path: metric saving path
|
|
46
|
+
has_evaluator: evaluation is done on a separate evaluator, not on master.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def _get_eval_event_file_pattern():
|
|
50
|
+
eval_dir = os.path.join(model_dir, 'eval_val/')
|
|
51
|
+
if not gfile.Exists(eval_dir):
|
|
52
|
+
eval_dir = os.path.join(model_dir, 'eval/')
|
|
53
|
+
assert gfile.Exists(eval_dir), 'eval_val or eval does exists'
|
|
54
|
+
event_file_pattern = os.path.join(eval_dir, '*.tfevents.*')
|
|
55
|
+
logging.info('event_file_pattern: %s' % event_file_pattern)
|
|
56
|
+
return event_file_pattern
|
|
57
|
+
|
|
58
|
+
all_eval_res = []
|
|
59
|
+
if 'TF_CONFIG' in os.environ:
|
|
60
|
+
# check whether evaluator exists
|
|
61
|
+
tf_config = json.loads(os.environ['TF_CONFIG'])
|
|
62
|
+
logging.info('tf_config = %s' % json.dumps(tf_config))
|
|
63
|
+
logging.info('model_dir = %s' % model_dir)
|
|
64
|
+
if has_evaluator:
|
|
65
|
+
if tf_config['task']['type'] == 'evaluator':
|
|
66
|
+
event_file_pattern = _get_eval_event_file_pattern()
|
|
67
|
+
all_eval_res = get_all_eval_result(event_file_pattern)
|
|
68
|
+
elif 'master' in tf_config['cluster'] or 'chief' in tf_config['cluster']:
|
|
69
|
+
if tf_config['task']['type'] in ['master', 'chief']:
|
|
70
|
+
event_file_pattern = _get_eval_event_file_pattern()
|
|
71
|
+
all_eval_res = get_all_eval_result(event_file_pattern)
|
|
72
|
+
else:
|
|
73
|
+
assert False, 'invalid cluster config, could not find master or chief or evaluator'
|
|
74
|
+
else:
|
|
75
|
+
# standalone mode
|
|
76
|
+
event_file_pattern = _get_eval_event_file_pattern()
|
|
77
|
+
all_eval_res = get_all_eval_result(event_file_pattern)
|
|
78
|
+
|
|
79
|
+
logging.info('all_eval_res num = %d' % len(all_eval_res))
|
|
80
|
+
if len(all_eval_res) > 0:
|
|
81
|
+
with gfile.GFile(metric_save_path, 'w') as fout:
|
|
82
|
+
for eval_res in all_eval_res:
|
|
83
|
+
fout.write(json.dumps(eval_res) + '\n')
|
|
84
|
+
logging.info('save all evaluation result to %s' % metric_save_path)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def kill_old_proc(tmp_dir, platform='pai'):
|
|
88
|
+
curr_pid = os.getpid()
|
|
89
|
+
if platform == 'pai':
|
|
90
|
+
for p in psutil.process_iter():
|
|
91
|
+
try:
|
|
92
|
+
cmd = ' '.join(p.cmdline())
|
|
93
|
+
if 'easy_rec.python.hpo.pai_hpo' in cmd and 'python' in cmd:
|
|
94
|
+
if p.pid != curr_pid:
|
|
95
|
+
logging.info('will kill: [%d] %s' % (p.pid, cmd))
|
|
96
|
+
p.terminate()
|
|
97
|
+
if 'client/experiment_main.py' in cmd and 'python' in cmd:
|
|
98
|
+
if p.pid != curr_pid:
|
|
99
|
+
logging.info('will kill: [%d] %s' % (p.pid, cmd))
|
|
100
|
+
p.terminate()
|
|
101
|
+
except Exception:
|
|
102
|
+
pass
|
|
103
|
+
else:
|
|
104
|
+
for p in psutil.process_iter():
|
|
105
|
+
try:
|
|
106
|
+
cmd = ' '.join(p.cmdline())
|
|
107
|
+
if 'easy_rec.python.hpo.emr_hpo' in cmd and 'python' in cmd:
|
|
108
|
+
if p.pid != curr_pid:
|
|
109
|
+
logging.info('will kill: [%d] %s' % (p.pid, cmd))
|
|
110
|
+
p.terminate()
|
|
111
|
+
if 'client/experiment_main.py' in cmd and 'python' in cmd:
|
|
112
|
+
if p.pid != curr_pid:
|
|
113
|
+
logging.info('will kill: [%d] %s' % (p.pid, cmd))
|
|
114
|
+
p.terminate()
|
|
115
|
+
if 'el_submit' in cmd and 'easy_rec_hpo' in cmd:
|
|
116
|
+
if p.pid != curr_pid:
|
|
117
|
+
logging.info('will kill: [%d] %s' % (p.pid, cmd))
|
|
118
|
+
p.terminate()
|
|
119
|
+
except Exception:
|
|
120
|
+
pass
|
|
121
|
+
|
|
122
|
+
if platform == 'emr':
|
|
123
|
+
# clear easy_rec_hpo yarn jobs
|
|
124
|
+
yarn_job_file = os.path.join(tmp_dir, 'yarn_job.txt')
|
|
125
|
+
os.system(
|
|
126
|
+
"yarn application -list | awk '{ if ($2 == \"easy_rec_hpo\") print $1 }' > %s"
|
|
127
|
+
% yarn_job_file)
|
|
128
|
+
yarn_job_arr = []
|
|
129
|
+
with open(yarn_job_file, 'r') as fin:
|
|
130
|
+
for line_str in fin:
|
|
131
|
+
line_str = line_str.strip()
|
|
132
|
+
yarn_job_arr.append(line_str)
|
|
133
|
+
yarn_job_arr = list(set(yarn_job_arr))
|
|
134
|
+
if len(yarn_job_arr) > 0:
|
|
135
|
+
logging.info('will kill the easy_rec_hpo yarn jobs: %s' %
|
|
136
|
+
','.join(yarn_job_arr))
|
|
137
|
+
os.system('yarn application -kill %s' % ' '.join(yarn_job_arr))
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# -*- encoding: utf-8 -*-
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
import tensorflow as tf
|
|
5
|
+
from tensorflow.python.framework import ops
|
|
6
|
+
from tensorflow.python.training import session_run_hook
|
|
7
|
+
|
|
8
|
+
from easy_rec.python.utils import constant
|
|
9
|
+
|
|
10
|
+
# from horovod.tensorflow.compression import Compression
|
|
11
|
+
try:
|
|
12
|
+
from horovod.tensorflow.functions import broadcast_variables
|
|
13
|
+
except Exception:
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
if tf.__version__ >= '2.0':
|
|
17
|
+
tf = tf.compat.v1
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BroadcastGlobalVariablesHook(session_run_hook.SessionRunHook):
|
|
21
|
+
"""SessionRunHook that will broadcast all global variables from root rank to all other processes during initialization.
|
|
22
|
+
|
|
23
|
+
This is necessary to ensure consistent initialization of all workers when
|
|
24
|
+
training is started with random weights or restored from a checkpoint.
|
|
25
|
+
""" # noqa: E501
|
|
26
|
+
|
|
27
|
+
def __init__(self, root_rank, device=''):
|
|
28
|
+
"""Construct a new BroadcastGlobalVariablesHook that will broadcast all global variables from root rank to all other processes during initialization.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
root_rank:
|
|
32
|
+
Rank that will send data, other ranks will receive data.
|
|
33
|
+
device:
|
|
34
|
+
Device to be used for broadcasting. Uses GPU by default
|
|
35
|
+
if Horovod was built with HOROVOD_GPU_OPERATIONS.
|
|
36
|
+
""" # noqa: E501
|
|
37
|
+
super(BroadcastGlobalVariablesHook, self).__init__()
|
|
38
|
+
self.root_rank = root_rank
|
|
39
|
+
self.bcast_op = None
|
|
40
|
+
self.device = device
|
|
41
|
+
|
|
42
|
+
def begin(self):
|
|
43
|
+
bcast_vars = []
|
|
44
|
+
embed_para_vars = ops.get_collection(constant.EmbeddingParallel)
|
|
45
|
+
for x in tf.global_variables():
|
|
46
|
+
# if '/embedding' not in x.name and 'DynamicVariable' not in str(type(x)):
|
|
47
|
+
if x.name not in embed_para_vars:
|
|
48
|
+
bcast_vars.append(x)
|
|
49
|
+
logging.info('will broadcast variable: name=%s shape=%s' %
|
|
50
|
+
(x.name, x.get_shape()))
|
|
51
|
+
if not self.bcast_op or self.bcast_op.graph != tf.get_default_graph():
|
|
52
|
+
with tf.device(self.device):
|
|
53
|
+
self.bcast_op = broadcast_variables(bcast_vars, self.root_rank)
|
|
54
|
+
|
|
55
|
+
def after_create_session(self, session, coord):
|
|
56
|
+
session.run(self.bcast_op)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
|
|
7
|
+
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
|
|
8
|
+
|
|
9
|
+
if tf.__version__ >= '2.0':
|
|
10
|
+
tf = tf.compat.v1
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_type_defaults(field_type, default_val=''):
|
|
14
|
+
type_defaults = {
|
|
15
|
+
DatasetConfig.INT32: 0,
|
|
16
|
+
DatasetConfig.INT64: 0,
|
|
17
|
+
DatasetConfig.STRING: '',
|
|
18
|
+
DatasetConfig.BOOL: False,
|
|
19
|
+
DatasetConfig.FLOAT: 0.0,
|
|
20
|
+
DatasetConfig.DOUBLE: 0.0
|
|
21
|
+
}
|
|
22
|
+
assert field_type in type_defaults, 'invalid type: %s' % field_type
|
|
23
|
+
if default_val == '':
|
|
24
|
+
default_val = type_defaults[field_type]
|
|
25
|
+
if field_type == DatasetConfig.INT32:
|
|
26
|
+
return int(default_val)
|
|
27
|
+
elif field_type == DatasetConfig.INT64:
|
|
28
|
+
return np.int64(default_val)
|
|
29
|
+
elif field_type == DatasetConfig.STRING:
|
|
30
|
+
return default_val
|
|
31
|
+
elif field_type == DatasetConfig.BOOL:
|
|
32
|
+
return default_val.lower() == 'true'
|
|
33
|
+
elif field_type in [DatasetConfig.FLOAT]:
|
|
34
|
+
return float(default_val)
|
|
35
|
+
elif field_type in [DatasetConfig.DOUBLE]:
|
|
36
|
+
return np.float64(default_val)
|
|
37
|
+
|
|
38
|
+
return type_defaults[field_type]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def string_to_number(field, ftype, default_value, name=''):
|
|
42
|
+
"""Type conversion for parsing rtp fg input format.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
field: field to be converted.
|
|
46
|
+
ftype: field dtype set in DatasetConfig.
|
|
47
|
+
default_value: default value for this field
|
|
48
|
+
name: field name for
|
|
49
|
+
Returns: A name for the operation (optional).
|
|
50
|
+
"""
|
|
51
|
+
default_vals = tf.tile(tf.constant([str(default_value)]), tf.shape(field))
|
|
52
|
+
field = tf.where(tf.greater(tf.strings.length(field), 0), field, default_vals)
|
|
53
|
+
|
|
54
|
+
if ftype in [DatasetConfig.INT32, DatasetConfig.INT64]:
|
|
55
|
+
# Int type is not supported in fg.
|
|
56
|
+
# If you specify INT32, INT64 in DatasetConfig, you need to perform a cast at here.
|
|
57
|
+
tmp_field = tf.string_to_number(
|
|
58
|
+
field, tf.double, name='field_as_flt_%s' % name)
|
|
59
|
+
if ftype in [DatasetConfig.INT64]:
|
|
60
|
+
tmp_field = tf.cast(tmp_field, tf.int64)
|
|
61
|
+
else:
|
|
62
|
+
tmp_field = tf.cast(tmp_field, tf.int32)
|
|
63
|
+
elif ftype in [DatasetConfig.FLOAT]:
|
|
64
|
+
tmp_field = tf.string_to_number(
|
|
65
|
+
field, tf.float32, name='field_as_flt_%s' % name)
|
|
66
|
+
elif ftype in [DatasetConfig.DOUBLE]:
|
|
67
|
+
tmp_field = tf.string_to_number(
|
|
68
|
+
field, tf.float64, name='field_as_flt_%s' % name)
|
|
69
|
+
elif ftype in [DatasetConfig.BOOL]:
|
|
70
|
+
tmp_field = tf.logical_or(tf.equal(field, 'True'), tf.equal(field, 'true'))
|
|
71
|
+
elif ftype in [DatasetConfig.STRING]:
|
|
72
|
+
tmp_field = field
|
|
73
|
+
else:
|
|
74
|
+
assert False, 'invalid types: %s' % str(ftype)
|
|
75
|
+
return tmp_field
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def np_to_tf_type(np_type):
|
|
79
|
+
_types_map = {
|
|
80
|
+
int: tf.int32,
|
|
81
|
+
np.int32: tf.int32,
|
|
82
|
+
np.int64: tf.int64,
|
|
83
|
+
str: tf.string,
|
|
84
|
+
np.float: tf.float32,
|
|
85
|
+
np.float32: tf.float32,
|
|
86
|
+
float: tf.float32,
|
|
87
|
+
np.double: tf.float64
|
|
88
|
+
}
|
|
89
|
+
if np_type in _types_map:
|
|
90
|
+
return _types_map[np_type]
|
|
91
|
+
else:
|
|
92
|
+
return tf.string
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_tf_type_from_parquet_file(cols, parquet_file):
|
|
96
|
+
# gfile not supported, read_parquet requires random access
|
|
97
|
+
input_data = pd.read_parquet(parquet_file, columns=cols)
|
|
98
|
+
tf_types = []
|
|
99
|
+
for col in cols:
|
|
100
|
+
obj = input_data[col][0]
|
|
101
|
+
if isinstance(obj, list):
|
|
102
|
+
data_type = type(obj[0])
|
|
103
|
+
elif isinstance(obj, np.ndarray):
|
|
104
|
+
data_type = type(obj[0])
|
|
105
|
+
else:
|
|
106
|
+
data_type = type(obj)
|
|
107
|
+
tf_types.append(np_to_tf_type(data_type))
|
|
108
|
+
return tf_types
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
"""IO utils.
|
|
4
|
+
|
|
5
|
+
isort:skip_file
|
|
6
|
+
"""
|
|
7
|
+
import logging
|
|
8
|
+
from future import standard_library
|
|
9
|
+
standard_library.install_aliases()
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
import traceback
|
|
13
|
+
from subprocess import getstatusoutput
|
|
14
|
+
|
|
15
|
+
import six
|
|
16
|
+
import tensorflow as tf
|
|
17
|
+
from six.moves import http_client
|
|
18
|
+
from six.moves import urllib
|
|
19
|
+
import json
|
|
20
|
+
if six.PY2:
|
|
21
|
+
from urllib import quote
|
|
22
|
+
else:
|
|
23
|
+
from urllib.parse import quote
|
|
24
|
+
|
|
25
|
+
if tf.__version__ >= '2.0':
|
|
26
|
+
tf = tf.compat.v1
|
|
27
|
+
|
|
28
|
+
EASY_REC_RES_DIR = 'easy_rec_user_resources'
|
|
29
|
+
HTTP_MAX_NUM_RETRY = 5
|
|
30
|
+
HTTP_MAX_TIMEOUT = 600
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def http_read(url, timeout=HTTP_MAX_TIMEOUT, max_retry=HTTP_MAX_NUM_RETRY):
|
|
34
|
+
"""Read data from url with maximum retry.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
url: http url to be read
|
|
38
|
+
timeout: specifies a timeout in seconds for blocking operations.
|
|
39
|
+
max_retry: http max retry times.
|
|
40
|
+
"""
|
|
41
|
+
num_read_try = 0
|
|
42
|
+
data = None
|
|
43
|
+
while num_read_try < max_retry:
|
|
44
|
+
try:
|
|
45
|
+
if six.PY2:
|
|
46
|
+
url = url.encode('utf-8')
|
|
47
|
+
url = quote(url, safe='%/:?=&')
|
|
48
|
+
data = urllib.request.urlopen(url, timeout=timeout).read()
|
|
49
|
+
break
|
|
50
|
+
except http_client.IncompleteRead:
|
|
51
|
+
tf.logging.warning('incomplete read exception, will retry: %s' % url)
|
|
52
|
+
num_read_try += 1
|
|
53
|
+
except Exception:
|
|
54
|
+
tf.logging.error(traceback.format_exc())
|
|
55
|
+
break
|
|
56
|
+
|
|
57
|
+
if data is None:
|
|
58
|
+
tf.logging.error('http read %s failed' % url)
|
|
59
|
+
|
|
60
|
+
return data
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def download(oss_or_url, dst_dir=''):
|
|
64
|
+
"""Download file.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
oss_or_url: http or oss path
|
|
68
|
+
dst_dir: destination directory
|
|
69
|
+
Return:
|
|
70
|
+
dst_file: local path for the downloaded file
|
|
71
|
+
"""
|
|
72
|
+
_, basename = os.path.split(oss_or_url)
|
|
73
|
+
if oss_or_url[:3] == 'oss':
|
|
74
|
+
with tf.gfile.GFile(oss_or_url, 'rb') as infile:
|
|
75
|
+
file_content = infile.read()
|
|
76
|
+
elif oss_or_url[:4] == 'http':
|
|
77
|
+
try:
|
|
78
|
+
response = urllib.request.urlopen(oss_or_url, timeout=HTTP_MAX_TIMEOUT)
|
|
79
|
+
file_content = response.read()
|
|
80
|
+
except Exception as e:
|
|
81
|
+
raise RuntimeError('Download %s failed: %s\n %s' %
|
|
82
|
+
(oss_or_url, str(e), traceback.format_exc()))
|
|
83
|
+
else:
|
|
84
|
+
tf.logging.warning('skip downloading %s, seems to be a local file' %
|
|
85
|
+
oss_or_url)
|
|
86
|
+
return oss_or_url
|
|
87
|
+
|
|
88
|
+
if dst_dir != '' and not os.path.exists(dst_dir):
|
|
89
|
+
os.makedirs(dst_dir)
|
|
90
|
+
dst_file = os.path.join(dst_dir, basename)
|
|
91
|
+
with tf.gfile.GFile(dst_file, 'wb') as ofile:
|
|
92
|
+
ofile.write(file_content)
|
|
93
|
+
|
|
94
|
+
return dst_file
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def create_module_dir(dst_dir):
|
|
98
|
+
if not os.path.exists(dst_dir):
|
|
99
|
+
os.makedirs(dst_dir)
|
|
100
|
+
with open(os.path.join(dst_dir, '__init__.py'), 'w') as ofile:
|
|
101
|
+
ofile.write('\n')
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def download_resource(resource_path, dst_dir=EASY_REC_RES_DIR):
|
|
105
|
+
"""Download user resource.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
resource_path: http or oss path
|
|
109
|
+
dst_dir: destination directory
|
|
110
|
+
"""
|
|
111
|
+
create_module_dir(dst_dir)
|
|
112
|
+
_, basename = os.path.split(resource_path)
|
|
113
|
+
if not basename.endswith('.py'):
|
|
114
|
+
raise ValueError('resource %s should be python file' % resource_path)
|
|
115
|
+
|
|
116
|
+
target = download(resource_path, dst_dir)
|
|
117
|
+
|
|
118
|
+
return target
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def download_and_uncompress_resource(resource_path, dst_dir=EASY_REC_RES_DIR):
|
|
122
|
+
"""Download user resource and uncompress it if necessary.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
resource_path: http or oss path
|
|
126
|
+
dst_dir: download destination directory
|
|
127
|
+
"""
|
|
128
|
+
create_module_dir(dst_dir)
|
|
129
|
+
|
|
130
|
+
_, basename = os.path.split(resource_path)
|
|
131
|
+
if not basename.endswith('.tar.gz') and not basename.endswith('.zip') and \
|
|
132
|
+
not basename.endswith('.py'):
|
|
133
|
+
raise ValueError('resource %s should be tar.gz or zip or py' %
|
|
134
|
+
resource_path)
|
|
135
|
+
|
|
136
|
+
download(resource_path, dst_dir)
|
|
137
|
+
|
|
138
|
+
stat = 0
|
|
139
|
+
if basename.endswith('tar.gz'):
|
|
140
|
+
stat, output = getstatusoutput('cd %s && tar -zxf %s' % (dst_dir, basename))
|
|
141
|
+
elif basename.endswith('zip'):
|
|
142
|
+
stat, output = getstatusoutput('cd %s && unzip %s' % (dst_dir, basename))
|
|
143
|
+
|
|
144
|
+
if stat != 0:
|
|
145
|
+
raise ValueError('uncompress resoruce %s failed: %s' % resource_path,
|
|
146
|
+
output)
|
|
147
|
+
|
|
148
|
+
return dst_dir
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def oss_has_t_mode(target_file):
|
|
152
|
+
"""Test if current enviroment support t-mode written to oss."""
|
|
153
|
+
if 'PAI' not in tf.__version__:
|
|
154
|
+
return False
|
|
155
|
+
# test if running on cluster
|
|
156
|
+
test_file = target_file + '.tmp'
|
|
157
|
+
try:
|
|
158
|
+
with tf.gfile.GFile(test_file, 't') as ofile:
|
|
159
|
+
ofile.write('a')
|
|
160
|
+
pass
|
|
161
|
+
tf.gfile.Remove(test_file)
|
|
162
|
+
return True
|
|
163
|
+
except: # noqa: E722
|
|
164
|
+
return False
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def fix_oss_dir(path):
|
|
168
|
+
"""Make sure that oss dir endswith /."""
|
|
169
|
+
if path.startswith('oss://') and not path.endswith('/'):
|
|
170
|
+
return path + '/'
|
|
171
|
+
return path
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def save_data_to_json_path(json_path, data):
|
|
175
|
+
with tf.gfile.GFile(json_path, 'w') as fout:
|
|
176
|
+
fout.write(json.dumps(data))
|
|
177
|
+
assert tf.gfile.Exists(json_path), 'in_save_data_to_json_path, save_failed'
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def read_data_from_json_path(json_path):
|
|
181
|
+
if json_path and tf.gfile.Exists(json_path):
|
|
182
|
+
with tf.gfile.GFile(json_path, 'r') as fin:
|
|
183
|
+
data = json.loads(fin.read())
|
|
184
|
+
return data
|
|
185
|
+
else:
|
|
186
|
+
logging.info('json_path not exists, return None')
|
|
187
|
+
return None
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def convert_tf_flags_to_argparse(flags):
|
|
191
|
+
"""Convert tf.app.flags.FLAGS to argparse.ArgumentParser.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
flags: tf.app.flags.FLAGS
|
|
195
|
+
Returns:
|
|
196
|
+
argparse.ArgumentParser: configurate ArgumentParser object
|
|
197
|
+
"""
|
|
198
|
+
import argparse
|
|
199
|
+
import ast
|
|
200
|
+
parser = argparse.ArgumentParser()
|
|
201
|
+
|
|
202
|
+
args = {}
|
|
203
|
+
for flag in flags._flags().values():
|
|
204
|
+
flag_name = flag.name
|
|
205
|
+
if flag_name in args:
|
|
206
|
+
args[flag_name][0] = True
|
|
207
|
+
continue
|
|
208
|
+
default = flag.value
|
|
209
|
+
flag_type = type(default)
|
|
210
|
+
help_str = flag.help or ''
|
|
211
|
+
args[flag_name] = [
|
|
212
|
+
False, flag_type, default, help_str,
|
|
213
|
+
flag.choices if hasattr(flag, 'choices') else None
|
|
214
|
+
]
|
|
215
|
+
|
|
216
|
+
def str2bool(v):
|
|
217
|
+
if isinstance(v, bool):
|
|
218
|
+
return v
|
|
219
|
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
|
220
|
+
return True
|
|
221
|
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
|
222
|
+
return False
|
|
223
|
+
else:
|
|
224
|
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
|
225
|
+
|
|
226
|
+
for flag_name, (multi, flag_type, default, help_str, choices) in args.items():
|
|
227
|
+
if flag_type == bool:
|
|
228
|
+
parser.add_argument(
|
|
229
|
+
'--' + flag_name,
|
|
230
|
+
type=str2bool,
|
|
231
|
+
nargs='?',
|
|
232
|
+
const=True,
|
|
233
|
+
default=False,
|
|
234
|
+
help=help_str)
|
|
235
|
+
elif flag_type == str:
|
|
236
|
+
if choices:
|
|
237
|
+
parser.add_argument(
|
|
238
|
+
'--' + flag_name,
|
|
239
|
+
type=str,
|
|
240
|
+
choices=choices,
|
|
241
|
+
default=default,
|
|
242
|
+
help=help_str)
|
|
243
|
+
elif multi:
|
|
244
|
+
parser.add_argument(
|
|
245
|
+
'--' + flag_name,
|
|
246
|
+
type=str,
|
|
247
|
+
action='append',
|
|
248
|
+
default=default,
|
|
249
|
+
help=help_str)
|
|
250
|
+
else:
|
|
251
|
+
parser.add_argument(
|
|
252
|
+
'--' + flag_name, type=str, default=default, help=help_str)
|
|
253
|
+
elif flag_type in (list, dict):
|
|
254
|
+
parser.add_argument(
|
|
255
|
+
'--' + flag_name,
|
|
256
|
+
type=lambda s: ast.literal_eval(s),
|
|
257
|
+
default=default,
|
|
258
|
+
help=help_str)
|
|
259
|
+
elif flag_type in (int, float):
|
|
260
|
+
parser.add_argument(
|
|
261
|
+
'--' + flag_name, type=flag_type, default=default, help=help_str)
|
|
262
|
+
else:
|
|
263
|
+
parser.add_argument(
|
|
264
|
+
'--' + flag_name, type=str, default=default, help=help_str)
|
|
265
|
+
return parser
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def filter_unknown_args(flags, args):
|
|
269
|
+
"""Filter unknown args."""
|
|
270
|
+
known_args = [args[0]]
|
|
271
|
+
parser = convert_tf_flags_to_argparse(flags)
|
|
272
|
+
args, unknown = parser.parse_known_args(args)
|
|
273
|
+
if len(unknown) > 1:
|
|
274
|
+
logging.info('undefined arguments: %s', ', '.join(unknown[1:]))
|
|
275
|
+
for key, value in vars(args).items():
|
|
276
|
+
if value is None:
|
|
277
|
+
continue
|
|
278
|
+
if type(value) in (list, dict) and not value:
|
|
279
|
+
continue
|
|
280
|
+
known_args.append('--' + key + '=' + str(value))
|
|
281
|
+
logging.info('defined arguments: %s', ', '.join(known_args[1:]))
|
|
282
|
+
return known_args
|