easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
# -*- encoding: utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from collections import OrderedDict
|
|
6
|
+
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
from tensorflow.python.framework import ops
|
|
9
|
+
from tensorflow.python.ops import array_ops
|
|
10
|
+
from tensorflow.python.ops import variable_scope
|
|
11
|
+
|
|
12
|
+
from easy_rec.python.compat import regularizers
|
|
13
|
+
from easy_rec.python.compat.feature_column import feature_column
|
|
14
|
+
from easy_rec.python.feature_column.feature_column import FeatureColumnParser
|
|
15
|
+
from easy_rec.python.feature_column.feature_group import FeatureGroup
|
|
16
|
+
from easy_rec.python.layers import sequence_feature_layer
|
|
17
|
+
from easy_rec.python.layers import variational_dropout_layer
|
|
18
|
+
from easy_rec.python.layers.keras import TextCNN
|
|
19
|
+
from easy_rec.python.layers.utils import Parameter
|
|
20
|
+
from easy_rec.python.protos.feature_config_pb2 import WideOrDeep
|
|
21
|
+
from easy_rec.python.utils import conditional
|
|
22
|
+
from easy_rec.python.utils import shape_utils
|
|
23
|
+
|
|
24
|
+
from easy_rec.python.compat.feature_column.feature_column_v2 import is_embedding_column # NOQA
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class InputLayer(object):
|
|
28
|
+
"""Input Layer for generate input features.
|
|
29
|
+
|
|
30
|
+
This class apply feature_columns to input tensors to generate wide features and deep features.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self,
|
|
34
|
+
feature_configs,
|
|
35
|
+
feature_groups_config,
|
|
36
|
+
variational_dropout_config=None,
|
|
37
|
+
wide_output_dim=-1,
|
|
38
|
+
ev_params=None,
|
|
39
|
+
embedding_regularizer=None,
|
|
40
|
+
kernel_regularizer=None,
|
|
41
|
+
is_training=False,
|
|
42
|
+
is_predicting=False):
|
|
43
|
+
self._feature_groups = {
|
|
44
|
+
x.group_name: FeatureGroup(x) for x in feature_groups_config
|
|
45
|
+
}
|
|
46
|
+
self.sequence_feature_layer = sequence_feature_layer.SequenceFeatureLayer(
|
|
47
|
+
feature_configs, feature_groups_config, ev_params,
|
|
48
|
+
embedding_regularizer, kernel_regularizer, is_training, is_predicting)
|
|
49
|
+
self._seq_feature_groups_config = []
|
|
50
|
+
for x in feature_groups_config:
|
|
51
|
+
for y in x.sequence_features:
|
|
52
|
+
self._seq_feature_groups_config.append(y)
|
|
53
|
+
self._group_name_to_seq_features = {
|
|
54
|
+
x.group_name: x.sequence_features
|
|
55
|
+
for x in feature_groups_config
|
|
56
|
+
if len(x.sequence_features) > 0
|
|
57
|
+
}
|
|
58
|
+
wide_and_deep_dict = self.get_wide_deep_dict()
|
|
59
|
+
self._fc_parser = FeatureColumnParser(
|
|
60
|
+
feature_configs,
|
|
61
|
+
wide_and_deep_dict,
|
|
62
|
+
wide_output_dim,
|
|
63
|
+
ev_params=ev_params)
|
|
64
|
+
|
|
65
|
+
self._embedding_regularizer = embedding_regularizer
|
|
66
|
+
self._kernel_regularizer = kernel_regularizer
|
|
67
|
+
self._is_training = is_training
|
|
68
|
+
self._is_predicting = is_predicting
|
|
69
|
+
self._variational_dropout_config = variational_dropout_config
|
|
70
|
+
|
|
71
|
+
def has_group(self, group_name):
|
|
72
|
+
return group_name in self._feature_groups
|
|
73
|
+
|
|
74
|
+
def get_combined_feature(self, features, group_name, is_dict=False):
|
|
75
|
+
"""Get combined features by group_name.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
features: input tensor dict
|
|
79
|
+
group_name: feature_group name
|
|
80
|
+
is_dict: whether to return group_features in dict
|
|
81
|
+
|
|
82
|
+
Return:
|
|
83
|
+
features: all features concatenate together
|
|
84
|
+
group_features: list of features
|
|
85
|
+
feature_name_to_output_tensors: dict, feature_name to feature_value, only present when is_dict is True
|
|
86
|
+
"""
|
|
87
|
+
feature_name_to_output_tensors = {}
|
|
88
|
+
negative_sampler = self._feature_groups[group_name]._config.negative_sampler
|
|
89
|
+
|
|
90
|
+
place_on_cpu = os.getenv('place_embedding_on_cpu')
|
|
91
|
+
place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
|
|
92
|
+
with conditional(self._is_predicting and place_on_cpu,
|
|
93
|
+
ops.device('/CPU:0')):
|
|
94
|
+
concat_features, group_features = self.single_call_input_layer(
|
|
95
|
+
features, group_name, feature_name_to_output_tensors)
|
|
96
|
+
if group_name in self._group_name_to_seq_features:
|
|
97
|
+
# for target attention
|
|
98
|
+
group_seq_arr = self._group_name_to_seq_features[group_name]
|
|
99
|
+
concat_features, all_seq_fea = self.sequence_feature_layer(
|
|
100
|
+
features,
|
|
101
|
+
concat_features,
|
|
102
|
+
group_seq_arr,
|
|
103
|
+
feature_name_to_output_tensors,
|
|
104
|
+
negative_sampler=negative_sampler,
|
|
105
|
+
scope_name=group_name)
|
|
106
|
+
group_features.extend(all_seq_fea)
|
|
107
|
+
for col, fea in zip(group_seq_arr, all_seq_fea):
|
|
108
|
+
feature_name_to_output_tensors['seq_fea/' + col.group_name] = fea
|
|
109
|
+
all_seq_fea = array_ops.concat(all_seq_fea, axis=-1)
|
|
110
|
+
concat_features = array_ops.concat([concat_features, all_seq_fea],
|
|
111
|
+
axis=-1)
|
|
112
|
+
if is_dict:
|
|
113
|
+
return concat_features, group_features, feature_name_to_output_tensors
|
|
114
|
+
else:
|
|
115
|
+
return concat_features, group_features
|
|
116
|
+
|
|
117
|
+
def get_plain_feature(self, features, group_name):
|
|
118
|
+
"""Get plain features by group_name. Exclude sequence features.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
features: input tensor dict
|
|
122
|
+
group_name: feature_group name
|
|
123
|
+
|
|
124
|
+
Return:
|
|
125
|
+
features: all features concatenate together
|
|
126
|
+
group_features: list of features
|
|
127
|
+
"""
|
|
128
|
+
assert group_name in self._feature_groups, 'invalid group_name[%s], list: %s' % (
|
|
129
|
+
group_name, ','.join([x for x in self._feature_groups]))
|
|
130
|
+
|
|
131
|
+
feature_group = self._feature_groups[group_name]
|
|
132
|
+
group_columns, _ = feature_group.select_columns(self._fc_parser)
|
|
133
|
+
if not group_columns:
|
|
134
|
+
return None, []
|
|
135
|
+
|
|
136
|
+
cols_to_output_tensors = OrderedDict()
|
|
137
|
+
output_features = feature_column.input_layer(
|
|
138
|
+
features,
|
|
139
|
+
group_columns,
|
|
140
|
+
cols_to_output_tensors=cols_to_output_tensors,
|
|
141
|
+
is_training=self._is_training)
|
|
142
|
+
group_features = [cols_to_output_tensors[x] for x in group_columns]
|
|
143
|
+
|
|
144
|
+
embedding_reg_lst = []
|
|
145
|
+
for col, val in cols_to_output_tensors.items():
|
|
146
|
+
if is_embedding_column(col):
|
|
147
|
+
embedding_reg_lst.append(val)
|
|
148
|
+
|
|
149
|
+
if self._embedding_regularizer is not None and len(embedding_reg_lst) > 0:
|
|
150
|
+
regularizers.apply_regularization(
|
|
151
|
+
self._embedding_regularizer, weights_list=embedding_reg_lst)
|
|
152
|
+
return output_features, group_features
|
|
153
|
+
|
|
154
|
+
def get_sequence_feature(self, features, group_name):
|
|
155
|
+
"""Get sequence features by group_name. Exclude plain features.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
features: input tensor dict
|
|
159
|
+
group_name: feature_group name
|
|
160
|
+
|
|
161
|
+
Return:
|
|
162
|
+
seq_features: list of sequence features, each element is a tuple:
|
|
163
|
+
3d embedding tensor (batch_size, max_seq_len, embedding_dimension),
|
|
164
|
+
1d sequence length tensor.
|
|
165
|
+
"""
|
|
166
|
+
assert group_name in self._feature_groups, 'invalid group_name[%s], list: %s' % (
|
|
167
|
+
group_name, ','.join([x for x in self._feature_groups]))
|
|
168
|
+
|
|
169
|
+
if self._variational_dropout_config is not None:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
'variational dropout is not supported in not combined mode now.')
|
|
172
|
+
|
|
173
|
+
feature_group = self._feature_groups[group_name]
|
|
174
|
+
_, group_seq_columns = feature_group.select_columns(self._fc_parser)
|
|
175
|
+
|
|
176
|
+
embedding_reg_lst = []
|
|
177
|
+
builder = feature_column._LazyBuilder(features)
|
|
178
|
+
seq_features = []
|
|
179
|
+
for fc in group_seq_columns:
|
|
180
|
+
with variable_scope.variable_scope('input_layer/' +
|
|
181
|
+
fc.categorical_column.name):
|
|
182
|
+
tmp_embedding, tmp_seq_len = fc._get_sequence_dense_tensor(builder)
|
|
183
|
+
if fc.max_seq_length > 0:
|
|
184
|
+
tmp_embedding, tmp_seq_len = shape_utils.truncate_sequence(
|
|
185
|
+
tmp_embedding, tmp_seq_len, fc.max_seq_length)
|
|
186
|
+
seq_features.append((tmp_embedding, tmp_seq_len))
|
|
187
|
+
embedding_reg_lst.append(tmp_embedding)
|
|
188
|
+
|
|
189
|
+
if self._embedding_regularizer is not None and len(embedding_reg_lst) > 0:
|
|
190
|
+
regularizers.apply_regularization(
|
|
191
|
+
self._embedding_regularizer, weights_list=embedding_reg_lst)
|
|
192
|
+
return seq_features
|
|
193
|
+
|
|
194
|
+
def get_raw_features(self, features, group_name):
|
|
195
|
+
"""Get features by group_name.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
features: input tensor dict
|
|
199
|
+
group_name: feature_group name
|
|
200
|
+
|
|
201
|
+
Return:
|
|
202
|
+
features: all raw features in list
|
|
203
|
+
"""
|
|
204
|
+
assert group_name in self._feature_groups, 'invalid group_name[%s], list: %s' % (
|
|
205
|
+
group_name, ','.join([x for x in self._feature_groups]))
|
|
206
|
+
feature_group = self._feature_groups[group_name]
|
|
207
|
+
return [features[x] for x in feature_group.feature_names]
|
|
208
|
+
|
|
209
|
+
def get_bucketized_features(self, features, group_name):
|
|
210
|
+
"""Get features by group_name.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
features: input tensor dict
|
|
214
|
+
group_name: feature_group name
|
|
215
|
+
|
|
216
|
+
Return:
|
|
217
|
+
features: all raw features in list, added feature offset
|
|
218
|
+
"""
|
|
219
|
+
assert group_name in self._feature_groups, 'invalid group_name[%s], list: %s' % (
|
|
220
|
+
group_name, ','.join([x for x in self._feature_groups]))
|
|
221
|
+
feature_group = self._feature_groups[group_name]
|
|
222
|
+
offset = 0
|
|
223
|
+
values = []
|
|
224
|
+
weights = []
|
|
225
|
+
for feature in feature_group.feature_names:
|
|
226
|
+
vocab = self._fc_parser.get_feature_vocab_size(feature)
|
|
227
|
+
logging.info('vocab size of feature %s is %d' % (feature, vocab))
|
|
228
|
+
weights.append(None)
|
|
229
|
+
if tf.is_numeric_tensor(features[feature]):
|
|
230
|
+
# suppose feature already have be bucketized
|
|
231
|
+
value = tf.to_int64(features[feature])
|
|
232
|
+
elif isinstance(features[feature], tf.SparseTensor):
|
|
233
|
+
# TagFeature
|
|
234
|
+
dense = tf.sparse.to_dense(features[feature], default_value='')
|
|
235
|
+
value = tf.string_to_hash_bucket_fast(dense, vocab)
|
|
236
|
+
if (feature + '_w') in features:
|
|
237
|
+
weights[-1] = features[feature + '_w'] # SparseTensor
|
|
238
|
+
logging.info('feature %s has weight %s', feature, feature + '_w')
|
|
239
|
+
else: # IdFeature
|
|
240
|
+
value = tf.string_to_hash_bucket_fast(features[feature], vocab)
|
|
241
|
+
values.append(value + offset)
|
|
242
|
+
offset += vocab
|
|
243
|
+
return values, offset, weights
|
|
244
|
+
|
|
245
|
+
def __call__(self, features, group_name, is_combine=True, is_dict=False):
|
|
246
|
+
"""Get features by group_name.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
features: input tensor dict
|
|
250
|
+
group_name: feature_group name
|
|
251
|
+
is_combine: whether to combine sequence features over the
|
|
252
|
+
time dimension.
|
|
253
|
+
is_dict: whether to return group_features in dict
|
|
254
|
+
|
|
255
|
+
Return:
|
|
256
|
+
is_combine: True
|
|
257
|
+
features: all features concatenate together
|
|
258
|
+
group_features: list of features
|
|
259
|
+
feature_name_to_output_tensors: dict, feature_name to feature_value, only present when is_dict is True
|
|
260
|
+
is_combine: False
|
|
261
|
+
seq_features: list of sequence features, each element is a tuple:
|
|
262
|
+
3 dimension embedding tensor (batch_size, max_seq_len, embedding_dimension),
|
|
263
|
+
1 dimension sequence length tensor.
|
|
264
|
+
"""
|
|
265
|
+
assert group_name in self._feature_groups, 'invalid group_name[%s], list: %s' % (
|
|
266
|
+
group_name, ','.join([x for x in self._feature_groups]))
|
|
267
|
+
if is_combine:
|
|
268
|
+
return self.get_combined_feature(features, group_name, is_dict)
|
|
269
|
+
|
|
270
|
+
# return sequence feature in raw format instead of combine them
|
|
271
|
+
place_on_cpu = os.getenv('place_embedding_on_cpu')
|
|
272
|
+
place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
|
|
273
|
+
with conditional(self._is_predicting and place_on_cpu,
|
|
274
|
+
ops.device('/CPU:0')):
|
|
275
|
+
seq_features = self.get_sequence_feature(features, group_name)
|
|
276
|
+
plain_features, feature_list = self.get_plain_feature(
|
|
277
|
+
features, group_name)
|
|
278
|
+
return seq_features, plain_features, feature_list
|
|
279
|
+
|
|
280
|
+
def single_call_input_layer(self,
|
|
281
|
+
features,
|
|
282
|
+
group_name,
|
|
283
|
+
feature_name_to_output_tensors=None):
|
|
284
|
+
"""Get features by group_name.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
features: input tensor dict
|
|
288
|
+
group_name: feature_group name
|
|
289
|
+
feature_name_to_output_tensors: if set sequence_features,
|
|
290
|
+
feature_name_to_output_tensors will take key tensors to reuse.
|
|
291
|
+
|
|
292
|
+
Return:
|
|
293
|
+
features: all features concatenate together
|
|
294
|
+
group_features: list of features
|
|
295
|
+
"""
|
|
296
|
+
assert group_name in self._feature_groups, 'invalid group_name[%s], list: %s' % (
|
|
297
|
+
group_name, ','.join([x for x in self._feature_groups]))
|
|
298
|
+
feature_group = self._feature_groups[group_name]
|
|
299
|
+
group_columns, group_seq_columns = feature_group.select_columns(
|
|
300
|
+
self._fc_parser)
|
|
301
|
+
cols_to_output_tensors = OrderedDict()
|
|
302
|
+
output_features = feature_column.input_layer(
|
|
303
|
+
features,
|
|
304
|
+
group_columns if len(group_columns) > 0 else group_seq_columns,
|
|
305
|
+
cols_to_output_tensors=cols_to_output_tensors,
|
|
306
|
+
feature_name_to_output_tensors=feature_name_to_output_tensors,
|
|
307
|
+
is_training=self._is_training)
|
|
308
|
+
|
|
309
|
+
embedding_reg_lst = []
|
|
310
|
+
builder = feature_column._LazyBuilder(features)
|
|
311
|
+
seq_features = []
|
|
312
|
+
for column in sorted(group_seq_columns, key=lambda x: x.name):
|
|
313
|
+
with variable_scope.variable_scope(
|
|
314
|
+
None, default_name=column._var_scope_name):
|
|
315
|
+
seq_feature, seq_len = column._get_sequence_dense_tensor(builder)
|
|
316
|
+
embedding_reg_lst.append(seq_feature)
|
|
317
|
+
|
|
318
|
+
sequence_combiner = column.sequence_combiner
|
|
319
|
+
if sequence_combiner is None:
|
|
320
|
+
raise ValueError(
|
|
321
|
+
'sequence_combiner is none, please set sequence_combiner or use TagFeature'
|
|
322
|
+
)
|
|
323
|
+
if sequence_combiner.WhichOneof('combiner') == 'attention':
|
|
324
|
+
attn_logits = tf.layers.dense(
|
|
325
|
+
inputs=seq_feature,
|
|
326
|
+
units=1,
|
|
327
|
+
kernel_regularizer=self._kernel_regularizer,
|
|
328
|
+
use_bias=False,
|
|
329
|
+
activation=None,
|
|
330
|
+
name='attention')
|
|
331
|
+
attn_logits = tf.squeeze(attn_logits, axis=-1)
|
|
332
|
+
attn_logits_padding = tf.ones_like(attn_logits) * (-2**32 + 1)
|
|
333
|
+
seq_mask = tf.sequence_mask(seq_len)
|
|
334
|
+
attn_score = tf.nn.softmax(
|
|
335
|
+
tf.where(seq_mask, attn_logits, attn_logits_padding))
|
|
336
|
+
seq_feature = tf.reduce_sum(
|
|
337
|
+
attn_score[:, :, tf.newaxis] * seq_feature, axis=1)
|
|
338
|
+
seq_features.append(seq_feature)
|
|
339
|
+
cols_to_output_tensors[column] = seq_feature
|
|
340
|
+
elif sequence_combiner.WhichOneof('combiner') == 'text_cnn':
|
|
341
|
+
params = Parameter.make_from_pb(sequence_combiner.text_cnn)
|
|
342
|
+
text_cnn_layer = TextCNN(params, name=column.name + '_text_cnn')
|
|
343
|
+
cnn_feature = text_cnn_layer((seq_feature, seq_len))
|
|
344
|
+
seq_features.append(cnn_feature)
|
|
345
|
+
cols_to_output_tensors[column] = cnn_feature
|
|
346
|
+
else:
|
|
347
|
+
raise NotImplementedError
|
|
348
|
+
if self._variational_dropout_config is not None:
|
|
349
|
+
features_dimension = OrderedDict([
|
|
350
|
+
(k.raw_name, int(v.shape[-1]))
|
|
351
|
+
for k, v in cols_to_output_tensors.items()
|
|
352
|
+
])
|
|
353
|
+
concat_features = array_ops.concat(
|
|
354
|
+
[output_features] + seq_features, axis=-1)
|
|
355
|
+
variational_dropout = variational_dropout_layer.VariationalDropoutLayer(
|
|
356
|
+
self._variational_dropout_config,
|
|
357
|
+
features_dimension,
|
|
358
|
+
self._is_training,
|
|
359
|
+
name=group_name)
|
|
360
|
+
concat_features = variational_dropout(concat_features)
|
|
361
|
+
group_features = tf.split(
|
|
362
|
+
concat_features, list(features_dimension.values()), axis=-1)
|
|
363
|
+
else:
|
|
364
|
+
concat_features = array_ops.concat(
|
|
365
|
+
[output_features] + seq_features, axis=-1)
|
|
366
|
+
group_features = [cols_to_output_tensors[x] for x in group_columns] + \
|
|
367
|
+
[cols_to_output_tensors[x] for x in group_seq_columns]
|
|
368
|
+
|
|
369
|
+
if self._embedding_regularizer is not None:
|
|
370
|
+
for fc, val in cols_to_output_tensors.items():
|
|
371
|
+
if is_embedding_column(fc):
|
|
372
|
+
embedding_reg_lst.append(val)
|
|
373
|
+
if embedding_reg_lst:
|
|
374
|
+
regularizers.apply_regularization(
|
|
375
|
+
self._embedding_regularizer, weights_list=embedding_reg_lst)
|
|
376
|
+
return concat_features, group_features
|
|
377
|
+
|
|
378
|
+
def get_wide_deep_dict(self):
|
|
379
|
+
"""Get wide or deep indicator for feature columns.
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
dict of { feature_name : WideOrDeep }
|
|
383
|
+
"""
|
|
384
|
+
wide_and_deep_dict = {}
|
|
385
|
+
for fg_name in self._feature_groups.keys():
|
|
386
|
+
fg = self._feature_groups[fg_name]
|
|
387
|
+
tmp_dict = fg.wide_and_deep_dict
|
|
388
|
+
for k in tmp_dict:
|
|
389
|
+
v = tmp_dict[k]
|
|
390
|
+
if k not in wide_and_deep_dict:
|
|
391
|
+
wide_and_deep_dict[k] = v
|
|
392
|
+
elif wide_and_deep_dict[k] != v:
|
|
393
|
+
wide_and_deep_dict[k] = WideOrDeep.WIDE_AND_DEEP
|
|
394
|
+
else:
|
|
395
|
+
pass
|
|
396
|
+
return wide_and_deep_dict
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from .attention import Attention
|
|
2
|
+
from .auxiliary_loss import AuxiliaryLoss
|
|
3
|
+
from .blocks import MLP
|
|
4
|
+
from .blocks import Gate
|
|
5
|
+
from .blocks import Highway
|
|
6
|
+
from .blocks import TextCNN
|
|
7
|
+
from .bst import BST
|
|
8
|
+
from .custom_ops import EditDistance
|
|
9
|
+
from .custom_ops import MappedDotProduct
|
|
10
|
+
from .custom_ops import OverlapFeature
|
|
11
|
+
from .custom_ops import SeqAugmentOps
|
|
12
|
+
from .custom_ops import TextNormalize
|
|
13
|
+
from .data_augment import SeqAugment
|
|
14
|
+
from .din import DIN
|
|
15
|
+
from .embedding import EmbeddingLayer
|
|
16
|
+
from .fibinet import BiLinear
|
|
17
|
+
from .fibinet import FiBiNet
|
|
18
|
+
from .fibinet import SENet
|
|
19
|
+
from .interaction import CIN
|
|
20
|
+
from .interaction import FM
|
|
21
|
+
from .interaction import Cross
|
|
22
|
+
from .interaction import DotInteraction
|
|
23
|
+
from .mask_net import MaskBlock
|
|
24
|
+
from .mask_net import MaskNet
|
|
25
|
+
from .multi_head_attention import MultiHeadAttention
|
|
26
|
+
from .multi_task import AITMTower
|
|
27
|
+
from .multi_task import MMoE
|
|
28
|
+
from .numerical_embedding import AutoDisEmbedding
|
|
29
|
+
from .numerical_embedding import NaryDisEmbedding
|
|
30
|
+
from .numerical_embedding import PeriodicEmbedding
|
|
31
|
+
from .ppnet import PPNet
|
|
32
|
+
from .transformer import TextEncoder
|
|
33
|
+
from .transformer import TransformerBlock
|
|
34
|
+
from .transformer import TransformerEncoder
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
from tensorflow.python.keras.layers import Activation
|
|
5
|
+
from tensorflow.python.keras.layers import Layer
|
|
6
|
+
|
|
7
|
+
import easy_rec.python.utils.activation
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from tensorflow.python.ops.init_ops import Zeros
|
|
11
|
+
except ImportError:
|
|
12
|
+
from tensorflow.python.ops.init_ops_v2 import Zeros
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from tensorflow.python.keras.layers import BatchNormalization
|
|
16
|
+
except ImportError:
|
|
17
|
+
BatchNormalization = tf.keras.layers.BatchNormalization
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
unicode
|
|
21
|
+
except NameError:
|
|
22
|
+
unicode = str
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Dice(Layer):
|
|
26
|
+
"""The Data Adaptive Activation Function in DIN.
|
|
27
|
+
|
|
28
|
+
which can be viewed as a generalization of PReLu
|
|
29
|
+
and can adaptively adjust the rectified point according to distribution of input data.
|
|
30
|
+
|
|
31
|
+
Input shape
|
|
32
|
+
- Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis)
|
|
33
|
+
when using this layer as the first layer in a model.
|
|
34
|
+
|
|
35
|
+
Output shape
|
|
36
|
+
- Same shape as the input.
|
|
37
|
+
|
|
38
|
+
Arguments
|
|
39
|
+
- **axis** : Integer, the axis that should be used to compute data distribution (typically the features axis).
|
|
40
|
+
- **epsilon** : Small float added to variance to avoid dividing by zero.
|
|
41
|
+
|
|
42
|
+
References
|
|
43
|
+
- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]
|
|
44
|
+
https://arxiv.org/pdf/1706.06978.pdf
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, axis=-1, epsilon=1e-9, **kwargs):
|
|
48
|
+
self.axis = axis
|
|
49
|
+
self.epsilon = epsilon
|
|
50
|
+
super(Dice, self).__init__(**kwargs)
|
|
51
|
+
|
|
52
|
+
def build(self, input_shape):
|
|
53
|
+
self.bn = BatchNormalization(
|
|
54
|
+
axis=self.axis, epsilon=self.epsilon, center=False, scale=False)
|
|
55
|
+
self.alphas = self.add_weight(
|
|
56
|
+
shape=(input_shape[-1],),
|
|
57
|
+
initializer=Zeros(),
|
|
58
|
+
dtype=tf.float32,
|
|
59
|
+
name='dice_alpha') # name='alpha_'+self.name
|
|
60
|
+
super(Dice, self).build(input_shape) # Be sure to call this somewhere!
|
|
61
|
+
self.uses_learning_phase = True
|
|
62
|
+
|
|
63
|
+
def call(self, inputs, training=None, **kwargs):
|
|
64
|
+
inputs_normed = self.bn(inputs, training=training)
|
|
65
|
+
# tf.layers.batch_normalization(
|
|
66
|
+
# inputs, axis=self.axis, epsilon=self.epsilon, center=False, scale=False)
|
|
67
|
+
x_p = tf.sigmoid(inputs_normed)
|
|
68
|
+
return self.alphas * (1.0 - x_p) * inputs + x_p * inputs
|
|
69
|
+
|
|
70
|
+
def compute_output_shape(self, input_shape):
|
|
71
|
+
return input_shape
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def updates(self):
|
|
75
|
+
return self.bn.updates
|
|
76
|
+
|
|
77
|
+
def get_config(self,):
|
|
78
|
+
config = {'axis': self.axis, 'epsilon': self.epsilon}
|
|
79
|
+
base_config = super(Dice, self).get_config()
|
|
80
|
+
return dict(list(base_config.items()) + list(config.items()))
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class MaskedSoftmax(Layer):
|
|
84
|
+
|
|
85
|
+
def __init__(self, axis=-1, **kwargs):
|
|
86
|
+
super(MaskedSoftmax, self).__init__(**kwargs)
|
|
87
|
+
self.axis = axis
|
|
88
|
+
|
|
89
|
+
def call(self, inputs, mask=None):
|
|
90
|
+
if mask is not None:
|
|
91
|
+
adder = (1.0 - tf.cast(mask, inputs.dtype)) * -1e9
|
|
92
|
+
inputs += adder
|
|
93
|
+
# Calculate softmax
|
|
94
|
+
if isinstance(self.axis, (tuple, list)):
|
|
95
|
+
if len(self.axis) > 1:
|
|
96
|
+
raise ValueError('MaskedSoftmax not support multiple axis')
|
|
97
|
+
else:
|
|
98
|
+
return tf.nn.softmax(inputs, axis=self.axis[0])
|
|
99
|
+
return tf.nn.softmax(inputs, axis=self.axis)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def activation_layer(activation, name=None):
|
|
103
|
+
if activation in ('dice', 'Dice'):
|
|
104
|
+
act_layer = Dice(name=name)
|
|
105
|
+
elif isinstance(activation, (str, unicode)):
|
|
106
|
+
act_fn = easy_rec.python.utils.activation.get_activation(activation)
|
|
107
|
+
act_layer = Activation(act_fn, name=name)
|
|
108
|
+
elif issubclass(activation, Layer):
|
|
109
|
+
act_layer = activation(name=name)
|
|
110
|
+
else:
|
|
111
|
+
raise ValueError(
|
|
112
|
+
'Invalid activation,found %s.You should use a str or a Activation Layer Class.'
|
|
113
|
+
% (activation))
|
|
114
|
+
return act_layer
|