easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,941 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
from google.protobuf import text_format
|
|
8
|
+
from tensorflow.python.framework import ops
|
|
9
|
+
from tensorflow.python.platform.gfile import GFile
|
|
10
|
+
# from tensorflow.python.saved_model import constants
|
|
11
|
+
from tensorflow.python.saved_model import signature_constants
|
|
12
|
+
from tensorflow.python.saved_model.loader_impl import SavedModelLoader
|
|
13
|
+
|
|
14
|
+
from easy_rec.python.utils import conditional
|
|
15
|
+
from easy_rec.python.utils import constant
|
|
16
|
+
from easy_rec.python.utils import embedding_utils
|
|
17
|
+
from easy_rec.python.utils import proto_util
|
|
18
|
+
|
|
19
|
+
EMBEDDING_INITIALIZERS = 'embedding_initializers'
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MetaGraphEditor:
|
|
23
|
+
|
|
24
|
+
def __init__(self,
|
|
25
|
+
lookup_lib_path,
|
|
26
|
+
saved_model_dir,
|
|
27
|
+
redis_url=None,
|
|
28
|
+
redis_passwd=None,
|
|
29
|
+
redis_timeout=0,
|
|
30
|
+
redis_cache_names=[],
|
|
31
|
+
oss_path=None,
|
|
32
|
+
oss_endpoint=None,
|
|
33
|
+
oss_ak=None,
|
|
34
|
+
oss_sk=None,
|
|
35
|
+
oss_timeout=0,
|
|
36
|
+
meta_graph_def=None,
|
|
37
|
+
norm_name_to_ids=None,
|
|
38
|
+
incr_update_params=None,
|
|
39
|
+
debug_dir=''):
|
|
40
|
+
self._lookup_op = tf.load_op_library(lookup_lib_path)
|
|
41
|
+
self._debug_dir = debug_dir
|
|
42
|
+
self._verbose = debug_dir != ''
|
|
43
|
+
if saved_model_dir:
|
|
44
|
+
tags = ['serve']
|
|
45
|
+
loader = SavedModelLoader(saved_model_dir)
|
|
46
|
+
saver, _ = loader.load_graph(tf.get_default_graph(), tags, None)
|
|
47
|
+
meta_graph_def = loader.get_meta_graph_def_from_tags(tags)
|
|
48
|
+
else:
|
|
49
|
+
assert meta_graph_def, 'either saved_model_dir or meta_graph_def must be set'
|
|
50
|
+
tf.reset_default_graph()
|
|
51
|
+
from tensorflow.python.framework import meta_graph
|
|
52
|
+
meta_graph.import_scoped_meta_graph_with_return_elements(
|
|
53
|
+
meta_graph_def, clear_devices=True)
|
|
54
|
+
# tf.train.import_meta_graph(meta_graph_def)
|
|
55
|
+
self._meta_graph_version = meta_graph_def.meta_info_def.meta_graph_version
|
|
56
|
+
self._signature_def = meta_graph_def.signature_def[
|
|
57
|
+
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
|
58
|
+
|
|
59
|
+
if self._verbose:
|
|
60
|
+
debug_out_path = os.path.join(self._debug_dir, 'meta_graph_raw.txt')
|
|
61
|
+
with GFile(debug_out_path, 'w') as fout:
|
|
62
|
+
fout.write(text_format.MessageToString(meta_graph_def, as_utf8=True))
|
|
63
|
+
self._meta_graph_def = meta_graph_def
|
|
64
|
+
self._old_node_num = len(self._meta_graph_def.graph_def.node)
|
|
65
|
+
self._all_graph_nodes = None
|
|
66
|
+
self._all_graph_node_flags = None
|
|
67
|
+
self._restore_tensor_node = None
|
|
68
|
+
self._restore_shard_node = None
|
|
69
|
+
self._restore_all_node = []
|
|
70
|
+
self._lookup_outs = None
|
|
71
|
+
self._feature_names = None
|
|
72
|
+
self._embed_names = None
|
|
73
|
+
self._embed_name_to_ids = norm_name_to_ids
|
|
74
|
+
self._is_cache_from_redis = []
|
|
75
|
+
self._redis_cache_names = redis_cache_names
|
|
76
|
+
self._embed_ids = None
|
|
77
|
+
self._embed_dims = None
|
|
78
|
+
self._embed_sizes = None
|
|
79
|
+
self._embed_combiners = None
|
|
80
|
+
self._redis_url = redis_url
|
|
81
|
+
self._redis_passwd = redis_passwd
|
|
82
|
+
self._redis_timeout = redis_timeout
|
|
83
|
+
self._oss_path = oss_path
|
|
84
|
+
self._oss_endpoint = oss_endpoint
|
|
85
|
+
self._oss_ak = oss_ak
|
|
86
|
+
self._oss_sk = oss_sk
|
|
87
|
+
self._oss_timeout = oss_timeout
|
|
88
|
+
|
|
89
|
+
self._incr_update_params = incr_update_params
|
|
90
|
+
|
|
91
|
+
# increment update placeholders
|
|
92
|
+
self._embedding_update_inputs = {}
|
|
93
|
+
self._embedding_update_outputs = {}
|
|
94
|
+
|
|
95
|
+
self._dense_update_inputs = {}
|
|
96
|
+
self._dense_update_outputs = {}
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def sparse_update_inputs(self):
|
|
100
|
+
return self._embedding_update_inputs
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def sparse_update_outputs(self):
|
|
104
|
+
return self._embedding_update_outputs
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def dense_update_inputs(self):
|
|
108
|
+
return self._dense_update_inputs
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def dense_update_outputs(self):
|
|
112
|
+
return self._dense_update_outputs
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def graph_def(self):
|
|
116
|
+
return self._meta_graph_def.graph_def
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def signature_def(self):
|
|
120
|
+
return self._signature_def
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def meta_graph_version(self):
|
|
124
|
+
return self._meta_graph_version
|
|
125
|
+
|
|
126
|
+
def init_graph_node_clear_flags(self):
|
|
127
|
+
graph_def = self._meta_graph_def.graph_def
|
|
128
|
+
self._all_graph_nodes = [n for n in graph_def.node]
|
|
129
|
+
self._all_graph_node_flags = [True for n in graph_def.node]
|
|
130
|
+
|
|
131
|
+
def _get_share_embed_name(self, x, embed_names):
|
|
132
|
+
"""Map share embedding tensor names to embed names.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
x: string, embedding tensor names, such as:
|
|
136
|
+
input_layer_1/shared_embed_1/field16_shared_embedding
|
|
137
|
+
input_layer_1/shared_embed_2/field17_shared_embedding
|
|
138
|
+
input_layer/shared_embed_wide/field15_shared_embedding
|
|
139
|
+
input_layer/shared_embed_wide_1/field16_shared_embedding
|
|
140
|
+
embed_names: all the optional embedding_names
|
|
141
|
+
Return:
|
|
142
|
+
one element in embed_names, such as:
|
|
143
|
+
input_layer_1/shared_embed
|
|
144
|
+
input_layer_1/shared_embed
|
|
145
|
+
input_layer/shared_embed_wide
|
|
146
|
+
input_layer/shared_embed_wide
|
|
147
|
+
"""
|
|
148
|
+
assert x.endswith('_shared_embedding')
|
|
149
|
+
name_toks = x.split('/')
|
|
150
|
+
name_toks = name_toks[:-1]
|
|
151
|
+
tmp = name_toks[-1]
|
|
152
|
+
tmp = tmp.split('_')
|
|
153
|
+
try:
|
|
154
|
+
int(tmp[-1])
|
|
155
|
+
name_toks[-1] = '_'.join(tmp[:-1])
|
|
156
|
+
except Exception:
|
|
157
|
+
pass
|
|
158
|
+
tmp_name = '/'.join(name_toks[1:])
|
|
159
|
+
sel_embed_name = ''
|
|
160
|
+
for embed_name in embed_names:
|
|
161
|
+
tmp_toks = embed_name.split('/')
|
|
162
|
+
tmp_toks = tmp_toks[1:]
|
|
163
|
+
embed_name_sub = '/'.join(tmp_toks)
|
|
164
|
+
if tmp_name == embed_name_sub:
|
|
165
|
+
assert not sel_embed_name, 'confusions encountered: %s %s' % (
|
|
166
|
+
x, ','.join(embed_names))
|
|
167
|
+
sel_embed_name = embed_name
|
|
168
|
+
assert sel_embed_name, '%s not find in shared_embeddings: %s' % (
|
|
169
|
+
tmp_name, ','.join(embed_names))
|
|
170
|
+
return sel_embed_name
|
|
171
|
+
|
|
172
|
+
def _find_embed_combiners(self, norm_embed_names):
|
|
173
|
+
"""Find embedding lookup combiner methods.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
norm_embed_names: normalized embedding names
|
|
177
|
+
Return:
|
|
178
|
+
list: combiner methods for each features: sum, mean, sqrtn
|
|
179
|
+
"""
|
|
180
|
+
embed_combiners = {}
|
|
181
|
+
embed_combine_node_cts = {}
|
|
182
|
+
combiner_map = {
|
|
183
|
+
'SparseSegmentSum': 'sum',
|
|
184
|
+
'SparseSegmentMean': 'mean',
|
|
185
|
+
'SparseSegmentSqrtN': 'sqrtn'
|
|
186
|
+
}
|
|
187
|
+
for node in self._meta_graph_def.graph_def.node:
|
|
188
|
+
if node.op in combiner_map:
|
|
189
|
+
norm_name, _ = proto_util.get_norm_embed_name(node.name)
|
|
190
|
+
embed_combiners[norm_name] = combiner_map[node.op]
|
|
191
|
+
embed_combine_node_cts[norm_name] = embed_combine_node_cts.get(
|
|
192
|
+
norm_name, 0) + 1
|
|
193
|
+
elif node.op == 'RealDiv' and len(node.input) == 2:
|
|
194
|
+
# for tag feature with weights, and combiner == mean
|
|
195
|
+
if 'SegmentSum' in node.input[0] and 'SegmentSum' in node.input[1]:
|
|
196
|
+
norm_name, _ = proto_util.get_norm_embed_name(node.name)
|
|
197
|
+
embed_combiners[norm_name] = 'mean'
|
|
198
|
+
embed_combine_node_cts[norm_name] = embed_combine_node_cts.get(
|
|
199
|
+
norm_name, 0) + 1
|
|
200
|
+
elif node.op == 'SegmentSum':
|
|
201
|
+
norm_name, _ = proto_util.get_norm_embed_name(node.name)
|
|
202
|
+
# avoid overwrite RealDiv results
|
|
203
|
+
if norm_name not in embed_combiners:
|
|
204
|
+
embed_combiners[norm_name] = 'sum'
|
|
205
|
+
embed_combine_node_cts[norm_name] = embed_combine_node_cts.get(
|
|
206
|
+
norm_name, 0) + 1
|
|
207
|
+
return [embed_combiners[x] for x in norm_embed_names]
|
|
208
|
+
|
|
209
|
+
def _find_lookup_indices_values_shapes(self):
|
|
210
|
+
# use the specific _embedding_weights/SparseReshape to find out
|
|
211
|
+
# lookup inputs: indices, values, dense_shape, weights
|
|
212
|
+
indices = {}
|
|
213
|
+
values = {}
|
|
214
|
+
shapes = {}
|
|
215
|
+
|
|
216
|
+
def _get_output_shape(graph_def, input_name):
|
|
217
|
+
out_id = 0
|
|
218
|
+
if ':' in input_name:
|
|
219
|
+
node_name, out_id = input_name.split(':')
|
|
220
|
+
out_id = int(out_id)
|
|
221
|
+
else:
|
|
222
|
+
node_name = input_name
|
|
223
|
+
for node in graph_def.node:
|
|
224
|
+
if node.name == node_name:
|
|
225
|
+
return node.attr['_output_shapes'].list.shape[out_id]
|
|
226
|
+
return None
|
|
227
|
+
|
|
228
|
+
for node in self._meta_graph_def.graph_def.node:
|
|
229
|
+
if '_embedding_weights/SparseReshape' in node.name:
|
|
230
|
+
if node.op == 'SparseReshape':
|
|
231
|
+
# embed_name, _ = proto_util.get_norm_embed_name(node.name, self._verbose)
|
|
232
|
+
fea_name, _ = proto_util.get_norm_embed_name(node.name, self._verbose)
|
|
233
|
+
for tmp_input in node.input:
|
|
234
|
+
tmp_shape = _get_output_shape(self._meta_graph_def.graph_def,
|
|
235
|
+
tmp_input)
|
|
236
|
+
if '_embedding_weights/Cast' in tmp_input:
|
|
237
|
+
continue
|
|
238
|
+
elif len(tmp_shape.dim) == 2:
|
|
239
|
+
indices[fea_name] = tmp_input
|
|
240
|
+
elif len(tmp_shape.dim) == 1:
|
|
241
|
+
shapes[fea_name] = tmp_input
|
|
242
|
+
elif node.op == 'Identity':
|
|
243
|
+
fea_name, _ = proto_util.get_norm_embed_name(node.name, self._verbose)
|
|
244
|
+
values[fea_name] = node.input[0]
|
|
245
|
+
return indices, values, shapes
|
|
246
|
+
|
|
247
|
+
def _find_lookup_weights(self):
|
|
248
|
+
weights = {}
|
|
249
|
+
for node in self._meta_graph_def.graph_def.node:
|
|
250
|
+
if '_weighted_by_' in node.name and 'GatherV2' in node.name:
|
|
251
|
+
has_sparse_reshape = False
|
|
252
|
+
for tmp_input in node.input:
|
|
253
|
+
if 'SparseReshape' in tmp_input:
|
|
254
|
+
has_sparse_reshape = True
|
|
255
|
+
if has_sparse_reshape:
|
|
256
|
+
continue
|
|
257
|
+
if len(node.input) != 3:
|
|
258
|
+
continue
|
|
259
|
+
# try to find nodes with weights
|
|
260
|
+
# input_layer/xxx_weighted_by_yyy_embedding/xxx_weighted_by_yyy_embedding_weights/GatherV2_[0-9]
|
|
261
|
+
# which has three inputs:
|
|
262
|
+
# input_layer/xxx_weighted_by_yyy_embedding/xxx_weighted_by_yyy_embedding_weights/Reshape_1
|
|
263
|
+
# DeserializeSparse_1 (this is the weight)
|
|
264
|
+
# input_layer/xxx_weighted_by_yyy_embedding/xxx_weighted_by_yyy_embedding_weights/GatherV2_4/axis
|
|
265
|
+
fea_name, _ = proto_util.get_norm_embed_name(node.name, self._verbose)
|
|
266
|
+
for tmp_input in node.input:
|
|
267
|
+
if '_weighted_by_' not in tmp_input:
|
|
268
|
+
weights[fea_name] = tmp_input
|
|
269
|
+
return weights
|
|
270
|
+
|
|
271
|
+
def _find_embed_names_and_dims(self, norm_embed_names):
|
|
272
|
+
# get embedding dimensions from Variables
|
|
273
|
+
embed_dims = {}
|
|
274
|
+
embed_sizes = {}
|
|
275
|
+
embed_is_kv = {}
|
|
276
|
+
for node in self._meta_graph_def.graph_def.node:
|
|
277
|
+
if 'embedding_weights' in node.name and node.op in [
|
|
278
|
+
'VariableV2', 'KvVarHandleOp'
|
|
279
|
+
]:
|
|
280
|
+
tmp = node.attr['shape'].shape.dim[-1].size
|
|
281
|
+
tmp2 = 1
|
|
282
|
+
for x in node.attr['shape'].shape.dim[:-1]:
|
|
283
|
+
tmp2 = tmp2 * x.size
|
|
284
|
+
embed_name, _ = proto_util.get_norm_embed_name(node.name, self._verbose)
|
|
285
|
+
assert embed_name is not None,\
|
|
286
|
+
'fail to get_norm_embed_name(%s)' % node.name
|
|
287
|
+
embed_dims[embed_name] = tmp
|
|
288
|
+
embed_sizes[embed_name] = tmp2
|
|
289
|
+
embed_is_kv[embed_name] = 1 if node.op == 'KvVarHandleOp' else 0
|
|
290
|
+
|
|
291
|
+
# get all embedding dimensions, note that some embeddings
|
|
292
|
+
# are shared by multiple inputs, so the names should be
|
|
293
|
+
# transformed
|
|
294
|
+
all_embed_dims = []
|
|
295
|
+
all_embed_names = []
|
|
296
|
+
all_embed_sizes = []
|
|
297
|
+
all_embed_is_kv = []
|
|
298
|
+
for x in norm_embed_names:
|
|
299
|
+
if x in embed_dims:
|
|
300
|
+
all_embed_names.append(x)
|
|
301
|
+
all_embed_dims.append(embed_dims[x])
|
|
302
|
+
all_embed_sizes.append(embed_sizes[x])
|
|
303
|
+
all_embed_is_kv.append(embed_is_kv[x])
|
|
304
|
+
elif x.endswith('_shared_embedding'):
|
|
305
|
+
tmp_embed_name = self._get_share_embed_name(x, embed_dims.keys())
|
|
306
|
+
all_embed_names.append(tmp_embed_name)
|
|
307
|
+
all_embed_dims.append(embed_dims[tmp_embed_name])
|
|
308
|
+
all_embed_sizes.append(embed_sizes[tmp_embed_name])
|
|
309
|
+
all_embed_is_kv.append(embed_is_kv[tmp_embed_name])
|
|
310
|
+
return all_embed_names, all_embed_dims, all_embed_sizes, all_embed_is_kv
|
|
311
|
+
|
|
312
|
+
def find_lookup_inputs(self):
|
|
313
|
+
logging.info('Extract embedding_lookup inputs')
|
|
314
|
+
|
|
315
|
+
indices, values, shapes = self._find_lookup_indices_values_shapes()
|
|
316
|
+
weights = self._find_lookup_weights()
|
|
317
|
+
|
|
318
|
+
for fea in shapes.keys():
|
|
319
|
+
logging.info('Lookup Input[%s]: indices=%s values=%s shapes=%s' %
|
|
320
|
+
(fea, indices[fea], values[fea], shapes[fea]))
|
|
321
|
+
|
|
322
|
+
graph = tf.get_default_graph()
|
|
323
|
+
|
|
324
|
+
def _get_tensor_by_name(tensor_name):
|
|
325
|
+
if ':' not in tensor_name:
|
|
326
|
+
tensor_name = tensor_name + ':0'
|
|
327
|
+
return graph.get_tensor_by_name(tensor_name)
|
|
328
|
+
|
|
329
|
+
lookup_input_values = []
|
|
330
|
+
lookup_input_indices = []
|
|
331
|
+
lookup_input_shapes = []
|
|
332
|
+
lookup_input_weights = []
|
|
333
|
+
for key in values.keys():
|
|
334
|
+
tmp_val, tmp_ind, tmp_shape = values[key], indices[key], shapes[key]
|
|
335
|
+
lookup_input_values.append(_get_tensor_by_name(tmp_val))
|
|
336
|
+
lookup_input_indices.append(_get_tensor_by_name(tmp_ind))
|
|
337
|
+
lookup_input_shapes.append(_get_tensor_by_name(tmp_shape))
|
|
338
|
+
if key in weights:
|
|
339
|
+
tmp_w = weights[key]
|
|
340
|
+
lookup_input_weights.append(_get_tensor_by_name(tmp_w))
|
|
341
|
+
else:
|
|
342
|
+
lookup_input_weights.append([])
|
|
343
|
+
|
|
344
|
+
# get embedding combiners
|
|
345
|
+
self._embed_combiners = self._find_embed_combiners(values.keys())
|
|
346
|
+
|
|
347
|
+
# get embedding dimensions
|
|
348
|
+
self._embed_names, self._embed_dims, self._embed_sizes, self._embed_is_kv\
|
|
349
|
+
= self._find_embed_names_and_dims(values.keys())
|
|
350
|
+
|
|
351
|
+
if not self._embed_name_to_ids:
|
|
352
|
+
embed_name_uniq = list(set(self._embed_names))
|
|
353
|
+
self._embed_name_to_ids = {
|
|
354
|
+
t: tid for tid, t in enumerate(embed_name_uniq)
|
|
355
|
+
}
|
|
356
|
+
self._embed_ids = [
|
|
357
|
+
int(self._embed_name_to_ids[x]) for x in self._embed_names
|
|
358
|
+
]
|
|
359
|
+
|
|
360
|
+
self._is_cache_from_redis = [
|
|
361
|
+
proto_util.is_cache_from_redis(x, self._redis_cache_names)
|
|
362
|
+
for x in self._embed_names
|
|
363
|
+
]
|
|
364
|
+
|
|
365
|
+
# normalized feature names
|
|
366
|
+
self._feature_names = list(values.keys())
|
|
367
|
+
|
|
368
|
+
return lookup_input_indices, lookup_input_values, lookup_input_shapes,\
|
|
369
|
+
lookup_input_weights
|
|
370
|
+
|
|
371
|
+
def add_lookup_op(self, lookup_input_indices, lookup_input_values,
|
|
372
|
+
lookup_input_shapes, lookup_input_weights):
|
|
373
|
+
logging.info('add custom lookup operation to lookup embeddings from redis')
|
|
374
|
+
self._lookup_outs = [None for i in range(len(lookup_input_values))]
|
|
375
|
+
for i in range(len(lookup_input_values)):
|
|
376
|
+
if lookup_input_values[i].dtype == tf.int32:
|
|
377
|
+
lookup_input_values[i] = tf.to_int64(lookup_input_values[i])
|
|
378
|
+
for i in range(len(self._lookup_outs)):
|
|
379
|
+
i_1 = i + 1
|
|
380
|
+
self._lookup_outs[i] = self._lookup_op.kv_lookup(
|
|
381
|
+
lookup_input_indices[i:i_1],
|
|
382
|
+
lookup_input_values[i:i_1],
|
|
383
|
+
lookup_input_shapes[i:i_1],
|
|
384
|
+
lookup_input_weights[i:i_1],
|
|
385
|
+
url=self._redis_url,
|
|
386
|
+
password=self._redis_passwd,
|
|
387
|
+
timeout=self._redis_timeout,
|
|
388
|
+
combiners=self._embed_combiners[i:i_1],
|
|
389
|
+
embedding_dims=self._embed_dims[i:i_1],
|
|
390
|
+
embedding_names=self._embed_ids[i:i_1],
|
|
391
|
+
cache=self._is_cache_from_redis,
|
|
392
|
+
version=self._meta_graph_version)[0]
|
|
393
|
+
|
|
394
|
+
meta_graph_def = tf.train.export_meta_graph()
|
|
395
|
+
|
|
396
|
+
if self._verbose:
|
|
397
|
+
debug_path = os.path.join(self._debug_dir, 'graph_raw.txt')
|
|
398
|
+
with GFile(debug_path, 'w') as fout:
|
|
399
|
+
fout.write(
|
|
400
|
+
text_format.MessageToString(
|
|
401
|
+
self._meta_graph_def.graph_def, as_utf8=True))
|
|
402
|
+
return meta_graph_def
|
|
403
|
+
|
|
404
|
+
def add_oss_lookup_op(self, lookup_input_indices, lookup_input_values,
|
|
405
|
+
lookup_input_shapes, lookup_input_weights):
|
|
406
|
+
logging.info('add custom lookup operation to lookup embeddings from oss')
|
|
407
|
+
place_on_cpu = os.getenv('place_embedding_on_cpu')
|
|
408
|
+
place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
|
|
409
|
+
with conditional(place_on_cpu, ops.device('/CPU:0')):
|
|
410
|
+
for i in range(len(lookup_input_values)):
|
|
411
|
+
if lookup_input_values[i].dtype == tf.int32:
|
|
412
|
+
lookup_input_values[i] = tf.to_int64(lookup_input_values[i])
|
|
413
|
+
# N = len(lookup_input_indices)
|
|
414
|
+
# self._lookup_outs = [ None for _ in range(N) ]
|
|
415
|
+
# for i in range(N):
|
|
416
|
+
# i_1 = i + 1
|
|
417
|
+
# self._lookup_outs[i] = self._lookup_op.oss_read_kv(
|
|
418
|
+
# lookup_input_indices[i:i_1],
|
|
419
|
+
# lookup_input_values[i:i_1],
|
|
420
|
+
# lookup_input_shapes[i:i_1],
|
|
421
|
+
# lookup_input_weights[i:i_1],
|
|
422
|
+
# osspath=self._oss_path,
|
|
423
|
+
# endpoint=self._oss_endpoint,
|
|
424
|
+
# ak=self._oss_ak,
|
|
425
|
+
# sk=self._oss_sk,
|
|
426
|
+
# timeout=self._oss_timeout,
|
|
427
|
+
# combiners=self._embed_combiners[i:i_1],
|
|
428
|
+
# embedding_dims=self._embed_dims[i:i_1],
|
|
429
|
+
# embedding_ids=self._embed_ids[i:i_1],
|
|
430
|
+
# embedding_is_kv=self._embed_is_kv[i:i_1],
|
|
431
|
+
# shared_name='embedding_lookup_res',
|
|
432
|
+
# name='embedding_lookup_fused/lookup')[0]
|
|
433
|
+
self._lookup_outs = self._lookup_op.oss_read_kv(
|
|
434
|
+
lookup_input_indices,
|
|
435
|
+
lookup_input_values,
|
|
436
|
+
lookup_input_shapes,
|
|
437
|
+
lookup_input_weights,
|
|
438
|
+
osspath=self._oss_path,
|
|
439
|
+
endpoint=self._oss_endpoint,
|
|
440
|
+
ak=self._oss_ak,
|
|
441
|
+
sk=self._oss_sk,
|
|
442
|
+
timeout=self._oss_timeout,
|
|
443
|
+
combiners=self._embed_combiners,
|
|
444
|
+
embedding_dims=self._embed_dims,
|
|
445
|
+
embedding_ids=self._embed_ids,
|
|
446
|
+
embedding_is_kv=self._embed_is_kv,
|
|
447
|
+
shared_name='embedding_lookup_res',
|
|
448
|
+
name='embedding_lookup_fused/lookup')
|
|
449
|
+
|
|
450
|
+
N = np.max([int(x) for x in self._embed_ids]) + 1
|
|
451
|
+
uniq_embed_ids = [x for x in range(N)]
|
|
452
|
+
uniq_embed_dims = [0 for x in range(N)]
|
|
453
|
+
uniq_embed_combiners = ['mean' for x in range(N)]
|
|
454
|
+
uniq_embed_is_kvs = [0 for x in range(N)]
|
|
455
|
+
for embed_id, embed_combiner, embed_is_kv, embed_dim in zip(
|
|
456
|
+
self._embed_ids, self._embed_combiners, self._embed_is_kv,
|
|
457
|
+
self._embed_dims):
|
|
458
|
+
uniq_embed_combiners[embed_id] = embed_combiner
|
|
459
|
+
uniq_embed_is_kvs[embed_id] = embed_is_kv
|
|
460
|
+
uniq_embed_dims[embed_id] = embed_dim
|
|
461
|
+
|
|
462
|
+
lookup_init_op = self._lookup_op.oss_init(
|
|
463
|
+
osspath=self._oss_path,
|
|
464
|
+
endpoint=self._oss_endpoint,
|
|
465
|
+
ak=self._oss_ak,
|
|
466
|
+
sk=self._oss_sk,
|
|
467
|
+
combiners=uniq_embed_combiners,
|
|
468
|
+
embedding_dims=uniq_embed_dims,
|
|
469
|
+
embedding_ids=uniq_embed_ids,
|
|
470
|
+
embedding_is_kv=uniq_embed_is_kvs,
|
|
471
|
+
N=N,
|
|
472
|
+
shared_name='embedding_lookup_res',
|
|
473
|
+
name='embedding_lookup_fused/init')
|
|
474
|
+
|
|
475
|
+
ops.add_to_collection(EMBEDDING_INITIALIZERS, lookup_init_op)
|
|
476
|
+
|
|
477
|
+
if self._incr_update_params is not None:
|
|
478
|
+
# all sparse variables are updated by a single custom operation
|
|
479
|
+
message_ph = tf.placeholder(tf.int8, [None], name='incr_update/message')
|
|
480
|
+
embedding_update = self._lookup_op.embedding_update(
|
|
481
|
+
message=message_ph,
|
|
482
|
+
shared_name='embedding_lookup_res',
|
|
483
|
+
name='embedding_lookup_fused/embedding_update')
|
|
484
|
+
self._embedding_update_inputs['incr_update/sparse/message'] = message_ph
|
|
485
|
+
self._embedding_update_outputs[
|
|
486
|
+
'incr_update/sparse/embedding_update'] = embedding_update
|
|
487
|
+
|
|
488
|
+
# dense variables are updated one by one
|
|
489
|
+
dense_name_to_ids = embedding_utils.get_dense_name_to_ids()
|
|
490
|
+
for x in ops.get_collection(constant.DENSE_UPDATE_VARIABLES):
|
|
491
|
+
dense_var_id = dense_name_to_ids[x.op.name]
|
|
492
|
+
dense_input_name = 'incr_update/dense/%d/input' % dense_var_id
|
|
493
|
+
dense_output_name = 'incr_update/dense/%d/output' % dense_var_id
|
|
494
|
+
dense_update_input = tf.placeholder(
|
|
495
|
+
tf.float32, x.get_shape(), name=dense_input_name)
|
|
496
|
+
self._dense_update_inputs[dense_input_name] = dense_update_input
|
|
497
|
+
dense_assign_op = tf.assign(x, dense_update_input)
|
|
498
|
+
self._dense_update_outputs[dense_output_name] = dense_assign_op
|
|
499
|
+
|
|
500
|
+
meta_graph_def = tf.train.export_meta_graph()
|
|
501
|
+
|
|
502
|
+
if self._verbose:
|
|
503
|
+
debug_path = os.path.join(self._debug_dir, 'graph_raw.txt')
|
|
504
|
+
with GFile(debug_path, 'w') as fout:
|
|
505
|
+
fout.write(
|
|
506
|
+
text_format.MessageToString(
|
|
507
|
+
self._meta_graph_def.graph_def, as_utf8=True))
|
|
508
|
+
return meta_graph_def
|
|
509
|
+
|
|
510
|
+
def bytes2str(self, x):
|
|
511
|
+
if bytes == str:
|
|
512
|
+
return x
|
|
513
|
+
else:
|
|
514
|
+
try:
|
|
515
|
+
return x.decode('utf-8')
|
|
516
|
+
except Exception:
|
|
517
|
+
# in case of some special chars in protobuf
|
|
518
|
+
return str(x)
|
|
519
|
+
|
|
520
|
+
def clear_meta_graph_embeding(self, meta_graph_def):
|
|
521
|
+
logging.info('clear meta graph embedding_weights')
|
|
522
|
+
|
|
523
|
+
def _clear_embedding_in_meta_collect(meta_graph_def, collect_name):
|
|
524
|
+
tmp_vals = [
|
|
525
|
+
x
|
|
526
|
+
for x in meta_graph_def.collection_def[collect_name].bytes_list.value
|
|
527
|
+
if 'embedding_weights' not in self.bytes2str(x)
|
|
528
|
+
]
|
|
529
|
+
meta_graph_def.collection_def[collect_name].bytes_list.ClearField('value')
|
|
530
|
+
for tmp_v in tmp_vals:
|
|
531
|
+
meta_graph_def.collection_def[collect_name].bytes_list.value.append(
|
|
532
|
+
tmp_v)
|
|
533
|
+
|
|
534
|
+
_clear_embedding_in_meta_collect(meta_graph_def, 'model_variables')
|
|
535
|
+
_clear_embedding_in_meta_collect(meta_graph_def, 'trainable_variables')
|
|
536
|
+
_clear_embedding_in_meta_collect(meta_graph_def, 'variables')
|
|
537
|
+
|
|
538
|
+
# clear Kv(pai embedding variable) ops in meta_info_def.stripped_op_list.op
|
|
539
|
+
kept_ops = [
|
|
540
|
+
x for x in meta_graph_def.meta_info_def.stripped_op_list.op
|
|
541
|
+
if x.name not in [
|
|
542
|
+
'InitializeKvVariableOp', 'KvResourceGather', 'KvResourceImportV2',
|
|
543
|
+
'KvVarHandleOp', 'KvVarIsInitializedOp', 'ReadKvVariableOp'
|
|
544
|
+
]
|
|
545
|
+
]
|
|
546
|
+
meta_graph_def.meta_info_def.stripped_op_list.ClearField('op')
|
|
547
|
+
meta_graph_def.meta_info_def.stripped_op_list.op.extend(kept_ops)
|
|
548
|
+
for tmp_op in meta_graph_def.meta_info_def.stripped_op_list.op:
|
|
549
|
+
if tmp_op.name == 'SaveV2':
|
|
550
|
+
for tmp_id, tmp_attr in enumerate(tmp_op.attr):
|
|
551
|
+
if tmp_attr.name == 'has_ev':
|
|
552
|
+
tmp_op.attr.remove(tmp_attr)
|
|
553
|
+
break
|
|
554
|
+
|
|
555
|
+
def clear_meta_collect(self, meta_graph_def):
|
|
556
|
+
drop_meta_collects = []
|
|
557
|
+
for key in meta_graph_def.collection_def:
|
|
558
|
+
val = meta_graph_def.collection_def[key]
|
|
559
|
+
if val.HasField('node_list'):
|
|
560
|
+
if 'embedding_weights' in val.node_list.value[
|
|
561
|
+
0] and 'easy_rec' not in val.node_list.value[0]:
|
|
562
|
+
drop_meta_collects.append(key)
|
|
563
|
+
elif key == 'saved_model_assets':
|
|
564
|
+
drop_meta_collects.append(key)
|
|
565
|
+
for key in drop_meta_collects:
|
|
566
|
+
meta_graph_def.collection_def.pop(key)
|
|
567
|
+
|
|
568
|
+
def remove_embedding_weights_and_update_lookup_outputs(self):
|
|
569
|
+
|
|
570
|
+
def _should_drop(name):
|
|
571
|
+
if '_embedding_weights' in name:
|
|
572
|
+
if self._verbose:
|
|
573
|
+
logging.info('[SHOULD_DROP] %s' % name)
|
|
574
|
+
return True
|
|
575
|
+
|
|
576
|
+
logging.info('remove embedding_weights node in graph_def.node')
|
|
577
|
+
logging.info(
|
|
578
|
+
'and replace the old embedding_lookup outputs with new lookup_op outputs'
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
for tid, node in enumerate(self._all_graph_nodes):
|
|
582
|
+
# drop the nodes
|
|
583
|
+
if _should_drop(node.name):
|
|
584
|
+
self._all_graph_node_flags[tid] = False
|
|
585
|
+
else:
|
|
586
|
+
for i in range(len(node.input)):
|
|
587
|
+
if _should_drop(node.input[i]):
|
|
588
|
+
input_name, _ = proto_util.get_norm_embed_name(
|
|
589
|
+
node.input[i], self._verbose)
|
|
590
|
+
print('REPLACE:' + node.input[i] + '=>' + input_name)
|
|
591
|
+
input_name = self._lookup_outs[self._feature_names.index(
|
|
592
|
+
input_name)].name
|
|
593
|
+
if input_name.endswith(':0'):
|
|
594
|
+
input_name = input_name.replace(':0', '')
|
|
595
|
+
node.input[i] = input_name
|
|
596
|
+
|
|
597
|
+
# drop by ids
|
|
598
|
+
def _drop_by_ids(self, tmp_obj, key, drop_ids):
|
|
599
|
+
keep_vals = [
|
|
600
|
+
x for i, x in enumerate(getattr(tmp_obj, key)) if i not in drop_ids
|
|
601
|
+
]
|
|
602
|
+
tmp_obj.ClearField(key)
|
|
603
|
+
getattr(tmp_obj, key).extend(keep_vals)
|
|
604
|
+
|
|
605
|
+
def clear_save_restore(self):
|
|
606
|
+
"""Clear save restore ops.
|
|
607
|
+
|
|
608
|
+
save/restore_all need save/restore_shard as input
|
|
609
|
+
save/restore_shard needs save/Assign_[0-N] as input
|
|
610
|
+
save/Assign_[0-N] needs save/RestoreV2 as input
|
|
611
|
+
save/RestoreV2 use save/RestoreV2/tensor_names and save/RestoreV2/shape_and_slices as input
|
|
612
|
+
edit [ save/RestoreV2/tensor_names save/RestoreV2/shape_and_slices save/RestoreV2 save/restore_shard ]
|
|
613
|
+
"""
|
|
614
|
+
for tid, node in enumerate(self._all_graph_nodes):
|
|
615
|
+
if not self._all_graph_node_flags[tid]:
|
|
616
|
+
continue
|
|
617
|
+
if node.name == 'save/RestoreV2/tensor_names':
|
|
618
|
+
self._restore_tensor_node = node
|
|
619
|
+
break
|
|
620
|
+
# assert self._restore_tensor_node is not None, 'save/RestoreV2/tensor_names is not found'
|
|
621
|
+
|
|
622
|
+
if self._restore_tensor_node:
|
|
623
|
+
drop_ids = []
|
|
624
|
+
for tmp_id, tmp_name in enumerate(
|
|
625
|
+
self._restore_tensor_node.attr['value'].tensor.string_val):
|
|
626
|
+
if 'embedding_weights' in self.bytes2str(tmp_name):
|
|
627
|
+
drop_ids.append(tmp_id)
|
|
628
|
+
|
|
629
|
+
self._drop_by_ids(self._restore_tensor_node.attr['value'].tensor,
|
|
630
|
+
'string_val', drop_ids)
|
|
631
|
+
keep_node_num = len(
|
|
632
|
+
self._restore_tensor_node.attr['value'].tensor.string_val)
|
|
633
|
+
logging.info(
|
|
634
|
+
'update self._restore_tensor_node: string_val keep_num = %d drop_num = %d'
|
|
635
|
+
% (keep_node_num, len(drop_ids)))
|
|
636
|
+
self._restore_tensor_node.attr['value'].tensor.tensor_shape.dim[
|
|
637
|
+
0].size = keep_node_num
|
|
638
|
+
self._restore_tensor_node.attr['_output_shapes'].list.shape[0].dim[
|
|
639
|
+
0].size = keep_node_num
|
|
640
|
+
|
|
641
|
+
logging.info(
|
|
642
|
+
'update save/RestoreV2, drop tensor_shapes, _output_shapes, related to embedding_weights'
|
|
643
|
+
)
|
|
644
|
+
self._restore_shard_node = None
|
|
645
|
+
for node_id, node in enumerate(self._all_graph_nodes):
|
|
646
|
+
if not self._all_graph_node_flags[tid]:
|
|
647
|
+
continue
|
|
648
|
+
if node.name == 'save/RestoreV2/shape_and_slices':
|
|
649
|
+
node.attr['value'].tensor.tensor_shape.dim[0].size = keep_node_num
|
|
650
|
+
node.attr['_output_shapes'].list.shape[0].dim[0].size = keep_node_num
|
|
651
|
+
self._drop_by_ids(node.attr['value'].tensor, 'string_val', drop_ids)
|
|
652
|
+
elif node.name == 'save/RestoreV2':
|
|
653
|
+
self._drop_by_ids(node.attr['_output_shapes'].list, 'shape', drop_ids)
|
|
654
|
+
self._drop_by_ids(node.attr['dtypes'].list, 'type', drop_ids)
|
|
655
|
+
elif node.name == 'save/restore_shard':
|
|
656
|
+
self._restore_shard_node = node
|
|
657
|
+
elif node.name.startswith('save/restore_all'):
|
|
658
|
+
self._restore_all_node.append(node)
|
|
659
|
+
|
|
660
|
+
def clear_save_assign(self):
|
|
661
|
+
logging.info(
|
|
662
|
+
'update save/Assign, drop tensor_shapes, _output_shapes, related to embedding_weights'
|
|
663
|
+
)
|
|
664
|
+
# edit save/Assign
|
|
665
|
+
drop_save_assigns = []
|
|
666
|
+
all_kv_drop = []
|
|
667
|
+
for tid, node in enumerate(self._all_graph_nodes):
|
|
668
|
+
if not self._all_graph_node_flags[tid]:
|
|
669
|
+
continue
|
|
670
|
+
if node.op == 'Assign' and 'save/Assign' in node.name and \
|
|
671
|
+
'embedding_weights' in node.input[0]:
|
|
672
|
+
drop_save_assigns.append('^' + node.name)
|
|
673
|
+
self._all_graph_node_flags[tid] = False
|
|
674
|
+
elif 'embedding_weights/ConcatPartitions/concat' in node.name:
|
|
675
|
+
self._all_graph_node_flags[tid] = False
|
|
676
|
+
elif node.name.endswith('/embedding_weights') and node.op == 'Identity':
|
|
677
|
+
self._all_graph_node_flags[tid] = False
|
|
678
|
+
elif 'save/KvResourceImportV2' in node.name and node.op == 'KvResourceImportV2':
|
|
679
|
+
drop_save_assigns.append('^' + node.name)
|
|
680
|
+
self._all_graph_node_flags[tid] = False
|
|
681
|
+
elif 'KvResourceImportV2' in node.name:
|
|
682
|
+
self._all_graph_node_flags[tid] = False
|
|
683
|
+
elif 'save/Const' in node.name and node.op == 'Const':
|
|
684
|
+
if '_class' in node.attr and len(node.attr['_class'].list.s) > 0:
|
|
685
|
+
const_name = node.attr['_class'].list.s[0]
|
|
686
|
+
if not isinstance(const_name, str):
|
|
687
|
+
const_name = const_name.decode('utf-8')
|
|
688
|
+
if 'embedding_weights' in const_name:
|
|
689
|
+
self._all_graph_node_flags[tid] = False
|
|
690
|
+
elif 'ReadKvVariableOp' in node.name and node.op == 'ReadKvVariableOp':
|
|
691
|
+
all_kv_drop.append(node.name)
|
|
692
|
+
self._all_graph_node_flags[tid] = False
|
|
693
|
+
elif node.op == 'Assign' and 'save/Assign' in node.name:
|
|
694
|
+
# update node(save/Assign_[0-N])'s input[1] by the position of
|
|
695
|
+
# node.input[0] in save/RestoreV2/tensor_names
|
|
696
|
+
# the outputs of save/RestoreV2 is connected to save/Assign
|
|
697
|
+
tmp_id = [
|
|
698
|
+
self.bytes2str(x)
|
|
699
|
+
for x in self._restore_tensor_node.attr['value'].tensor.string_val
|
|
700
|
+
].index(node.input[0])
|
|
701
|
+
if tmp_id != 0:
|
|
702
|
+
tmp_input2 = 'save/RestoreV2:%d' % tmp_id
|
|
703
|
+
else:
|
|
704
|
+
tmp_input2 = 'save/RestoreV2'
|
|
705
|
+
if tmp_input2 != node.input[1]:
|
|
706
|
+
if self._verbose:
|
|
707
|
+
logging.info("update save/Assign[%s]'s input from %s to %s" %
|
|
708
|
+
(node.name, node.input[1], tmp_input2))
|
|
709
|
+
node.input[1] = tmp_input2
|
|
710
|
+
|
|
711
|
+
# save/restore_all need save/restore_shard as input
|
|
712
|
+
# save/restore_shard needs save/Assign_[0-N] as input
|
|
713
|
+
# save/Assign_[0-N] needs save/RestoreV2 as input
|
|
714
|
+
if self._restore_shard_node:
|
|
715
|
+
for tmp_input in drop_save_assigns:
|
|
716
|
+
self._restore_shard_node.input.remove(tmp_input)
|
|
717
|
+
if self._verbose:
|
|
718
|
+
logging.info('drop restore_shard input: %s' % tmp_input)
|
|
719
|
+
elif len(self._restore_all_node) > 0:
|
|
720
|
+
for tmp_input in drop_save_assigns:
|
|
721
|
+
for tmp_node in self._restore_all_node:
|
|
722
|
+
if tmp_input in tmp_node.input:
|
|
723
|
+
tmp_node.input.remove(tmp_input)
|
|
724
|
+
if self._verbose:
|
|
725
|
+
logging.info('drop %s input: %s' % (tmp_node.name, tmp_input))
|
|
726
|
+
break
|
|
727
|
+
|
|
728
|
+
def clear_save_v2(self):
|
|
729
|
+
"""Clear SaveV2 ops.
|
|
730
|
+
|
|
731
|
+
save/Identity need [ save/MergeV2Checkpoints, save/control_dependency ]
|
|
732
|
+
as input. Save/MergeV2Checkpoints need [save/MergeV2Checkpoints/checkpoint_prefixes]
|
|
733
|
+
as input. Save/MergeV2Checkpoints/checkpoint_prefixes need [ save/ShardedFilename,
|
|
734
|
+
save/control_dependency ] as input. save/control_dependency need save/SaveV2 as input.
|
|
735
|
+
save/SaveV2 input: [ save/SaveV2/tensor_names, save/SaveV2/shape_and_slices ]
|
|
736
|
+
edit save/SaveV2 save/SaveV2/shape_and_slices save/SaveV2/tensor_names.
|
|
737
|
+
"""
|
|
738
|
+
logging.info('update save/SaveV2 input shape, _output_shapes, tensor_shape')
|
|
739
|
+
save_drop_ids = []
|
|
740
|
+
for tid, node in enumerate(self._all_graph_nodes):
|
|
741
|
+
if not self._all_graph_node_flags[tid]:
|
|
742
|
+
continue
|
|
743
|
+
if node.name == 'save/SaveV2' and node.op == 'SaveV2':
|
|
744
|
+
for tmp_id, tmp_input in enumerate(node.input):
|
|
745
|
+
if '/embedding_weights' in tmp_input:
|
|
746
|
+
save_drop_ids.append(tmp_id)
|
|
747
|
+
diff_num = len(node.input) - len(node.attr['dtypes'].list.type)
|
|
748
|
+
self._drop_by_ids(node, 'input', save_drop_ids)
|
|
749
|
+
save_drop_ids = [x - diff_num for x in save_drop_ids]
|
|
750
|
+
self._drop_by_ids(node.attr['dtypes'].list, 'type', save_drop_ids)
|
|
751
|
+
if 'has_ev' in node.attr:
|
|
752
|
+
del node.attr['has_ev']
|
|
753
|
+
for node in self._all_graph_nodes:
|
|
754
|
+
if node.name == 'save/SaveV2/shape_and_slices' and node.op == 'Const':
|
|
755
|
+
# _output_shapes # size # string_val
|
|
756
|
+
node.attr['_output_shapes'].list.shape[0].dim[0].size -= len(
|
|
757
|
+
save_drop_ids)
|
|
758
|
+
node.attr['value'].tensor.tensor_shape.dim[0].size -= len(save_drop_ids)
|
|
759
|
+
self._drop_by_ids(node.attr['value'].tensor, 'string_val',
|
|
760
|
+
save_drop_ids)
|
|
761
|
+
elif node.name == 'save/SaveV2/tensor_names':
|
|
762
|
+
# tensor_names may not have the same order as save/SaveV2/shape_and_slices
|
|
763
|
+
tmp_drop_ids = [
|
|
764
|
+
tmp_id for tmp_id, tmp_val in enumerate(
|
|
765
|
+
node.attr['value'].tensor.string_val)
|
|
766
|
+
if 'embedding_weights' in self.bytes2str(tmp_val)
|
|
767
|
+
]
|
|
768
|
+
# attr['value'].tensor.string_val # tensor_shape # size
|
|
769
|
+
assert len(save_drop_ids) == len(save_drop_ids)
|
|
770
|
+
node.attr['_output_shapes'].list.shape[0].dim[0].size -= len(
|
|
771
|
+
tmp_drop_ids)
|
|
772
|
+
node.attr['value'].tensor.tensor_shape.dim[0].size -= len(tmp_drop_ids)
|
|
773
|
+
self._drop_by_ids(node.attr['value'].tensor, 'string_val', tmp_drop_ids)
|
|
774
|
+
|
|
775
|
+
def clear_initialize(self):
|
|
776
|
+
"""Clear initialization ops.
|
|
777
|
+
|
|
778
|
+
*/read(Identity) depend on [*(VariableV2)]
|
|
779
|
+
*/Assign depend on [*/Initializer/*, *(VariableV2)]
|
|
780
|
+
drop embedding_weights initialization nodes
|
|
781
|
+
*/embedding_weights/part_x [,/Assign,/read]
|
|
782
|
+
*/embedding_weights/part_1/Initializer/truncated_normal [,/shape,/mean,/stddev,/TruncatedNormal,/mul]
|
|
783
|
+
"""
|
|
784
|
+
logging.info('Remove Initialization nodes for embedding_weights')
|
|
785
|
+
for tid, node in enumerate(self._all_graph_nodes):
|
|
786
|
+
if not self._all_graph_node_flags[tid]:
|
|
787
|
+
continue
|
|
788
|
+
if 'embedding_weights' in node.name and 'Initializer' in node.name:
|
|
789
|
+
self._all_graph_node_flags[tid] = False
|
|
790
|
+
elif 'embedding_weights' in node.name and 'Assign' in node.name:
|
|
791
|
+
self._all_graph_node_flags[tid] = False
|
|
792
|
+
elif 'embedding_weights' in node.name and node.op == 'VariableV2':
|
|
793
|
+
self._all_graph_node_flags[tid] = False
|
|
794
|
+
elif 'embedding_weights' in node.name and node.name.endswith(
|
|
795
|
+
'/read') and node.op == 'Identity':
|
|
796
|
+
self._all_graph_node_flags[tid] = False
|
|
797
|
+
elif 'embedding_weights' in node.name and node.op == 'Identity':
|
|
798
|
+
node_toks = node.name.split('/')
|
|
799
|
+
node_tok = node_toks[-1]
|
|
800
|
+
if 'embedding_weights_' in node_tok:
|
|
801
|
+
node_tok = node_tok[len('embedding_weights_'):]
|
|
802
|
+
try:
|
|
803
|
+
int(node_tok)
|
|
804
|
+
self._all_graph_node_flags[tid] = False
|
|
805
|
+
except Exception:
|
|
806
|
+
pass
|
|
807
|
+
|
|
808
|
+
def clear_embedding_variable(self):
|
|
809
|
+
# for pai embedding variable, we drop some special nodes
|
|
810
|
+
for tid, node in enumerate(self._all_graph_nodes):
|
|
811
|
+
if not self._all_graph_node_flags[tid]:
|
|
812
|
+
continue
|
|
813
|
+
if node.op in [
|
|
814
|
+
'ReadKvVariableOp', 'KvVarIsInitializedOp', 'KvVarHandleOp'
|
|
815
|
+
]:
|
|
816
|
+
self._all_graph_node_flags[tid] = False
|
|
817
|
+
|
|
818
|
+
# there maybe some nodes depend on the dropped nodes, they are dropped as well
|
|
819
|
+
def drop_dependent_nodes(self):
|
|
820
|
+
drop_names = [
|
|
821
|
+
tmp_node.name
|
|
822
|
+
for tid, tmp_node in enumerate(self._all_graph_nodes)
|
|
823
|
+
if not self._all_graph_node_flags[tid]
|
|
824
|
+
]
|
|
825
|
+
while True:
|
|
826
|
+
more_drop_names = []
|
|
827
|
+
for tid, tmp_node in enumerate(self._all_graph_nodes):
|
|
828
|
+
if not self._all_graph_node_flags[tid]:
|
|
829
|
+
continue
|
|
830
|
+
if len(tmp_node.input) > 0 and tmp_node.input[0] in drop_names:
|
|
831
|
+
logging.info('drop dependent node: %s depend on %s' %
|
|
832
|
+
(tmp_node.name, tmp_node.input[0]))
|
|
833
|
+
self._all_graph_node_flags[tid] = False
|
|
834
|
+
more_drop_names.append(tmp_node.name)
|
|
835
|
+
drop_names = more_drop_names
|
|
836
|
+
if not drop_names:
|
|
837
|
+
break
|
|
838
|
+
|
|
839
|
+
def edit_graph(self):
|
|
840
|
+
# the main entrance
|
|
841
|
+
lookup_input_indices, lookup_input_values, lookup_input_shapes,\
|
|
842
|
+
lookup_input_weights = self.find_lookup_inputs()
|
|
843
|
+
|
|
844
|
+
# add lookup op to the graph
|
|
845
|
+
self._meta_graph_def = self.add_lookup_op(lookup_input_indices,
|
|
846
|
+
lookup_input_values,
|
|
847
|
+
lookup_input_shapes,
|
|
848
|
+
lookup_input_weights)
|
|
849
|
+
|
|
850
|
+
self.clear_meta_graph_embeding(self._meta_graph_def)
|
|
851
|
+
|
|
852
|
+
self.clear_meta_collect(self._meta_graph_def)
|
|
853
|
+
|
|
854
|
+
self.init_graph_node_clear_flags()
|
|
855
|
+
|
|
856
|
+
self.remove_embedding_weights_and_update_lookup_outputs()
|
|
857
|
+
|
|
858
|
+
# save/RestoreV2
|
|
859
|
+
self.clear_save_restore()
|
|
860
|
+
|
|
861
|
+
# save/Assign
|
|
862
|
+
self.clear_save_assign()
|
|
863
|
+
|
|
864
|
+
# save/SaveV2
|
|
865
|
+
self.clear_save_v2()
|
|
866
|
+
|
|
867
|
+
self.clear_initialize()
|
|
868
|
+
|
|
869
|
+
self.clear_embedding_variable()
|
|
870
|
+
|
|
871
|
+
self.drop_dependent_nodes()
|
|
872
|
+
|
|
873
|
+
self._meta_graph_def.graph_def.ClearField('node')
|
|
874
|
+
self._meta_graph_def.graph_def.node.extend([
|
|
875
|
+
x for tid, x in enumerate(self._all_graph_nodes)
|
|
876
|
+
if self._all_graph_node_flags[tid]
|
|
877
|
+
])
|
|
878
|
+
|
|
879
|
+
logging.info('old node number = %d' % self._old_node_num)
|
|
880
|
+
logging.info('node number = %d' % len(self._meta_graph_def.graph_def.node))
|
|
881
|
+
|
|
882
|
+
if self._verbose:
|
|
883
|
+
debug_dump_path = os.path.join(self._debug_dir, 'graph.txt')
|
|
884
|
+
with GFile(debug_dump_path, 'w') as fout:
|
|
885
|
+
fout.write(text_format.MessageToString(self.graph_def, as_utf8=True))
|
|
886
|
+
debug_dump_path = os.path.join(self._debug_dir, 'meta_graph.txt')
|
|
887
|
+
with GFile(debug_dump_path, 'w') as fout:
|
|
888
|
+
fout.write(
|
|
889
|
+
text_format.MessageToString(self._meta_graph_def, as_utf8=True))
|
|
890
|
+
|
|
891
|
+
def edit_graph_for_oss(self):
|
|
892
|
+
# the main entrance
|
|
893
|
+
lookup_input_indices, lookup_input_values, lookup_input_shapes,\
|
|
894
|
+
lookup_input_weights = self.find_lookup_inputs()
|
|
895
|
+
|
|
896
|
+
# add lookup op to the graph
|
|
897
|
+
self._meta_graph_def = self.add_oss_lookup_op(lookup_input_indices,
|
|
898
|
+
lookup_input_values,
|
|
899
|
+
lookup_input_shapes,
|
|
900
|
+
lookup_input_weights)
|
|
901
|
+
|
|
902
|
+
self.clear_meta_graph_embeding(self._meta_graph_def)
|
|
903
|
+
|
|
904
|
+
self.clear_meta_collect(self._meta_graph_def)
|
|
905
|
+
|
|
906
|
+
self.init_graph_node_clear_flags()
|
|
907
|
+
|
|
908
|
+
self.remove_embedding_weights_and_update_lookup_outputs()
|
|
909
|
+
|
|
910
|
+
# save/RestoreV2
|
|
911
|
+
self.clear_save_restore()
|
|
912
|
+
|
|
913
|
+
# save/Assign
|
|
914
|
+
self.clear_save_assign()
|
|
915
|
+
|
|
916
|
+
# save/SaveV2
|
|
917
|
+
self.clear_save_v2()
|
|
918
|
+
|
|
919
|
+
self.clear_initialize()
|
|
920
|
+
|
|
921
|
+
self.clear_embedding_variable()
|
|
922
|
+
|
|
923
|
+
self.drop_dependent_nodes()
|
|
924
|
+
|
|
925
|
+
self._meta_graph_def.graph_def.ClearField('node')
|
|
926
|
+
self._meta_graph_def.graph_def.node.extend([
|
|
927
|
+
x for tid, x in enumerate(self._all_graph_nodes)
|
|
928
|
+
if self._all_graph_node_flags[tid]
|
|
929
|
+
])
|
|
930
|
+
|
|
931
|
+
logging.info('old node number = %d' % self._old_node_num)
|
|
932
|
+
logging.info('node number = %d' % len(self._meta_graph_def.graph_def.node))
|
|
933
|
+
|
|
934
|
+
if self._verbose:
|
|
935
|
+
debug_dump_path = os.path.join(self._debug_dir, 'graph.txt')
|
|
936
|
+
with GFile(debug_dump_path, 'w') as fout:
|
|
937
|
+
fout.write(text_format.MessageToString(self.graph_def, as_utf8=True))
|
|
938
|
+
debug_dump_path = os.path.join(self._debug_dir, 'meta_graph.txt')
|
|
939
|
+
with GFile(debug_dump_path, 'w') as fout:
|
|
940
|
+
fout.write(
|
|
941
|
+
text_format.MessageToString(self._meta_graph_def, as_utf8=True))
|