easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
|
|
5
|
+
from easy_rec.python.layers import dnn
|
|
6
|
+
from easy_rec.python.model.multi_task_model import MultiTaskModel
|
|
7
|
+
|
|
8
|
+
from easy_rec.python.protos.simple_multi_task_pb2 import SimpleMultiTask as SimpleMultiTaskConfig # NOQA
|
|
9
|
+
|
|
10
|
+
if tf.__version__ >= '2.0':
|
|
11
|
+
tf = tf.compat.v1
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SimpleMultiTask(MultiTaskModel):
|
|
15
|
+
|
|
16
|
+
def __init__(self,
|
|
17
|
+
model_config,
|
|
18
|
+
feature_configs,
|
|
19
|
+
features,
|
|
20
|
+
labels=None,
|
|
21
|
+
is_training=False):
|
|
22
|
+
super(SimpleMultiTask, self).__init__(model_config, feature_configs,
|
|
23
|
+
features, labels, is_training)
|
|
24
|
+
|
|
25
|
+
assert self._model_config.WhichOneof('model') == 'simple_multi_task', \
|
|
26
|
+
'invalid model config: %s' % self._model_config.WhichOneof('model')
|
|
27
|
+
self._model_config = self._model_config.simple_multi_task
|
|
28
|
+
assert isinstance(self._model_config, SimpleMultiTaskConfig)
|
|
29
|
+
|
|
30
|
+
if self.has_backbone:
|
|
31
|
+
self._features = self.backbone
|
|
32
|
+
else:
|
|
33
|
+
self._features, _ = self._input_layer(self._feature_dict, 'all')
|
|
34
|
+
self._init_towers(self._model_config.task_towers)
|
|
35
|
+
|
|
36
|
+
def build_predict_graph(self):
|
|
37
|
+
tower_outputs = {}
|
|
38
|
+
for i, task_tower_cfg in enumerate(self._task_towers):
|
|
39
|
+
tower_name = task_tower_cfg.tower_name
|
|
40
|
+
task_dnn = dnn.DNN(
|
|
41
|
+
task_tower_cfg.dnn,
|
|
42
|
+
self._l2_reg,
|
|
43
|
+
name=tower_name,
|
|
44
|
+
is_training=self._is_training)
|
|
45
|
+
task_fea = task_dnn(self._features)
|
|
46
|
+
task_output = tf.layers.dense(
|
|
47
|
+
inputs=task_fea,
|
|
48
|
+
units=task_tower_cfg.num_class,
|
|
49
|
+
kernel_regularizer=self._l2_reg,
|
|
50
|
+
name='dnn_output_%d' % i)
|
|
51
|
+
tower_outputs[tower_name] = task_output
|
|
52
|
+
|
|
53
|
+
self._add_to_prediction_dict(tower_outputs)
|
|
54
|
+
return self._prediction_dict
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
|
|
5
|
+
from easy_rec.python.layers import dnn
|
|
6
|
+
from easy_rec.python.layers import uniter
|
|
7
|
+
from easy_rec.python.model.rank_model import RankModel
|
|
8
|
+
|
|
9
|
+
from easy_rec.python.protos.uniter_pb2 import Uniter as UNITERConfig # NOQA
|
|
10
|
+
|
|
11
|
+
if tf.__version__ >= '2.0':
|
|
12
|
+
tf = tf.compat.v1
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Uniter(RankModel):
|
|
16
|
+
"""UNITER: UNiversal Image-TExt Representation Learning.
|
|
17
|
+
|
|
18
|
+
See the original paper:
|
|
19
|
+
https://arxiv.org/abs/1909.11740
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self,
|
|
23
|
+
model_config,
|
|
24
|
+
feature_configs,
|
|
25
|
+
features,
|
|
26
|
+
labels=None,
|
|
27
|
+
is_training=False):
|
|
28
|
+
super(Uniter, self).__init__(model_config, feature_configs, features,
|
|
29
|
+
labels, is_training)
|
|
30
|
+
assert self._model_config.WhichOneof('model') == 'uniter', (
|
|
31
|
+
'invalid model config: %s' % self._model_config.WhichOneof('model'))
|
|
32
|
+
|
|
33
|
+
self._uniter_layer = uniter.Uniter(model_config, feature_configs, features,
|
|
34
|
+
self._model_config.uniter.config,
|
|
35
|
+
self._input_layer)
|
|
36
|
+
self._model_config = self._model_config.uniter
|
|
37
|
+
|
|
38
|
+
def build_predict_graph(self):
|
|
39
|
+
hidden = self._uniter_layer(self._is_training, l2_reg=self._l2_reg)
|
|
40
|
+
final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg,
|
|
41
|
+
'final_dnn', self._is_training)
|
|
42
|
+
all_fea = final_dnn_layer(hidden)
|
|
43
|
+
|
|
44
|
+
final = tf.layers.dense(all_fea, self._num_class, name='output')
|
|
45
|
+
self._add_to_prediction_dict(final)
|
|
46
|
+
return self._prediction_dict
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
|
|
7
|
+
from easy_rec.python.layers import dnn
|
|
8
|
+
from easy_rec.python.model.rank_model import RankModel
|
|
9
|
+
|
|
10
|
+
from easy_rec.python.protos.wide_and_deep_pb2 import WideAndDeep as WideAndDeepConfig # NOQA
|
|
11
|
+
|
|
12
|
+
if tf.__version__ >= '2.0':
|
|
13
|
+
tf = tf.compat.v1
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class WideAndDeep(RankModel):
|
|
17
|
+
|
|
18
|
+
def __init__(self,
|
|
19
|
+
model_config,
|
|
20
|
+
feature_configs,
|
|
21
|
+
features,
|
|
22
|
+
labels=None,
|
|
23
|
+
is_training=False):
|
|
24
|
+
super(WideAndDeep, self).__init__(model_config, feature_configs, features,
|
|
25
|
+
labels, is_training)
|
|
26
|
+
assert model_config.WhichOneof('model') == 'wide_and_deep', \
|
|
27
|
+
'invalid model config: %s' % model_config.WhichOneof('model')
|
|
28
|
+
self._model_config = model_config.wide_and_deep
|
|
29
|
+
assert isinstance(self._model_config, WideAndDeepConfig)
|
|
30
|
+
assert self._input_layer.has_group('wide')
|
|
31
|
+
_, self._wide_features = self._input_layer(self._feature_dict, 'wide')
|
|
32
|
+
assert self._input_layer.has_group('deep')
|
|
33
|
+
_, self._deep_features = self._input_layer(self._feature_dict, 'deep')
|
|
34
|
+
|
|
35
|
+
def build_input_layer(self, model_config, feature_configs):
|
|
36
|
+
# overwrite create input_layer to support wide_output_dim
|
|
37
|
+
has_final = len(model_config.wide_and_deep.final_dnn.hidden_units) > 0
|
|
38
|
+
self._wide_output_dim = model_config.wide_and_deep.wide_output_dim
|
|
39
|
+
if not has_final:
|
|
40
|
+
model_config.wide_and_deep.wide_output_dim = model_config.num_class
|
|
41
|
+
self._wide_output_dim = model_config.num_class
|
|
42
|
+
super(WideAndDeep, self).build_input_layer(model_config, feature_configs)
|
|
43
|
+
|
|
44
|
+
def build_predict_graph(self):
|
|
45
|
+
wide_fea = tf.add_n(self._wide_features)
|
|
46
|
+
logging.info('wide features dimension: %d' % wide_fea.get_shape()[-1])
|
|
47
|
+
|
|
48
|
+
self._deep_features = tf.concat(self._deep_features, axis=1)
|
|
49
|
+
logging.info('input deep features dimension: %d' %
|
|
50
|
+
self._deep_features.get_shape()[-1])
|
|
51
|
+
|
|
52
|
+
deep_layer = dnn.DNN(self._model_config.dnn, self._l2_reg, 'deep_feature',
|
|
53
|
+
self._is_training)
|
|
54
|
+
deep_fea = deep_layer(self._deep_features)
|
|
55
|
+
logging.info('output deep features dimension: %d' %
|
|
56
|
+
deep_fea.get_shape()[-1])
|
|
57
|
+
|
|
58
|
+
has_final = len(self._model_config.final_dnn.hidden_units) > 0
|
|
59
|
+
print('wide_deep has_final_dnn layers = %d' % has_final)
|
|
60
|
+
if has_final:
|
|
61
|
+
all_fea = tf.concat([wide_fea, deep_fea], axis=1)
|
|
62
|
+
final_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg,
|
|
63
|
+
'final_dnn', self._is_training)
|
|
64
|
+
all_fea = final_layer(all_fea)
|
|
65
|
+
output = tf.layers.dense(
|
|
66
|
+
all_fea,
|
|
67
|
+
self._num_class,
|
|
68
|
+
kernel_regularizer=self._l2_reg,
|
|
69
|
+
name='output')
|
|
70
|
+
else:
|
|
71
|
+
deep_out = tf.layers.dense(
|
|
72
|
+
deep_fea,
|
|
73
|
+
self._num_class,
|
|
74
|
+
kernel_regularizer=self._l2_reg,
|
|
75
|
+
name='deep_out')
|
|
76
|
+
output = deep_out + wide_fea
|
|
77
|
+
|
|
78
|
+
self._add_to_prediction_dict(output)
|
|
79
|
+
|
|
80
|
+
return self._prediction_dict
|
|
81
|
+
|
|
82
|
+
def get_grouped_vars(self, opt_num):
|
|
83
|
+
"""Group the vars into different optimization groups.
|
|
84
|
+
|
|
85
|
+
Each group will be optimized by a separate optimizer.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
opt_num: number of optimizers from easyrec config.
|
|
89
|
+
|
|
90
|
+
Return:
|
|
91
|
+
list of list of variables.
|
|
92
|
+
"""
|
|
93
|
+
assert opt_num <= 3, 'could only support 2 or 3 optimizers, ' + \
|
|
94
|
+
'if opt_num = 2, one for the wide , and one for the others, ' + \
|
|
95
|
+
'if opt_num = 3, one for the wide, second for the deep embeddings, ' + \
|
|
96
|
+
'and third for the other layers.'
|
|
97
|
+
|
|
98
|
+
if opt_num == 2:
|
|
99
|
+
wide_vars = []
|
|
100
|
+
deep_vars = []
|
|
101
|
+
for tmp_var in tf.trainable_variables():
|
|
102
|
+
if tmp_var.name.startswith('input_layer') and \
|
|
103
|
+
(not tmp_var.name.startswith('input_layer_1')):
|
|
104
|
+
wide_vars.append(tmp_var)
|
|
105
|
+
else:
|
|
106
|
+
deep_vars.append(tmp_var)
|
|
107
|
+
return [wide_vars, deep_vars]
|
|
108
|
+
elif opt_num == 3:
|
|
109
|
+
wide_vars = []
|
|
110
|
+
embedding_vars = []
|
|
111
|
+
deep_vars = []
|
|
112
|
+
for tmp_var in tf.trainable_variables():
|
|
113
|
+
if tmp_var.name.startswith('input_layer') and \
|
|
114
|
+
(not tmp_var.name.startswith('input_layer_1')):
|
|
115
|
+
wide_vars.append(tmp_var)
|
|
116
|
+
elif tmp_var.name.startswith(
|
|
117
|
+
'input_layer') or '/embedding_weights' in tmp_var.name:
|
|
118
|
+
embedding_vars.append(tmp_var)
|
|
119
|
+
else:
|
|
120
|
+
deep_vars.append(tmp_var)
|
|
121
|
+
return [wide_vars, embedding_vars, deep_vars]
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
File without changes
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""Python wrappers around TensorFlow ops.
|
|
2
|
+
|
|
3
|
+
This file is MACHINE GENERATED! Do not edit.
|
|
4
|
+
Original C++ source file: kafka_ops_deprecated.cc
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import traceback
|
|
10
|
+
|
|
11
|
+
import six as _six
|
|
12
|
+
import tensorflow as tf
|
|
13
|
+
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
|
14
|
+
from tensorflow.python.eager import context as _context
|
|
15
|
+
from tensorflow.python.eager import core as _core
|
|
16
|
+
from tensorflow.python.eager import execute as _execute
|
|
17
|
+
# Needed to trigger the call to _set_call_cpp_shape_fn.
|
|
18
|
+
from tensorflow.python.framework import dtypes as _dtypes
|
|
19
|
+
from tensorflow.python.framework import ops as _ops
|
|
20
|
+
from tensorflow.python.util.tf_export import tf_export
|
|
21
|
+
|
|
22
|
+
import easy_rec
|
|
23
|
+
|
|
24
|
+
kafka_module = None
|
|
25
|
+
if easy_rec.ops_dir is not None:
|
|
26
|
+
kafka_ops_path = os.path.join(easy_rec.ops_dir, 'kafka.so')
|
|
27
|
+
if os.path.exists(kafka_ops_path):
|
|
28
|
+
try:
|
|
29
|
+
kafka_module = tf.load_op_library(kafka_ops_path)
|
|
30
|
+
except Exception:
|
|
31
|
+
logging.warning('load %s failed: %s' %
|
|
32
|
+
(kafka_ops_path, traceback.format_exc()))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@tf_export('io_kafka_dataset_v2')
|
|
36
|
+
def io_kafka_dataset_v2(topics,
|
|
37
|
+
servers,
|
|
38
|
+
group,
|
|
39
|
+
eof,
|
|
40
|
+
timeout,
|
|
41
|
+
config_global,
|
|
42
|
+
config_topic,
|
|
43
|
+
message_key,
|
|
44
|
+
message_offset,
|
|
45
|
+
name=None):
|
|
46
|
+
"""Creates a dataset that emits the messages of one or more Kafka topics.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
topics: A `Tensor` of type `string`.
|
|
50
|
+
A `tf.string` tensor containing one or more subscriptions,
|
|
51
|
+
in the format of [topic:partition:offset].
|
|
52
|
+
servers: A `Tensor` of type `string`. A list of bootstrap servers.
|
|
53
|
+
group: A `Tensor` of type `string`. The consumer group id.
|
|
54
|
+
eof: A `Tensor` of type `bool`.
|
|
55
|
+
If True, the kafka reader will stop on EOF.
|
|
56
|
+
timeout: A `Tensor` of type `int64`.
|
|
57
|
+
The timeout value for the Kafka Consumer to wait
|
|
58
|
+
(in millisecond).
|
|
59
|
+
config_global: A `Tensor` of type `string`.
|
|
60
|
+
A `tf.string` tensor containing global configuration
|
|
61
|
+
properties in [Key=Value] format,
|
|
62
|
+
eg. ["enable.auto.commit=false", "heartbeat.interval.ms=2000"],
|
|
63
|
+
please refer to 'Global configuration properties' in librdkafka doc.
|
|
64
|
+
config_topic: A `Tensor` of type `string`.
|
|
65
|
+
A `tf.string` tensor containing topic configuration
|
|
66
|
+
properties in [Key=Value] format, eg. ["auto.offset.reset=earliest"],
|
|
67
|
+
please refer to 'Topic configuration properties' in librdkafka doc.
|
|
68
|
+
message_key: A `Tensor` of type `bool`.
|
|
69
|
+
message_offset: A `Tensor` of type `bool`.
|
|
70
|
+
name: A name for the operation (optional).
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
A `Tensor` of type `variant`.
|
|
74
|
+
"""
|
|
75
|
+
return kafka_module.io_kafka_dataset_v2(
|
|
76
|
+
topics=topics,
|
|
77
|
+
servers=servers,
|
|
78
|
+
group=group,
|
|
79
|
+
eof=eof,
|
|
80
|
+
timeout=timeout,
|
|
81
|
+
config_global=config_global,
|
|
82
|
+
config_topic=config_topic,
|
|
83
|
+
message_key=message_key,
|
|
84
|
+
message_offset=message_offset,
|
|
85
|
+
name=name)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def io_kafka_dataset_eager_fallback(topics,
|
|
89
|
+
servers,
|
|
90
|
+
group,
|
|
91
|
+
eof,
|
|
92
|
+
timeout,
|
|
93
|
+
config_global,
|
|
94
|
+
config_topic,
|
|
95
|
+
message_key,
|
|
96
|
+
message_offset,
|
|
97
|
+
name=None,
|
|
98
|
+
ctx=None):
|
|
99
|
+
"""This is the slowpath function for Eager mode.
|
|
100
|
+
|
|
101
|
+
This is for function io_kafka_dataset
|
|
102
|
+
"""
|
|
103
|
+
_ctx = ctx if ctx else _context.context()
|
|
104
|
+
topics = _ops.convert_to_tensor(topics, _dtypes.string)
|
|
105
|
+
servers = _ops.convert_to_tensor(servers, _dtypes.string)
|
|
106
|
+
group = _ops.convert_to_tensor(group, _dtypes.string)
|
|
107
|
+
eof = _ops.convert_to_tensor(eof, _dtypes.bool)
|
|
108
|
+
timeout = _ops.convert_to_tensor(timeout, _dtypes.int64)
|
|
109
|
+
config_global = _ops.convert_to_tensor(config_global, _dtypes.string)
|
|
110
|
+
config_topic = _ops.convert_to_tensor(config_topic, _dtypes.string)
|
|
111
|
+
message_key = _ops.convert_to_tensor(message_key, _dtypes.bool)
|
|
112
|
+
message_offset = _ops.convert_to_tensor(message_offset, _dtypes.bool)
|
|
113
|
+
_inputs_flat = [
|
|
114
|
+
topics, servers, group, eof, timeout, config_global, config_topic,
|
|
115
|
+
message_key, message_offset
|
|
116
|
+
]
|
|
117
|
+
_attrs = None
|
|
118
|
+
_result = _execute.execute(
|
|
119
|
+
b'IOKafkaDataset',
|
|
120
|
+
1,
|
|
121
|
+
inputs=_inputs_flat,
|
|
122
|
+
attrs=_attrs,
|
|
123
|
+
ctx=_ctx,
|
|
124
|
+
name=name)
|
|
125
|
+
_execute.record_gradient('IOKafkaDataset', _inputs_flat, _attrs, _result,
|
|
126
|
+
name)
|
|
127
|
+
_result, = _result
|
|
128
|
+
return _result
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@tf_export('io_write_kafka_v2')
|
|
132
|
+
def io_write_kafka_v2(message, topic, servers, name=None):
|
|
133
|
+
r"""TODO: add doc.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
message: A `Tensor` of type `string`.
|
|
137
|
+
topic: A `Tensor` of type `string`.
|
|
138
|
+
servers: A `Tensor` of type `string`.
|
|
139
|
+
name: A name for the operation (optional).
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
A `Tensor` of type `string`.
|
|
143
|
+
"""
|
|
144
|
+
_ctx = _context._context
|
|
145
|
+
if _ctx is None or not _ctx._eager_context.is_eager:
|
|
146
|
+
_op = kafka_module.io_write_kafka_v2(
|
|
147
|
+
message=message, topic=topic, servers=servers, name=name)
|
|
148
|
+
_result = _op.outputs[:]
|
|
149
|
+
_inputs_flat = _op.inputs
|
|
150
|
+
_attrs = None
|
|
151
|
+
_execute.record_gradient('IOWriteKafka', _inputs_flat, _attrs, _result,
|
|
152
|
+
name)
|
|
153
|
+
_result, = _result
|
|
154
|
+
return _result
|
|
155
|
+
|
|
156
|
+
else:
|
|
157
|
+
try:
|
|
158
|
+
_result = _pywrap_tensorflow.TFE_Py_FastPathExecute(
|
|
159
|
+
_ctx._context_handle, _ctx._eager_context.device_name, 'IOWriteKafka',
|
|
160
|
+
name, _ctx._post_execution_callbacks, message, topic, servers)
|
|
161
|
+
return _result
|
|
162
|
+
except _core._FallbackException:
|
|
163
|
+
return io_write_kafka_eager_fallback(
|
|
164
|
+
message, topic, servers, name=name, ctx=_ctx)
|
|
165
|
+
except _core._NotOkStatusException as e:
|
|
166
|
+
if name is not None:
|
|
167
|
+
message = e.message + ' name: ' + name
|
|
168
|
+
else:
|
|
169
|
+
message = e.message
|
|
170
|
+
_six.raise_from(_core._status_to_exception(e.code, message), None)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def io_write_kafka_eager_fallback(message, topic, servers, name=None, ctx=None):
|
|
174
|
+
"""This is the slowpath function for Eager mode.
|
|
175
|
+
|
|
176
|
+
This is for function io_write_kafka
|
|
177
|
+
"""
|
|
178
|
+
_ctx = ctx if ctx else _context.context()
|
|
179
|
+
message = _ops.convert_to_tensor(message, _dtypes.string)
|
|
180
|
+
topic = _ops.convert_to_tensor(topic, _dtypes.string)
|
|
181
|
+
servers = _ops.convert_to_tensor(servers, _dtypes.string)
|
|
182
|
+
_inputs_flat = [message, topic, servers]
|
|
183
|
+
_attrs = None
|
|
184
|
+
_result = _execute.execute(
|
|
185
|
+
b'IOWriteKafka',
|
|
186
|
+
1,
|
|
187
|
+
inputs=_inputs_flat,
|
|
188
|
+
attrs=_attrs,
|
|
189
|
+
ctx=_ctx,
|
|
190
|
+
name=name)
|
|
191
|
+
_execute.record_gradient('IOWriteKafka', _inputs_flat, _attrs, _result, name)
|
|
192
|
+
_result, = _result
|
|
193
|
+
return _result
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
from tensorflow.python.ops import string_ops
|
|
8
|
+
|
|
9
|
+
import easy_rec
|
|
10
|
+
from easy_rec.python.utils import constant
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
str_avx_op_path = os.path.join(easy_rec.ops_dir, 'libstr_avx_op.so')
|
|
14
|
+
str_avx_op = tf.load_op_library(str_avx_op_path)
|
|
15
|
+
logging.info('load avx string_split op from %s succeed' % str_avx_op_path)
|
|
16
|
+
except Exception as ex:
|
|
17
|
+
logging.warning('load avx string_split op failed: %s' % str(ex))
|
|
18
|
+
str_avx_op = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def str_split_by_chr(input_str, sep, skip_empty):
|
|
22
|
+
if constant.has_avx_str_split() and str_avx_op is not None:
|
|
23
|
+
assert len(sep) == 1, \
|
|
24
|
+
'invalid data_config.separator(%s) len(%d) != 1' % (
|
|
25
|
+
sep, len(sep))
|
|
26
|
+
return str_avx_op.avx512_string_split(input_str, sep, skip_empty=skip_empty)
|
|
27
|
+
else:
|
|
28
|
+
return string_ops.string_split(input_str, sep, skip_empty=skip_empty)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
|
|
8
|
+
import easy_rec
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
op_path = os.path.join(easy_rec.ops_dir, 'incr_record.so')
|
|
12
|
+
op = tf.load_op_library(op_path)
|
|
13
|
+
get_sparse_indices = op.get_sparse_indices
|
|
14
|
+
set_sparse_indices = op.set_sparse_indices
|
|
15
|
+
if 'kv_resource_incr_gather' in dir(op):
|
|
16
|
+
kv_resource_incr_gather = getattr(op, 'kv_resource_incr_gather')
|
|
17
|
+
else:
|
|
18
|
+
kv_resource_incr_gather = None
|
|
19
|
+
except ImportError as ex:
|
|
20
|
+
get_sparse_indices = None
|
|
21
|
+
set_sparse_indices = None
|
|
22
|
+
kv_resource_incr_gather = None
|
|
23
|
+
logging.warning('failed to import gen_io_ops.collect_sparse_indices: %s' %
|
|
24
|
+
str(ex))
|
|
25
|
+
except Exception as ex:
|
|
26
|
+
get_sparse_indices = None
|
|
27
|
+
set_sparse_indices = None
|
|
28
|
+
kv_resource_incr_gather = None
|
|
29
|
+
logging.warning('failed to import gen_io_ops.collect_sparse_indices: %s' %
|
|
30
|
+
str(ex))
|