easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
"""Add embedding column for EmbeddingVariable which is only available on pai."""
|
|
4
|
+
|
|
5
|
+
from tensorflow.python.framework import dtypes
|
|
6
|
+
from tensorflow.python.framework import ops
|
|
7
|
+
from tensorflow.python.framework import sparse_tensor
|
|
8
|
+
from tensorflow.python.framework import tensor_shape
|
|
9
|
+
from tensorflow.python.ops import array_ops
|
|
10
|
+
from tensorflow.python.ops import embedding_ops
|
|
11
|
+
from tensorflow.python.ops import math_ops
|
|
12
|
+
from tensorflow.python.ops import sparse_ops
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _prune_invalid_ids(sparse_ids, sparse_weights):
|
|
16
|
+
"""Prune invalid IDs (< 0) from the input ids and weights."""
|
|
17
|
+
is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
|
|
18
|
+
if sparse_weights is not None:
|
|
19
|
+
is_id_valid = math_ops.logical_and(
|
|
20
|
+
is_id_valid,
|
|
21
|
+
array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
|
|
22
|
+
sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
|
|
23
|
+
if sparse_weights is not None:
|
|
24
|
+
sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
|
|
25
|
+
return sparse_ids, sparse_weights
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _prune_invalid_weights(sparse_ids, sparse_weights):
|
|
29
|
+
"""Prune invalid weights (< 0) from the input ids and weights."""
|
|
30
|
+
if sparse_weights is not None:
|
|
31
|
+
is_weights_valid = math_ops.greater(sparse_weights.values, 0)
|
|
32
|
+
sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
|
|
33
|
+
sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
|
|
34
|
+
return sparse_ids, sparse_weights
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def safe_embedding_lookup_sparse(embedding_weights,
|
|
38
|
+
sparse_ids,
|
|
39
|
+
sparse_weights=None,
|
|
40
|
+
combiner='mean',
|
|
41
|
+
default_id=None,
|
|
42
|
+
name=None,
|
|
43
|
+
partition_strategy='div',
|
|
44
|
+
max_norm=None):
|
|
45
|
+
"""Lookup embedding results, accounting for invalid IDs and empty features.
|
|
46
|
+
|
|
47
|
+
Fixed so that could be used with Pai EmbeddingVariables.
|
|
48
|
+
|
|
49
|
+
The partitioned embedding in `embedding_weights` must all be the same shape
|
|
50
|
+
except for the first dimension. The first dimension is allowed to vary as the
|
|
51
|
+
vocabulary size is not necessarily a multiple of `P`. `embedding_weights`
|
|
52
|
+
may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a
|
|
53
|
+
partitioner.
|
|
54
|
+
|
|
55
|
+
Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
|
|
56
|
+
with non-positive weight. For an entry with no features, the embedding vector
|
|
57
|
+
for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
|
|
58
|
+
|
|
59
|
+
The ids and weights may be multi-dimensional. Embeddings are always aggregated
|
|
60
|
+
along the last dimension.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
embedding_weights: A list of `P` float `Tensor`s or values representing
|
|
64
|
+
partitioned embedding `Tensor`s. Alternatively, a `PartitionedVariable`
|
|
65
|
+
created by partitioning along dimension 0. The total unpartitioned
|
|
66
|
+
shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the
|
|
67
|
+
vocab size and `e_1, ..., e_m` are the embedding dimensions.
|
|
68
|
+
sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
|
|
69
|
+
ids. `d_0` is typically batch size.
|
|
70
|
+
sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
|
|
71
|
+
float weights corresponding to `sparse_ids`, or `None` if all weights
|
|
72
|
+
are be assumed to be 1.0.
|
|
73
|
+
combiner: A string specifying how to combine embedding results for each
|
|
74
|
+
entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
|
|
75
|
+
the default.
|
|
76
|
+
default_id: The id to use for an entry with no features.
|
|
77
|
+
name: A name for this operation (optional).
|
|
78
|
+
partition_strategy: A string specifying the partitioning strategy.
|
|
79
|
+
Currently `"div"` and `"mod"` are supported. Default is `"div"`.
|
|
80
|
+
max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
|
|
81
|
+
combining.
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
|
|
86
|
+
|
|
87
|
+
Raises:
|
|
88
|
+
ValueError: if `embedding_weights` is empty.
|
|
89
|
+
"""
|
|
90
|
+
if embedding_weights is None:
|
|
91
|
+
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
|
|
92
|
+
|
|
93
|
+
embed_tensors = [ops.convert_to_tensor(embedding_weights)]
|
|
94
|
+
with ops.name_scope(name, 'embedding_lookup',
|
|
95
|
+
embed_tensors + [sparse_ids, sparse_weights]) as scope:
|
|
96
|
+
# Reshape higher-rank sparse ids and weights to linear segment ids.
|
|
97
|
+
original_shape = sparse_ids.dense_shape
|
|
98
|
+
original_rank_dim = sparse_ids.dense_shape.get_shape()[0]
|
|
99
|
+
original_rank = (
|
|
100
|
+
array_ops.size(original_shape)
|
|
101
|
+
if original_rank_dim.value is None else original_rank_dim.value)
|
|
102
|
+
sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
|
|
103
|
+
math_ops.reduce_prod(
|
|
104
|
+
array_ops.slice(original_shape, [0], [original_rank - 1])),
|
|
105
|
+
array_ops.gather(original_shape, original_rank - 1)
|
|
106
|
+
])
|
|
107
|
+
if sparse_weights is not None:
|
|
108
|
+
sparse_weights = sparse_tensor.SparseTensor(sparse_ids.indices,
|
|
109
|
+
sparse_weights.values,
|
|
110
|
+
sparse_ids.dense_shape)
|
|
111
|
+
|
|
112
|
+
# Prune invalid ids and weights.
|
|
113
|
+
sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
|
|
114
|
+
if combiner != 'sum':
|
|
115
|
+
sparse_ids, sparse_weights = _prune_invalid_weights(
|
|
116
|
+
sparse_ids, sparse_weights)
|
|
117
|
+
|
|
118
|
+
# Fill in dummy values for empty features, if necessary.
|
|
119
|
+
sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(
|
|
120
|
+
sparse_ids, default_id or 0)
|
|
121
|
+
if sparse_weights is not None:
|
|
122
|
+
sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
|
|
123
|
+
|
|
124
|
+
indices = sparse_ids.indices
|
|
125
|
+
values = sparse_ids.values
|
|
126
|
+
if values.dtype != dtypes.int64:
|
|
127
|
+
values = math_ops.to_int64(values)
|
|
128
|
+
sparse_ids = sparse_tensor.SparseTensor(
|
|
129
|
+
indices=indices, values=values, dense_shape=sparse_ids.dense_shape)
|
|
130
|
+
|
|
131
|
+
result = embedding_ops.embedding_lookup_sparse(
|
|
132
|
+
embedding_weights,
|
|
133
|
+
sparse_ids,
|
|
134
|
+
sparse_weights,
|
|
135
|
+
combiner=combiner,
|
|
136
|
+
partition_strategy=partition_strategy,
|
|
137
|
+
name=None if default_id is None else scope,
|
|
138
|
+
max_norm=max_norm)
|
|
139
|
+
|
|
140
|
+
if default_id is None:
|
|
141
|
+
# Broadcast is_row_empty to the same shape as embedding_lookup_result,
|
|
142
|
+
# for use in Select.
|
|
143
|
+
is_row_empty = array_ops.tile(
|
|
144
|
+
array_ops.reshape(is_row_empty, [-1, 1]),
|
|
145
|
+
array_ops.stack([1, array_ops.shape(result)[1]]))
|
|
146
|
+
|
|
147
|
+
result = array_ops.where(
|
|
148
|
+
is_row_empty, array_ops.zeros_like(result), result, name=scope)
|
|
149
|
+
|
|
150
|
+
# Reshape back from linear ids back into higher-dimensional dense result.
|
|
151
|
+
final_result = array_ops.reshape(
|
|
152
|
+
result,
|
|
153
|
+
array_ops.concat([
|
|
154
|
+
array_ops.slice(
|
|
155
|
+
math_ops.cast(original_shape, dtypes.int32), [0],
|
|
156
|
+
[original_rank - 1]),
|
|
157
|
+
array_ops.slice(array_ops.shape(result), [1], [-1])
|
|
158
|
+
], 0))
|
|
159
|
+
final_result.set_shape(
|
|
160
|
+
tensor_shape.unknown_shape(
|
|
161
|
+
(original_rank_dim - 1).value).concatenate(result.get_shape()[1:]))
|
|
162
|
+
return final_result
|
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from tensorflow.core.protobuf import saver_pb2
|
|
8
|
+
from tensorflow.python.framework import dtypes
|
|
9
|
+
from tensorflow.python.framework import ops
|
|
10
|
+
# from tensorflow.python.ops import math_ops
|
|
11
|
+
# from tensorflow.python.ops import logging_ops
|
|
12
|
+
from tensorflow.python.ops import array_ops
|
|
13
|
+
from tensorflow.python.ops import control_flow_ops
|
|
14
|
+
from tensorflow.python.ops import script_ops
|
|
15
|
+
from tensorflow.python.ops import state_ops
|
|
16
|
+
from tensorflow.python.platform import gfile
|
|
17
|
+
from tensorflow.python.training import saver
|
|
18
|
+
|
|
19
|
+
from easy_rec.python.utils import constant
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import horovod.tensorflow as hvd
|
|
23
|
+
from sparse_operation_kit.experiment import raw_ops as dynamic_variable_ops
|
|
24
|
+
from easy_rec.python.compat import dynamic_variable
|
|
25
|
+
except Exception:
|
|
26
|
+
dynamic_variable_ops = None
|
|
27
|
+
dynamic_variable = None
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
from tensorflow.python.framework.load_library import load_op_library
|
|
31
|
+
import easy_rec
|
|
32
|
+
load_embed_lib_path = os.path.join(easy_rec.ops_dir, 'libload_embed.so')
|
|
33
|
+
load_embed_lib = load_op_library(load_embed_lib_path)
|
|
34
|
+
except Exception as ex:
|
|
35
|
+
logging.warning('load libload_embed.so failed: %s' % str(ex))
|
|
36
|
+
load_embed_lib = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _get_embed_part_id(embed_file):
|
|
40
|
+
embed_file = embed_file.split('/')[-1]
|
|
41
|
+
embed_file = embed_file.split('.')[0]
|
|
42
|
+
embed_id = embed_file.split('-')[-1]
|
|
43
|
+
return int(embed_id)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class EmbeddingParallelSaver(saver.Saver):
|
|
47
|
+
|
|
48
|
+
def __init__(self,
|
|
49
|
+
var_list=None,
|
|
50
|
+
reshape=False,
|
|
51
|
+
sharded=False,
|
|
52
|
+
max_to_keep=5,
|
|
53
|
+
keep_checkpoint_every_n_hours=10000.0,
|
|
54
|
+
name=None,
|
|
55
|
+
restore_sequentially=False,
|
|
56
|
+
saver_def=None,
|
|
57
|
+
builder=None,
|
|
58
|
+
defer_build=False,
|
|
59
|
+
allow_empty=False,
|
|
60
|
+
write_version=saver_pb2.SaverDef.V2,
|
|
61
|
+
pad_step_number=False,
|
|
62
|
+
save_relative_paths=False,
|
|
63
|
+
filename=None):
|
|
64
|
+
self._kv_vars = []
|
|
65
|
+
self._embed_vars = []
|
|
66
|
+
tf_vars = []
|
|
67
|
+
embed_para_vars = ops.get_collection(constant.EmbeddingParallel)
|
|
68
|
+
for var in var_list:
|
|
69
|
+
if dynamic_variable is not None and isinstance(
|
|
70
|
+
var, dynamic_variable.DynamicVariable):
|
|
71
|
+
self._kv_vars.append(var)
|
|
72
|
+
elif var.name in embed_para_vars:
|
|
73
|
+
logging.info('save shard embedding %s part_id=%d part_shape=%s' %
|
|
74
|
+
(var.name, hvd.rank(), var.get_shape()))
|
|
75
|
+
self._embed_vars.append(var)
|
|
76
|
+
else:
|
|
77
|
+
tf_vars.append(var)
|
|
78
|
+
super(EmbeddingParallelSaver, self).__init__(
|
|
79
|
+
tf_vars,
|
|
80
|
+
reshape=reshape,
|
|
81
|
+
sharded=sharded,
|
|
82
|
+
max_to_keep=max_to_keep,
|
|
83
|
+
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
|
|
84
|
+
name=name,
|
|
85
|
+
restore_sequentially=restore_sequentially,
|
|
86
|
+
saver_def=saver_def,
|
|
87
|
+
builder=builder,
|
|
88
|
+
defer_build=defer_build,
|
|
89
|
+
allow_empty=allow_empty,
|
|
90
|
+
write_version=write_version,
|
|
91
|
+
pad_step_number=pad_step_number,
|
|
92
|
+
save_relative_paths=save_relative_paths,
|
|
93
|
+
filename=filename)
|
|
94
|
+
self._is_build = False
|
|
95
|
+
|
|
96
|
+
def _has_embed_vars(self):
|
|
97
|
+
return (len(self._kv_vars) + len(self._embed_vars)) > 0
|
|
98
|
+
|
|
99
|
+
def _save_dense_embedding(self, embed_var):
|
|
100
|
+
logging.info('task[%d] save_dense_embed: %s' % (hvd.rank(), embed_var.name))
|
|
101
|
+
|
|
102
|
+
def _save_embed(embed, filename, var_name):
|
|
103
|
+
task_id = hvd.rank()
|
|
104
|
+
filename = filename.decode('utf-8')
|
|
105
|
+
var_name = var_name.decode('utf-8').replace('/', '__')
|
|
106
|
+
embed_dir = filename + '-embedding/'
|
|
107
|
+
logging.info('task[%d] save_dense_embed: %s to %s' %
|
|
108
|
+
(task_id, var_name, embed_dir))
|
|
109
|
+
if not gfile.Exists(embed_dir):
|
|
110
|
+
gfile.MakeDirs(embed_dir)
|
|
111
|
+
embed_file = filename + '-embedding/embed-' + var_name + '-part-%d.bin' % task_id
|
|
112
|
+
with gfile.GFile(embed_file, 'wb') as fout:
|
|
113
|
+
fout.write(embed.tobytes())
|
|
114
|
+
|
|
115
|
+
if task_id == 0:
|
|
116
|
+
# clear old embedding tables
|
|
117
|
+
embed_pattern = filename + '-embedding/embed-' + var_name + '-part-*.bin'
|
|
118
|
+
embed_files = gfile.Glob(embed_pattern)
|
|
119
|
+
for embed_file in embed_files:
|
|
120
|
+
embed_id = _get_embed_part_id(embed_file)
|
|
121
|
+
if embed_id >= hvd.size():
|
|
122
|
+
gfile.DeleteRecursively(embed_file)
|
|
123
|
+
return np.asarray([embed_file], order='C', dtype=np.object)
|
|
124
|
+
|
|
125
|
+
file_name = ops.get_default_graph().get_tensor_by_name(
|
|
126
|
+
self.saver_def.filename_tensor_name)
|
|
127
|
+
save_paths = script_ops.py_func(_save_embed,
|
|
128
|
+
[embed_var, file_name, embed_var.name],
|
|
129
|
+
dtypes.string)
|
|
130
|
+
return save_paths
|
|
131
|
+
|
|
132
|
+
def _load_dense_embedding(self, embed_var):
|
|
133
|
+
file_name = ops.get_default_graph().get_tensor_by_name(
|
|
134
|
+
self.saver_def.filename_tensor_name)
|
|
135
|
+
embed_dim = embed_var.get_shape()[-1]
|
|
136
|
+
embed_part_size = embed_var.get_shape()[0]
|
|
137
|
+
|
|
138
|
+
def _load_embed(embed, embed_dim, embed_part_size, part_id, part_num,
|
|
139
|
+
filename, var_name):
|
|
140
|
+
filename = filename.decode('utf-8')
|
|
141
|
+
var_name = var_name.decode('utf-8').replace('/', '__')
|
|
142
|
+
embed_pattern = filename + '-embedding/embed-' + var_name + '-part-*.bin'
|
|
143
|
+
embed_files = gfile.Glob(embed_pattern)
|
|
144
|
+
|
|
145
|
+
embed_files.sort(key=_get_embed_part_id)
|
|
146
|
+
|
|
147
|
+
logging.info('task[%d] embed_files=%s embed_dim=%d embed_part_size=%d' %
|
|
148
|
+
(part_id, ','.join(embed_files), embed_dim, embed_part_size))
|
|
149
|
+
|
|
150
|
+
part_embed_vals = np.zeros([embed_part_size, embed_dim], dtype=np.float32)
|
|
151
|
+
part_update_cnt = 0
|
|
152
|
+
for embed_file in embed_files:
|
|
153
|
+
part_id_o = _get_embed_part_id(embed_file)
|
|
154
|
+
with gfile.GFile(embed_file, 'rb') as fin:
|
|
155
|
+
embed_val = np.frombuffer(fin.read(), np.float32)
|
|
156
|
+
embed_val = embed_val.reshape([-1, embed_dim])
|
|
157
|
+
embed_ids_o = np.arange(len(embed_val))
|
|
158
|
+
embed_ids_o = part_id_o + embed_ids_o * len(embed_files)
|
|
159
|
+
sel_ids = np.where(
|
|
160
|
+
np.logical_and((embed_ids_o % part_num) == part_id,
|
|
161
|
+
embed_ids_o < embed_part_size * part_num))[0]
|
|
162
|
+
part_update_cnt += len(sel_ids)
|
|
163
|
+
embed_ids = embed_ids_o[sel_ids]
|
|
164
|
+
embed_ids_n = np.array(embed_ids / part_num, dtype=np.int64)
|
|
165
|
+
part_embed_vals[embed_ids_n] = embed_val[sel_ids]
|
|
166
|
+
logging.info('task[%d] load_part_cnt=%d' % (part_id, part_update_cnt))
|
|
167
|
+
return part_embed_vals
|
|
168
|
+
|
|
169
|
+
with ops.control_dependencies([embed_var._initializer_op]):
|
|
170
|
+
if load_embed_lib is not None:
|
|
171
|
+
embed_val = load_embed_lib.load_embed(
|
|
172
|
+
task_index=hvd.rank(),
|
|
173
|
+
task_num=hvd.size(),
|
|
174
|
+
embed_dim=embed_dim,
|
|
175
|
+
embed_part_size=embed_part_size,
|
|
176
|
+
var_name='embed-' + embed_var.name.replace('/', '__'),
|
|
177
|
+
ckpt_path=file_name)
|
|
178
|
+
else:
|
|
179
|
+
embed_val = script_ops.py_func(_load_embed, [
|
|
180
|
+
embed_var, embed_dim, embed_part_size,
|
|
181
|
+
hvd.rank(),
|
|
182
|
+
hvd.size(), file_name, embed_var.name
|
|
183
|
+
], dtypes.float32)
|
|
184
|
+
embed_val.set_shape(embed_var.get_shape())
|
|
185
|
+
return state_ops.assign(embed_var, embed_val)
|
|
186
|
+
|
|
187
|
+
def _save_kv_embedding(self, sok_var):
|
|
188
|
+
indices, values = dynamic_variable_ops.dummy_var_export(
|
|
189
|
+
sok_var.handle, key_type=sok_var.key_type, dtype=sok_var.handle_dtype)
|
|
190
|
+
file_name = ops.get_default_graph().get_tensor_by_name(
|
|
191
|
+
self.saver_def.filename_tensor_name)
|
|
192
|
+
|
|
193
|
+
def _save_key_vals(indices, values, filename, var_name):
|
|
194
|
+
var_name = var_name.decode('utf-8').replace('/', '__')
|
|
195
|
+
filename = filename.decode('utf-8')
|
|
196
|
+
sok_dir = filename + '-embedding/'
|
|
197
|
+
if not gfile.Exists(sok_dir):
|
|
198
|
+
gfile.MakeDirs(sok_dir)
|
|
199
|
+
task_id = hvd.rank()
|
|
200
|
+
key_file = filename + '-embedding/embed-' + var_name + '-part-%d.key' % task_id
|
|
201
|
+
with gfile.GFile(key_file, 'wb') as fout:
|
|
202
|
+
fout.write(indices.tobytes())
|
|
203
|
+
val_file = filename + '-embedding/embed-' + var_name + '-part-%d.val' % task_id
|
|
204
|
+
with gfile.GFile(val_file, 'wb') as fout:
|
|
205
|
+
fout.write(values.tobytes())
|
|
206
|
+
|
|
207
|
+
if task_id == 0:
|
|
208
|
+
key_file_pattern = filename + '-embedding/embed-' + var_name + '-part-*.key'
|
|
209
|
+
key_files = gfile.Glob(key_file_pattern)
|
|
210
|
+
for key_file in key_files:
|
|
211
|
+
embed_id = _get_embed_part_id(key_file)
|
|
212
|
+
if embed_id >= hvd.size():
|
|
213
|
+
gfile.DeleteRecursively(key_file)
|
|
214
|
+
val_file = key_file[:-4] + '.val'
|
|
215
|
+
if gfile.Exists(val_file):
|
|
216
|
+
gfile.DeleteRecursively(val_file)
|
|
217
|
+
|
|
218
|
+
return np.asarray([key_file, val_file], order='C', dtype=np.object)
|
|
219
|
+
|
|
220
|
+
save_paths = script_ops.py_func(_save_key_vals,
|
|
221
|
+
[indices, values, file_name, sok_var.name],
|
|
222
|
+
dtypes.string)
|
|
223
|
+
return save_paths
|
|
224
|
+
|
|
225
|
+
def _load_kv_embedding(self, sok_var):
|
|
226
|
+
|
|
227
|
+
def _load_key_vals(filename, var_name):
|
|
228
|
+
var_name = var_name.decode('utf-8').replace('/', '__')
|
|
229
|
+
filename = filename.decode('utf-8')
|
|
230
|
+
key_file_pattern = filename + '-embedding/embed-' + var_name + '-part-*.key'
|
|
231
|
+
logging.info('key_file_pattern=%s filename=%s var_name=%s var=%s' %
|
|
232
|
+
(key_file_pattern, filename, var_name, str(sok_var)))
|
|
233
|
+
key_files = gfile.Glob(key_file_pattern)
|
|
234
|
+
logging.info('key_file_pattern=%s file_num=%d' %
|
|
235
|
+
(key_file_pattern, len(key_files)))
|
|
236
|
+
all_keys = []
|
|
237
|
+
all_vals = []
|
|
238
|
+
for key_file in key_files:
|
|
239
|
+
with gfile.GFile(key_file, 'rb') as fin:
|
|
240
|
+
tmp_keys = np.frombuffer(fin.read(), dtype=np.int64)
|
|
241
|
+
tmp_ids = tmp_keys % hvd.size()
|
|
242
|
+
tmp_ids = np.where(tmp_ids == hvd.rank())[0]
|
|
243
|
+
if len(tmp_ids) == 0:
|
|
244
|
+
break
|
|
245
|
+
all_keys.append(tmp_keys.take(tmp_ids, axis=0))
|
|
246
|
+
logging.info('part_keys.shape=%s %s %s' % (str(
|
|
247
|
+
tmp_keys.shape), str(tmp_ids.shape), str(all_keys[-1].shape)))
|
|
248
|
+
|
|
249
|
+
val_file = key_file[:-4] + 'vals'
|
|
250
|
+
with gfile.GFile(val_file, 'rb') as fin:
|
|
251
|
+
tmp_vals = np.frombuffer(
|
|
252
|
+
fin.read(), dtype=np.float32).reshape([-1, sok_var._dimension])
|
|
253
|
+
all_vals.append(tmp_vals.take(tmp_ids, axis=0))
|
|
254
|
+
logging.info('part_vals.shape=%s %s %s' % (str(
|
|
255
|
+
tmp_vals.shape), str(tmp_ids.shape), str(all_vals[-1].shape)))
|
|
256
|
+
|
|
257
|
+
all_keys = np.concatenate(all_keys, axis=0)
|
|
258
|
+
all_vals = np.concatenate(all_vals, axis=0)
|
|
259
|
+
|
|
260
|
+
shuffle_ids = np.array(range(len(all_keys)))
|
|
261
|
+
np.random.shuffle(shuffle_ids)
|
|
262
|
+
all_keys = all_keys.take(shuffle_ids, axis=0)
|
|
263
|
+
all_vals = all_vals.take(shuffle_ids, axis=0)
|
|
264
|
+
return all_keys, all_vals
|
|
265
|
+
|
|
266
|
+
file_name = ops.get_default_graph().get_tensor_by_name(
|
|
267
|
+
self.saver_def.filename_tensor_name)
|
|
268
|
+
if load_embed_lib is not None:
|
|
269
|
+
keys, vals = load_embed_lib.load_kv_embed(
|
|
270
|
+
task_index=hvd.rank(),
|
|
271
|
+
task_num=hvd.size(),
|
|
272
|
+
embed_dim=sok_var._dimension,
|
|
273
|
+
var_name='embed-' + sok_var.name.replace('/', '__'),
|
|
274
|
+
ckpt_path=file_name)
|
|
275
|
+
else:
|
|
276
|
+
logging.warning('libload_embed.so not loaded, will use python script_ops')
|
|
277
|
+
keys, vals = script_ops.py_func(_load_key_vals, [file_name, sok_var.name],
|
|
278
|
+
(dtypes.int64, dtypes.float32))
|
|
279
|
+
with ops.control_dependencies([sok_var._initializer_op]):
|
|
280
|
+
return dynamic_variable_ops.dummy_var_assign(sok_var.handle, keys, vals)
|
|
281
|
+
|
|
282
|
+
def build(self):
|
|
283
|
+
if self._is_built:
|
|
284
|
+
return
|
|
285
|
+
super(EmbeddingParallelSaver, self).build()
|
|
286
|
+
if self.saver_def.restore_op_name and self._has_embed_vars():
|
|
287
|
+
# load data from the model
|
|
288
|
+
restore_ops = []
|
|
289
|
+
for sok_var in self._kv_vars:
|
|
290
|
+
restore_ops.append(self._load_kv_embedding(sok_var))
|
|
291
|
+
for embed_var in self._embed_vars:
|
|
292
|
+
restore_ops.append(self._load_dense_embedding(embed_var))
|
|
293
|
+
old_restore_op = ops.get_default_graph().get_operation_by_name(
|
|
294
|
+
self.saver_def.restore_op_name)
|
|
295
|
+
restore_ops.append(old_restore_op)
|
|
296
|
+
restore_op_n = control_flow_ops.group(restore_ops)
|
|
297
|
+
self.saver_def.restore_op_name = restore_op_n.name
|
|
298
|
+
|
|
299
|
+
if self.saver_def.save_tensor_name and self._has_embed_vars():
|
|
300
|
+
file_name = ops.get_default_graph().get_tensor_by_name(
|
|
301
|
+
self.saver_def.filename_tensor_name)
|
|
302
|
+
save_part_ops = []
|
|
303
|
+
for sok_var in self._kv_vars:
|
|
304
|
+
save_part_op = self._save_kv_embedding(sok_var)
|
|
305
|
+
save_part_ops.append(save_part_op)
|
|
306
|
+
for embed_var in self._embed_vars:
|
|
307
|
+
save_part_op = self._save_dense_embedding(embed_var)
|
|
308
|
+
save_part_ops.append(save_part_op)
|
|
309
|
+
old_save_op = ops.get_default_graph().get_tensor_by_name(
|
|
310
|
+
self.saver_def.save_tensor_name)
|
|
311
|
+
# only the first worker needs to save non embedding variables
|
|
312
|
+
if hvd.rank() == 0:
|
|
313
|
+
save_part_ops.append(old_save_op)
|
|
314
|
+
with ops.control_dependencies(save_part_ops):
|
|
315
|
+
save_op_n = array_ops.identity(file_name)
|
|
316
|
+
self.saver_def.save_tensor_name = save_op_n.name
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
from tensorflow.python.estimator import run_config as run_config_lib
|
|
8
|
+
from tensorflow.python.util import compat
|
|
9
|
+
from tensorflow_estimator.python.estimator.training import _assert_eval_spec
|
|
10
|
+
from tensorflow_estimator.python.estimator.training import _TrainingExecutor
|
|
11
|
+
|
|
12
|
+
from easy_rec.python.compat.exporter import FinalExporter
|
|
13
|
+
from easy_rec.python.utils import estimator_utils
|
|
14
|
+
|
|
15
|
+
from tensorflow_estimator.python.estimator.training import _ContinuousEvalListener # NOQA
|
|
16
|
+
|
|
17
|
+
from tensorflow.python.distribute import estimator_training as distribute_coordinator_training # NOQA
|
|
18
|
+
|
|
19
|
+
if tf.__version__ >= '2.0':
|
|
20
|
+
tf = tf.compat.v1
|
|
21
|
+
gfile = tf.gfile
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TrainDoneListener(_ContinuousEvalListener):
|
|
25
|
+
"""Interface for listeners that take action before or after evaluation."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, estimator):
|
|
28
|
+
self._model_dir = estimator.model_dir
|
|
29
|
+
self._train_done_file = os.path.join(self._model_dir,
|
|
30
|
+
'ESTIMATOR_TRAIN_DONE')
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def train_done_file(self):
|
|
34
|
+
return self._train_done_file
|
|
35
|
+
|
|
36
|
+
def after_eval(self, eval_result):
|
|
37
|
+
"""Called after the evaluation is executed.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
eval_result: An `_EvalResult` instance.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
False if you want to early stop continuous evaluation; `True` otherwise.
|
|
44
|
+
"""
|
|
45
|
+
last_ckpt_path = eval_result.checkpoint_path
|
|
46
|
+
if last_ckpt_path is not None:
|
|
47
|
+
model_dir = os.path.dirname(last_ckpt_path).rstrip('/') + '/'
|
|
48
|
+
latest_ckpt_path = estimator_utils.latest_checkpoint(model_dir)
|
|
49
|
+
if latest_ckpt_path != last_ckpt_path:
|
|
50
|
+
logging.info(
|
|
51
|
+
'TrainDoneListener: latest_ckpt_path[%s] != last_ckpt_path[%s]' %
|
|
52
|
+
(latest_ckpt_path, last_ckpt_path))
|
|
53
|
+
# there are more checkpoints wait to be evaluated
|
|
54
|
+
return True
|
|
55
|
+
return not gfile.Exists(self._train_done_file)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def train_and_evaluate(estimator, train_spec, eval_spec):
|
|
59
|
+
_assert_eval_spec(eval_spec) # fail fast if eval_spec is invalid.
|
|
60
|
+
|
|
61
|
+
train_done_listener = TrainDoneListener(estimator)
|
|
62
|
+
executor = _TrainingExecutor(
|
|
63
|
+
estimator=estimator,
|
|
64
|
+
train_spec=train_spec,
|
|
65
|
+
eval_spec=eval_spec,
|
|
66
|
+
continuous_eval_listener=train_done_listener)
|
|
67
|
+
config = estimator.config
|
|
68
|
+
|
|
69
|
+
# If `distribute_coordinator_mode` is set and running in distributed
|
|
70
|
+
# environment, we run `train_and_evaluate` via distribute coordinator.
|
|
71
|
+
if distribute_coordinator_training.should_run_distribute_coordinator(config):
|
|
72
|
+
logging.info('Running `train_and_evaluate` with Distribute Coordinator.')
|
|
73
|
+
distribute_coordinator_training.train_and_evaluate(estimator, train_spec,
|
|
74
|
+
eval_spec,
|
|
75
|
+
_TrainingExecutor)
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
if (config.task_type == run_config_lib.TaskType.EVALUATOR and
|
|
79
|
+
config.task_id > 0):
|
|
80
|
+
raise ValueError(
|
|
81
|
+
'For distributed training, there can only be one `evaluator` task '
|
|
82
|
+
'(with task id 0). Given task id {}'.format(config.task_id))
|
|
83
|
+
|
|
84
|
+
result = executor.run()
|
|
85
|
+
|
|
86
|
+
# fix for the bug evaluator fails to export in case num_epoch is reached
|
|
87
|
+
# before num_steps is reached or num_steps is set to infinite
|
|
88
|
+
if estimator_utils.is_evaluator():
|
|
89
|
+
export_dir_base = os.path.join(
|
|
90
|
+
compat.as_str_any(estimator.model_dir), compat.as_str_any('export'))
|
|
91
|
+
for exporter in eval_spec.exporters:
|
|
92
|
+
if isinstance(exporter, FinalExporter):
|
|
93
|
+
export_path = os.path.join(
|
|
94
|
+
compat.as_str_any(export_dir_base),
|
|
95
|
+
compat.as_str_any(exporter.name))
|
|
96
|
+
# avoid duplicate export
|
|
97
|
+
if gfile.IsDirectory(export_path + '/'):
|
|
98
|
+
continue
|
|
99
|
+
exporter.export(
|
|
100
|
+
estimator=estimator,
|
|
101
|
+
export_path=export_path,
|
|
102
|
+
checkpoint_path=estimator_utils.latest_checkpoint(
|
|
103
|
+
estimator.model_dir),
|
|
104
|
+
eval_result=None,
|
|
105
|
+
is_the_final_export=True)
|
|
106
|
+
|
|
107
|
+
if estimator_utils.is_chief():
|
|
108
|
+
with gfile.GFile(train_done_listener.train_done_file, 'w') as fout:
|
|
109
|
+
fout.write('Train Done.')
|
|
110
|
+
|
|
111
|
+
return result
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def estimator_train_done(estimator):
|
|
115
|
+
train_done_file = os.path.join(estimator.model_dir, 'ESTIMATOR_TRAIN_DONE')
|
|
116
|
+
return gfile.Exists(train_done_file)
|