easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,542 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2022, NVIDIA CORPORATION.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
#
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
|
|
19
|
+
import tensorflow as tf
|
|
20
|
+
from sparse_operation_kit.experiment import raw_ops as dynamic_variable_ops
|
|
21
|
+
from sparse_operation_kit.experiment.communication import num_gpus
|
|
22
|
+
from tensorflow.python.eager import context
|
|
23
|
+
from tensorflow.python.framework import ops
|
|
24
|
+
# from tensorflow.python.ops import array_ops
|
|
25
|
+
from tensorflow.python.ops import resource_variable_ops
|
|
26
|
+
from tensorflow.python.ops.resource_variable_ops import ResourceVariable
|
|
27
|
+
from tensorflow.python.ops.resource_variable_ops import variable_accessed
|
|
28
|
+
|
|
29
|
+
# from tensorflow.python.util import object_identity
|
|
30
|
+
|
|
31
|
+
dynamic_variable_count = 0
|
|
32
|
+
|
|
33
|
+
_resource_var_from_proto = ResourceVariable.from_proto
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class DynamicVariable(ResourceVariable):
|
|
37
|
+
"""Abbreviated as ``sok.experiment.DynamicVariable``.
|
|
38
|
+
|
|
39
|
+
A variable that allocates memory dynamically.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
dimension: int
|
|
44
|
+
The last dimension of this variable(that is, the embedding vector
|
|
45
|
+
size of embedding table).
|
|
46
|
+
|
|
47
|
+
initializer: string
|
|
48
|
+
a string to specify how to initialize this variable.
|
|
49
|
+
Currently, only support "random" or string of a float
|
|
50
|
+
value(meaning const initializer). Default value is "random".
|
|
51
|
+
|
|
52
|
+
var_type: string
|
|
53
|
+
a string to specify to use DET or HKV as the backend.
|
|
54
|
+
If use HKV as the backend, only support tf.int64 as key_type
|
|
55
|
+
If use HKV as the backend, please set init_capacity and max_capacity value equal to 2 powers.
|
|
56
|
+
|
|
57
|
+
key_type: dtype
|
|
58
|
+
specify the data type of indices. Unlike the static variable of
|
|
59
|
+
tensorflow, this variable is dynamically allocated and contains
|
|
60
|
+
a hash table inside it. So the data type of indices must be
|
|
61
|
+
specified to construct the hash table. Default value is tf.int64.
|
|
62
|
+
|
|
63
|
+
dtype: dtype
|
|
64
|
+
specify the data type of values. Default value is tf.float32.
|
|
65
|
+
|
|
66
|
+
Example
|
|
67
|
+
-------
|
|
68
|
+
.. code-block:: python
|
|
69
|
+
|
|
70
|
+
import numpy as np
|
|
71
|
+
import tensorflow as tf
|
|
72
|
+
import horovod.tensorflow as hvd
|
|
73
|
+
from sparse_operation_kit import experiment as sok
|
|
74
|
+
|
|
75
|
+
v = sok.DynamicVariable(dimension=3, initializer="13")
|
|
76
|
+
print("v.shape:", v.shape)
|
|
77
|
+
print("v.size:", v.size)
|
|
78
|
+
|
|
79
|
+
indices = tf.convert_to_tensor([0, 1, 2**40], dtype=tf.int64)
|
|
80
|
+
|
|
81
|
+
embedding = tf.nn.embedding_lookup(v, indices)
|
|
82
|
+
print("embedding:", embedding)
|
|
83
|
+
print("v.shape:", v.shape)
|
|
84
|
+
print("v.size:", v.size)
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
def __init__(self,
|
|
88
|
+
dimension,
|
|
89
|
+
initializer=None,
|
|
90
|
+
var_type=None,
|
|
91
|
+
name=None,
|
|
92
|
+
constraint=None,
|
|
93
|
+
trainable=True,
|
|
94
|
+
key_type=None,
|
|
95
|
+
dtype=None,
|
|
96
|
+
mode=None,
|
|
97
|
+
variable_def=None,
|
|
98
|
+
import_scope=None,
|
|
99
|
+
**kwargs):
|
|
100
|
+
self._indices = None
|
|
101
|
+
if variable_def is not None:
|
|
102
|
+
super(DynamicVariable, self)._init_from_proto(
|
|
103
|
+
variable_def, import_scope=import_scope, validate_shape=False)
|
|
104
|
+
g = ops.get_default_graph()
|
|
105
|
+
handle = g.as_graph_element(
|
|
106
|
+
ops.prepend_name_scope(
|
|
107
|
+
variable_def.variable_name, import_scope=import_scope),
|
|
108
|
+
allow_operation=False)
|
|
109
|
+
self._dimension = handle.op.get_attr('shape').dim[-1].size
|
|
110
|
+
self._key_type = handle.op.get_attr('key_type')
|
|
111
|
+
self._handle_type = handle.op.get_attr('dtype')
|
|
112
|
+
self._mode = None
|
|
113
|
+
self._config = {}
|
|
114
|
+
self._name = variable_def.variable_name.split(':')[0]
|
|
115
|
+
self._trainable = variable_def.trainable
|
|
116
|
+
self._dummy_handle = handle
|
|
117
|
+
self._handle = handle
|
|
118
|
+
|
|
119
|
+
# init op
|
|
120
|
+
init_op = g.as_graph_element(variable_def.initializer_name)
|
|
121
|
+
self._initializer_op = init_op
|
|
122
|
+
|
|
123
|
+
init_tf = init_op.control_inputs[0]
|
|
124
|
+
# init_dummy = init_op.control_inputs[1]
|
|
125
|
+
|
|
126
|
+
self._tf_handle = init_tf.inputs[0]
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
self._key_type = key_type if key_type is not None else tf.int64
|
|
130
|
+
self._handle_dtype = dtype if dtype is not None else tf.float32
|
|
131
|
+
self._dimension = dimension
|
|
132
|
+
self._mode = mode
|
|
133
|
+
self._config = json.dumps(kwargs)
|
|
134
|
+
self._config_dict = kwargs
|
|
135
|
+
if var_type == 'hybrid' and self._key_type != tf.int64:
|
|
136
|
+
raise NotImplementedError(
|
|
137
|
+
'only key_type tf.int64 is supported in HKV backend')
|
|
138
|
+
if name is None:
|
|
139
|
+
global dynamic_variable_count
|
|
140
|
+
name = 'sok_dynamic_Variable_' + str(dynamic_variable_count)
|
|
141
|
+
dynamic_variable_count += 1
|
|
142
|
+
var_type = 'hbm' if var_type is None else var_type
|
|
143
|
+
self._var_type = var_type
|
|
144
|
+
self._base = super(DynamicVariable, self)
|
|
145
|
+
self._base.__init__(
|
|
146
|
+
initial_value=[[0.0] * dimension],
|
|
147
|
+
trainable=trainable,
|
|
148
|
+
name=name + '/proxy',
|
|
149
|
+
dtype=self._handle_dtype,
|
|
150
|
+
constraint=constraint,
|
|
151
|
+
distribute_strategy=None,
|
|
152
|
+
synchronization=None,
|
|
153
|
+
aggregation=None,
|
|
154
|
+
shape=[None, dimension],
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
with ops.init_scope():
|
|
158
|
+
# name = "DynamicVariable" if name is None else name
|
|
159
|
+
with ops.name_scope(name) as name_scope:
|
|
160
|
+
self._dummy_name = ops.name_from_scope_name(name_scope)
|
|
161
|
+
if context.executing_eagerly():
|
|
162
|
+
self._dummy_name = '%s_%d' % (name, ops.uid())
|
|
163
|
+
with ops.NullContextmanager():
|
|
164
|
+
shape = [None, dimension]
|
|
165
|
+
initializer = '' if initializer is None else initializer
|
|
166
|
+
self._initializer = initializer
|
|
167
|
+
handle = dynamic_variable_ops.dummy_var_handle(
|
|
168
|
+
container='DummyVariableContainer',
|
|
169
|
+
shared_name=self._dummy_name,
|
|
170
|
+
key_type=self._key_type,
|
|
171
|
+
dtype=self._handle_dtype,
|
|
172
|
+
shape=shape,
|
|
173
|
+
)
|
|
174
|
+
if type(initializer) is str:
|
|
175
|
+
init_op = dynamic_variable_ops.dummy_var_initialize(
|
|
176
|
+
handle,
|
|
177
|
+
initializer=initializer,
|
|
178
|
+
var_type=var_type,
|
|
179
|
+
unique_name=self._dummy_name,
|
|
180
|
+
key_type=self._key_type,
|
|
181
|
+
dtype=self._handle_dtype,
|
|
182
|
+
config=self._config,
|
|
183
|
+
)
|
|
184
|
+
else:
|
|
185
|
+
with tf.control_dependencies([initializer._initializer_op]):
|
|
186
|
+
initial_val = initializer.read_value()
|
|
187
|
+
init_op = dynamic_variable_ops.dummy_var_initialize(
|
|
188
|
+
handle,
|
|
189
|
+
initializer=initial_val,
|
|
190
|
+
var_type=var_type,
|
|
191
|
+
unique_name=self._dummy_name,
|
|
192
|
+
key_type=self._key_type,
|
|
193
|
+
dtype=self._handle_dtype,
|
|
194
|
+
config=self._config,
|
|
195
|
+
)
|
|
196
|
+
# TODO: Add is_initialized_op
|
|
197
|
+
# is_initialized_op = ops.convert_to_tensor(True)
|
|
198
|
+
|
|
199
|
+
self._tf_handle = self._handle
|
|
200
|
+
self._dummy_handle = handle
|
|
201
|
+
# Note that the default handle will be sok's handle
|
|
202
|
+
self._handle = self._dummy_handle
|
|
203
|
+
self._initializer_op = tf.group([self._initializer_op, init_op])
|
|
204
|
+
# self._is_initialized_op = tf.group([self._is_initialized_op, is_initialized_op])
|
|
205
|
+
|
|
206
|
+
handle_data = (
|
|
207
|
+
resource_variable_ops.cpp_shape_inference_pb2.CppShapeInferenceResult
|
|
208
|
+
.HandleData())
|
|
209
|
+
handle_data.is_set = True
|
|
210
|
+
handle_data.shape_and_type.append(
|
|
211
|
+
resource_variable_ops.cpp_shape_inference_pb2.CppShapeInferenceResult
|
|
212
|
+
.HandleShapeAndType(
|
|
213
|
+
shape=self.shape.as_proto(), dtype=self.dtype.as_datatype_enum))
|
|
214
|
+
resource_variable_ops._set_handle_shapes_and_types(
|
|
215
|
+
self._handle,
|
|
216
|
+
handle_data,
|
|
217
|
+
graph_mode=False if context.executing_eagerly() else True)
|
|
218
|
+
|
|
219
|
+
def is_static(self):
|
|
220
|
+
return self._handle is self._tf_handle
|
|
221
|
+
|
|
222
|
+
def to_static(self, indices, lookup_only=False):
|
|
223
|
+
if not self.is_static() and self._indices is None:
|
|
224
|
+
buffer = self.sparse_read(indices, lookup_only)
|
|
225
|
+
self._indices = indices
|
|
226
|
+
self._handle = self._tf_handle
|
|
227
|
+
return self.assign(buffer)
|
|
228
|
+
else:
|
|
229
|
+
raise RuntimeError('to_static() must be called in dynamic mode.')
|
|
230
|
+
|
|
231
|
+
def to_dynamic(self):
|
|
232
|
+
if self.is_static():
|
|
233
|
+
buffer = self.read_value()
|
|
234
|
+
sparse_delta = ops.IndexedSlices(buffer, self._indices, self.shape)
|
|
235
|
+
self._indices = None
|
|
236
|
+
self._handle = self._dummy_handle
|
|
237
|
+
return self.scatter_update(sparse_delta)
|
|
238
|
+
else:
|
|
239
|
+
raise RuntimeError('to_dynamic() must be called in static mode.')
|
|
240
|
+
|
|
241
|
+
@property
|
|
242
|
+
def name(self):
|
|
243
|
+
return self._dummy_handle.name
|
|
244
|
+
|
|
245
|
+
def __repr__(self):
|
|
246
|
+
if self.is_static():
|
|
247
|
+
return self._base.__repr__()
|
|
248
|
+
return "<sok.DynamicVariable '%s' shape=%s dtype=%s>" % (
|
|
249
|
+
self._dummy_name,
|
|
250
|
+
self.shape,
|
|
251
|
+
self.dtype.name,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
@property
|
|
255
|
+
def size(self):
|
|
256
|
+
return dynamic_variable_ops.dummy_var_shape(
|
|
257
|
+
self._dummy_handle, key_type=self._key_type, dtype=self._handle_dtype)
|
|
258
|
+
|
|
259
|
+
@property
|
|
260
|
+
def indices(self):
|
|
261
|
+
return self._indices
|
|
262
|
+
|
|
263
|
+
@property
|
|
264
|
+
def dimension(self):
|
|
265
|
+
return self._dimension
|
|
266
|
+
|
|
267
|
+
def get_shape(self):
|
|
268
|
+
return [self._dimension]
|
|
269
|
+
|
|
270
|
+
@property
|
|
271
|
+
def key_type(self):
|
|
272
|
+
return self._key_type
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def handle_dtype(self):
|
|
276
|
+
return self._handle_dtype
|
|
277
|
+
|
|
278
|
+
@property
|
|
279
|
+
def backend_type(self):
|
|
280
|
+
return self._var_type
|
|
281
|
+
|
|
282
|
+
@property
|
|
283
|
+
def config_dict(self):
|
|
284
|
+
return self._config_dict
|
|
285
|
+
|
|
286
|
+
@property
|
|
287
|
+
def mode(self):
|
|
288
|
+
return self._mode
|
|
289
|
+
|
|
290
|
+
@property
|
|
291
|
+
def num_gpus(self):
|
|
292
|
+
return num_gpus()
|
|
293
|
+
|
|
294
|
+
@property
|
|
295
|
+
def initializer_str(self):
|
|
296
|
+
return self._initializer
|
|
297
|
+
|
|
298
|
+
def key_map(self, indices):
|
|
299
|
+
return indices
|
|
300
|
+
|
|
301
|
+
# -------------------------------------------------------------------------
|
|
302
|
+
# Methods supported both in static mode and dynamic mode
|
|
303
|
+
# -------------------------------------------------------------------------
|
|
304
|
+
|
|
305
|
+
def sparse_read(self, indices, name=None, lookup_only=False):
|
|
306
|
+
if self.is_static():
|
|
307
|
+
return self._base.sparse_read(indices, name)
|
|
308
|
+
|
|
309
|
+
variable_accessed(self)
|
|
310
|
+
if indices.dtype == tf.int32:
|
|
311
|
+
indices = tf.cast(indices, tf.int64)
|
|
312
|
+
return dynamic_variable_ops.dummy_var_sparse_read(
|
|
313
|
+
self._dummy_handle,
|
|
314
|
+
indices,
|
|
315
|
+
dtype=self._handle_dtype,
|
|
316
|
+
lookup_only=lookup_only)
|
|
317
|
+
|
|
318
|
+
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
|
319
|
+
if self.is_static():
|
|
320
|
+
return self._base.scatter_sub(sparse_delta, use_locking, name)
|
|
321
|
+
if not isinstance(sparse_delta, ops.IndexedSlices):
|
|
322
|
+
raise TypeError('sparse_delta is not IndexedSlices: %s' % sparse_delta)
|
|
323
|
+
return dynamic_variable_ops.dummy_var_scatter_add(
|
|
324
|
+
self._dummy_handle,
|
|
325
|
+
sparse_delta.indices,
|
|
326
|
+
ops.convert_to_tensor(-sparse_delta.values, self.dtype),
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
|
330
|
+
if self.is_static():
|
|
331
|
+
return self._base.scatter_add(sparse_delta, use_locking, name)
|
|
332
|
+
if not isinstance(sparse_delta, ops.IndexedSlices):
|
|
333
|
+
raise TypeError('sparse_delta is not IndexedSlices: %s' % sparse_delta)
|
|
334
|
+
return dynamic_variable_ops.dummy_var_scatter_add(
|
|
335
|
+
self._dummy_handle,
|
|
336
|
+
sparse_delta.indices,
|
|
337
|
+
ops.convert_to_tensor(sparse_delta.values, self.dtype),
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
|
341
|
+
if self.is_static():
|
|
342
|
+
return self._base.scatter_update(sparse_delta, use_locking, name)
|
|
343
|
+
if not isinstance(sparse_delta, ops.IndexedSlices):
|
|
344
|
+
raise TypeError('sparse_delta is not IndexedSlices: %s' % sparse_delta)
|
|
345
|
+
return dynamic_variable_ops.dummy_var_scatter_update(
|
|
346
|
+
self._dummy_handle,
|
|
347
|
+
sparse_delta.indices,
|
|
348
|
+
ops.convert_to_tensor(sparse_delta.values, self.dtype),
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
# -------------------------------------------------------------------------
|
|
352
|
+
# Methods not supported both in static mode and dynamic mode
|
|
353
|
+
# -------------------------------------------------------------------------
|
|
354
|
+
|
|
355
|
+
def __deepcopy__(self, *args, **kwargs):
|
|
356
|
+
raise NotImplementedError('__deepcopy__() is not supported.')
|
|
357
|
+
|
|
358
|
+
def __reduce__(self, *args, **kwargs):
|
|
359
|
+
raise NotImplementedError('__reduce__() is not supported.')
|
|
360
|
+
|
|
361
|
+
def to_proto(self, *args, **kwargs):
|
|
362
|
+
return super(DynamicVariable, self).to_proto(*args, **kwargs)
|
|
363
|
+
# raise NotImplementedError("to_proto() is not supported.")
|
|
364
|
+
|
|
365
|
+
@staticmethod
|
|
366
|
+
def from_proto(variable_def, import_scope=None):
|
|
367
|
+
if '/DummyVarHandle' in variable_def.variable_name:
|
|
368
|
+
return DynamicVariable(
|
|
369
|
+
dimension=0, variable_def=variable_def, import_scope=import_scope)
|
|
370
|
+
else:
|
|
371
|
+
return _resource_var_from_proto(variable_def, import_scope)
|
|
372
|
+
# raise NotImplementedError("from_proto() is not supported.")
|
|
373
|
+
|
|
374
|
+
def set_shape(self, *args, **kwargs):
|
|
375
|
+
raise NotImplementedError('set_shape() is not supported.')
|
|
376
|
+
|
|
377
|
+
# -------------------------------------------------------------------------
|
|
378
|
+
# Methods only supported in static mode
|
|
379
|
+
# -------------------------------------------------------------------------
|
|
380
|
+
|
|
381
|
+
def is_initialized(self, name):
|
|
382
|
+
return True
|
|
383
|
+
if self.is_static():
|
|
384
|
+
return self._base.is_initialized(name)
|
|
385
|
+
raise NotImplementedError(
|
|
386
|
+
'is_initialized() is not supported in dynamic mode.')
|
|
387
|
+
|
|
388
|
+
def _read_variable_op(self):
|
|
389
|
+
if self.is_static():
|
|
390
|
+
return self._base._read_variable_op()
|
|
391
|
+
raise NotImplementedError(
|
|
392
|
+
'_read_variable_op() is not supported in dynamic mode.')
|
|
393
|
+
|
|
394
|
+
def value(self):
|
|
395
|
+
if self.is_static():
|
|
396
|
+
return self._base.value()
|
|
397
|
+
raise NotImplementedError('value() is not supported in dynamic mode.')
|
|
398
|
+
|
|
399
|
+
def _dense_var_to_tensor(self, *args, **kwargs):
|
|
400
|
+
if self.is_static():
|
|
401
|
+
return self._base._dense_var_to_tensor(*args, **kwargs)
|
|
402
|
+
raise NotImplementedError(
|
|
403
|
+
'_dense_var_to_tensor() is not supported in dynamic mode.')
|
|
404
|
+
|
|
405
|
+
def _gather_saveables_for_checkpoint(self):
|
|
406
|
+
if self.is_static():
|
|
407
|
+
return self._base._gather_saveables_for_checkpoint()
|
|
408
|
+
raise NotImplementedError(
|
|
409
|
+
'_gather_saveables_for_checkpoint() is not supported in dynamic mode.')
|
|
410
|
+
|
|
411
|
+
def gather_nd(self, *args, **kwargs):
|
|
412
|
+
if self.is_static():
|
|
413
|
+
return self._base.gather_nd(*args, **kwargs)
|
|
414
|
+
raise NotImplementedError('gather_nd() is not supported in dynamic mode.')
|
|
415
|
+
|
|
416
|
+
def assign_add(self, *args, **kwargs):
|
|
417
|
+
if self.is_static():
|
|
418
|
+
return self._base.assign_add(*args, **kwargs)
|
|
419
|
+
raise NotImplementedError('assign_add() is not supported in dynamic mode.')
|
|
420
|
+
|
|
421
|
+
def assign(self, *args, **kwargs):
|
|
422
|
+
if self.is_static():
|
|
423
|
+
return self._base.assign(*args, **kwargs)
|
|
424
|
+
raise NotImplementedError('assign() is not supported in dynamic mode.')
|
|
425
|
+
|
|
426
|
+
def scatter_max(self, *args, **kwargs):
|
|
427
|
+
if self.is_static():
|
|
428
|
+
return self._base.scatter_max(*args, **kwargs)
|
|
429
|
+
raise NotImplementedError('scatter_max() is not supported in dynamic mode.')
|
|
430
|
+
|
|
431
|
+
def scatter_min(self, *args, **kwargs):
|
|
432
|
+
if self.is_static():
|
|
433
|
+
return self._base.scatter_min(*args, **kwargs)
|
|
434
|
+
raise NotImplementedError('scatter_min() is not supported in dynamic mode.')
|
|
435
|
+
|
|
436
|
+
def scatter_mul(self, *args, **kwargs):
|
|
437
|
+
if self.is_static():
|
|
438
|
+
return self._base.scatter_mul(*args, **kwargs)
|
|
439
|
+
raise NotImplementedError('scatter_mul() is not supported in dynamic mode.')
|
|
440
|
+
|
|
441
|
+
def scatter_dim(self, *args, **kwargs):
|
|
442
|
+
if self.is_static():
|
|
443
|
+
return self._base.scatter_dim(*args, **kwargs)
|
|
444
|
+
raise NotImplementedError('scatter_dim() is not supported in dynamic mode.')
|
|
445
|
+
|
|
446
|
+
def batch_scatter_update(self, *args, **kwargs):
|
|
447
|
+
if self.is_static():
|
|
448
|
+
return self._base.batch_scatter_update(*args, **kwargs)
|
|
449
|
+
raise NotImplementedError(
|
|
450
|
+
'batch_scatter_update() is not supported in dynamic mode.')
|
|
451
|
+
|
|
452
|
+
def scatter_nd_sub(self, *args, **kwargs):
|
|
453
|
+
if self.is_static():
|
|
454
|
+
return self._base.scatter_nd_sub(*args, **kwargs)
|
|
455
|
+
raise NotImplementedError(
|
|
456
|
+
'scatter_nd_sub() is not supported in dynamic mode.')
|
|
457
|
+
|
|
458
|
+
def scatter_nd_update(self, *args, **kwargs):
|
|
459
|
+
if self.is_static():
|
|
460
|
+
return self._base.scatter_nd_update(*args, **kwargs)
|
|
461
|
+
raise NotImplementedError(
|
|
462
|
+
'scatter_nd_update() is not supported in dynamic mode.')
|
|
463
|
+
|
|
464
|
+
def _strided_slice_assign(self, *args, **kwargs):
|
|
465
|
+
if self.is_static():
|
|
466
|
+
return self._base._strided_slice_assign(*args, **kwargs)
|
|
467
|
+
raise NotImplementedError(
|
|
468
|
+
'_strided_slice_assign() is not supported in dynamic mode.')
|
|
469
|
+
|
|
470
|
+
def __int__(self, *args, **kwargs):
|
|
471
|
+
if self.is_static():
|
|
472
|
+
return self._base.__int__(*args, **kwargs)
|
|
473
|
+
raise NotImplementedError('__int__() is not supported in dynamic mode.')
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
ResourceVariable.from_proto = DynamicVariable.from_proto
|
|
477
|
+
|
|
478
|
+
# @tf.RegisterGradient("DummyVarSparseRead")
|
|
479
|
+
# def _SparseReadGrad(op, grad):
|
|
480
|
+
# """Gradient for sparse_read."""
|
|
481
|
+
# handle = op.inputs[0]
|
|
482
|
+
# indices = op.inputs[1]
|
|
483
|
+
# key_type = op.get_attr("key_type")
|
|
484
|
+
# dtype = op.get_attr("dtype")
|
|
485
|
+
# variable_shape = dynamic_variable_ops.dummy_var_shape(handle, key_type=key_type, dtype=dtype)
|
|
486
|
+
# size = array_ops.expand_dims(array_ops.size(indices), 0)
|
|
487
|
+
# values_shape = array_ops.concat([size, variable_shape[1:]], 0)
|
|
488
|
+
# grad = array_ops.reshape(grad, values_shape)
|
|
489
|
+
# indices = array_ops.reshape(indices, size)
|
|
490
|
+
# return (ops.IndexedSlices(grad, indices, variable_shape), None)
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def export(var):
|
|
494
|
+
"""Abbreviated as ``sok.experiment.export``.
|
|
495
|
+
|
|
496
|
+
Export the indices and value tensor from the given variable.
|
|
497
|
+
|
|
498
|
+
Parameters
|
|
499
|
+
----------
|
|
500
|
+
var: sok.DynamicVariable
|
|
501
|
+
The variable to extract indices and values.
|
|
502
|
+
|
|
503
|
+
Returns
|
|
504
|
+
-------
|
|
505
|
+
indices: tf.Tensor
|
|
506
|
+
The indices of the given variable.
|
|
507
|
+
|
|
508
|
+
values: tf.Tensor
|
|
509
|
+
the values of the given variable.
|
|
510
|
+
"""
|
|
511
|
+
if isinstance(var, DynamicVariable):
|
|
512
|
+
indices, values = dynamic_variable_ops.dummy_var_export(
|
|
513
|
+
var.handle, key_type=var.key_type, dtype=var.handle_dtype)
|
|
514
|
+
with tf.device('CPU'):
|
|
515
|
+
indices = tf.identity(indices)
|
|
516
|
+
values = tf.identity(values)
|
|
517
|
+
return indices, values
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
def assign(var, indices, values):
|
|
521
|
+
"""Abbreviated as ``sok.experiment.assign``.
|
|
522
|
+
|
|
523
|
+
Assign the indices and value tensor to the target variable.
|
|
524
|
+
|
|
525
|
+
Parameters
|
|
526
|
+
----------
|
|
527
|
+
var: sok.DynamicVariable
|
|
528
|
+
The target variable of assign.
|
|
529
|
+
|
|
530
|
+
indices: tf.Tensor
|
|
531
|
+
indices to be assigned to the variable.
|
|
532
|
+
|
|
533
|
+
values: tf.Tensor
|
|
534
|
+
values to be assigned to the variable
|
|
535
|
+
|
|
536
|
+
Returns
|
|
537
|
+
-------
|
|
538
|
+
variable: sok.DynamicVariable
|
|
539
|
+
"""
|
|
540
|
+
if isinstance(var, DynamicVariable):
|
|
541
|
+
tf.cast(indices, var._key_type)
|
|
542
|
+
return dynamic_variable_ops.dummy_var_assign(var.handle, indices, values)
|