easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
|
|
5
|
+
from easy_rec.python.layers import dnn
|
|
6
|
+
from easy_rec.python.layers import multihead_cross_attention
|
|
7
|
+
from easy_rec.python.utils.activation import get_activation
|
|
8
|
+
from easy_rec.python.utils.shape_utils import get_shape_list
|
|
9
|
+
|
|
10
|
+
if tf.__version__ >= '2.0':
|
|
11
|
+
tf = tf.compat.v1
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Uniter(object):
|
|
15
|
+
"""UNITER: UNiversal Image-TExt Representation Learning.
|
|
16
|
+
|
|
17
|
+
See the original paper:
|
|
18
|
+
https://arxiv.org/abs/1909.11740
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, model_config, feature_configs, features, uniter_config,
|
|
22
|
+
input_layer):
|
|
23
|
+
self._model_config = uniter_config
|
|
24
|
+
tower_num = 0
|
|
25
|
+
self._img_features = None
|
|
26
|
+
if input_layer.has_group('image'):
|
|
27
|
+
self._img_features, _ = input_layer(features, 'image')
|
|
28
|
+
tower_num += 1
|
|
29
|
+
self._general_features = None
|
|
30
|
+
if input_layer.has_group('general'):
|
|
31
|
+
self._general_features, _ = input_layer(features, 'general')
|
|
32
|
+
tower_num += 1
|
|
33
|
+
self._txt_seq_features = None
|
|
34
|
+
if input_layer.has_group('text'):
|
|
35
|
+
self._txt_seq_features, _, _ = input_layer(
|
|
36
|
+
features, 'text', is_combine=False)
|
|
37
|
+
tower_num += 1
|
|
38
|
+
self._use_token_type = True if tower_num > 1 else False
|
|
39
|
+
self._other_features = None
|
|
40
|
+
if input_layer.has_group('other'): # e.g. statistical feature
|
|
41
|
+
self._other_features, _ = input_layer(features, 'other')
|
|
42
|
+
tower_num += 1
|
|
43
|
+
assert tower_num > 0, 'there must be one of the feature groups: [image, text, general, other]'
|
|
44
|
+
|
|
45
|
+
self._general_feature_num = 0
|
|
46
|
+
self._txt_feature_num, self._img_feature_num = 0, 0
|
|
47
|
+
general_feature_names = set()
|
|
48
|
+
img_feature_names, txt_feature_names = set(), set()
|
|
49
|
+
for fea_group in model_config.feature_groups:
|
|
50
|
+
if fea_group.group_name == 'general':
|
|
51
|
+
self._general_feature_num = len(fea_group.feature_names)
|
|
52
|
+
general_feature_names = set(fea_group.feature_names)
|
|
53
|
+
assert self._general_feature_num == len(general_feature_names), (
|
|
54
|
+
'there are duplicate features in `general` feature group')
|
|
55
|
+
elif fea_group.group_name == 'image':
|
|
56
|
+
self._img_feature_num = len(fea_group.feature_names)
|
|
57
|
+
img_feature_names = set(fea_group.feature_names)
|
|
58
|
+
assert self._img_feature_num == len(img_feature_names), (
|
|
59
|
+
'there are duplicate features in `image` feature group')
|
|
60
|
+
elif fea_group.group_name == 'text':
|
|
61
|
+
self._txt_feature_num = len(fea_group.feature_names)
|
|
62
|
+
txt_feature_names = set(fea_group.feature_names)
|
|
63
|
+
assert self._txt_feature_num == len(txt_feature_names), (
|
|
64
|
+
'there are duplicate features in `text` feature group')
|
|
65
|
+
|
|
66
|
+
if self._txt_feature_num > 1 or self._img_feature_num > 1:
|
|
67
|
+
self._use_token_type = True
|
|
68
|
+
self._token_type_vocab_size = self._txt_feature_num
|
|
69
|
+
if self._img_feature_num > 0:
|
|
70
|
+
self._token_type_vocab_size += 1
|
|
71
|
+
if self._general_feature_num > 0:
|
|
72
|
+
self._token_type_vocab_size += 1
|
|
73
|
+
|
|
74
|
+
max_seq_len = 0
|
|
75
|
+
txt_fea_emb_dim_list = []
|
|
76
|
+
general_emb_dim_list = []
|
|
77
|
+
img_fea_emb_dim_list = []
|
|
78
|
+
for feature_config in feature_configs:
|
|
79
|
+
fea_name = feature_config.input_names[0]
|
|
80
|
+
if feature_config.HasField('feature_name'):
|
|
81
|
+
fea_name = feature_config.feature_name
|
|
82
|
+
if fea_name in img_feature_names:
|
|
83
|
+
img_fea_emb_dim_list.append(feature_config.raw_input_dim)
|
|
84
|
+
if fea_name in general_feature_names:
|
|
85
|
+
general_emb_dim_list.append(feature_config.embedding_dim)
|
|
86
|
+
if fea_name in txt_feature_names:
|
|
87
|
+
txt_fea_emb_dim_list.append(feature_config.embedding_dim)
|
|
88
|
+
if feature_config.HasField('max_seq_len'):
|
|
89
|
+
assert feature_config.max_seq_len > 0, (
|
|
90
|
+
'feature config `max_seq_len` must be greater than 0 for feature: '
|
|
91
|
+
+ fea_name)
|
|
92
|
+
if feature_config.max_seq_len > max_seq_len:
|
|
93
|
+
max_seq_len = feature_config.max_seq_len
|
|
94
|
+
|
|
95
|
+
unique_dim_num = len(set(txt_fea_emb_dim_list))
|
|
96
|
+
assert unique_dim_num <= 1 and len(
|
|
97
|
+
txt_fea_emb_dim_list
|
|
98
|
+
) == self._txt_feature_num, (
|
|
99
|
+
'Uniter requires that all `text` feature dimensions must be consistent.'
|
|
100
|
+
)
|
|
101
|
+
unique_dim_num = len(set(img_fea_emb_dim_list))
|
|
102
|
+
assert unique_dim_num <= 1 and len(
|
|
103
|
+
img_fea_emb_dim_list
|
|
104
|
+
) == self._img_feature_num, (
|
|
105
|
+
'Uniter requires that all `image` feature dimensions must be consistent.'
|
|
106
|
+
)
|
|
107
|
+
unique_dim_num = len(set(general_emb_dim_list))
|
|
108
|
+
assert unique_dim_num <= 1 and len(
|
|
109
|
+
general_emb_dim_list
|
|
110
|
+
) == self._general_feature_num, (
|
|
111
|
+
'Uniter requires that all `general` feature dimensions must be consistent.'
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if self._txt_feature_num > 0 and uniter_config.use_position_embeddings:
|
|
115
|
+
assert uniter_config.max_position_embeddings > 0, (
|
|
116
|
+
'model config `max_position_embeddings` must be greater than 0. ')
|
|
117
|
+
assert uniter_config.max_position_embeddings >= max_seq_len, (
|
|
118
|
+
'model config `max_position_embeddings` must be greater than or equal to the maximum of all feature config '
|
|
119
|
+
'`max_seq_len`, which is %d' % max_seq_len)
|
|
120
|
+
|
|
121
|
+
self._img_emb_size = img_fea_emb_dim_list[0] if img_fea_emb_dim_list else 0
|
|
122
|
+
self._txt_emb_size = txt_fea_emb_dim_list[0] if txt_fea_emb_dim_list else 0
|
|
123
|
+
self._general_emb_size = general_emb_dim_list[
|
|
124
|
+
0] if general_emb_dim_list else 0
|
|
125
|
+
if self._img_features is not None:
|
|
126
|
+
assert self._img_emb_size > 0, '`image` feature dimensions must be greater than 0, set by `raw_input_dim`'
|
|
127
|
+
|
|
128
|
+
def text_embeddings(self, token_type_id):
|
|
129
|
+
all_txt_features = []
|
|
130
|
+
input_masks = []
|
|
131
|
+
hidden_size = self._model_config.hidden_size
|
|
132
|
+
if self._general_features is not None:
|
|
133
|
+
general_features = self._general_features
|
|
134
|
+
if self._general_emb_size != hidden_size:
|
|
135
|
+
# Run a linear projection of `hidden_size`
|
|
136
|
+
general_features = tf.reshape(
|
|
137
|
+
general_features, shape=[-1, self._general_emb_size])
|
|
138
|
+
general_features = tf.layers.dense(
|
|
139
|
+
general_features, hidden_size, name='txt_projection')
|
|
140
|
+
general_features = tf.reshape(
|
|
141
|
+
general_features, shape=[-1, self._general_feature_num, hidden_size])
|
|
142
|
+
|
|
143
|
+
batch_size = tf.shape(general_features)[0]
|
|
144
|
+
general_features = multihead_cross_attention.embedding_postprocessor(
|
|
145
|
+
general_features,
|
|
146
|
+
use_token_type=self._use_token_type,
|
|
147
|
+
token_type_ids=tf.ones(
|
|
148
|
+
shape=tf.stack([batch_size, self._general_feature_num]),
|
|
149
|
+
dtype=tf.int32) * token_type_id,
|
|
150
|
+
token_type_vocab_size=self._token_type_vocab_size,
|
|
151
|
+
reuse_token_type=tf.AUTO_REUSE,
|
|
152
|
+
use_position_embeddings=False,
|
|
153
|
+
dropout_prob=self._model_config.hidden_dropout_prob)
|
|
154
|
+
|
|
155
|
+
all_txt_features.append(general_features)
|
|
156
|
+
mask = tf.ones(
|
|
157
|
+
shape=tf.stack([batch_size, self._general_feature_num]),
|
|
158
|
+
dtype=tf.int32)
|
|
159
|
+
input_masks.append(mask)
|
|
160
|
+
|
|
161
|
+
if self._txt_seq_features is not None:
|
|
162
|
+
|
|
163
|
+
def dynamic_mask(x, max_len):
|
|
164
|
+
ones = tf.ones(shape=tf.stack([x]), dtype=tf.int32)
|
|
165
|
+
zeros = tf.zeros(shape=tf.stack([max_len - x]), dtype=tf.int32)
|
|
166
|
+
return tf.concat([ones, zeros], axis=0)
|
|
167
|
+
|
|
168
|
+
token_type_id += len(all_txt_features)
|
|
169
|
+
for i, (seq_fea, seq_len) in enumerate(self._txt_seq_features):
|
|
170
|
+
batch_size, max_seq_len, emb_size = get_shape_list(seq_fea, 3)
|
|
171
|
+
if emb_size != hidden_size:
|
|
172
|
+
seq_fea = tf.reshape(seq_fea, shape=[-1, emb_size])
|
|
173
|
+
seq_fea = tf.layers.dense(
|
|
174
|
+
seq_fea, hidden_size, name='txt_seq_projection_%d' % i)
|
|
175
|
+
seq_fea = tf.reshape(seq_fea, shape=[-1, max_seq_len, hidden_size])
|
|
176
|
+
|
|
177
|
+
seq_fea = multihead_cross_attention.embedding_postprocessor(
|
|
178
|
+
seq_fea,
|
|
179
|
+
use_token_type=self._use_token_type,
|
|
180
|
+
token_type_ids=tf.ones(
|
|
181
|
+
shape=tf.stack([batch_size, max_seq_len]), dtype=tf.int32) *
|
|
182
|
+
(i + token_type_id),
|
|
183
|
+
token_type_vocab_size=self._token_type_vocab_size,
|
|
184
|
+
reuse_token_type=tf.AUTO_REUSE,
|
|
185
|
+
use_position_embeddings=self._model_config.use_position_embeddings,
|
|
186
|
+
max_position_embeddings=self._model_config.max_position_embeddings,
|
|
187
|
+
position_embedding_name='txt_position_embeddings_%d' % i,
|
|
188
|
+
dropout_prob=self._model_config.hidden_dropout_prob)
|
|
189
|
+
all_txt_features.append(seq_fea)
|
|
190
|
+
|
|
191
|
+
input_mask = tf.map_fn(
|
|
192
|
+
fn=lambda t: dynamic_mask(t, max_seq_len),
|
|
193
|
+
elems=tf.to_int32(seq_len))
|
|
194
|
+
input_masks.append(input_mask)
|
|
195
|
+
|
|
196
|
+
return all_txt_features, input_masks
|
|
197
|
+
|
|
198
|
+
def image_embeddings(self):
|
|
199
|
+
if self._img_features is None:
|
|
200
|
+
return None
|
|
201
|
+
hidden_size = self._model_config.hidden_size
|
|
202
|
+
image_features = self._img_features
|
|
203
|
+
if self._img_emb_size != hidden_size:
|
|
204
|
+
# Run a linear projection of `hidden_size`
|
|
205
|
+
image_features = tf.reshape(
|
|
206
|
+
image_features, shape=[-1, self._img_emb_size])
|
|
207
|
+
image_features = tf.layers.dense(
|
|
208
|
+
image_features, hidden_size, name='img_projection')
|
|
209
|
+
image_features = tf.reshape(
|
|
210
|
+
image_features, shape=[-1, self._img_feature_num, hidden_size])
|
|
211
|
+
|
|
212
|
+
batch_size = tf.shape(image_features)[0]
|
|
213
|
+
img_fea = multihead_cross_attention.embedding_postprocessor(
|
|
214
|
+
image_features,
|
|
215
|
+
use_token_type=self._use_token_type,
|
|
216
|
+
token_type_ids=tf.zeros(
|
|
217
|
+
shape=tf.stack([batch_size, self._img_feature_num]),
|
|
218
|
+
dtype=tf.int32),
|
|
219
|
+
token_type_vocab_size=self._token_type_vocab_size,
|
|
220
|
+
reuse_token_type=tf.AUTO_REUSE,
|
|
221
|
+
use_position_embeddings=self._model_config.use_position_embeddings,
|
|
222
|
+
max_position_embeddings=self._model_config.max_position_embeddings,
|
|
223
|
+
position_embedding_name='img_position_embeddings',
|
|
224
|
+
dropout_prob=self._model_config.hidden_dropout_prob)
|
|
225
|
+
return img_fea
|
|
226
|
+
|
|
227
|
+
def __call__(self, is_training, *args, **kwargs):
|
|
228
|
+
if not is_training:
|
|
229
|
+
self._model_config.hidden_dropout_prob = 0.0
|
|
230
|
+
self._model_config.attention_probs_dropout_prob = 0.0
|
|
231
|
+
|
|
232
|
+
sub_modules = []
|
|
233
|
+
|
|
234
|
+
img_fea = self.image_embeddings()
|
|
235
|
+
start_token_id = 1 if self._img_feature_num > 0 else 0
|
|
236
|
+
txt_features, txt_masks = self.text_embeddings(start_token_id)
|
|
237
|
+
|
|
238
|
+
if img_fea is not None:
|
|
239
|
+
batch_size = tf.shape(img_fea)[0]
|
|
240
|
+
elif txt_features:
|
|
241
|
+
batch_size = tf.shape(txt_features[0])[0]
|
|
242
|
+
else:
|
|
243
|
+
batch_size = None
|
|
244
|
+
|
|
245
|
+
hidden_size = self._model_config.hidden_size
|
|
246
|
+
if batch_size is not None:
|
|
247
|
+
all_features = []
|
|
248
|
+
masks = []
|
|
249
|
+
cls_emb = tf.get_variable(name='cls_emb', shape=[1, 1, hidden_size])
|
|
250
|
+
cls_emb = tf.tile(cls_emb, [batch_size, 1, 1])
|
|
251
|
+
all_features.append(cls_emb)
|
|
252
|
+
|
|
253
|
+
mask = tf.ones(shape=tf.stack([batch_size, 1]), dtype=tf.int32)
|
|
254
|
+
masks.append(mask)
|
|
255
|
+
|
|
256
|
+
if img_fea is not None:
|
|
257
|
+
all_features.append(img_fea)
|
|
258
|
+
mask = tf.ones(
|
|
259
|
+
shape=tf.stack([batch_size, self._img_feature_num]), dtype=tf.int32)
|
|
260
|
+
masks.append(mask)
|
|
261
|
+
|
|
262
|
+
if txt_features:
|
|
263
|
+
all_features.extend(txt_features)
|
|
264
|
+
masks.extend(txt_masks)
|
|
265
|
+
|
|
266
|
+
all_fea = tf.concat(all_features, axis=1)
|
|
267
|
+
input_mask = tf.concat(masks, axis=1)
|
|
268
|
+
attention_mask = multihead_cross_attention.create_attention_mask_from_input_mask(
|
|
269
|
+
from_tensor=all_fea, to_mask=input_mask)
|
|
270
|
+
hidden_act = get_activation(self._model_config.hidden_act)
|
|
271
|
+
attention_fea = multihead_cross_attention.transformer_encoder(
|
|
272
|
+
all_fea,
|
|
273
|
+
hidden_size=hidden_size,
|
|
274
|
+
num_hidden_layers=self._model_config.num_hidden_layers,
|
|
275
|
+
num_attention_heads=self._model_config.num_attention_heads,
|
|
276
|
+
attention_mask=attention_mask,
|
|
277
|
+
intermediate_size=self._model_config.intermediate_size,
|
|
278
|
+
intermediate_act_fn=hidden_act,
|
|
279
|
+
hidden_dropout_prob=self._model_config.hidden_dropout_prob,
|
|
280
|
+
attention_probs_dropout_prob=self._model_config
|
|
281
|
+
.attention_probs_dropout_prob,
|
|
282
|
+
initializer_range=self._model_config.initializer_range,
|
|
283
|
+
name='uniter') # shape: [batch_size, seq_length, hidden_size]
|
|
284
|
+
print('attention_fea:', attention_fea.shape)
|
|
285
|
+
mm_fea = attention_fea[:, 0, :] # [CLS] feature
|
|
286
|
+
sub_modules.append(mm_fea)
|
|
287
|
+
|
|
288
|
+
if self._other_features is not None:
|
|
289
|
+
if self._model_config.HasField('other_feature_dnn'):
|
|
290
|
+
l2_reg = kwargs['l2_reg'] if 'l2_reg' in kwargs else 0
|
|
291
|
+
other_dnn_layer = dnn.DNN(self._model_config.other_feature_dnn, l2_reg,
|
|
292
|
+
'other_dnn', is_training)
|
|
293
|
+
other_fea = other_dnn_layer(self._other_features)
|
|
294
|
+
else:
|
|
295
|
+
other_fea = self._other_features
|
|
296
|
+
sub_modules.append(other_fea)
|
|
297
|
+
|
|
298
|
+
if len(sub_modules) == 1:
|
|
299
|
+
return sub_modules[0]
|
|
300
|
+
output = tf.concat(sub_modules, axis=-1)
|
|
301
|
+
return output
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Common util functions used by layers."""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
from __future__ import division
|
|
18
|
+
from __future__ import print_function
|
|
19
|
+
|
|
20
|
+
import json
|
|
21
|
+
|
|
22
|
+
from google.protobuf import struct_pb2
|
|
23
|
+
from google.protobuf.descriptor import FieldDescriptor
|
|
24
|
+
from tensorflow.python.framework import ops
|
|
25
|
+
from tensorflow.python.framework import sparse_tensor
|
|
26
|
+
from tensorflow.python.ops import variables
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
from tensorflow.python.ops import kv_variable_ops
|
|
30
|
+
except ImportError:
|
|
31
|
+
kv_variable_ops = None
|
|
32
|
+
|
|
33
|
+
ColumnNameInCollection = {}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _tensor_to_map(tensor):
|
|
37
|
+
return {
|
|
38
|
+
'node_path': tensor.name,
|
|
39
|
+
'shape': tensor.shape.as_list() if tensor.shape else None,
|
|
40
|
+
'dtype': tensor.dtype.name
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _tensor_to_tensorinfo(tensor):
|
|
45
|
+
tensor_info = {}
|
|
46
|
+
if isinstance(tensor, sparse_tensor.SparseTensor):
|
|
47
|
+
tensor_info['is_dense'] = False
|
|
48
|
+
tensor_info['values'] = _tensor_to_map(tensor.values)
|
|
49
|
+
tensor_info['indices'] = _tensor_to_map(tensor.indices)
|
|
50
|
+
tensor_info['dense_shape'] = _tensor_to_map(tensor.dense_shape)
|
|
51
|
+
else:
|
|
52
|
+
tensor_info['is_dense'] = True
|
|
53
|
+
tensor_info.update(_tensor_to_map(tensor))
|
|
54
|
+
return tensor_info
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def add_tensor_to_collection(collection_name, name, tensor):
|
|
58
|
+
tensor_info = _tensor_to_tensorinfo(tensor)
|
|
59
|
+
tensor_info['name'] = name
|
|
60
|
+
update_attr_to_collection(collection_name, tensor_info)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def append_tensor_to_collection(collection_name, name, key, tensor):
|
|
64
|
+
tensor_info = _tensor_to_tensorinfo(tensor)
|
|
65
|
+
append_attr_to_collection(collection_name, name, key, tensor_info)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _collection_item_key(col, name):
|
|
69
|
+
return '%d#%s' % (id(col), name)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _process_item(collection_name, name, func):
|
|
73
|
+
col = ops.get_collection_ref(collection_name)
|
|
74
|
+
item_found = {}
|
|
75
|
+
idx_found = -1
|
|
76
|
+
|
|
77
|
+
# add id(col) because col may re-new sometimes
|
|
78
|
+
key = _collection_item_key(col, name)
|
|
79
|
+
if key in ColumnNameInCollection:
|
|
80
|
+
idx_found = ColumnNameInCollection[key]
|
|
81
|
+
if idx_found >= len(col):
|
|
82
|
+
raise Exception(
|
|
83
|
+
'Find column name in collection failed: index out of range')
|
|
84
|
+
|
|
85
|
+
item_found = json.loads(col[idx_found])
|
|
86
|
+
if item_found['name'] != name:
|
|
87
|
+
raise Exception(
|
|
88
|
+
'Find column name in collection failed: item name not match')
|
|
89
|
+
func(item_found)
|
|
90
|
+
col[idx_found] = json.dumps(item_found)
|
|
91
|
+
else:
|
|
92
|
+
func(item_found)
|
|
93
|
+
col.append(json.dumps(item_found))
|
|
94
|
+
ColumnNameInCollection[key] = len(col) - 1
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def append_attr_to_collection(collection_name, name, key, value):
|
|
98
|
+
|
|
99
|
+
def append(item_found):
|
|
100
|
+
if key not in item_found:
|
|
101
|
+
item_found[key] = []
|
|
102
|
+
item_found[key].append(value)
|
|
103
|
+
|
|
104
|
+
_process_item(collection_name, name, append)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def update_attr_to_collection(collection_name, attrs):
|
|
108
|
+
|
|
109
|
+
def update(item_found):
|
|
110
|
+
item_found.update(attrs)
|
|
111
|
+
|
|
112
|
+
_process_item(collection_name, attrs['name'], update)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def unique_name_in_collection(collection_name, name):
|
|
116
|
+
col = ops.get_collection_ref(collection_name)
|
|
117
|
+
unique_name = name
|
|
118
|
+
index = 0
|
|
119
|
+
while True:
|
|
120
|
+
key = _collection_item_key(col, unique_name)
|
|
121
|
+
if key not in ColumnNameInCollection:
|
|
122
|
+
break
|
|
123
|
+
index += 1
|
|
124
|
+
unique_name = '%s_%d' % (name, index)
|
|
125
|
+
return unique_name
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def gen_embedding_attrs(column=None,
|
|
129
|
+
variable=None,
|
|
130
|
+
bucket_size=None,
|
|
131
|
+
combiner=None,
|
|
132
|
+
is_embedding_var=None):
|
|
133
|
+
attrs = dict()
|
|
134
|
+
attrs['name'] = column.name
|
|
135
|
+
attrs['bucket_size'] = bucket_size
|
|
136
|
+
attrs['combiner'] = combiner
|
|
137
|
+
attrs['is_embedding_var'] = is_embedding_var
|
|
138
|
+
attrs['weights_op_path'] = variable.name
|
|
139
|
+
if kv_variable_ops:
|
|
140
|
+
if isinstance(variable, kv_variable_ops.EmbeddingVariable):
|
|
141
|
+
attrs['is_embedding_var'] = True
|
|
142
|
+
attrs['embedding_var_keys'] = variable._shared_name + '-keys'
|
|
143
|
+
attrs['embedding_var_values'] = variable._shared_name + '-values'
|
|
144
|
+
elif (isinstance(variable, variables.PartitionedVariable)) and \
|
|
145
|
+
(isinstance(variable._get_variable_list()[0], kv_variable_ops.EmbeddingVariable)):
|
|
146
|
+
attrs['embedding_var_keys'] = [v._shared_name + '-keys' for v in variable]
|
|
147
|
+
attrs['embedding_var_values'] = [
|
|
148
|
+
v._shared_name + '-values' for v in variable
|
|
149
|
+
]
|
|
150
|
+
else:
|
|
151
|
+
attrs['is_embedding_var'] = False
|
|
152
|
+
else:
|
|
153
|
+
attrs['is_embedding_var'] = False
|
|
154
|
+
return attrs
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def mark_input_src(name, src_desc):
|
|
158
|
+
ops.add_to_collection(ops.GraphKeys.RANK_SERVICE_INPUT_SRC,
|
|
159
|
+
json.dumps({
|
|
160
|
+
'name': name,
|
|
161
|
+
'src': src_desc
|
|
162
|
+
}))
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def is_proto_message(pb_obj, field):
|
|
166
|
+
if not hasattr(pb_obj, 'DESCRIPTOR'):
|
|
167
|
+
return False
|
|
168
|
+
if field not in pb_obj.DESCRIPTOR.fields_by_name:
|
|
169
|
+
return False
|
|
170
|
+
field_type = pb_obj.DESCRIPTOR.fields_by_name[field].type
|
|
171
|
+
return field_type == FieldDescriptor.TYPE_MESSAGE
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class Parameter(object):
|
|
175
|
+
|
|
176
|
+
def __init__(self, params, is_struct, l2_reg=None):
|
|
177
|
+
self.params = params
|
|
178
|
+
self.is_struct = is_struct
|
|
179
|
+
self._l2_reg = l2_reg
|
|
180
|
+
|
|
181
|
+
@staticmethod
|
|
182
|
+
def make_from_pb(config):
|
|
183
|
+
return Parameter(config, False)
|
|
184
|
+
|
|
185
|
+
def get_pb_config(self):
|
|
186
|
+
assert not self.is_struct, 'Struct parameter can not convert to pb config'
|
|
187
|
+
return self.params
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def l2_regularizer(self):
|
|
191
|
+
return self._l2_reg
|
|
192
|
+
|
|
193
|
+
@l2_regularizer.setter
|
|
194
|
+
def l2_regularizer(self, value):
|
|
195
|
+
self._l2_reg = value
|
|
196
|
+
|
|
197
|
+
def __getattr__(self, key):
|
|
198
|
+
if self.is_struct:
|
|
199
|
+
if key not in self.params:
|
|
200
|
+
return None
|
|
201
|
+
value = self.params[key]
|
|
202
|
+
if type(value) == struct_pb2.Struct:
|
|
203
|
+
return Parameter(value, True, self._l2_reg)
|
|
204
|
+
else:
|
|
205
|
+
return value
|
|
206
|
+
value = getattr(self.params, key)
|
|
207
|
+
if is_proto_message(self.params, key):
|
|
208
|
+
return Parameter(value, False, self._l2_reg)
|
|
209
|
+
return value
|
|
210
|
+
|
|
211
|
+
def __getitem__(self, key):
|
|
212
|
+
return self.__getattr__(key)
|
|
213
|
+
|
|
214
|
+
def get_or_default(self, key, def_val):
|
|
215
|
+
if self.is_struct:
|
|
216
|
+
if key in self.params:
|
|
217
|
+
if def_val is None:
|
|
218
|
+
return self.params[key]
|
|
219
|
+
value = self.params[key]
|
|
220
|
+
if type(value) == float:
|
|
221
|
+
return type(def_val)(value)
|
|
222
|
+
return value
|
|
223
|
+
return def_val
|
|
224
|
+
else: # pb message
|
|
225
|
+
value = getattr(self.params, key, def_val)
|
|
226
|
+
if hasattr(value, '__len__'): # repeated
|
|
227
|
+
return value if len(value) > 0 else def_val
|
|
228
|
+
try:
|
|
229
|
+
if self.params.HasField(key):
|
|
230
|
+
return value
|
|
231
|
+
except ValueError:
|
|
232
|
+
pass
|
|
233
|
+
return def_val # maybe not equal to the default value of msg field
|
|
234
|
+
|
|
235
|
+
def check_required(self, keys):
|
|
236
|
+
if not self.is_struct:
|
|
237
|
+
return
|
|
238
|
+
if not isinstance(keys, (list, tuple)):
|
|
239
|
+
keys = [keys]
|
|
240
|
+
for key in keys:
|
|
241
|
+
if key not in self.params:
|
|
242
|
+
raise KeyError('%s must be set in params' % key)
|
|
243
|
+
|
|
244
|
+
def has_field(self, key):
|
|
245
|
+
if self.is_struct:
|
|
246
|
+
return key in self.params
|
|
247
|
+
else:
|
|
248
|
+
return self.params.HasField(key)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
# -*- encoding: utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
|
|
8
|
+
from easy_rec.python.compat.feature_column.feature_column import _SharedEmbeddingColumn # NOQA
|
|
9
|
+
from easy_rec.python.compat.feature_column.feature_column_v2 import EmbeddingColumn # NOQA
|
|
10
|
+
|
|
11
|
+
if tf.__version__ >= '2.0':
|
|
12
|
+
tf = tf.compat.v1
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class VariationalDropoutLayer(object):
|
|
16
|
+
"""Rank features by variational dropout.
|
|
17
|
+
|
|
18
|
+
Use the Dropout concept on the input feature layer and optimize the corresponding feature-wise dropout rate
|
|
19
|
+
paper: Dropout Feature Ranking for Deep Learning Models
|
|
20
|
+
arXiv: 1712.08645
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self,
|
|
24
|
+
variational_dropout_config,
|
|
25
|
+
features_dimension,
|
|
26
|
+
is_training=False,
|
|
27
|
+
name=''):
|
|
28
|
+
self._config = variational_dropout_config
|
|
29
|
+
self.features_dimension = features_dimension
|
|
30
|
+
self.features_total_dimension = sum(self.features_dimension.values())
|
|
31
|
+
|
|
32
|
+
if self.variational_dropout_wise():
|
|
33
|
+
self._dropout_param_size = self.features_total_dimension
|
|
34
|
+
self.drop_param_shape = [self._dropout_param_size]
|
|
35
|
+
else:
|
|
36
|
+
self._dropout_param_size = len(self.features_dimension)
|
|
37
|
+
self.drop_param_shape = [self._dropout_param_size]
|
|
38
|
+
self.evaluate = not is_training
|
|
39
|
+
|
|
40
|
+
logit_p_name = 'logit_p' if name == 'all' else 'logit_p_%s' % name
|
|
41
|
+
self.logit_p = tf.get_variable(
|
|
42
|
+
name=logit_p_name,
|
|
43
|
+
shape=self.drop_param_shape,
|
|
44
|
+
dtype=tf.float32,
|
|
45
|
+
initializer=None)
|
|
46
|
+
tf.add_to_collection(
|
|
47
|
+
'variational_dropout',
|
|
48
|
+
json.dumps([name, list(self.features_dimension.items())]))
|
|
49
|
+
|
|
50
|
+
def get_lambda(self):
|
|
51
|
+
return self._config.regularization_lambda
|
|
52
|
+
|
|
53
|
+
def variational_dropout_wise(self):
|
|
54
|
+
return self._config.embedding_wise_variational_dropout
|
|
55
|
+
|
|
56
|
+
def build_expand_index(self, batch_size):
|
|
57
|
+
# Build index_list--->[[0,0],[0,0],[0,0],[0,0],[0,1]......]
|
|
58
|
+
expanded_index = []
|
|
59
|
+
for i, index_loop_count in enumerate(self.features_dimension.values()):
|
|
60
|
+
for m in range(index_loop_count):
|
|
61
|
+
expanded_index.append([i])
|
|
62
|
+
expanded_index = tf.tile(expanded_index, [batch_size, 1])
|
|
63
|
+
batch_size_range = tf.range(batch_size)
|
|
64
|
+
expand_range_axis = tf.expand_dims(batch_size_range, 1)
|
|
65
|
+
batch_size_range_expand_dim_len = tf.tile(
|
|
66
|
+
expand_range_axis, [1, self.features_total_dimension])
|
|
67
|
+
index_i = tf.reshape(batch_size_range_expand_dim_len, [-1, 1])
|
|
68
|
+
expanded_index = tf.concat([index_i, expanded_index], 1)
|
|
69
|
+
return expanded_index
|
|
70
|
+
|
|
71
|
+
def sample_noisy_input(self, input):
|
|
72
|
+
batch_size = tf.shape(input)[0]
|
|
73
|
+
if self.evaluate:
|
|
74
|
+
expanded_dims_logit_p = tf.expand_dims(self.logit_p, 0)
|
|
75
|
+
expanded_logit_p = tf.tile(expanded_dims_logit_p, [batch_size, 1])
|
|
76
|
+
p = tf.sigmoid(expanded_logit_p)
|
|
77
|
+
if self.variational_dropout_wise():
|
|
78
|
+
scaled_input = input * (1 - p)
|
|
79
|
+
else:
|
|
80
|
+
# expand dropout layer
|
|
81
|
+
expanded_index = self.build_expand_index(batch_size)
|
|
82
|
+
expanded_p = tf.gather_nd(p, expanded_index)
|
|
83
|
+
expanded_p = tf.reshape(expanded_p, [-1, self.features_total_dimension])
|
|
84
|
+
scaled_input = input * (1 - expanded_p)
|
|
85
|
+
|
|
86
|
+
return scaled_input
|
|
87
|
+
else:
|
|
88
|
+
bern_val = self.sampled_from_logit_p(batch_size)
|
|
89
|
+
bern_val = tf.reshape(bern_val, [-1, self.features_total_dimension])
|
|
90
|
+
noisy_input = input * bern_val
|
|
91
|
+
return noisy_input
|
|
92
|
+
|
|
93
|
+
def sampled_from_logit_p(self, num_samples):
|
|
94
|
+
expand_dims_logit_p = tf.expand_dims(self.logit_p, 0)
|
|
95
|
+
expand_logit_p = tf.tile(expand_dims_logit_p, [num_samples, 1])
|
|
96
|
+
dropout_p = tf.sigmoid(expand_logit_p)
|
|
97
|
+
bern_val = self.concrete_dropout_neuron(dropout_p)
|
|
98
|
+
|
|
99
|
+
if self.variational_dropout_wise():
|
|
100
|
+
return bern_val
|
|
101
|
+
else:
|
|
102
|
+
# from feature_num to embedding_dim_num
|
|
103
|
+
expanded_index = self.build_expand_index(num_samples)
|
|
104
|
+
bern_val_gather_nd = tf.gather_nd(bern_val, expanded_index)
|
|
105
|
+
return bern_val_gather_nd
|
|
106
|
+
|
|
107
|
+
def concrete_dropout_neuron(self, dropout_p, temp=1.0 / 10.0):
|
|
108
|
+
EPSILON = np.finfo(float).eps
|
|
109
|
+
unif_noise = tf.random_uniform(
|
|
110
|
+
tf.shape(dropout_p), dtype=tf.float32, seed=None, name='unif_noise')
|
|
111
|
+
|
|
112
|
+
approx = (
|
|
113
|
+
tf.log(dropout_p + EPSILON) - tf.log(1. - dropout_p + EPSILON) +
|
|
114
|
+
tf.log(unif_noise + EPSILON) - tf.log(1. - unif_noise + EPSILON))
|
|
115
|
+
|
|
116
|
+
approx_output = tf.sigmoid(approx / temp)
|
|
117
|
+
return 1 - approx_output
|
|
118
|
+
|
|
119
|
+
def __call__(self, output_features):
|
|
120
|
+
batch_size = tf.shape(output_features)[0]
|
|
121
|
+
noisy_input = self.sample_noisy_input(output_features)
|
|
122
|
+
dropout_p = tf.sigmoid(self.logit_p)
|
|
123
|
+
variational_dropout_penalty = 1. - dropout_p
|
|
124
|
+
variational_dropout_penalty_lambda = self.get_lambda() / tf.cast(
|
|
125
|
+
batch_size, dtype=tf.float32)
|
|
126
|
+
variational_dropout_loss_sum = variational_dropout_penalty_lambda * tf.reduce_sum(
|
|
127
|
+
variational_dropout_penalty, axis=0)
|
|
128
|
+
tf.add_to_collection('variational_dropout_loss',
|
|
129
|
+
variational_dropout_loss_sum)
|
|
130
|
+
return noisy_input
|
|
File without changes
|