easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) 2022, NVIDIA CORPORATION.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
#
|
|
16
|
+
|
|
17
|
+
import tensorflow as tf
|
|
18
|
+
from tensorflow.python.eager import context
|
|
19
|
+
# from tensorflow.python.framework import dtypes
|
|
20
|
+
from tensorflow.python.framework import ops
|
|
21
|
+
# from tensorflow.python.ops import control_flow_ops
|
|
22
|
+
from tensorflow.python.ops import array_ops
|
|
23
|
+
from tensorflow.python.ops import gradients
|
|
24
|
+
from tensorflow.python.ops import resource_variable_ops
|
|
25
|
+
from tensorflow.python.ops import state_ops
|
|
26
|
+
|
|
27
|
+
from easy_rec.python.compat.dynamic_variable import DynamicVariable
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def OptimizerWrapper(optimizer):
|
|
31
|
+
"""Abbreviated as ``sok.experiment.OptimizerWrapper``.
|
|
32
|
+
|
|
33
|
+
This is a wrapper for tensorflow optimizer so that it can update
|
|
34
|
+
dynamic_variable.DynamicVariable.
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
optimizer: tensorflow optimizer
|
|
39
|
+
The original tensorflow optimizer.
|
|
40
|
+
|
|
41
|
+
Example
|
|
42
|
+
-------
|
|
43
|
+
.. code-block:: python
|
|
44
|
+
|
|
45
|
+
import numpy as np
|
|
46
|
+
import tensorflow as tf
|
|
47
|
+
import horovod.tensorflow as hvd
|
|
48
|
+
from sparse_operation_kit import experiment as sok
|
|
49
|
+
|
|
50
|
+
v = dynamic_variable.DynamicVariable(dimension=3, initializer="13")
|
|
51
|
+
|
|
52
|
+
indices = tf.convert_to_tensor([0, 1, 2**40], dtype=tf.int64)
|
|
53
|
+
|
|
54
|
+
with tf.GradientTape() as tape:
|
|
55
|
+
embedding = tf.nn.embedding_lookup(v, indices)
|
|
56
|
+
print("embedding:", embedding)
|
|
57
|
+
loss = tf.reduce_sum(embedding)
|
|
58
|
+
|
|
59
|
+
grads = tape.gradient(loss, [v])
|
|
60
|
+
|
|
61
|
+
optimizer = tf.keras.optimizers.SGD(learning_rate=1.0)
|
|
62
|
+
optimizer = sok.OptimizerWrapper(optimizer)
|
|
63
|
+
optimizer.apply_gradients(zip(grads, [v]))
|
|
64
|
+
|
|
65
|
+
embedding = tf.nn.embedding_lookup(v, indices)
|
|
66
|
+
print("embedding:", embedding)
|
|
67
|
+
"""
|
|
68
|
+
# a specific code path for dl framework tf2.11.0
|
|
69
|
+
try:
|
|
70
|
+
if isinstance(optimizer, tf.keras.optimizers.legacy.Optimizer):
|
|
71
|
+
return OptimizerWrapperV2(optimizer)
|
|
72
|
+
except Exception:
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
if isinstance(optimizer, tf.keras.optimizers.Optimizer):
|
|
76
|
+
return OptimizerWrapperV2(optimizer)
|
|
77
|
+
else:
|
|
78
|
+
return OptimizerWrapperV1(optimizer)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class OptimizerWrapperV1(object):
|
|
82
|
+
|
|
83
|
+
def __init__(self, optimizer):
|
|
84
|
+
self._optimizer = optimizer
|
|
85
|
+
# slots
|
|
86
|
+
unused = tf.Variable([0.0],
|
|
87
|
+
dtype=tf.float32,
|
|
88
|
+
name='unused',
|
|
89
|
+
trainable=False)
|
|
90
|
+
self._optimizer._create_slots([unused])
|
|
91
|
+
names, slots = [], []
|
|
92
|
+
for name in self._optimizer.get_slot_names():
|
|
93
|
+
names.append(name)
|
|
94
|
+
slots.append(self._optimizer.get_slot(unused, name))
|
|
95
|
+
unused_key = self._var_key(unused)
|
|
96
|
+
for name in names:
|
|
97
|
+
assert unused_key in self._optimizer._slots[name]
|
|
98
|
+
self._optimizer._slots[name].pop(unused_key)
|
|
99
|
+
self._initial_vals = {}
|
|
100
|
+
for i, name in enumerate(names):
|
|
101
|
+
self._initial_vals[name] = slots[i]
|
|
102
|
+
# self._optimizer._prepare()
|
|
103
|
+
|
|
104
|
+
def compute_gradients(self,
|
|
105
|
+
loss,
|
|
106
|
+
var_list=None,
|
|
107
|
+
aggregation_method=None,
|
|
108
|
+
colocate_gradients_with_ops=False,
|
|
109
|
+
grad_loss=None):
|
|
110
|
+
self._loss = loss
|
|
111
|
+
tmp_grads = gradients.gradients(loss, var_list)
|
|
112
|
+
return list(zip(tmp_grads, var_list))
|
|
113
|
+
# TODO: the following routine does not work with DynamicVariable
|
|
114
|
+
# return self._optimizer.compute_gradients(loss=loss, var_list=var_list,
|
|
115
|
+
# # gate_gradients=gate_gradients,
|
|
116
|
+
# aggregation_method=aggregation_method,
|
|
117
|
+
# colocate_gradients_with_ops=colocate_gradients_with_ops,
|
|
118
|
+
# grad_loss=grad_loss)
|
|
119
|
+
|
|
120
|
+
def _var_key(self, var):
|
|
121
|
+
if isinstance(var, DynamicVariable):
|
|
122
|
+
return (var._tf_handle.op.graph, var._tf_handle.op.name)
|
|
123
|
+
else:
|
|
124
|
+
return (var.op.graph, var.op.name)
|
|
125
|
+
|
|
126
|
+
def _create_slots(self, vars):
|
|
127
|
+
for var in vars:
|
|
128
|
+
if isinstance(var, DynamicVariable):
|
|
129
|
+
self._create_slots_dynamic(var)
|
|
130
|
+
else:
|
|
131
|
+
self._optimizer._create_slots(var)
|
|
132
|
+
|
|
133
|
+
def _create_slots_dynamic(self, var):
|
|
134
|
+
key = self._var_key(var)
|
|
135
|
+
for slot_name in self._initial_vals:
|
|
136
|
+
if key not in self._optimizer._slots[slot_name]:
|
|
137
|
+
if var.backend_type == 'hbm':
|
|
138
|
+
with ops.colocate_with(var):
|
|
139
|
+
slot = DynamicVariable(
|
|
140
|
+
dimension=var.dimension,
|
|
141
|
+
initializer=self._initial_vals[slot_name],
|
|
142
|
+
name='DynamicSlot',
|
|
143
|
+
trainable=False)
|
|
144
|
+
else:
|
|
145
|
+
tmp_config = var.config_dict
|
|
146
|
+
# tmp_initializer = var.initializer_str
|
|
147
|
+
with ops.colocate_with(var):
|
|
148
|
+
slot = DynamicVariable(
|
|
149
|
+
dimension=var.dimension,
|
|
150
|
+
initializer=self._initial_vals[slot_name],
|
|
151
|
+
var_type=var.backend_type,
|
|
152
|
+
name='DynamicSlot',
|
|
153
|
+
trainable=False,
|
|
154
|
+
**tmp_config)
|
|
155
|
+
|
|
156
|
+
self._optimizer._slots[slot_name][key] = slot
|
|
157
|
+
|
|
158
|
+
def get_slot_names(self):
|
|
159
|
+
return self._optimizer.get_slot_names()
|
|
160
|
+
|
|
161
|
+
def get_slot(self, var, slot_name):
|
|
162
|
+
key = self._var_key(var)
|
|
163
|
+
return self._optimizer._slots[slot_name][key]
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def _slots(self):
|
|
167
|
+
return self._optimizer._slots
|
|
168
|
+
|
|
169
|
+
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
|
170
|
+
gradients = grads_and_vars
|
|
171
|
+
sparse_vars = [x for x in gradients if 'DynamicVariable' in str(type(x[1]))]
|
|
172
|
+
dense_vars = [
|
|
173
|
+
x for x in gradients if 'DynamicVariable' not in str(type(x[1]))
|
|
174
|
+
]
|
|
175
|
+
|
|
176
|
+
def _dummy_finish(update_ops, name_scope):
|
|
177
|
+
return update_ops
|
|
178
|
+
|
|
179
|
+
finish_func = self._optimizer._finish
|
|
180
|
+
self._optimizer._finish = _dummy_finish
|
|
181
|
+
with ops.control_dependencies([array_ops.identity(self._loss)]):
|
|
182
|
+
sparse_grad_updates = self.apply_sparse_gradients(sparse_vars, name=name)
|
|
183
|
+
|
|
184
|
+
dense_grad_updates = self._optimizer.apply_gradients(
|
|
185
|
+
dense_vars, global_step=None, name=name)
|
|
186
|
+
if sparse_grad_updates is not None and dense_grad_updates is not None:
|
|
187
|
+
grad_updates = sparse_grad_updates + dense_grad_updates
|
|
188
|
+
elif sparse_grad_updates is not None:
|
|
189
|
+
grad_updates = sparse_grad_updates
|
|
190
|
+
elif dense_grad_updates is not None:
|
|
191
|
+
grad_updates = dense_grad_updates
|
|
192
|
+
|
|
193
|
+
assert global_step is not None
|
|
194
|
+
with ops.control_dependencies([finish_func(grad_updates, 'update')]):
|
|
195
|
+
with ops.colocate_with(global_step):
|
|
196
|
+
if isinstance(global_step, resource_variable_ops.BaseResourceVariable):
|
|
197
|
+
# TODO(apassos): the implicit read in assign_add is slow; consider
|
|
198
|
+
# making it less so.
|
|
199
|
+
apply_updates = resource_variable_ops.assign_add_variable_op(
|
|
200
|
+
global_step.handle,
|
|
201
|
+
ops.convert_to_tensor(1, dtype=global_step.dtype),
|
|
202
|
+
name=name)
|
|
203
|
+
else:
|
|
204
|
+
apply_updates = state_ops.assign_add(global_step, 1, name=name)
|
|
205
|
+
|
|
206
|
+
if not context.executing_eagerly():
|
|
207
|
+
if isinstance(apply_updates, ops.Tensor):
|
|
208
|
+
apply_updates = apply_updates.op
|
|
209
|
+
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
|
|
210
|
+
if apply_updates not in train_op:
|
|
211
|
+
train_op.append(apply_updates)
|
|
212
|
+
|
|
213
|
+
return apply_updates
|
|
214
|
+
|
|
215
|
+
def apply_sparse_gradients(self, grads_and_vars, global_step=None, name=None):
|
|
216
|
+
# 1. Create slots and do sparse_read
|
|
217
|
+
to_static_ops = []
|
|
218
|
+
grad_list, var_list = [], []
|
|
219
|
+
for g, v in grads_and_vars:
|
|
220
|
+
if g is not None:
|
|
221
|
+
unique, indices = tf.unique(g.indices)
|
|
222
|
+
grad_list.append(ops.IndexedSlices(g.values, indices, g.dense_shape))
|
|
223
|
+
# TODO: Check multi-thread safety of DET
|
|
224
|
+
with tf.control_dependencies([g.values]):
|
|
225
|
+
to_static_ops.append(v.to_static(unique, False))
|
|
226
|
+
var_list.append(v)
|
|
227
|
+
key = self._var_key(v)
|
|
228
|
+
for slot_name in self._initial_vals:
|
|
229
|
+
if key not in self._optimizer._slots[slot_name]:
|
|
230
|
+
tmp_slot_var_name = v._dummy_handle.op.name + '/' + self._optimizer._name
|
|
231
|
+
if v.backend_type == 'hbm':
|
|
232
|
+
with ops.colocate_with(v):
|
|
233
|
+
slot = DynamicVariable(
|
|
234
|
+
dimension=v.dimension,
|
|
235
|
+
initializer=self._initial_vals[slot_name],
|
|
236
|
+
name=tmp_slot_var_name,
|
|
237
|
+
trainable=False,
|
|
238
|
+
)
|
|
239
|
+
else:
|
|
240
|
+
tmp_config = v.config_dict
|
|
241
|
+
# tmp_initializer = v.initializer_str
|
|
242
|
+
with ops.colocate_with(v):
|
|
243
|
+
slot = DynamicVariable(
|
|
244
|
+
dimension=v.dimension,
|
|
245
|
+
initializer=self._initial_vals[slot_name],
|
|
246
|
+
var_type=v.backend_type,
|
|
247
|
+
name=tmp_slot_var_name,
|
|
248
|
+
trainable=False,
|
|
249
|
+
**tmp_config)
|
|
250
|
+
|
|
251
|
+
self._optimizer._slots[slot_name][key] = slot
|
|
252
|
+
else:
|
|
253
|
+
slot = self._optimizer._slots[slot_name][key]
|
|
254
|
+
to_static_ops.append(slot.to_static(unique))
|
|
255
|
+
|
|
256
|
+
if len(grad_list) == 0:
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
# 3. Call tf-optimizer
|
|
260
|
+
with ops.control_dependencies(to_static_ops):
|
|
261
|
+
train_op = self._optimizer.apply_gradients(
|
|
262
|
+
zip(grad_list, var_list), global_step=global_step, name=name)
|
|
263
|
+
|
|
264
|
+
# 5. Write buffer back to dynamic variables
|
|
265
|
+
to_dynamic_ops = []
|
|
266
|
+
if not isinstance(train_op, list):
|
|
267
|
+
train_op = [train_op]
|
|
268
|
+
with ops.control_dependencies(train_op):
|
|
269
|
+
for v in var_list:
|
|
270
|
+
key = self._var_key(v)
|
|
271
|
+
to_dynamic_ops.append(v.to_dynamic())
|
|
272
|
+
for name in self._initial_vals:
|
|
273
|
+
slot = self._optimizer._slots[name][key]
|
|
274
|
+
to_dynamic_ops.append(slot.to_dynamic())
|
|
275
|
+
|
|
276
|
+
return to_dynamic_ops
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class OptimizerWrapperV2(object):
|
|
280
|
+
|
|
281
|
+
def __init__(self, optimizer):
|
|
282
|
+
self._optimizer = optimizer
|
|
283
|
+
# slots
|
|
284
|
+
if tf.__version__[0] == '1':
|
|
285
|
+
unused = tf.Variable([0.0],
|
|
286
|
+
name='unused',
|
|
287
|
+
trainable=False,
|
|
288
|
+
use_resource=True)
|
|
289
|
+
else:
|
|
290
|
+
unused = tf.Variable([0.0], name='unused', trainable=False)
|
|
291
|
+
self._optimizer._create_slots([unused])
|
|
292
|
+
names, slots = [], []
|
|
293
|
+
for name in self._optimizer.get_slot_names():
|
|
294
|
+
names.append(name)
|
|
295
|
+
slots.append(self._optimizer.get_slot(unused, name))
|
|
296
|
+
unused_key = self._var_key(unused)
|
|
297
|
+
if unused_key in self._optimizer._slots:
|
|
298
|
+
self._optimizer._slots.pop(unused_key)
|
|
299
|
+
self._initial_vals = {}
|
|
300
|
+
for i, name in enumerate(names):
|
|
301
|
+
self._initial_vals[name] = slots[i]
|
|
302
|
+
self._iterations = tf.Variable(0)
|
|
303
|
+
|
|
304
|
+
@property
|
|
305
|
+
def lr(self):
|
|
306
|
+
return self._optimizer.lr
|
|
307
|
+
|
|
308
|
+
def _create_slots(self, vars):
|
|
309
|
+
for tmp_var in vars:
|
|
310
|
+
if isinstance(tmp_var, DynamicVariable):
|
|
311
|
+
self._create_slots_dynamic(tmp_var)
|
|
312
|
+
else:
|
|
313
|
+
self._optimizer._create_slots(tmp_var)
|
|
314
|
+
|
|
315
|
+
def _create_slots_dynamic(self, var):
|
|
316
|
+
key = self._var_key(var)
|
|
317
|
+
if key not in self._optimizer._slots:
|
|
318
|
+
self._optimizer._slots[key] = {}
|
|
319
|
+
for slot_name in self._initial_vals:
|
|
320
|
+
if slot_name not in self._optimizer._slots[key]:
|
|
321
|
+
if var.backend_type == 'hbm':
|
|
322
|
+
slot = DynamicVariable(
|
|
323
|
+
dimension=var.dimension,
|
|
324
|
+
initializer=self._initial_vals[slot_name],
|
|
325
|
+
name='DynamicSlot',
|
|
326
|
+
trainable=False,
|
|
327
|
+
)
|
|
328
|
+
else:
|
|
329
|
+
tmp_config = var.config_dict
|
|
330
|
+
# tmp_initializer = var.initializer_str
|
|
331
|
+
slot = DynamicVariable(
|
|
332
|
+
dimension=var.dimension,
|
|
333
|
+
initializer=self._initial_vals[slot_name],
|
|
334
|
+
var_type=var.backend_type,
|
|
335
|
+
name='DynamicSlot',
|
|
336
|
+
trainable=False,
|
|
337
|
+
**tmp_config)
|
|
338
|
+
self._optimizer._slots[key][slot_name] = slot
|
|
339
|
+
|
|
340
|
+
def _var_key(self, var):
|
|
341
|
+
if hasattr(var, '_distributed_container'):
|
|
342
|
+
var = var._distributed_container()
|
|
343
|
+
if var._in_graph_mode:
|
|
344
|
+
return var._shared_name
|
|
345
|
+
return var._unique_id
|
|
346
|
+
|
|
347
|
+
def get_slot_names(self):
|
|
348
|
+
return self._optimizer.get_slot_names()
|
|
349
|
+
|
|
350
|
+
def get_slot(self, var, name):
|
|
351
|
+
return self._optimizer.get_slot(var, name)
|
|
352
|
+
|
|
353
|
+
@property
|
|
354
|
+
def _slots(self):
|
|
355
|
+
return self._optimizer._slots
|
|
356
|
+
|
|
357
|
+
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
|
358
|
+
# 1. Create slots and do sparse_read
|
|
359
|
+
to_static_ops = []
|
|
360
|
+
grad_list, var_list = [], []
|
|
361
|
+
for g, v in grads_and_vars:
|
|
362
|
+
if g is not None:
|
|
363
|
+
unique, indices = tf.unique(g.indices)
|
|
364
|
+
grad_list.append(ops.IndexedSlices(g.values, indices, g.dense_shape))
|
|
365
|
+
# TODO: Check multi-thread safety of DET
|
|
366
|
+
# with tf.control_dependencies([g.values]):
|
|
367
|
+
to_static_ops.append(v.to_static(unique))
|
|
368
|
+
var_list.append(v)
|
|
369
|
+
key = self._var_key(v)
|
|
370
|
+
if key not in self._optimizer._slots:
|
|
371
|
+
self._optimizer._slots[key] = {}
|
|
372
|
+
for slot_name in self._initial_vals:
|
|
373
|
+
if slot_name not in self._optimizer._slots[key]:
|
|
374
|
+
if v.backend_type == 'hbm':
|
|
375
|
+
slot = DynamicVariable(
|
|
376
|
+
dimension=v.dimension,
|
|
377
|
+
initializer=self._initial_vals[slot_name],
|
|
378
|
+
name='DynamicSlot',
|
|
379
|
+
trainable=False,
|
|
380
|
+
)
|
|
381
|
+
else:
|
|
382
|
+
tmp_config = v.config_dict
|
|
383
|
+
# tmp_initializer = v.initializer_str
|
|
384
|
+
slot = DynamicVariable(
|
|
385
|
+
dimension=v.dimension,
|
|
386
|
+
initializer=self._initial_vals[slot_name],
|
|
387
|
+
var_type=v.backend_type,
|
|
388
|
+
name='DynamicSlot',
|
|
389
|
+
trainable=False,
|
|
390
|
+
**tmp_config)
|
|
391
|
+
|
|
392
|
+
self._optimizer._slots[key][slot_name] = slot
|
|
393
|
+
else:
|
|
394
|
+
slot = self._optimizer._slots[key][slot_name]
|
|
395
|
+
to_static_ops.append(slot.to_static(unique))
|
|
396
|
+
|
|
397
|
+
if len(grad_list) == 0:
|
|
398
|
+
return
|
|
399
|
+
|
|
400
|
+
# 2. Switch iterations
|
|
401
|
+
iterations = self._optimizer._iterations
|
|
402
|
+
self._optimizer._iterations = self._iterations
|
|
403
|
+
|
|
404
|
+
# 3. Call tf-optimizer
|
|
405
|
+
with tf.control_dependencies(to_static_ops):
|
|
406
|
+
train_op = self._optimizer.apply_gradients(
|
|
407
|
+
zip(grad_list, var_list), name=name)
|
|
408
|
+
|
|
409
|
+
# 4. Switch iterations
|
|
410
|
+
self._optimizer._iterations = iterations
|
|
411
|
+
|
|
412
|
+
# 5. Write buffer back to dynamic variables
|
|
413
|
+
to_dynamic_ops = []
|
|
414
|
+
with tf.control_dependencies([train_op]):
|
|
415
|
+
for v in var_list:
|
|
416
|
+
key = self._var_key(v)
|
|
417
|
+
to_dynamic_ops.append(v.to_dynamic())
|
|
418
|
+
for name in self._initial_vals:
|
|
419
|
+
slot = self._optimizer._slots[key][name]
|
|
420
|
+
to_dynamic_ops.append(slot.to_dynamic())
|
|
421
|
+
return tf.group(to_dynamic_ops)
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
class SGD(object):
|
|
425
|
+
|
|
426
|
+
def __init__(self, lr):
|
|
427
|
+
self._lr = tf.Variable(lr)
|
|
428
|
+
|
|
429
|
+
@property
|
|
430
|
+
def lr(self):
|
|
431
|
+
return self._lr
|
|
432
|
+
|
|
433
|
+
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
|
434
|
+
train_ops = []
|
|
435
|
+
for g, v in grads_and_vars:
|
|
436
|
+
if g is not None:
|
|
437
|
+
scaled_g = ops.IndexedSlices(g.values * self._lr, g.indices,
|
|
438
|
+
g.dense_shape)
|
|
439
|
+
train_ops.append(v.scatter_sub(scaled_g))
|
|
440
|
+
return tf.group(train_ops)
|