easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
"""Defines functions common to multiple feature column files."""
|
|
17
|
+
|
|
18
|
+
from __future__ import absolute_import
|
|
19
|
+
from __future__ import division
|
|
20
|
+
from __future__ import print_function
|
|
21
|
+
|
|
22
|
+
import six
|
|
23
|
+
from tensorflow.python.framework import dtypes
|
|
24
|
+
from tensorflow.python.framework import ops
|
|
25
|
+
from tensorflow.python.ops import array_ops
|
|
26
|
+
from tensorflow.python.ops import math_ops
|
|
27
|
+
from tensorflow.python.util import nest
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
|
|
31
|
+
"""Returns a [batch_size] Tensor with per-example sequence length."""
|
|
32
|
+
with ops.name_scope(None, 'sequence_length') as name_scope:
|
|
33
|
+
row_ids = sp_tensor.indices[:, 0]
|
|
34
|
+
column_ids = sp_tensor.indices[:, 1]
|
|
35
|
+
# Add one to convert column indices to element length
|
|
36
|
+
column_ids += array_ops.ones_like(column_ids)
|
|
37
|
+
# Get the number of elements we will have per example/row
|
|
38
|
+
seq_length = math_ops.segment_max(column_ids, segment_ids=row_ids)
|
|
39
|
+
|
|
40
|
+
# The raw values are grouped according to num_elements;
|
|
41
|
+
# how many entities will we have after grouping?
|
|
42
|
+
# Example: orig tensor [[1, 2], [3]], col_ids = (0, 1, 1),
|
|
43
|
+
# row_ids = (0, 0, 1), seq_length = [2, 1]. If num_elements = 2,
|
|
44
|
+
# these will get grouped, and the final seq_length is [1, 1]
|
|
45
|
+
seq_length = math_ops.cast(
|
|
46
|
+
math_ops.ceil(seq_length / num_elements), dtypes.int64)
|
|
47
|
+
|
|
48
|
+
# If the last n rows do not have ids, seq_length will have shape
|
|
49
|
+
# [batch_size - n]. Pad the remaining values with zeros.
|
|
50
|
+
n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
|
|
51
|
+
padding = array_ops.zeros(n_pad, dtype=seq_length.dtype)
|
|
52
|
+
return array_ops.concat([seq_length, padding], axis=0, name=name_scope)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def assert_string_or_int(dtype, prefix):
|
|
56
|
+
if (dtype != dtypes.string) and (not dtype.is_integer):
|
|
57
|
+
raise ValueError('{} dtype must be string or integer. dtype: {}.'.format(
|
|
58
|
+
prefix, dtype))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def assert_key_is_string(key):
|
|
62
|
+
if not isinstance(key, six.string_types):
|
|
63
|
+
raise ValueError(
|
|
64
|
+
'key must be a string. Got: type {}. Given key: {}.'.format(
|
|
65
|
+
type(key), key))
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def check_default_value(shape, default_value, dtype, key):
|
|
69
|
+
"""Returns default value as tuple if it's valid, otherwise raises errors.
|
|
70
|
+
|
|
71
|
+
This function verifies that `default_value` is compatible with both `shape`
|
|
72
|
+
and `dtype`. If it is not compatible, it raises an error. If it is compatible,
|
|
73
|
+
it casts default_value to a tuple and returns it. `key` is used only
|
|
74
|
+
for error message.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
shape: An iterable of integers specifies the shape of the `Tensor`.
|
|
78
|
+
default_value: If a single value is provided, the same value will be applied
|
|
79
|
+
as the default value for every item. If an iterable of values is
|
|
80
|
+
provided, the shape of the `default_value` should be equal to the given
|
|
81
|
+
`shape`.
|
|
82
|
+
dtype: defines the type of values. Default value is `tf.float32`. Must be a
|
|
83
|
+
non-quantized, real integer or floating point type.
|
|
84
|
+
key: Column name, used only for error messages.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
A tuple which will be used as default value.
|
|
88
|
+
|
|
89
|
+
Raises:
|
|
90
|
+
TypeError: if `default_value` is an iterable but not compatible with `shape`
|
|
91
|
+
TypeError: if `default_value` is not compatible with `dtype`.
|
|
92
|
+
ValueError: if `dtype` is not convertible to `tf.float32`.
|
|
93
|
+
"""
|
|
94
|
+
if default_value is None:
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
if isinstance(default_value, int):
|
|
98
|
+
return _create_tuple(shape, default_value)
|
|
99
|
+
|
|
100
|
+
if isinstance(default_value, float) and dtype.is_floating:
|
|
101
|
+
return _create_tuple(shape, default_value)
|
|
102
|
+
|
|
103
|
+
if callable(getattr(default_value, 'tolist', None)): # Handles numpy arrays
|
|
104
|
+
default_value = default_value.tolist()
|
|
105
|
+
|
|
106
|
+
if nest.is_sequence(default_value):
|
|
107
|
+
if not _is_shape_and_default_value_compatible(default_value, shape):
|
|
108
|
+
raise ValueError(
|
|
109
|
+
'The shape of default_value must be equal to given shape. '
|
|
110
|
+
'default_value: {}, shape: {}, key: {}'.format(
|
|
111
|
+
default_value, shape, key))
|
|
112
|
+
# Check if the values in the list are all integers or are convertible to
|
|
113
|
+
# floats.
|
|
114
|
+
is_list_all_int = all(
|
|
115
|
+
isinstance(v, int) for v in nest.flatten(default_value))
|
|
116
|
+
is_list_has_float = any(
|
|
117
|
+
isinstance(v, float) for v in nest.flatten(default_value))
|
|
118
|
+
if is_list_all_int:
|
|
119
|
+
return _as_tuple(default_value)
|
|
120
|
+
if is_list_has_float and dtype.is_floating:
|
|
121
|
+
return _as_tuple(default_value)
|
|
122
|
+
raise TypeError('default_value must be compatible with dtype. '
|
|
123
|
+
'default_value: {}, dtype: {}, key: {}'.format(
|
|
124
|
+
default_value, dtype, key))
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _create_tuple(shape, value):
|
|
128
|
+
"""Returns a tuple with given shape and filled with value."""
|
|
129
|
+
if shape:
|
|
130
|
+
return tuple([_create_tuple(shape[1:], value) for _ in range(shape[0])])
|
|
131
|
+
return value
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _as_tuple(value):
|
|
135
|
+
if not nest.is_sequence(value):
|
|
136
|
+
return value
|
|
137
|
+
return tuple([_as_tuple(v) for v in value])
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _is_shape_and_default_value_compatible(default_value, shape):
|
|
141
|
+
"""Verifies compatibility of shape and default_value."""
|
|
142
|
+
# Invalid condition:
|
|
143
|
+
# * if default_value is not a scalar and shape is empty
|
|
144
|
+
# * or if default_value is an iterable and shape is not empty
|
|
145
|
+
if nest.is_sequence(default_value) != bool(shape):
|
|
146
|
+
return False
|
|
147
|
+
if not shape:
|
|
148
|
+
return True
|
|
149
|
+
if len(default_value) != shape[0]:
|
|
150
|
+
return False
|
|
151
|
+
for i in range(shape[0]):
|
|
152
|
+
if not _is_shape_and_default_value_compatible(default_value[i], shape[1:]):
|
|
153
|
+
return False
|
|
154
|
+
return True
|
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
"""Higher level ops for building layers."""
|
|
17
|
+
|
|
18
|
+
from __future__ import absolute_import
|
|
19
|
+
from __future__ import division
|
|
20
|
+
from __future__ import print_function
|
|
21
|
+
|
|
22
|
+
import functools
|
|
23
|
+
|
|
24
|
+
from tensorflow.python.framework import dtypes
|
|
25
|
+
from tensorflow.python.framework import ops
|
|
26
|
+
from tensorflow.python.ops import init_ops
|
|
27
|
+
from tensorflow.python.ops import nn
|
|
28
|
+
from tensorflow.python.ops import variable_scope
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def layer_norm(inputs,
|
|
32
|
+
center=True,
|
|
33
|
+
scale=True,
|
|
34
|
+
activation_fn=None,
|
|
35
|
+
reuse=None,
|
|
36
|
+
variables_collections=None,
|
|
37
|
+
outputs_collections=None,
|
|
38
|
+
trainable=True,
|
|
39
|
+
begin_norm_axis=1,
|
|
40
|
+
begin_params_axis=-1,
|
|
41
|
+
scope=None):
|
|
42
|
+
"""Adds a Layer Normalization layer.
|
|
43
|
+
|
|
44
|
+
Based on the paper:
|
|
45
|
+
|
|
46
|
+
"Layer Normalization"
|
|
47
|
+
Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
|
|
48
|
+
https://arxiv.org/abs/1607.06450.
|
|
49
|
+
|
|
50
|
+
Can be used as a normalizer function for conv2d and fully_connected.
|
|
51
|
+
|
|
52
|
+
Given a tensor `inputs` of rank `R`, moments are calculated and normalization
|
|
53
|
+
is performed over axes `begin_norm_axis ... R - 1`. Scaling and centering,
|
|
54
|
+
if requested, is performed over axes `begin_params_axis .. R - 1`.
|
|
55
|
+
|
|
56
|
+
By default, `begin_norm_axis = 1` and `begin_params_axis = -1`,
|
|
57
|
+
meaning that normalization is performed over all but the first axis
|
|
58
|
+
(the `HWC` if `inputs` is `NHWC`), while the `beta` and `gamma` trainable
|
|
59
|
+
parameters are calculated for the rightmost axis (the `C` if `inputs` is
|
|
60
|
+
`NHWC`). Scaling and recentering is performed via broadcast of the
|
|
61
|
+
`beta` and `gamma` parameters with the normalized tensor.
|
|
62
|
+
|
|
63
|
+
The shapes of `beta` and `gamma` are `inputs.shape[begin_params_axis:]`,
|
|
64
|
+
and this part of the inputs' shape must be fully defined.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
inputs: A tensor having rank `R`. The normalization is performed over
|
|
68
|
+
axes `begin_norm_axis ... R - 1` and centering and scaling parameters
|
|
69
|
+
are calculated over `begin_params_axis ... R - 1`.
|
|
70
|
+
center: If True, add offset of `beta` to normalized tensor. If False, `beta`
|
|
71
|
+
is ignored.
|
|
72
|
+
scale: If True, multiply by `gamma`. If False, `gamma` is
|
|
73
|
+
not used. When the next layer is linear (also e.g. `nn.relu`), this can be
|
|
74
|
+
disabled since the scaling can be done by the next layer.
|
|
75
|
+
activation_fn: Activation function, default set to None to skip it and
|
|
76
|
+
maintain a linear activation.
|
|
77
|
+
reuse: Whether or not the layer and its variables should be reused. To be
|
|
78
|
+
able to reuse the layer scope must be given.
|
|
79
|
+
variables_collections: Optional collections for the variables.
|
|
80
|
+
outputs_collections: Collections to add the outputs.
|
|
81
|
+
trainable: If `True` also add variables to the graph collection
|
|
82
|
+
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
|
|
83
|
+
begin_norm_axis: The first normalization dimension: normalization will be
|
|
84
|
+
performed along dimensions `begin_norm_axis : rank(inputs)`
|
|
85
|
+
begin_params_axis: The first parameter (beta, gamma) dimension: scale
|
|
86
|
+
and centering parameters will have dimensions
|
|
87
|
+
`begin_params_axis : rank(inputs)` and will be broadcast with the
|
|
88
|
+
normalized inputs accordingly.
|
|
89
|
+
scope: Optional scope for `variable_scope`.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
A `Tensor` representing the output of the operation, having the same
|
|
93
|
+
shape and dtype as `inputs`.
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
ValueError: If the rank of `inputs` is not known at graph build time,
|
|
97
|
+
or if `inputs.shape[begin_params_axis:]` is not fully defined at
|
|
98
|
+
graph build time.
|
|
99
|
+
"""
|
|
100
|
+
with variable_scope.variable_scope(
|
|
101
|
+
scope, 'LayerNorm', [inputs], reuse=reuse) as sc:
|
|
102
|
+
inputs = ops.convert_to_tensor(inputs)
|
|
103
|
+
inputs_shape = inputs.shape
|
|
104
|
+
inputs_rank = inputs_shape.ndims
|
|
105
|
+
if inputs_rank is None:
|
|
106
|
+
raise ValueError('Inputs %s has undefined rank.' % inputs.name)
|
|
107
|
+
dtype = inputs.dtype.base_dtype
|
|
108
|
+
if begin_norm_axis < 0:
|
|
109
|
+
begin_norm_axis = inputs_rank + begin_norm_axis
|
|
110
|
+
if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank:
|
|
111
|
+
raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) '
|
|
112
|
+
'must be < rank(inputs) (%d)' %
|
|
113
|
+
(begin_params_axis, begin_norm_axis, inputs_rank))
|
|
114
|
+
params_shape = inputs_shape[begin_params_axis:]
|
|
115
|
+
if not params_shape.is_fully_defined():
|
|
116
|
+
raise ValueError(
|
|
117
|
+
'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' %
|
|
118
|
+
(inputs.name, begin_params_axis, inputs_shape))
|
|
119
|
+
# Allocate parameters for the beta and gamma of the normalization.
|
|
120
|
+
beta, gamma = None, None
|
|
121
|
+
if center:
|
|
122
|
+
beta_collections = get_variable_collections(variables_collections, 'beta')
|
|
123
|
+
beta = model_variable(
|
|
124
|
+
'beta',
|
|
125
|
+
shape=params_shape,
|
|
126
|
+
dtype=dtype,
|
|
127
|
+
initializer=init_ops.zeros_initializer(),
|
|
128
|
+
collections=beta_collections,
|
|
129
|
+
trainable=trainable)
|
|
130
|
+
if scale:
|
|
131
|
+
gamma_collections = get_variable_collections(variables_collections,
|
|
132
|
+
'gamma')
|
|
133
|
+
gamma = model_variable(
|
|
134
|
+
'gamma',
|
|
135
|
+
shape=params_shape,
|
|
136
|
+
dtype=dtype,
|
|
137
|
+
initializer=init_ops.ones_initializer(),
|
|
138
|
+
collections=gamma_collections,
|
|
139
|
+
trainable=trainable)
|
|
140
|
+
# Calculate the moments on the last axis (layer activations).
|
|
141
|
+
norm_axes = list(range(begin_norm_axis, inputs_rank))
|
|
142
|
+
mean, variance = nn.moments(inputs, norm_axes, keep_dims=True)
|
|
143
|
+
# Compute layer normalization using the batch_normalization function.
|
|
144
|
+
variance_epsilon = 1e-12
|
|
145
|
+
outputs = nn.batch_normalization(
|
|
146
|
+
inputs,
|
|
147
|
+
mean,
|
|
148
|
+
variance,
|
|
149
|
+
offset=beta,
|
|
150
|
+
scale=gamma,
|
|
151
|
+
variance_epsilon=variance_epsilon)
|
|
152
|
+
outputs.set_shape(inputs_shape)
|
|
153
|
+
if activation_fn is not None:
|
|
154
|
+
outputs = activation_fn(outputs)
|
|
155
|
+
return collect_named_outputs(outputs_collections, sc.name, outputs)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def get_variable_collections(variables_collections, name):
|
|
159
|
+
if isinstance(variables_collections, dict):
|
|
160
|
+
variable_collections = variables_collections.get(name, None)
|
|
161
|
+
else:
|
|
162
|
+
variable_collections = variables_collections
|
|
163
|
+
return variable_collections
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def collect_named_outputs(collections, alias, outputs):
|
|
167
|
+
"""Add `Tensor` outputs tagged with alias to collections.
|
|
168
|
+
|
|
169
|
+
It is useful to collect end-points or tags for summaries. Example of usage:
|
|
170
|
+
logits = collect_named_outputs('end_points', 'inception_v3/logits', logits)
|
|
171
|
+
assert 'inception_v3/logits' in logits.aliases
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
collections: A collection or list of collections. If None skip collection.
|
|
175
|
+
alias: String to append to the list of aliases of outputs, for example,
|
|
176
|
+
'inception_v3/conv1'.
|
|
177
|
+
outputs: Tensor, an output tensor to collect
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
The outputs Tensor to allow inline call.
|
|
181
|
+
"""
|
|
182
|
+
if collections:
|
|
183
|
+
append_tensor_alias(outputs, alias)
|
|
184
|
+
ops.add_to_collections(collections, outputs)
|
|
185
|
+
return outputs
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def append_tensor_alias(tensor, alias):
|
|
189
|
+
"""Append an alias to the list of aliases of the tensor.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
tensor: A `Tensor`.
|
|
193
|
+
alias: String, to add to the list of aliases of the tensor.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
The tensor with a new alias appended to its list of aliases.
|
|
197
|
+
"""
|
|
198
|
+
# Remove ending '/' if present.
|
|
199
|
+
if alias[-1] == '/':
|
|
200
|
+
alias = alias[:-1]
|
|
201
|
+
if hasattr(tensor, 'aliases'):
|
|
202
|
+
tensor.aliases.append(alias)
|
|
203
|
+
else:
|
|
204
|
+
tensor.aliases = [alias]
|
|
205
|
+
return tensor
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def variable(name,
|
|
209
|
+
shape=None,
|
|
210
|
+
dtype=None,
|
|
211
|
+
initializer=None,
|
|
212
|
+
regularizer=None,
|
|
213
|
+
trainable=True,
|
|
214
|
+
collections=None,
|
|
215
|
+
caching_device=None,
|
|
216
|
+
device=None,
|
|
217
|
+
partitioner=None,
|
|
218
|
+
custom_getter=None,
|
|
219
|
+
use_resource=None):
|
|
220
|
+
"""Gets an existing variable with these parameters or creates a new one.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
name: the name of the new or existing variable.
|
|
224
|
+
shape: shape of the new or existing variable.
|
|
225
|
+
dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
|
|
226
|
+
initializer: initializer for the variable if one is created.
|
|
227
|
+
regularizer: a (Tensor -> Tensor or None) function; the result of
|
|
228
|
+
applying it on a newly created variable will be added to the collection
|
|
229
|
+
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
|
|
230
|
+
trainable: If `True` also add the variable to the graph collection
|
|
231
|
+
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
|
|
232
|
+
collections: A list of collection names to which the Variable will be added.
|
|
233
|
+
If None it would default to `tf.GraphKeys.GLOBAL_VARIABLES`.
|
|
234
|
+
caching_device: Optional device string or function describing where the
|
|
235
|
+
Variable should be cached for reading. Defaults to the Variable's
|
|
236
|
+
device.
|
|
237
|
+
device: Optional device to place the variable. It can be an string or a
|
|
238
|
+
function that is called to get the device for the variable.
|
|
239
|
+
partitioner: Optional callable that accepts a fully defined `TensorShape`
|
|
240
|
+
and dtype of the `Variable` to be created, and returns a list of
|
|
241
|
+
partitions for each axis (currently only one axis can be partitioned).
|
|
242
|
+
custom_getter: Callable that allows overwriting the internal
|
|
243
|
+
get_variable method and has to have the same signature.
|
|
244
|
+
use_resource: If `True` use a ResourceVariable instead of a Variable.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
The created or existing variable.
|
|
248
|
+
"""
|
|
249
|
+
collections = list(collections if collections is not None else
|
|
250
|
+
[ops.GraphKeys.GLOBAL_VARIABLES])
|
|
251
|
+
|
|
252
|
+
# Remove duplicates
|
|
253
|
+
collections = list(set(collections))
|
|
254
|
+
getter = variable_scope.get_variable
|
|
255
|
+
if custom_getter is not None:
|
|
256
|
+
getter = functools.partial(
|
|
257
|
+
custom_getter, reuse=variable_scope.get_variable_scope().reuse)
|
|
258
|
+
with ops.device(device or ''):
|
|
259
|
+
return getter(
|
|
260
|
+
name,
|
|
261
|
+
shape=shape,
|
|
262
|
+
dtype=dtype,
|
|
263
|
+
initializer=initializer,
|
|
264
|
+
regularizer=regularizer,
|
|
265
|
+
trainable=trainable,
|
|
266
|
+
collections=collections,
|
|
267
|
+
caching_device=caching_device,
|
|
268
|
+
partitioner=partitioner,
|
|
269
|
+
use_resource=use_resource)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def model_variable(name,
|
|
273
|
+
shape=None,
|
|
274
|
+
dtype=dtypes.float32,
|
|
275
|
+
initializer=None,
|
|
276
|
+
regularizer=None,
|
|
277
|
+
trainable=True,
|
|
278
|
+
collections=None,
|
|
279
|
+
caching_device=None,
|
|
280
|
+
device=None,
|
|
281
|
+
partitioner=None,
|
|
282
|
+
custom_getter=None,
|
|
283
|
+
use_resource=None):
|
|
284
|
+
"""Gets an existing model variable with these parameters or creates a new one.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
name: the name of the new or existing variable.
|
|
288
|
+
shape: shape of the new or existing variable.
|
|
289
|
+
dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
|
|
290
|
+
initializer: initializer for the variable if one is created.
|
|
291
|
+
regularizer: a (Tensor -> Tensor or None) function; the result of
|
|
292
|
+
applying it on a newly created variable will be added to the collection
|
|
293
|
+
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
|
|
294
|
+
trainable: If `True` also add the variable to the graph collection
|
|
295
|
+
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
|
|
296
|
+
collections: A list of collection names to which the Variable will be added.
|
|
297
|
+
Note that the variable is always also added to the
|
|
298
|
+
`GraphKeys.GLOBAL_VARIABLES` and `GraphKeys.MODEL_VARIABLES` collections.
|
|
299
|
+
caching_device: Optional device string or function describing where the
|
|
300
|
+
Variable should be cached for reading. Defaults to the Variable's
|
|
301
|
+
device.
|
|
302
|
+
device: Optional device to place the variable. It can be an string or a
|
|
303
|
+
function that is called to get the device for the variable.
|
|
304
|
+
partitioner: Optional callable that accepts a fully defined `TensorShape`
|
|
305
|
+
and dtype of the `Variable` to be created, and returns a list of
|
|
306
|
+
partitions for each axis (currently only one axis can be partitioned).
|
|
307
|
+
custom_getter: Callable that allows overwriting the internal
|
|
308
|
+
get_variable method and has to have the same signature.
|
|
309
|
+
use_resource: If `True` use a ResourceVariable instead of a Variable.
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
The created or existing variable.
|
|
313
|
+
"""
|
|
314
|
+
collections = list(collections or [])
|
|
315
|
+
collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES]
|
|
316
|
+
var = variable(
|
|
317
|
+
name,
|
|
318
|
+
shape=shape,
|
|
319
|
+
dtype=dtype,
|
|
320
|
+
initializer=initializer,
|
|
321
|
+
regularizer=regularizer,
|
|
322
|
+
trainable=trainable,
|
|
323
|
+
collections=collections,
|
|
324
|
+
caching_device=caching_device,
|
|
325
|
+
device=device,
|
|
326
|
+
partitioner=partitioner,
|
|
327
|
+
custom_getter=custom_getter,
|
|
328
|
+
use_resource=use_resource)
|
|
329
|
+
return var
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from tensorflow.python.framework import ops
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class GraphKeys(ops.GraphKeys):
|
|
5
|
+
# For rank service
|
|
6
|
+
RANK_SERVICE_FG_CONF = '__rank_service_fg_conf'
|
|
7
|
+
RANK_SERVICE_INPUT = '__rank_service_input'
|
|
8
|
+
RANK_SERVICE_OUTPUT = '__rank_service_output'
|
|
9
|
+
RANK_SERVICE_EMBEDDING = '__rank_service_embedding'
|
|
10
|
+
RANK_SERVICE_INPUT_SRC = '__rank_service_input_src'
|
|
11
|
+
RANK_SERVICE_REPLACE_OP = '__rank_service_replace'
|
|
12
|
+
RANK_SERVICE_SHAPE_OPT_FLAG = '__rank_service_shape_opt_flag'
|
|
13
|
+
# For compatition between RTP and EasyRec
|
|
14
|
+
RANK_SERVICE_FEATURE_NODE = '__rank_service_feature_node'
|