easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,1064 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from abc import abstractmethod
|
|
6
|
+
from collections import OrderedDict
|
|
7
|
+
|
|
8
|
+
import six
|
|
9
|
+
import tensorflow as tf
|
|
10
|
+
from tensorflow.python.framework import ops
|
|
11
|
+
from tensorflow.python.ops import array_ops
|
|
12
|
+
from tensorflow.python.ops import sparse_ops
|
|
13
|
+
from tensorflow.python.ops import string_ops
|
|
14
|
+
from tensorflow.python.platform import gfile
|
|
15
|
+
|
|
16
|
+
from easy_rec.python.core import sampler as sampler_lib
|
|
17
|
+
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
|
|
18
|
+
from easy_rec.python.utils import conditional
|
|
19
|
+
from easy_rec.python.utils import config_util
|
|
20
|
+
from easy_rec.python.utils import constant
|
|
21
|
+
from easy_rec.python.utils.check_utils import check_split
|
|
22
|
+
from easy_rec.python.utils.check_utils import check_string_to_number
|
|
23
|
+
from easy_rec.python.utils.expr_util import get_expression
|
|
24
|
+
from easy_rec.python.utils.input_utils import get_type_defaults
|
|
25
|
+
from easy_rec.python.utils.load_class import get_register_class_meta
|
|
26
|
+
from easy_rec.python.utils.load_class import load_by_path
|
|
27
|
+
from easy_rec.python.utils.tf_utils import get_tf_type
|
|
28
|
+
|
|
29
|
+
if tf.__version__ >= '2.0':
|
|
30
|
+
tf = tf.compat.v1
|
|
31
|
+
|
|
32
|
+
_INPUT_CLASS_MAP = {}
|
|
33
|
+
_meta_type = get_register_class_meta(_INPUT_CLASS_MAP, have_abstract_class=True)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Input(six.with_metaclass(_meta_type, object)):
|
|
37
|
+
|
|
38
|
+
DATA_OFFSET = 'DATA_OFFSET'
|
|
39
|
+
|
|
40
|
+
def __init__(self,
|
|
41
|
+
data_config,
|
|
42
|
+
feature_configs,
|
|
43
|
+
input_path,
|
|
44
|
+
task_index=0,
|
|
45
|
+
task_num=1,
|
|
46
|
+
check_mode=False,
|
|
47
|
+
pipeline_config=None,
|
|
48
|
+
**kwargs):
|
|
49
|
+
self._pipeline_config = pipeline_config
|
|
50
|
+
self._data_config = data_config
|
|
51
|
+
self._check_mode = check_mode
|
|
52
|
+
logging.info('check_mode: %s ' % self._check_mode)
|
|
53
|
+
# tf.estimator.ModeKeys.*, only available before
|
|
54
|
+
# calling self._build
|
|
55
|
+
self._mode = None
|
|
56
|
+
if pipeline_config is not None and pipeline_config.model_config.HasField(
|
|
57
|
+
'ev_params'):
|
|
58
|
+
self._has_ev = True
|
|
59
|
+
else:
|
|
60
|
+
self._has_ev = False
|
|
61
|
+
|
|
62
|
+
if self._data_config.auto_expand_input_fields:
|
|
63
|
+
input_fields = [x for x in self._data_config.input_fields]
|
|
64
|
+
while len(self._data_config.input_fields) > 0:
|
|
65
|
+
self._data_config.input_fields.pop()
|
|
66
|
+
for field in input_fields:
|
|
67
|
+
tmp_names = config_util.auto_expand_names(field.input_name)
|
|
68
|
+
for tmp_name in tmp_names:
|
|
69
|
+
one_field = DatasetConfig.Field()
|
|
70
|
+
one_field.CopyFrom(field)
|
|
71
|
+
one_field.input_name = tmp_name
|
|
72
|
+
self._data_config.input_fields.append(one_field)
|
|
73
|
+
|
|
74
|
+
self._input_fields = [x.input_name for x in data_config.input_fields]
|
|
75
|
+
self._input_dims = [x.input_dim for x in data_config.input_fields]
|
|
76
|
+
self._input_field_types = [x.input_type for x in data_config.input_fields]
|
|
77
|
+
self._input_field_defaults = [
|
|
78
|
+
x.default_val for x in data_config.input_fields
|
|
79
|
+
]
|
|
80
|
+
self._label_fields = list(data_config.label_fields)
|
|
81
|
+
self._feature_fields = list(data_config.feature_fields)
|
|
82
|
+
self._label_sep = list(data_config.label_sep)
|
|
83
|
+
self._label_dim = list(data_config.label_dim)
|
|
84
|
+
if len(self._label_dim) < len(self._label_fields):
|
|
85
|
+
for x in range(len(self._label_fields) - len(self._label_dim)):
|
|
86
|
+
self._label_dim.append(1)
|
|
87
|
+
|
|
88
|
+
self._label_udf_map = {}
|
|
89
|
+
for config in self._data_config.input_fields:
|
|
90
|
+
if config.HasField('user_define_fn'):
|
|
91
|
+
self._label_udf_map[config.input_name] = self._load_label_fn(config)
|
|
92
|
+
|
|
93
|
+
self._batch_size = data_config.batch_size
|
|
94
|
+
self._prefetch_size = data_config.prefetch_size
|
|
95
|
+
self._feature_configs = list(feature_configs)
|
|
96
|
+
self._task_index = task_index
|
|
97
|
+
self._task_num = task_num
|
|
98
|
+
|
|
99
|
+
self._input_path = input_path
|
|
100
|
+
|
|
101
|
+
# findout effective fields
|
|
102
|
+
self._effective_fields = []
|
|
103
|
+
|
|
104
|
+
# for multi value inputs, the types maybe different
|
|
105
|
+
# from the types defined in input_fields
|
|
106
|
+
# it is used in create_multi_placeholders
|
|
107
|
+
self._multi_value_types = {}
|
|
108
|
+
self._multi_value_fields = set()
|
|
109
|
+
|
|
110
|
+
self._normalizer_fn = {}
|
|
111
|
+
for fc in self._feature_configs:
|
|
112
|
+
for input_name in fc.input_names:
|
|
113
|
+
assert input_name in self._input_fields, 'invalid input_name in %s' % str(
|
|
114
|
+
fc)
|
|
115
|
+
if input_name not in self._effective_fields:
|
|
116
|
+
self._effective_fields.append(input_name)
|
|
117
|
+
|
|
118
|
+
if fc.feature_type in [fc.TagFeature, fc.SequenceFeature]:
|
|
119
|
+
if fc.hash_bucket_size > 0 or len(
|
|
120
|
+
fc.vocab_list) > 0 or fc.HasField('vocab_file'):
|
|
121
|
+
self._multi_value_types[fc.input_names[0]] = tf.string
|
|
122
|
+
self._multi_value_fields.add(fc.input_names[0])
|
|
123
|
+
else:
|
|
124
|
+
self._multi_value_types[fc.input_names[0]] = tf.int64
|
|
125
|
+
self._multi_value_fields.add(fc.input_names[0])
|
|
126
|
+
if len(fc.input_names) > 1:
|
|
127
|
+
self._multi_value_types[fc.input_names[1]] = tf.float32
|
|
128
|
+
self._multi_value_fields.add(fc.input_names[1])
|
|
129
|
+
|
|
130
|
+
if fc.feature_type == fc.RawFeature and fc.raw_input_dim > 1:
|
|
131
|
+
self._multi_value_types[fc.input_names[0]] = tf.float32
|
|
132
|
+
self._multi_value_fields.add(fc.input_names[0])
|
|
133
|
+
|
|
134
|
+
if fc.HasField('normalizer_fn'):
|
|
135
|
+
feature_name = fc.feature_name if fc.HasField(
|
|
136
|
+
'feature_name') else fc.input_names[0]
|
|
137
|
+
self._normalizer_fn[feature_name] = load_by_path(fc.normalizer_fn)
|
|
138
|
+
|
|
139
|
+
# add sample weight to effective fields
|
|
140
|
+
if self._data_config.HasField('sample_weight'):
|
|
141
|
+
self._effective_fields.append(self._data_config.sample_weight)
|
|
142
|
+
|
|
143
|
+
# add uid_field of GAUC and session_fields of SessionAUC
|
|
144
|
+
if self._pipeline_config is not None:
|
|
145
|
+
metrics = self._pipeline_config.eval_config.metrics_set
|
|
146
|
+
for metric in metrics:
|
|
147
|
+
metric_name = metric.WhichOneof('metric')
|
|
148
|
+
if metric_name == 'gauc':
|
|
149
|
+
uid = metric.gauc.uid_field
|
|
150
|
+
if uid not in self._effective_fields:
|
|
151
|
+
self._effective_fields.append(uid)
|
|
152
|
+
elif metric_name == 'session_auc':
|
|
153
|
+
sid = metric.session_auc.session_id_field
|
|
154
|
+
if sid not in self._effective_fields:
|
|
155
|
+
self._effective_fields.append(sid)
|
|
156
|
+
|
|
157
|
+
# check multi task model's metrics
|
|
158
|
+
model_config = self._pipeline_config.model_config
|
|
159
|
+
model_name = model_config.WhichOneof('model')
|
|
160
|
+
if model_name in {'mmoe', 'esmm', 'dbmtl', 'simple_multi_task', 'ple'}:
|
|
161
|
+
model = getattr(model_config, model_name)
|
|
162
|
+
towers = [model.ctr_tower, model.cvr_tower
|
|
163
|
+
] if model_name == 'esmm' else model.task_towers
|
|
164
|
+
for tower in towers:
|
|
165
|
+
metrics = tower.metrics_set
|
|
166
|
+
for metric in metrics:
|
|
167
|
+
metric_name = metric.WhichOneof('metric')
|
|
168
|
+
if metric_name == 'gauc':
|
|
169
|
+
uid = metric.gauc.uid_field
|
|
170
|
+
if uid not in self._effective_fields:
|
|
171
|
+
self._effective_fields.append(uid)
|
|
172
|
+
elif metric_name == 'session_auc':
|
|
173
|
+
sid = metric.session_auc.session_id_field
|
|
174
|
+
if sid not in self._effective_fields:
|
|
175
|
+
self._effective_fields.append(sid)
|
|
176
|
+
|
|
177
|
+
self._effective_fids = [
|
|
178
|
+
self._input_fields.index(x) for x in self._effective_fields
|
|
179
|
+
]
|
|
180
|
+
# sort fids from small to large
|
|
181
|
+
self._effective_fids = list(set(self._effective_fids))
|
|
182
|
+
self._effective_fields = [
|
|
183
|
+
self._input_fields[x] for x in self._effective_fids
|
|
184
|
+
]
|
|
185
|
+
|
|
186
|
+
self._label_fids = [self._input_fields.index(x) for x in self._label_fields]
|
|
187
|
+
|
|
188
|
+
# virtual fields generated by self._preprocess
|
|
189
|
+
# which will be inputs to feature columns
|
|
190
|
+
self._appended_fields = []
|
|
191
|
+
|
|
192
|
+
# sampler
|
|
193
|
+
self._sampler = None
|
|
194
|
+
if input_path is not None:
|
|
195
|
+
# build sampler only when train and eval
|
|
196
|
+
self._sampler = sampler_lib.build(data_config)
|
|
197
|
+
|
|
198
|
+
self.get_type_defaults = get_type_defaults
|
|
199
|
+
|
|
200
|
+
def _load_label_fn(self, config):
|
|
201
|
+
udf_class = config.user_define_fn
|
|
202
|
+
udf_path = config.user_define_fn_path if config.HasField(
|
|
203
|
+
'user_define_fn_path') else None
|
|
204
|
+
dtype = config.user_define_fn_res_type if config.HasField(
|
|
205
|
+
'user_define_fn_res_type') else None
|
|
206
|
+
|
|
207
|
+
if udf_path:
|
|
208
|
+
if udf_path.startswith('oss://') or udf_path.startswith('hdfs://'):
|
|
209
|
+
with gfile.GFile(udf_path, 'r') as fin:
|
|
210
|
+
udf_content = fin.read()
|
|
211
|
+
final_udf_tmp_path = '/udf/'
|
|
212
|
+
final_udf_path = final_udf_tmp_path + udf_path.split('/')[-1]
|
|
213
|
+
logging.info('final udf path %s' % final_udf_path)
|
|
214
|
+
logging.info('udf content: %s' % udf_content)
|
|
215
|
+
if not gfile.Exists(final_udf_tmp_path):
|
|
216
|
+
gfile.MkDir(final_udf_tmp_path)
|
|
217
|
+
with gfile.GFile(final_udf_path, 'w') as fin:
|
|
218
|
+
fin.write(udf_content)
|
|
219
|
+
else:
|
|
220
|
+
final_udf_path = udf_path
|
|
221
|
+
final_udf_path = final_udf_path[:-3].replace('/', '.')
|
|
222
|
+
udf_class = final_udf_path + '.' + udf_class
|
|
223
|
+
logging.info('apply udf %s' % udf_class)
|
|
224
|
+
return load_by_path(udf_class), udf_class, dtype
|
|
225
|
+
|
|
226
|
+
@property
|
|
227
|
+
def num_epochs(self):
|
|
228
|
+
if self._data_config.num_epochs > 0:
|
|
229
|
+
return self._data_config.num_epochs
|
|
230
|
+
else:
|
|
231
|
+
return None
|
|
232
|
+
|
|
233
|
+
def get_feature_input_fields(self):
|
|
234
|
+
return [
|
|
235
|
+
x for x in self._input_fields
|
|
236
|
+
if x not in self._label_fields and x != self._data_config.sample_weight
|
|
237
|
+
]
|
|
238
|
+
|
|
239
|
+
def should_stop(self, curr_epoch):
|
|
240
|
+
"""Check whether have run enough num epochs."""
|
|
241
|
+
total_epoch = self.num_epochs
|
|
242
|
+
if self._mode != tf.estimator.ModeKeys.TRAIN:
|
|
243
|
+
total_epoch = 1
|
|
244
|
+
return total_epoch is not None and curr_epoch >= total_epoch
|
|
245
|
+
|
|
246
|
+
def create_multi_placeholders(self, export_config):
|
|
247
|
+
"""Create multiply placeholders on export, one for each feature.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
export_config: ExportConfig instance.
|
|
251
|
+
"""
|
|
252
|
+
self._mode = tf.estimator.ModeKeys.PREDICT
|
|
253
|
+
|
|
254
|
+
if export_config.auto_multi_value:
|
|
255
|
+
export_fields_name = self._multi_value_fields
|
|
256
|
+
elif export_config.multi_value_fields:
|
|
257
|
+
export_fields_name = export_config.multi_value_fields.input_name
|
|
258
|
+
else:
|
|
259
|
+
export_fields_name = None
|
|
260
|
+
placeholder_named_by_input = export_config.placeholder_named_by_input
|
|
261
|
+
|
|
262
|
+
sample_weight_field = ''
|
|
263
|
+
if self._data_config.HasField('sample_weight'):
|
|
264
|
+
sample_weight_field = self._data_config.sample_weight
|
|
265
|
+
|
|
266
|
+
if export_config.filter_inputs:
|
|
267
|
+
effective_fids = list(self._effective_fids)
|
|
268
|
+
else:
|
|
269
|
+
effective_fids = [
|
|
270
|
+
fid for fid in range(len(self._input_fields))
|
|
271
|
+
if self._input_fields[fid] not in self._label_fields and
|
|
272
|
+
self._input_fields[fid] != sample_weight_field
|
|
273
|
+
]
|
|
274
|
+
|
|
275
|
+
inputs = {}
|
|
276
|
+
for fid in effective_fids:
|
|
277
|
+
input_name = self._input_fields[fid]
|
|
278
|
+
if input_name == sample_weight_field:
|
|
279
|
+
continue
|
|
280
|
+
if placeholder_named_by_input:
|
|
281
|
+
placeholder_name = input_name
|
|
282
|
+
else:
|
|
283
|
+
placeholder_name = 'input_%d' % fid
|
|
284
|
+
if input_name in export_fields_name:
|
|
285
|
+
tf_type = self._multi_value_types[input_name] if input_name in self._multi_value_types \
|
|
286
|
+
else get_tf_type(self._input_field_types[fid])
|
|
287
|
+
logging.info('multi value input_name: %s, dtype: %s' %
|
|
288
|
+
(input_name, tf_type))
|
|
289
|
+
finput = array_ops.placeholder(
|
|
290
|
+
tf_type, [None, None], name=placeholder_name)
|
|
291
|
+
else:
|
|
292
|
+
ftype = self._input_field_types[fid]
|
|
293
|
+
tf_type = get_tf_type(ftype)
|
|
294
|
+
logging.info('input_name: %s, dtype: %s' % (input_name, tf_type))
|
|
295
|
+
finput = array_ops.placeholder(tf_type, [None], name=placeholder_name)
|
|
296
|
+
inputs[input_name] = finput
|
|
297
|
+
features = {x: inputs[x] for x in inputs}
|
|
298
|
+
features = self._preprocess(features)
|
|
299
|
+
return inputs, features['feature']
|
|
300
|
+
|
|
301
|
+
def create_placeholders(self, export_config):
|
|
302
|
+
self._mode = tf.estimator.ModeKeys.PREDICT
|
|
303
|
+
inputs_placeholder = array_ops.placeholder(
|
|
304
|
+
tf.string, [None], name='features')
|
|
305
|
+
input_vals = tf.string_split(
|
|
306
|
+
inputs_placeholder, self._data_config.separator,
|
|
307
|
+
skip_empty=False).values
|
|
308
|
+
|
|
309
|
+
sample_weight_field = ''
|
|
310
|
+
if self._data_config.HasField('sample_weight'):
|
|
311
|
+
sample_weight_field = self._data_config.sample_weight
|
|
312
|
+
|
|
313
|
+
if export_config.filter_inputs:
|
|
314
|
+
effective_fids = list(self._effective_fids)
|
|
315
|
+
logging.info('number of effective inputs:%d, total number inputs: %d' %
|
|
316
|
+
(len(effective_fids), len(self._input_fields)))
|
|
317
|
+
else:
|
|
318
|
+
effective_fids = [
|
|
319
|
+
fid for fid in range(len(self._input_fields))
|
|
320
|
+
if self._input_fields[fid] not in self._label_fields and
|
|
321
|
+
self._input_fields[fid] != sample_weight_field
|
|
322
|
+
]
|
|
323
|
+
logging.info(
|
|
324
|
+
'will not filter any input[except labels], total number inputs:%d' %
|
|
325
|
+
len(effective_fids))
|
|
326
|
+
input_vals = tf.reshape(
|
|
327
|
+
input_vals, [-1, len(effective_fids)], name='input_reshape')
|
|
328
|
+
features = {}
|
|
329
|
+
for tmp_id, fid in enumerate(effective_fids):
|
|
330
|
+
ftype = self._input_field_types[fid]
|
|
331
|
+
tf_type = get_tf_type(ftype)
|
|
332
|
+
input_name = self._input_fields[fid]
|
|
333
|
+
if tf_type in [tf.float32, tf.double, tf.int32, tf.int64]:
|
|
334
|
+
features[input_name] = tf.string_to_number(
|
|
335
|
+
input_vals[:, tmp_id],
|
|
336
|
+
tf_type,
|
|
337
|
+
name='input_str_to_%s' % tf_type.name)
|
|
338
|
+
else:
|
|
339
|
+
if ftype not in [DatasetConfig.STRING]:
|
|
340
|
+
logging.warning('unexpected field type: ftype=%s tf_type=%s' %
|
|
341
|
+
(ftype, tf_type))
|
|
342
|
+
features[input_name] = input_vals[:, tmp_id]
|
|
343
|
+
features = self._preprocess(features)
|
|
344
|
+
return {'features': inputs_placeholder}, features['feature']
|
|
345
|
+
|
|
346
|
+
def _get_features(self, fields):
|
|
347
|
+
return fields['feature']
|
|
348
|
+
|
|
349
|
+
def _get_labels(self, fields):
|
|
350
|
+
labels = fields['label']
|
|
351
|
+
return OrderedDict([
|
|
352
|
+
(x, tf.squeeze(labels[x], axis=1) if len(labels[x].get_shape()) == 2 and
|
|
353
|
+
labels[x].get_shape()[1] == 1 else labels[x]) for x in labels
|
|
354
|
+
])
|
|
355
|
+
|
|
356
|
+
def _as_string(self, field, fc):
|
|
357
|
+
if field.dtype == tf.string:
|
|
358
|
+
return field
|
|
359
|
+
if field.dtype in [tf.float32, tf.double]:
|
|
360
|
+
feature_name = fc.feature_name if fc.HasField(
|
|
361
|
+
'feature_name') else fc.input_names[0]
|
|
362
|
+
assert fc.precision > 0, 'fc.precision not set for feature[%s], it is dangerous to convert ' \
|
|
363
|
+
'float or double to string due to precision problem, it is suggested ' \
|
|
364
|
+
' to convert them into string format before using EasyRec; ' \
|
|
365
|
+
'if you really need to do so, please set precision (the number of ' \
|
|
366
|
+
'decimal digits) carefully.' % feature_name
|
|
367
|
+
precision = None
|
|
368
|
+
if field.dtype in [tf.float32, tf.double]:
|
|
369
|
+
if fc.precision > 0:
|
|
370
|
+
precision = fc.precision
|
|
371
|
+
|
|
372
|
+
# convert to string
|
|
373
|
+
if 'as_string' in dir(tf.strings):
|
|
374
|
+
return tf.strings.as_string(field, precision=precision)
|
|
375
|
+
else:
|
|
376
|
+
return tf.as_string(field, precision=precision)
|
|
377
|
+
|
|
378
|
+
def _parse_combo_feature(self, fc, parsed_dict, field_dict):
|
|
379
|
+
# for compatibility with existing implementations
|
|
380
|
+
feature_name = fc.feature_name if fc.HasField(
|
|
381
|
+
'feature_name') else fc.input_names[0]
|
|
382
|
+
|
|
383
|
+
if len(fc.combo_input_seps) > 0:
|
|
384
|
+
assert len(fc.combo_input_seps) == len(fc.input_names), \
|
|
385
|
+
'len(combo_separator)[%d] != len(fc.input_names)[%d]' % (
|
|
386
|
+
len(fc.combo_input_seps), len(fc.input_names))
|
|
387
|
+
|
|
388
|
+
def _get_input_sep(input_id):
|
|
389
|
+
if input_id < len(fc.combo_input_seps):
|
|
390
|
+
return fc.combo_input_seps[input_id]
|
|
391
|
+
else:
|
|
392
|
+
return ''
|
|
393
|
+
|
|
394
|
+
if len(fc.combo_join_sep) == 0:
|
|
395
|
+
for input_id, input_name in enumerate(fc.input_names):
|
|
396
|
+
if input_id > 0:
|
|
397
|
+
key = feature_name + '_' + str(input_id)
|
|
398
|
+
else:
|
|
399
|
+
key = feature_name
|
|
400
|
+
input_sep = _get_input_sep(input_id)
|
|
401
|
+
if input_sep != '':
|
|
402
|
+
assert field_dict[
|
|
403
|
+
input_name].dtype == tf.string, 'could not apply string_split to input-name[%s] dtype=%s' % (
|
|
404
|
+
input_name, field_dict[input_name].dtype)
|
|
405
|
+
parsed_dict[key] = tf.string_split(field_dict[input_name], input_sep)
|
|
406
|
+
else:
|
|
407
|
+
parsed_dict[key] = self._as_string(field_dict[input_name], fc)
|
|
408
|
+
else:
|
|
409
|
+
if len(fc.combo_input_seps) > 0:
|
|
410
|
+
split_inputs = []
|
|
411
|
+
for input_id, input_name in enumerate(fc.input_names):
|
|
412
|
+
input_sep = fc.combo_input_seps[input_id]
|
|
413
|
+
if len(input_sep) > 0:
|
|
414
|
+
assert field_dict[
|
|
415
|
+
input_name].dtype == tf.string, 'could not apply string_split to input-name[%s] dtype=%s' % (
|
|
416
|
+
input_name, field_dict[input_name].dtype)
|
|
417
|
+
split_inputs.append(
|
|
418
|
+
tf.string_split(field_dict[input_name],
|
|
419
|
+
fc.combo_input_seps[input_id]))
|
|
420
|
+
else:
|
|
421
|
+
split_inputs.append(tf.reshape(field_dict[input_name], [-1, 1]))
|
|
422
|
+
parsed_dict[feature_name] = sparse_ops.sparse_cross(
|
|
423
|
+
split_inputs, fc.combo_join_sep)
|
|
424
|
+
else:
|
|
425
|
+
inputs = [
|
|
426
|
+
self._as_string(field_dict[input_name], fc)
|
|
427
|
+
for input_name in fc.input_names
|
|
428
|
+
]
|
|
429
|
+
parsed_dict[feature_name] = string_ops.string_join(
|
|
430
|
+
inputs, fc.combo_join_sep)
|
|
431
|
+
|
|
432
|
+
def _parse_tag_feature(self, fc, parsed_dict, field_dict):
|
|
433
|
+
input_0 = fc.input_names[0]
|
|
434
|
+
feature_name = fc.feature_name if fc.HasField('feature_name') else input_0
|
|
435
|
+
field = field_dict[input_0]
|
|
436
|
+
# Construct the output of TagFeature according to the dimension of field_dict.
|
|
437
|
+
# When the input field exceeds 2 dimensions, convert TagFeature to 2D output.
|
|
438
|
+
if len(field.get_shape()) < 2 or field.get_shape()[-1] == 1:
|
|
439
|
+
if len(field.get_shape()) == 0:
|
|
440
|
+
field = tf.expand_dims(field, axis=0)
|
|
441
|
+
elif len(field.get_shape()) == 2:
|
|
442
|
+
field = tf.squeeze(field, axis=-1)
|
|
443
|
+
if fc.HasField('kv_separator') and len(fc.input_names) > 1:
|
|
444
|
+
assert False, 'Tag Feature Error, ' \
|
|
445
|
+
'Cannot set kv_separator and multi input_names in one feature config. Feature: %s.' % input_0
|
|
446
|
+
parsed_dict[feature_name] = tf.string_split(field, fc.separator)
|
|
447
|
+
if fc.HasField('kv_separator'):
|
|
448
|
+
indices = parsed_dict[feature_name].indices
|
|
449
|
+
tmp_kvs = parsed_dict[feature_name].values
|
|
450
|
+
tmp_kvs = tf.string_split(tmp_kvs, fc.kv_separator, skip_empty=False)
|
|
451
|
+
tmp_kvs = tf.reshape(tmp_kvs.values, [-1, 2])
|
|
452
|
+
tmp_ks, tmp_vs = tmp_kvs[:, 0], tmp_kvs[:, 1]
|
|
453
|
+
|
|
454
|
+
check_list = [
|
|
455
|
+
tf.py_func(check_string_to_number, [tmp_vs, input_0], Tout=tf.bool)
|
|
456
|
+
] if self._check_mode else []
|
|
457
|
+
with tf.control_dependencies(check_list):
|
|
458
|
+
tmp_vs = tf.string_to_number(
|
|
459
|
+
tmp_vs, tf.float32, name='kv_tag_wgt_str_2_flt_%s' % input_0)
|
|
460
|
+
parsed_dict[feature_name] = tf.sparse.SparseTensor(
|
|
461
|
+
indices, tmp_ks, parsed_dict[feature_name].dense_shape)
|
|
462
|
+
parsed_dict[feature_name + '_w'] = tf.sparse.SparseTensor(
|
|
463
|
+
indices, tmp_vs, parsed_dict[feature_name].dense_shape)
|
|
464
|
+
if not fc.HasField('hash_bucket_size') and fc.num_buckets > 0:
|
|
465
|
+
check_list = [
|
|
466
|
+
tf.py_func(
|
|
467
|
+
check_string_to_number,
|
|
468
|
+
[parsed_dict[feature_name].values, input_0],
|
|
469
|
+
Tout=tf.bool)
|
|
470
|
+
] if self._check_mode else []
|
|
471
|
+
with tf.control_dependencies(check_list):
|
|
472
|
+
vals = tf.string_to_number(
|
|
473
|
+
parsed_dict[feature_name].values,
|
|
474
|
+
tf.int32,
|
|
475
|
+
name='tag_fea_%s' % input_0)
|
|
476
|
+
parsed_dict[feature_name] = tf.sparse.SparseTensor(
|
|
477
|
+
parsed_dict[feature_name].indices, vals,
|
|
478
|
+
parsed_dict[feature_name].dense_shape)
|
|
479
|
+
if len(fc.input_names) > 1:
|
|
480
|
+
input_1 = fc.input_names[1]
|
|
481
|
+
field = field_dict[input_1]
|
|
482
|
+
if len(field.get_shape()) == 0:
|
|
483
|
+
field = tf.expand_dims(field, axis=0)
|
|
484
|
+
field = tf.string_split(field, fc.separator)
|
|
485
|
+
check_list = [
|
|
486
|
+
tf.py_func(
|
|
487
|
+
check_string_to_number, [field.values, input_1], Tout=tf.bool)
|
|
488
|
+
] if self._check_mode else []
|
|
489
|
+
with tf.control_dependencies(check_list):
|
|
490
|
+
field_vals = tf.string_to_number(
|
|
491
|
+
field.values, tf.float32, name='tag_wgt_str_2_flt_%s' % input_1)
|
|
492
|
+
assert_op = tf.assert_equal(
|
|
493
|
+
tf.shape(field_vals)[0],
|
|
494
|
+
tf.shape(parsed_dict[feature_name].values)[0],
|
|
495
|
+
message='TagFeature Error: The size of %s not equal to the size of %s. Please check input: %s and %s.'
|
|
496
|
+
% (input_0, input_1, input_0, input_1))
|
|
497
|
+
with tf.control_dependencies([assert_op]):
|
|
498
|
+
field = tf.sparse.SparseTensor(field.indices, tf.identity(field_vals),
|
|
499
|
+
field.dense_shape)
|
|
500
|
+
parsed_dict[feature_name + '_w'] = field
|
|
501
|
+
else:
|
|
502
|
+
parsed_dict[feature_name] = field_dict[input_0]
|
|
503
|
+
if len(fc.input_names) > 1:
|
|
504
|
+
input_1 = fc.input_names[1]
|
|
505
|
+
parsed_dict[feature_name + '_w'] = field_dict[input_1]
|
|
506
|
+
|
|
507
|
+
def _parse_expr_feature(self, fc, parsed_dict, field_dict):
|
|
508
|
+
fea_name = fc.feature_name
|
|
509
|
+
prefix = 'expr_'
|
|
510
|
+
for input_name in fc.input_names:
|
|
511
|
+
new_input_name = prefix + input_name
|
|
512
|
+
if field_dict[input_name].dtype == tf.string:
|
|
513
|
+
check_list = [
|
|
514
|
+
tf.py_func(
|
|
515
|
+
check_string_to_number, [field_dict[input_name], input_name],
|
|
516
|
+
Tout=tf.bool)
|
|
517
|
+
] if self._check_mode else []
|
|
518
|
+
with tf.control_dependencies(check_list):
|
|
519
|
+
parsed_dict[new_input_name] = tf.string_to_number(
|
|
520
|
+
field_dict[input_name],
|
|
521
|
+
tf.float64,
|
|
522
|
+
name='%s_str_2_int_for_expr' % new_input_name)
|
|
523
|
+
elif field_dict[input_name].dtype in [
|
|
524
|
+
tf.int32, tf.int64, tf.double, tf.float32
|
|
525
|
+
]:
|
|
526
|
+
parsed_dict[new_input_name] = tf.cast(field_dict[input_name],
|
|
527
|
+
tf.float64)
|
|
528
|
+
else:
|
|
529
|
+
assert False, 'invalid input dtype[%s] for expr feature' % str(
|
|
530
|
+
field_dict[input_name].dtype)
|
|
531
|
+
|
|
532
|
+
expression = get_expression(fc.expression, fc.input_names, prefix=prefix)
|
|
533
|
+
logging.info('expression: %s' % expression)
|
|
534
|
+
parsed_dict[fea_name] = eval(expression)
|
|
535
|
+
self._appended_fields.append(fea_name)
|
|
536
|
+
|
|
537
|
+
def _parse_id_feature(self, fc, parsed_dict, field_dict):
|
|
538
|
+
input_0 = fc.input_names[0]
|
|
539
|
+
feature_name = fc.feature_name if fc.HasField('feature_name') else input_0
|
|
540
|
+
parsed_dict[feature_name] = field_dict[input_0]
|
|
541
|
+
if fc.HasField('hash_bucket_size'):
|
|
542
|
+
if field_dict[input_0].dtype != tf.string:
|
|
543
|
+
parsed_dict[feature_name] = self._as_string(field_dict[input_0], fc)
|
|
544
|
+
elif fc.num_buckets > 0:
|
|
545
|
+
if parsed_dict[feature_name].dtype == tf.string:
|
|
546
|
+
check_list = [
|
|
547
|
+
tf.py_func(
|
|
548
|
+
check_string_to_number, [parsed_dict[feature_name], input_0],
|
|
549
|
+
Tout=tf.bool)
|
|
550
|
+
] if self._check_mode else []
|
|
551
|
+
with tf.control_dependencies(check_list):
|
|
552
|
+
parsed_dict[feature_name] = tf.string_to_number(
|
|
553
|
+
parsed_dict[feature_name],
|
|
554
|
+
tf.int32,
|
|
555
|
+
name='%s_str_2_int' % input_0)
|
|
556
|
+
|
|
557
|
+
def _parse_raw_feature(self, fc, parsed_dict, field_dict):
|
|
558
|
+
input_0 = fc.input_names[0]
|
|
559
|
+
feature_name = fc.feature_name if fc.HasField('feature_name') else input_0
|
|
560
|
+
if field_dict[input_0].dtype == tf.string:
|
|
561
|
+
if fc.HasField('seq_multi_sep') and fc.HasField('combiner'):
|
|
562
|
+
fea = tf.string_split(field_dict[input_0], fc.seq_multi_sep)
|
|
563
|
+
segment_ids = fea.indices[:, 0]
|
|
564
|
+
vals = fea.values
|
|
565
|
+
else:
|
|
566
|
+
vals = field_dict[input_0]
|
|
567
|
+
segment_ids = tf.range(0, tf.shape(vals)[0])
|
|
568
|
+
if fc.raw_input_dim > 1:
|
|
569
|
+
check_list = [
|
|
570
|
+
tf.py_func(
|
|
571
|
+
check_split, [vals, fc.separator, fc.raw_input_dim, input_0],
|
|
572
|
+
Tout=tf.bool)
|
|
573
|
+
] if self._check_mode else []
|
|
574
|
+
with tf.control_dependencies(check_list):
|
|
575
|
+
tmp_fea = tf.string_split(vals, fc.separator)
|
|
576
|
+
check_list = [
|
|
577
|
+
tf.py_func(
|
|
578
|
+
check_string_to_number, [tmp_fea.values, input_0], Tout=tf.bool)
|
|
579
|
+
] if self._check_mode else []
|
|
580
|
+
with tf.control_dependencies(check_list):
|
|
581
|
+
tmp_vals = tf.string_to_number(
|
|
582
|
+
tmp_fea.values,
|
|
583
|
+
tf.float32,
|
|
584
|
+
name='multi_raw_fea_to_flt_%s' % input_0)
|
|
585
|
+
if fc.HasField('seq_multi_sep') and fc.HasField('combiner'):
|
|
586
|
+
emb = tf.reshape(tmp_vals, [-1, fc.raw_input_dim])
|
|
587
|
+
if fc.combiner == 'max':
|
|
588
|
+
emb = tf.segment_max(emb, segment_ids)
|
|
589
|
+
elif fc.combiner == 'sum':
|
|
590
|
+
emb = tf.segment_sum(emb, segment_ids)
|
|
591
|
+
elif fc.combiner == 'min':
|
|
592
|
+
emb = tf.segment_min(emb, segment_ids)
|
|
593
|
+
elif fc.combiner == 'mean':
|
|
594
|
+
emb = tf.segment_mean(emb, segment_ids)
|
|
595
|
+
else:
|
|
596
|
+
assert False, 'unsupported combine operator: ' + fc.combiner
|
|
597
|
+
parsed_dict[feature_name] = emb
|
|
598
|
+
else:
|
|
599
|
+
parsed_dict[feature_name] = tf.sparse_to_dense(
|
|
600
|
+
tmp_fea.indices,
|
|
601
|
+
[tf.shape(field_dict[input_0])[0], fc.raw_input_dim],
|
|
602
|
+
tmp_vals,
|
|
603
|
+
default_value=0)
|
|
604
|
+
elif fc.HasField('seq_multi_sep') and fc.HasField('combiner'):
|
|
605
|
+
check_list = [
|
|
606
|
+
tf.py_func(check_string_to_number, [vals, input_0], Tout=tf.bool)
|
|
607
|
+
] if self._check_mode else []
|
|
608
|
+
with tf.control_dependencies(check_list):
|
|
609
|
+
emb = tf.string_to_number(
|
|
610
|
+
vals, tf.float32, name='raw_fea_to_flt_%s' % input_0)
|
|
611
|
+
if fc.combiner == 'max':
|
|
612
|
+
emb = tf.segment_max(emb, segment_ids)
|
|
613
|
+
elif fc.combiner == 'sum':
|
|
614
|
+
emb = tf.segment_sum(emb, segment_ids)
|
|
615
|
+
elif fc.combiner == 'min':
|
|
616
|
+
emb = tf.segment_min(emb, segment_ids)
|
|
617
|
+
elif fc.combiner == 'mean':
|
|
618
|
+
emb = tf.segment_mean(emb, segment_ids)
|
|
619
|
+
else:
|
|
620
|
+
assert False, 'unsupported combine operator: ' + fc.combiner
|
|
621
|
+
parsed_dict[feature_name] = emb
|
|
622
|
+
else:
|
|
623
|
+
check_list = [
|
|
624
|
+
tf.py_func(
|
|
625
|
+
check_string_to_number, [field_dict[input_0], input_0],
|
|
626
|
+
Tout=tf.bool)
|
|
627
|
+
] if self._check_mode else []
|
|
628
|
+
with tf.control_dependencies(check_list):
|
|
629
|
+
parsed_dict[feature_name] = tf.string_to_number(
|
|
630
|
+
field_dict[input_0], tf.float32)
|
|
631
|
+
elif field_dict[input_0].dtype in [
|
|
632
|
+
tf.int32, tf.int64, tf.double, tf.float32
|
|
633
|
+
]:
|
|
634
|
+
parsed_dict[feature_name] = tf.to_float(field_dict[input_0])
|
|
635
|
+
else:
|
|
636
|
+
assert False, 'invalid dtype[%s] for raw feature' % str(
|
|
637
|
+
field_dict[input_0].dtype)
|
|
638
|
+
if fc.max_val > fc.min_val:
|
|
639
|
+
parsed_dict[feature_name] = (parsed_dict[feature_name] - fc.min_val) / (
|
|
640
|
+
fc.max_val - fc.min_val)
|
|
641
|
+
|
|
642
|
+
if fc.HasField('normalizer_fn'):
|
|
643
|
+
logging.info('apply normalizer_fn %s to `%s`' %
|
|
644
|
+
(fc.normalizer_fn, feature_name))
|
|
645
|
+
parsed_dict[feature_name] = self._normalizer_fn[feature_name](
|
|
646
|
+
parsed_dict[feature_name])
|
|
647
|
+
|
|
648
|
+
if not fc.boundaries and fc.num_buckets <= 1 and \
|
|
649
|
+
fc.embedding_dim > 0 and \
|
|
650
|
+
self._data_config.sample_weight != input_0:
|
|
651
|
+
# may need by wide model and deep model to project
|
|
652
|
+
# raw values to a vector, it maybe better implemented
|
|
653
|
+
# by a ProjectionColumn later
|
|
654
|
+
sample_num = tf.to_int64(tf.shape(parsed_dict[feature_name])[0])
|
|
655
|
+
indices_0 = tf.range(sample_num, dtype=tf.int64)
|
|
656
|
+
indices_1 = tf.range(fc.raw_input_dim, dtype=tf.int64)
|
|
657
|
+
indices_0 = indices_0[:, None]
|
|
658
|
+
indices_1 = indices_1[None, :]
|
|
659
|
+
indices_0 = tf.tile(indices_0, [1, fc.raw_input_dim])
|
|
660
|
+
indices_1 = tf.tile(indices_1, [sample_num, 1])
|
|
661
|
+
indices_0 = tf.reshape(indices_0, [-1, 1])
|
|
662
|
+
indices_1 = tf.reshape(indices_1, [-1, 1])
|
|
663
|
+
indices = tf.concat([indices_0, indices_1], axis=1)
|
|
664
|
+
|
|
665
|
+
tmp_parsed = parsed_dict[feature_name]
|
|
666
|
+
parsed_dict[feature_name + '_raw_proj_id'] = tf.SparseTensor(
|
|
667
|
+
indices=indices,
|
|
668
|
+
values=indices_1[:, 0],
|
|
669
|
+
dense_shape=[sample_num, fc.raw_input_dim])
|
|
670
|
+
parsed_dict[feature_name + '_raw_proj_val'] = tf.SparseTensor(
|
|
671
|
+
indices=indices,
|
|
672
|
+
values=tf.reshape(tmp_parsed, [-1]),
|
|
673
|
+
dense_shape=[sample_num, fc.raw_input_dim])
|
|
674
|
+
# self._appended_fields.append(input_0 + '_raw_proj_id')
|
|
675
|
+
# self._appended_fields.append(input_0 + '_raw_proj_val')
|
|
676
|
+
|
|
677
|
+
def _parse_seq_feature(self, fc, parsed_dict, field_dict):
|
|
678
|
+
input_0 = fc.input_names[0]
|
|
679
|
+
feature_name = fc.feature_name if fc.HasField('feature_name') else input_0
|
|
680
|
+
field = field_dict[input_0]
|
|
681
|
+
sub_feature_type = fc.sub_feature_type
|
|
682
|
+
# Construct the output of SeqFeature according to the dimension of field_dict.
|
|
683
|
+
# When the input field exceeds 2 dimensions, convert SeqFeature to 2D output.
|
|
684
|
+
if len(field.get_shape()) < 2:
|
|
685
|
+
parsed_dict[feature_name] = tf.strings.split(field, fc.separator)
|
|
686
|
+
if fc.HasField('seq_multi_sep'):
|
|
687
|
+
indices = parsed_dict[feature_name].indices
|
|
688
|
+
values = parsed_dict[feature_name].values
|
|
689
|
+
multi_vals = tf.string_split(values, fc.seq_multi_sep)
|
|
690
|
+
indices_1 = multi_vals.indices
|
|
691
|
+
indices = tf.gather(indices, indices_1[:, 0])
|
|
692
|
+
out_indices = tf.concat([indices, indices_1[:, 1:]], axis=1)
|
|
693
|
+
# 3 dimensional sparse tensor
|
|
694
|
+
out_shape = tf.concat(
|
|
695
|
+
[parsed_dict[feature_name].dense_shape, multi_vals.dense_shape[1:]],
|
|
696
|
+
axis=0)
|
|
697
|
+
parsed_dict[feature_name] = tf.sparse.SparseTensor(
|
|
698
|
+
out_indices, multi_vals.values, out_shape)
|
|
699
|
+
if (fc.num_buckets > 1 and fc.max_val == fc.min_val):
|
|
700
|
+
check_list = [
|
|
701
|
+
tf.py_func(
|
|
702
|
+
check_string_to_number,
|
|
703
|
+
[parsed_dict[feature_name].values, input_0],
|
|
704
|
+
Tout=tf.bool)
|
|
705
|
+
] if self._check_mode else []
|
|
706
|
+
with tf.control_dependencies(check_list):
|
|
707
|
+
parsed_dict[feature_name] = tf.sparse.SparseTensor(
|
|
708
|
+
parsed_dict[feature_name].indices,
|
|
709
|
+
tf.string_to_number(
|
|
710
|
+
parsed_dict[feature_name].values,
|
|
711
|
+
tf.int64,
|
|
712
|
+
name='sequence_str_2_int_%s' % input_0),
|
|
713
|
+
parsed_dict[feature_name].dense_shape)
|
|
714
|
+
elif sub_feature_type == fc.RawFeature:
|
|
715
|
+
check_list = [
|
|
716
|
+
tf.py_func(
|
|
717
|
+
check_string_to_number,
|
|
718
|
+
[parsed_dict[feature_name].values, input_0],
|
|
719
|
+
Tout=tf.bool)
|
|
720
|
+
] if self._check_mode else []
|
|
721
|
+
with tf.control_dependencies(check_list):
|
|
722
|
+
parsed_dict[feature_name] = tf.sparse.SparseTensor(
|
|
723
|
+
parsed_dict[feature_name].indices,
|
|
724
|
+
tf.string_to_number(
|
|
725
|
+
parsed_dict[feature_name].values,
|
|
726
|
+
tf.float32,
|
|
727
|
+
name='sequence_str_2_float_%s' % input_0),
|
|
728
|
+
parsed_dict[feature_name].dense_shape)
|
|
729
|
+
if fc.num_buckets > 1 and fc.max_val > fc.min_val:
|
|
730
|
+
normalized_values = (parsed_dict[feature_name].values - fc.min_val) / (
|
|
731
|
+
fc.max_val - fc.min_val)
|
|
732
|
+
parsed_dict[feature_name] = tf.sparse.SparseTensor(
|
|
733
|
+
parsed_dict[feature_name].indices, normalized_values,
|
|
734
|
+
parsed_dict[feature_name].dense_shape)
|
|
735
|
+
else:
|
|
736
|
+
parsed_dict[feature_name] = field
|
|
737
|
+
if not fc.boundaries and fc.num_buckets <= 1 and\
|
|
738
|
+
self._data_config.sample_weight != input_0 and\
|
|
739
|
+
sub_feature_type == fc.RawFeature and\
|
|
740
|
+
fc.raw_input_dim == 1:
|
|
741
|
+
logging.info(
|
|
742
|
+
'Not set boundaries or num_buckets or hash_bucket_size, %s will process as two dimension sequence raw feature'
|
|
743
|
+
% feature_name)
|
|
744
|
+
parsed_dict[feature_name] = tf.sparse_to_dense(
|
|
745
|
+
parsed_dict[feature_name].indices,
|
|
746
|
+
[tf.shape(parsed_dict[feature_name])[0], fc.sequence_length],
|
|
747
|
+
parsed_dict[feature_name].values)
|
|
748
|
+
sample_num = tf.to_int64(tf.shape(parsed_dict[feature_name])[0])
|
|
749
|
+
indices_0 = tf.range(sample_num, dtype=tf.int64)
|
|
750
|
+
indices_1 = tf.range(fc.sequence_length, dtype=tf.int64)
|
|
751
|
+
indices_0 = indices_0[:, None]
|
|
752
|
+
indices_1 = indices_1[None, :]
|
|
753
|
+
indices_0 = tf.tile(indices_0, [1, fc.sequence_length])
|
|
754
|
+
indices_1 = tf.tile(indices_1, [sample_num, 1])
|
|
755
|
+
indices_0 = tf.reshape(indices_0, [-1, 1])
|
|
756
|
+
indices_1 = tf.reshape(indices_1, [-1, 1])
|
|
757
|
+
indices = tf.concat([indices_0, indices_1], axis=1)
|
|
758
|
+
tmp_parsed = parsed_dict[feature_name]
|
|
759
|
+
parsed_dict[feature_name + '_raw_proj_id'] = tf.SparseTensor(
|
|
760
|
+
indices=indices,
|
|
761
|
+
values=indices_1[:, 0],
|
|
762
|
+
dense_shape=[sample_num, fc.sequence_length])
|
|
763
|
+
parsed_dict[feature_name + '_raw_proj_val'] = tf.SparseTensor(
|
|
764
|
+
indices=indices,
|
|
765
|
+
values=tf.reshape(tmp_parsed, [-1]),
|
|
766
|
+
dense_shape=[sample_num, fc.sequence_length])
|
|
767
|
+
elif (not fc.boundaries and fc.num_buckets <= 1 and
|
|
768
|
+
self._data_config.sample_weight != input_0 and
|
|
769
|
+
sub_feature_type == fc.RawFeature and fc.raw_input_dim > 1):
|
|
770
|
+
# for 3 dimension sequence feature input.
|
|
771
|
+
logging.info('Not set boundaries or num_buckets or hash_bucket_size,'
|
|
772
|
+
' %s will process as three dimension sequence raw feature' %
|
|
773
|
+
feature_name)
|
|
774
|
+
parsed_dict[feature_name] = tf.sparse_to_dense(
|
|
775
|
+
parsed_dict[feature_name].indices, [
|
|
776
|
+
tf.shape(parsed_dict[feature_name])[0], fc.sequence_length,
|
|
777
|
+
fc.raw_input_dim
|
|
778
|
+
], parsed_dict[feature_name].values)
|
|
779
|
+
sample_num = tf.to_int64(tf.shape(parsed_dict[feature_name])[0])
|
|
780
|
+
indices_0 = tf.range(sample_num, dtype=tf.int64)
|
|
781
|
+
indices_1 = tf.range(fc.sequence_length, dtype=tf.int64)
|
|
782
|
+
indices_2 = tf.range(fc.raw_input_dim, dtype=tf.int64)
|
|
783
|
+
indices_0 = indices_0[:, None, None]
|
|
784
|
+
indices_1 = indices_1[None, :, None]
|
|
785
|
+
indices_2 = indices_2[None, None, :]
|
|
786
|
+
indices_0 = tf.tile(indices_0, [1, fc.sequence_length, fc.raw_input_dim])
|
|
787
|
+
indices_1 = tf.tile(indices_1, [sample_num, 1, fc.raw_input_dim])
|
|
788
|
+
indices_2 = tf.tile(indices_2, [sample_num, fc.sequence_length, 1])
|
|
789
|
+
indices_0 = tf.reshape(indices_0, [-1, 1])
|
|
790
|
+
indices_1 = tf.reshape(indices_1, [-1, 1])
|
|
791
|
+
indices_2 = tf.reshape(indices_2, [-1, 1])
|
|
792
|
+
indices = tf.concat([indices_0, indices_1, indices_2], axis=1)
|
|
793
|
+
|
|
794
|
+
tmp_parsed = parsed_dict[feature_name]
|
|
795
|
+
parsed_dict[feature_name + '_raw_proj_id'] = tf.SparseTensor(
|
|
796
|
+
indices=indices,
|
|
797
|
+
values=indices_1[:, 0],
|
|
798
|
+
dense_shape=[sample_num, fc.sequence_length, fc.raw_input_dim])
|
|
799
|
+
parsed_dict[feature_name + '_raw_proj_val'] = tf.SparseTensor(
|
|
800
|
+
indices=indices,
|
|
801
|
+
values=tf.reshape(parsed_dict[feature_name], [-1]),
|
|
802
|
+
dense_shape=[sample_num, fc.sequence_length, fc.raw_input_dim])
|
|
803
|
+
# self._appended_fields.append(input_0 + '_raw_proj_id')
|
|
804
|
+
# self._appended_fields.append(input_0 + '_raw_proj_val')
|
|
805
|
+
|
|
806
|
+
def _preprocess(self, field_dict):
|
|
807
|
+
"""Preprocess the feature columns.
|
|
808
|
+
|
|
809
|
+
preprocess some feature columns, such as TagFeature or LookupFeature,
|
|
810
|
+
it is expected to handle batch inputs and single input,
|
|
811
|
+
it could be customized in subclasses
|
|
812
|
+
|
|
813
|
+
Args:
|
|
814
|
+
field_dict: string to tensor, tensors are dense,
|
|
815
|
+
could be of shape [batch_size], [batch_size, None], or of shape []
|
|
816
|
+
|
|
817
|
+
Returns:
|
|
818
|
+
output_dict: some of the tensors are transformed into sparse tensors,
|
|
819
|
+
such as input tensors of tag features and lookup features
|
|
820
|
+
"""
|
|
821
|
+
parsed_dict = {}
|
|
822
|
+
|
|
823
|
+
if self._sampler is not None and self._mode != tf.estimator.ModeKeys.PREDICT:
|
|
824
|
+
if self._mode != tf.estimator.ModeKeys.TRAIN:
|
|
825
|
+
self._sampler.set_eval_num_sample()
|
|
826
|
+
sampler_type = self._data_config.WhichOneof('sampler')
|
|
827
|
+
sampler_config = getattr(self._data_config, sampler_type)
|
|
828
|
+
item_ids = field_dict[sampler_config.item_id_field]
|
|
829
|
+
if sampler_type in ['negative_sampler', 'negative_sampler_in_memory']:
|
|
830
|
+
sampled = self._sampler.get(item_ids)
|
|
831
|
+
elif sampler_type == 'negative_sampler_v2':
|
|
832
|
+
user_ids = field_dict[sampler_config.user_id_field]
|
|
833
|
+
sampled = self._sampler.get(user_ids, item_ids)
|
|
834
|
+
elif sampler_type.startswith('hard_negative_sampler'):
|
|
835
|
+
user_ids = field_dict[sampler_config.user_id_field]
|
|
836
|
+
sampled = self._sampler.get(user_ids, item_ids)
|
|
837
|
+
else:
|
|
838
|
+
raise ValueError('Unknown sampler %s' % sampler_type)
|
|
839
|
+
for k, v in sampled.items():
|
|
840
|
+
if k in field_dict:
|
|
841
|
+
field_dict[k] = tf.concat([field_dict[k], v], axis=0)
|
|
842
|
+
else:
|
|
843
|
+
print('appended fields: %s' % k)
|
|
844
|
+
parsed_dict[k] = v
|
|
845
|
+
self._appended_fields.append(k)
|
|
846
|
+
|
|
847
|
+
for fc in self._feature_configs:
|
|
848
|
+
feature_name = fc.feature_name
|
|
849
|
+
feature_type = fc.feature_type
|
|
850
|
+
if feature_type == fc.TagFeature:
|
|
851
|
+
self._parse_tag_feature(fc, parsed_dict, field_dict)
|
|
852
|
+
elif feature_type == fc.LookupFeature:
|
|
853
|
+
assert feature_name is not None and feature_name != ''
|
|
854
|
+
assert len(fc.input_names) == 2
|
|
855
|
+
parsed_dict[feature_name] = self._lookup_preprocess(fc, field_dict)
|
|
856
|
+
elif feature_type == fc.SequenceFeature:
|
|
857
|
+
self._parse_seq_feature(fc, parsed_dict, field_dict)
|
|
858
|
+
elif feature_type == fc.RawFeature:
|
|
859
|
+
self._parse_raw_feature(fc, parsed_dict, field_dict)
|
|
860
|
+
elif feature_type == fc.IdFeature:
|
|
861
|
+
self._parse_id_feature(fc, parsed_dict, field_dict)
|
|
862
|
+
elif feature_type == fc.ExprFeature:
|
|
863
|
+
self._parse_expr_feature(fc, parsed_dict, field_dict)
|
|
864
|
+
elif feature_type == fc.ComboFeature:
|
|
865
|
+
self._parse_combo_feature(fc, parsed_dict, field_dict)
|
|
866
|
+
else:
|
|
867
|
+
feature_name = fc.feature_name if fc.HasField(
|
|
868
|
+
'feature_name') else fc.input_names[0]
|
|
869
|
+
for input_id, input_name in enumerate(fc.input_names):
|
|
870
|
+
if input_id > 0:
|
|
871
|
+
key = feature_name + '_' + str(input_id)
|
|
872
|
+
else:
|
|
873
|
+
key = feature_name
|
|
874
|
+
parsed_dict[key] = field_dict[input_name]
|
|
875
|
+
|
|
876
|
+
label_dict = {}
|
|
877
|
+
for input_id, input_name in enumerate(self._label_fields):
|
|
878
|
+
if input_name not in field_dict:
|
|
879
|
+
continue
|
|
880
|
+
if input_name in self._label_udf_map:
|
|
881
|
+
udf, udf_class, dtype = self._label_udf_map[input_name]
|
|
882
|
+
if dtype is None or dtype == '':
|
|
883
|
+
logging.info('apply tensorflow function transform: %s' % udf_class)
|
|
884
|
+
field_dict[input_name] = udf(field_dict[input_name])
|
|
885
|
+
else:
|
|
886
|
+
assert dtype is not None, 'must set user_define_fn_res_type'
|
|
887
|
+
logging.info('apply py_func transform: %s' % udf_class)
|
|
888
|
+
field_dict[input_name] = tf.py_func(
|
|
889
|
+
udf, [field_dict[input_name]], Tout=get_tf_type(dtype))
|
|
890
|
+
field_dict[input_name].set_shape(tf.TensorShape([None]))
|
|
891
|
+
|
|
892
|
+
if field_dict[input_name].dtype == tf.string:
|
|
893
|
+
if self._label_dim[input_id] > 1:
|
|
894
|
+
logging.info('will split labels[%d]=%s' % (input_id, input_name))
|
|
895
|
+
check_list = [
|
|
896
|
+
tf.py_func(
|
|
897
|
+
check_split, [
|
|
898
|
+
field_dict[input_name], self._label_sep[input_id],
|
|
899
|
+
self._label_dim[input_id], input_name
|
|
900
|
+
],
|
|
901
|
+
Tout=tf.bool)
|
|
902
|
+
] if self._check_mode else []
|
|
903
|
+
with tf.control_dependencies(check_list):
|
|
904
|
+
label_dict[input_name] = tf.string_split(
|
|
905
|
+
field_dict[input_name], self._label_sep[input_id]).values
|
|
906
|
+
label_dict[input_name] = tf.reshape(label_dict[input_name],
|
|
907
|
+
[-1, self._label_dim[input_id]])
|
|
908
|
+
else:
|
|
909
|
+
label_dict[input_name] = field_dict[input_name]
|
|
910
|
+
check_list = [
|
|
911
|
+
tf.py_func(
|
|
912
|
+
check_string_to_number, [label_dict[input_name], input_name],
|
|
913
|
+
Tout=tf.bool)
|
|
914
|
+
] if self._check_mode else []
|
|
915
|
+
with tf.control_dependencies(check_list):
|
|
916
|
+
label_dict[input_name] = tf.string_to_number(
|
|
917
|
+
label_dict[input_name], tf.float32, name=input_name)
|
|
918
|
+
else:
|
|
919
|
+
assert field_dict[input_name].dtype in [
|
|
920
|
+
tf.float32, tf.double, tf.int32, tf.int64
|
|
921
|
+
], 'invalid label dtype: %s' % str(field_dict[input_name].dtype)
|
|
922
|
+
label_dict[input_name] = field_dict[input_name]
|
|
923
|
+
|
|
924
|
+
if self._mode != tf.estimator.ModeKeys.PREDICT:
|
|
925
|
+
for func_config in self._data_config.extra_label_func:
|
|
926
|
+
lbl_name = func_config.label_name
|
|
927
|
+
func_name = func_config.label_func
|
|
928
|
+
logging.info('generating new label `%s` by transform: %s' %
|
|
929
|
+
(lbl_name, func_name))
|
|
930
|
+
lbl_fn = load_by_path(func_name)
|
|
931
|
+
label_dict[lbl_name] = lbl_fn(label_dict)
|
|
932
|
+
|
|
933
|
+
if self._data_config.HasField('sample_weight'):
|
|
934
|
+
parsed_dict[constant.SAMPLE_WEIGHT] = field_dict[
|
|
935
|
+
self._data_config.sample_weight]
|
|
936
|
+
|
|
937
|
+
if Input.DATA_OFFSET in field_dict:
|
|
938
|
+
parsed_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]
|
|
939
|
+
return {'feature': parsed_dict, 'label': label_dict}
|
|
940
|
+
|
|
941
|
+
def _lookup_preprocess(self, fc, field_dict):
|
|
942
|
+
"""Preprocess function for lookup features.
|
|
943
|
+
|
|
944
|
+
Args:
|
|
945
|
+
fc: FeatureConfig
|
|
946
|
+
field_dict: input dict
|
|
947
|
+
|
|
948
|
+
Returns:
|
|
949
|
+
output_dict: add { feature_name:SparseTensor} with
|
|
950
|
+
other items similar as field_dict
|
|
951
|
+
"""
|
|
952
|
+
max_sel_num = fc.lookup_max_sel_elem_num
|
|
953
|
+
|
|
954
|
+
def _lookup(args, pad=True):
|
|
955
|
+
one_key, one_map = args[0], args[1]
|
|
956
|
+
if len(one_map.get_shape()) == 0:
|
|
957
|
+
one_map = tf.expand_dims(one_map, axis=0)
|
|
958
|
+
kv_map = tf.string_split(one_map, fc.separator).values
|
|
959
|
+
kvs = tf.string_split(kv_map, fc.kv_separator)
|
|
960
|
+
kvs = tf.reshape(kvs.values, [-1, 2], name='kv_split_reshape')
|
|
961
|
+
keys, vals = kvs[:, 0], kvs[:, 1]
|
|
962
|
+
sel_ids = tf.where(tf.equal(keys, one_key))
|
|
963
|
+
sel_ids = tf.squeeze(sel_ids, axis=1)
|
|
964
|
+
sel_vals = tf.gather(vals, sel_ids)
|
|
965
|
+
if not pad:
|
|
966
|
+
return sel_vals
|
|
967
|
+
n = tf.shape(sel_vals)[0]
|
|
968
|
+
sel_vals = tf.pad(sel_vals, [[0, max_sel_num - n]])
|
|
969
|
+
len_msk = tf.sequence_mask(n, max_sel_num)
|
|
970
|
+
indices = tf.range(max_sel_num, dtype=tf.int64)
|
|
971
|
+
indices = indices * tf.to_int64(indices < tf.to_int64(n))
|
|
972
|
+
return sel_vals, len_msk, indices
|
|
973
|
+
|
|
974
|
+
key_field, map_field = fc.input_names[0], fc.input_names[1]
|
|
975
|
+
key_fields, map_fields = field_dict[key_field], field_dict[map_field]
|
|
976
|
+
if len(key_fields.get_shape()) == 0:
|
|
977
|
+
vals = _lookup((key_fields, map_fields), False)
|
|
978
|
+
n = tf.shape(vals)[0]
|
|
979
|
+
n = tf.to_int64(n)
|
|
980
|
+
indices_0 = tf.zeros([n], dtype=tf.int64)
|
|
981
|
+
indices_1 = tf.range(0, n, dtype=tf.int64)
|
|
982
|
+
indices = [
|
|
983
|
+
tf.expand_dims(indices_0, axis=1),
|
|
984
|
+
tf.expand_dims(indices_1, axis=1)
|
|
985
|
+
]
|
|
986
|
+
indices = tf.concat(indices, axis=1)
|
|
987
|
+
return tf.sparse.SparseTensor(indices, vals, [1, n])
|
|
988
|
+
|
|
989
|
+
vals, masks, indices = tf.map_fn(
|
|
990
|
+
_lookup, [key_fields, map_fields], dtype=(tf.string, tf.bool, tf.int64))
|
|
991
|
+
batch_size = tf.to_int64(tf.shape(vals)[0])
|
|
992
|
+
vals = tf.boolean_mask(vals, masks)
|
|
993
|
+
indices_1 = tf.boolean_mask(indices, masks)
|
|
994
|
+
indices_0 = tf.range(0, batch_size, dtype=tf.int64)
|
|
995
|
+
indices_0 = tf.expand_dims(indices_0, axis=1)
|
|
996
|
+
indices_0 = indices_0 + tf.zeros([1, max_sel_num], dtype=tf.int64)
|
|
997
|
+
indices_0 = tf.boolean_mask(indices_0, masks)
|
|
998
|
+
indices = tf.concat(
|
|
999
|
+
[tf.expand_dims(indices_0, axis=1),
|
|
1000
|
+
tf.expand_dims(indices_1, axis=1)],
|
|
1001
|
+
axis=1)
|
|
1002
|
+
shapes = tf.stack([batch_size, tf.reduce_max(indices_1) + 1])
|
|
1003
|
+
return tf.sparse.SparseTensor(indices, vals, shapes)
|
|
1004
|
+
|
|
1005
|
+
@abstractmethod
|
|
1006
|
+
def _build(self, mode, params):
|
|
1007
|
+
raise NotImplementedError
|
|
1008
|
+
|
|
1009
|
+
def _pre_build(self, mode, params):
|
|
1010
|
+
pass
|
|
1011
|
+
|
|
1012
|
+
def restore(self, checkpoint_path):
|
|
1013
|
+
pass
|
|
1014
|
+
|
|
1015
|
+
def stop(self):
|
|
1016
|
+
pass
|
|
1017
|
+
|
|
1018
|
+
def _safe_shard(self, dataset):
|
|
1019
|
+
if self._data_config.chief_redundant:
|
|
1020
|
+
return dataset.shard(
|
|
1021
|
+
max(self._task_num - 1, 1), max(self._task_index - 1, 0))
|
|
1022
|
+
else:
|
|
1023
|
+
return dataset.shard(self._task_num, self._task_index)
|
|
1024
|
+
|
|
1025
|
+
def create_input(self, export_config=None):
|
|
1026
|
+
|
|
1027
|
+
def _input_fn(mode=None, params=None, config=None):
|
|
1028
|
+
"""Build input_fn for estimator.
|
|
1029
|
+
|
|
1030
|
+
Args:
|
|
1031
|
+
mode: tf.estimator.ModeKeys.(TRAIN, EVAL, PREDICT)
|
|
1032
|
+
params: `dict` of hyper parameters, from Estimator
|
|
1033
|
+
config: tf.estimator.RunConfig instance
|
|
1034
|
+
|
|
1035
|
+
Return:
|
|
1036
|
+
if mode is not None, return:
|
|
1037
|
+
features: inputs to the model.
|
|
1038
|
+
labels: groundtruth
|
|
1039
|
+
else, return:
|
|
1040
|
+
tf.estimator.export.ServingInputReceiver instance
|
|
1041
|
+
"""
|
|
1042
|
+
self._pre_build(mode, params)
|
|
1043
|
+
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
|
|
1044
|
+
tf.estimator.ModeKeys.PREDICT):
|
|
1045
|
+
# build dataset from self._config.input_path
|
|
1046
|
+
self._mode = mode
|
|
1047
|
+
dataset = self._build(mode, params)
|
|
1048
|
+
return dataset
|
|
1049
|
+
elif mode is None: # serving_input_receiver_fn for export SavedModel
|
|
1050
|
+
place_on_cpu = os.getenv(constant.EmbeddingOnCPU)
|
|
1051
|
+
place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
|
|
1052
|
+
if export_config.multi_placeholder:
|
|
1053
|
+
with conditional(place_on_cpu, ops.device('/CPU:0')):
|
|
1054
|
+
inputs, features = self.create_multi_placeholders(export_config)
|
|
1055
|
+
return tf.estimator.export.ServingInputReceiver(features, inputs)
|
|
1056
|
+
else:
|
|
1057
|
+
with conditional(place_on_cpu, ops.device('/CPU:0')):
|
|
1058
|
+
inputs, features = self.create_placeholders(export_config)
|
|
1059
|
+
print('built feature placeholders. features: {}'.format(
|
|
1060
|
+
features.keys()))
|
|
1061
|
+
return tf.estimator.export.ServingInputReceiver(features, inputs)
|
|
1062
|
+
|
|
1063
|
+
_input_fn.input_creator = self
|
|
1064
|
+
return _input_fn
|