easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,397 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
import multiprocessing
|
|
5
|
+
import queue
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
import tensorflow as tf
|
|
9
|
+
from tensorflow.python.ops import array_ops
|
|
10
|
+
|
|
11
|
+
from easy_rec.python.compat import queues
|
|
12
|
+
from easy_rec.python.input import load_parquet
|
|
13
|
+
from easy_rec.python.input.input import Input
|
|
14
|
+
|
|
15
|
+
if tf.__version__ >= '2.0':
|
|
16
|
+
tf = tf.compat.v1
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ParquetInput(Input):
|
|
20
|
+
|
|
21
|
+
def __init__(self,
|
|
22
|
+
data_config,
|
|
23
|
+
feature_config,
|
|
24
|
+
input_path,
|
|
25
|
+
task_index=0,
|
|
26
|
+
task_num=1,
|
|
27
|
+
check_mode=False,
|
|
28
|
+
pipeline_config=None,
|
|
29
|
+
**kwargs):
|
|
30
|
+
super(ParquetInput,
|
|
31
|
+
self).__init__(data_config, feature_config, input_path, task_index,
|
|
32
|
+
task_num, check_mode, pipeline_config, **kwargs)
|
|
33
|
+
self._need_pack = True
|
|
34
|
+
if input_path is None:
|
|
35
|
+
return
|
|
36
|
+
|
|
37
|
+
self._input_files = []
|
|
38
|
+
for sub_path in input_path.strip().split(','):
|
|
39
|
+
self._input_files.extend(tf.gfile.Glob(sub_path))
|
|
40
|
+
logging.info('parquet input_path=%s file_num=%d' %
|
|
41
|
+
(input_path, len(self._input_files)))
|
|
42
|
+
mp_ctxt = multiprocessing.get_context('spawn')
|
|
43
|
+
self._data_que = queues.Queue(
|
|
44
|
+
name='data_que', ctx=mp_ctxt, maxsize=self._data_config.prefetch_size)
|
|
45
|
+
|
|
46
|
+
file_num = len(self._input_files)
|
|
47
|
+
logging.info('[task_index=%d] total_file_num=%d task_num=%d' %
|
|
48
|
+
(task_index, file_num, task_num))
|
|
49
|
+
|
|
50
|
+
self._my_files = []
|
|
51
|
+
for file_id in range(file_num):
|
|
52
|
+
if (file_id % task_num) == task_index:
|
|
53
|
+
self._my_files.append(self._input_files[file_id])
|
|
54
|
+
# self._my_files = self._input_files
|
|
55
|
+
|
|
56
|
+
logging.info('[task_index=%d] task_file_num=%d' %
|
|
57
|
+
(task_index, len(self._my_files)))
|
|
58
|
+
self._file_que = queues.Queue(name='file_que', ctx=mp_ctxt)
|
|
59
|
+
|
|
60
|
+
self._num_proc = 8
|
|
61
|
+
if file_num < self._num_proc:
|
|
62
|
+
self._num_proc = file_num
|
|
63
|
+
|
|
64
|
+
self._proc_start = False
|
|
65
|
+
self._proc_start_que = queues.Queue(name='proc_start_que', ctx=mp_ctxt)
|
|
66
|
+
self._proc_stop = False
|
|
67
|
+
self._proc_stop_que = queues.Queue(name='proc_stop_que', ctx=mp_ctxt)
|
|
68
|
+
|
|
69
|
+
self._reserve_fields = None
|
|
70
|
+
self._reserve_types = None
|
|
71
|
+
if 'reserve_fields' in kwargs and 'reserve_types' in kwargs:
|
|
72
|
+
self._reserve_fields = kwargs['reserve_fields']
|
|
73
|
+
self._reserve_types = kwargs['reserve_types']
|
|
74
|
+
|
|
75
|
+
# indicator whether is called from Predictor, do not go pass
|
|
76
|
+
if 'is_predictor' in kwargs:
|
|
77
|
+
self._is_predictor = kwargs['is_predictor']
|
|
78
|
+
else:
|
|
79
|
+
self._is_predictor = False
|
|
80
|
+
|
|
81
|
+
self._proc_arr = None
|
|
82
|
+
|
|
83
|
+
self._sparse_fea_names = []
|
|
84
|
+
self._dense_fea_names = []
|
|
85
|
+
self._dense_fea_cfgs = []
|
|
86
|
+
self._total_dense_fea_dim = 0
|
|
87
|
+
for fc in self._feature_configs:
|
|
88
|
+
feature_type = fc.feature_type
|
|
89
|
+
if feature_type in [fc.IdFeature, fc.TagFeature]:
|
|
90
|
+
input_name0 = fc.input_names[0]
|
|
91
|
+
self._sparse_fea_names.append(input_name0)
|
|
92
|
+
elif feature_type in [fc.RawFeature]:
|
|
93
|
+
input_name0 = fc.input_names[0]
|
|
94
|
+
self._dense_fea_names.append(input_name0)
|
|
95
|
+
self._dense_fea_cfgs.append(fc)
|
|
96
|
+
self._total_dense_fea_dim += fc.raw_input_dim
|
|
97
|
+
else:
|
|
98
|
+
assert False, 'feature_type[%s] not supported' % str(feature_type)
|
|
99
|
+
|
|
100
|
+
def _rebuild_que(self):
|
|
101
|
+
mp_ctxt = multiprocessing.get_context('spawn')
|
|
102
|
+
self._data_que = queues.Queue(
|
|
103
|
+
name='data_que', ctx=mp_ctxt, maxsize=self._data_config.prefetch_size)
|
|
104
|
+
self._file_que = queues.Queue(name='file_que', ctx=mp_ctxt)
|
|
105
|
+
self._proc_start_que = queues.Queue(name='proc_start_que', ctx=mp_ctxt)
|
|
106
|
+
self._proc_stop_que = queues.Queue(name='proc_stop_que', ctx=mp_ctxt)
|
|
107
|
+
|
|
108
|
+
def _sample_generator(self):
|
|
109
|
+
if not self._proc_start:
|
|
110
|
+
self._proc_start = True
|
|
111
|
+
for proc in (self._proc_arr):
|
|
112
|
+
self._proc_start_que.put(True)
|
|
113
|
+
logging.info('task[%s] data_proc=%s is_alive=%s' %
|
|
114
|
+
(self._task_index, proc, proc.is_alive()))
|
|
115
|
+
|
|
116
|
+
done_proc_cnt = 0
|
|
117
|
+
fetch_timeout_cnt = 0
|
|
118
|
+
|
|
119
|
+
# # for mock purpose
|
|
120
|
+
# all_samples = []
|
|
121
|
+
# while len(all_samples) < 64:
|
|
122
|
+
# try:
|
|
123
|
+
# sample = self._data_que.get(block=False)
|
|
124
|
+
# all_samples.append(sample)
|
|
125
|
+
# except queue.Empty:
|
|
126
|
+
# continue
|
|
127
|
+
# sid = 0
|
|
128
|
+
# while True:
|
|
129
|
+
# yield all_samples[sid]
|
|
130
|
+
# sid += 1
|
|
131
|
+
# if sid >= len(all_samples):
|
|
132
|
+
# sid = 0
|
|
133
|
+
|
|
134
|
+
fetch_good_cnt = 0
|
|
135
|
+
while True:
|
|
136
|
+
try:
|
|
137
|
+
sample = self._data_que.get(timeout=1)
|
|
138
|
+
if sample is None:
|
|
139
|
+
done_proc_cnt += 1
|
|
140
|
+
else:
|
|
141
|
+
fetch_good_cnt += 1
|
|
142
|
+
yield sample
|
|
143
|
+
if fetch_good_cnt % 200 == 0:
|
|
144
|
+
logging.info(
|
|
145
|
+
'task[%d] fetch_batch_cnt=%d, fetch_timeout_cnt=%d, qsize=%d' %
|
|
146
|
+
(self._task_index, fetch_good_cnt, fetch_timeout_cnt,
|
|
147
|
+
self._data_que.qsize()))
|
|
148
|
+
except queue.Empty:
|
|
149
|
+
fetch_timeout_cnt += 1
|
|
150
|
+
if done_proc_cnt >= len(self._proc_arr):
|
|
151
|
+
logging.info('all sample finished, fetch_timeout_cnt=%d' %
|
|
152
|
+
fetch_timeout_cnt)
|
|
153
|
+
break
|
|
154
|
+
except Exception as ex:
|
|
155
|
+
logging.warning('task[%d] get from data_que exception: %s' %
|
|
156
|
+
(self._task_index, str(ex)))
|
|
157
|
+
break
|
|
158
|
+
logging.info('task[%d] sample_generator: total_batches=%d' %
|
|
159
|
+
(self._task_index, fetch_good_cnt))
|
|
160
|
+
|
|
161
|
+
def stop(self):
|
|
162
|
+
if self._proc_arr is None or len(self._proc_arr) == 0:
|
|
163
|
+
return
|
|
164
|
+
logging.info('task[%d] will stop dataset procs, proc_num=%d' %
|
|
165
|
+
(self._task_index, len(self._proc_arr)))
|
|
166
|
+
self._file_que.close()
|
|
167
|
+
if self._proc_start:
|
|
168
|
+
logging.info('try close data que')
|
|
169
|
+
for _ in range(len(self._proc_arr)):
|
|
170
|
+
self._proc_stop_que.put(1)
|
|
171
|
+
self._proc_stop_que.close()
|
|
172
|
+
|
|
173
|
+
def _any_alive():
|
|
174
|
+
for proc in self._proc_arr:
|
|
175
|
+
if proc.is_alive():
|
|
176
|
+
return True
|
|
177
|
+
return False
|
|
178
|
+
|
|
179
|
+
# to ensure the sender part of the python Queue could exit
|
|
180
|
+
while _any_alive():
|
|
181
|
+
try:
|
|
182
|
+
self._data_que.get(timeout=1)
|
|
183
|
+
except Exception:
|
|
184
|
+
pass
|
|
185
|
+
time.sleep(1)
|
|
186
|
+
self._data_que.close()
|
|
187
|
+
logging.info('data que closed')
|
|
188
|
+
# import time
|
|
189
|
+
# time.sleep(10)
|
|
190
|
+
for proc in self._proc_arr:
|
|
191
|
+
# proc.terminate()
|
|
192
|
+
proc.join()
|
|
193
|
+
logging.info('join proc done')
|
|
194
|
+
|
|
195
|
+
# rebuild for next run, which is necessary for evaluation
|
|
196
|
+
self._rebuild_que()
|
|
197
|
+
self._proc_arr = None
|
|
198
|
+
self._proc_start = False
|
|
199
|
+
self._proc_stop = False
|
|
200
|
+
|
|
201
|
+
def _to_fea_dict(self, input_dict):
|
|
202
|
+
fea_dict = {}
|
|
203
|
+
|
|
204
|
+
if len(self._sparse_fea_names) > 0:
|
|
205
|
+
if self._has_ev:
|
|
206
|
+
tmp_vals, tmp_lens = input_dict['sparse_fea'][1], input_dict[
|
|
207
|
+
'sparse_fea'][0]
|
|
208
|
+
|
|
209
|
+
fea_dict['sparse_fea'] = (tmp_vals, tmp_lens)
|
|
210
|
+
else:
|
|
211
|
+
tmp_vals, tmp_lens = input_dict['sparse_fea'][1], input_dict[
|
|
212
|
+
'sparse_fea'][0]
|
|
213
|
+
num_buckets = -1
|
|
214
|
+
for fc in self._feature_configs:
|
|
215
|
+
if fc.num_buckets > 0:
|
|
216
|
+
if num_buckets < 0:
|
|
217
|
+
num_buckets = fc.num_buckets
|
|
218
|
+
else:
|
|
219
|
+
assert num_buckets == fc.num_buckets, 'all features must share the same buckets, but are %d and %s' % (
|
|
220
|
+
num_buckets, str(fc))
|
|
221
|
+
fea_dict['sparse_fea'] = (tmp_vals % num_buckets, tmp_lens)
|
|
222
|
+
|
|
223
|
+
if len(self._dense_fea_names) > 0:
|
|
224
|
+
fea_dict['dense_fea'] = input_dict['dense_fea']
|
|
225
|
+
|
|
226
|
+
output_dict = {'feature': fea_dict}
|
|
227
|
+
|
|
228
|
+
lbl_dict = {}
|
|
229
|
+
for lbl_name in self._label_fields:
|
|
230
|
+
if lbl_name in input_dict:
|
|
231
|
+
lbl_dict[lbl_name] = input_dict[lbl_name]
|
|
232
|
+
|
|
233
|
+
if len(lbl_dict) > 0:
|
|
234
|
+
output_dict['label'] = lbl_dict
|
|
235
|
+
|
|
236
|
+
if self._reserve_fields is not None:
|
|
237
|
+
output_dict['reserve'] = input_dict['reserve']
|
|
238
|
+
|
|
239
|
+
return output_dict
|
|
240
|
+
|
|
241
|
+
def add_fea_type_and_shape(self, out_types, out_shapes):
|
|
242
|
+
# all features are packed into one tuple sparse_fea
|
|
243
|
+
# first field: field lengths
|
|
244
|
+
# second field: field values
|
|
245
|
+
if len(self._sparse_fea_names) > 0:
|
|
246
|
+
out_types['sparse_fea'] = (tf.int32, tf.int64)
|
|
247
|
+
out_shapes['sparse_fea'] = (tf.TensorShape([None]), tf.TensorShape([None
|
|
248
|
+
]))
|
|
249
|
+
if len(self._dense_fea_names) > 0:
|
|
250
|
+
out_types['dense_fea'] = tf.float32
|
|
251
|
+
out_shapes['dense_fea'] = tf.TensorShape(
|
|
252
|
+
[None, self._total_dense_fea_dim])
|
|
253
|
+
|
|
254
|
+
def _build(self, mode, params):
|
|
255
|
+
if mode == tf.estimator.ModeKeys.TRAIN and self._data_config.num_epochs > 1:
|
|
256
|
+
logging.info('will repeat train data for %d epochs' %
|
|
257
|
+
self._data_config.num_epochs)
|
|
258
|
+
my_files = self._my_files * self._data_config.num_epochs
|
|
259
|
+
else:
|
|
260
|
+
my_files = self._my_files
|
|
261
|
+
|
|
262
|
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
|
263
|
+
drop_remainder = self._data_config.drop_remainder
|
|
264
|
+
lbl_fields = self._label_fields
|
|
265
|
+
else:
|
|
266
|
+
lbl_fields = self._label_fields
|
|
267
|
+
if mode == tf.estimator.ModeKeys.PREDICT:
|
|
268
|
+
lbl_fields = None
|
|
269
|
+
drop_remainder = False
|
|
270
|
+
self._proc_arr = load_parquet.start_data_proc(
|
|
271
|
+
self._task_index,
|
|
272
|
+
self._task_num,
|
|
273
|
+
self._num_proc,
|
|
274
|
+
self._file_que,
|
|
275
|
+
self._data_que,
|
|
276
|
+
self._proc_start_que,
|
|
277
|
+
self._proc_stop_que,
|
|
278
|
+
self._batch_size,
|
|
279
|
+
lbl_fields,
|
|
280
|
+
# self._effective_fields,
|
|
281
|
+
self._sparse_fea_names,
|
|
282
|
+
self._dense_fea_names,
|
|
283
|
+
self._dense_fea_cfgs,
|
|
284
|
+
self._reserve_fields,
|
|
285
|
+
drop_remainder,
|
|
286
|
+
need_pack=self._need_pack)
|
|
287
|
+
|
|
288
|
+
for input_file in my_files:
|
|
289
|
+
self._file_que.put(input_file)
|
|
290
|
+
|
|
291
|
+
# add end signal
|
|
292
|
+
for proc in self._proc_arr:
|
|
293
|
+
self._file_que.put(None)
|
|
294
|
+
logging.info('add input_files to file_que, qsize=%d' %
|
|
295
|
+
self._file_que.qsize())
|
|
296
|
+
|
|
297
|
+
out_types = {}
|
|
298
|
+
out_shapes = {}
|
|
299
|
+
|
|
300
|
+
if mode != tf.estimator.ModeKeys.PREDICT:
|
|
301
|
+
for k in self._label_fields:
|
|
302
|
+
out_types[k] = tf.int32
|
|
303
|
+
out_shapes[k] = tf.TensorShape([None])
|
|
304
|
+
|
|
305
|
+
if self._reserve_fields is not None:
|
|
306
|
+
out_types['reserve'] = {}
|
|
307
|
+
out_shapes['reserve'] = {}
|
|
308
|
+
for k, t in zip(self._reserve_fields, self._reserve_types):
|
|
309
|
+
out_types['reserve'][k] = t
|
|
310
|
+
out_shapes['reserve'][k] = tf.TensorShape([None])
|
|
311
|
+
|
|
312
|
+
self.add_fea_type_and_shape(out_types, out_shapes)
|
|
313
|
+
|
|
314
|
+
dataset = tf.data.Dataset.from_generator(
|
|
315
|
+
self._sample_generator,
|
|
316
|
+
output_types=out_types,
|
|
317
|
+
output_shapes=out_shapes)
|
|
318
|
+
num_parallel_calls = self._data_config.num_parallel_calls
|
|
319
|
+
dataset = dataset.map(
|
|
320
|
+
self._to_fea_dict, num_parallel_calls=num_parallel_calls)
|
|
321
|
+
dataset = dataset.prefetch(buffer_size=self._prefetch_size)
|
|
322
|
+
|
|
323
|
+
# Note: Input._preprocess is currently not supported as all features
|
|
324
|
+
# are concatenated together
|
|
325
|
+
# dataset = dataset.map(
|
|
326
|
+
# map_func=self._preprocess, num_parallel_calls=num_parallel_calls)
|
|
327
|
+
|
|
328
|
+
if mode != tf.estimator.ModeKeys.PREDICT:
|
|
329
|
+
dataset = dataset.map(lambda x:
|
|
330
|
+
(self._get_features(x), self._get_labels(x)))
|
|
331
|
+
# initial test show that prefetch to gpu has no performance gain
|
|
332
|
+
# dataset = dataset.apply(tf.data.experimental.prefetch_to_device('/gpu:0'))
|
|
333
|
+
else:
|
|
334
|
+
if self._is_predictor:
|
|
335
|
+
dataset = dataset.map(self._get_for_predictor)
|
|
336
|
+
else:
|
|
337
|
+
dataset = dataset.map(lambda x: self._get_features(x))
|
|
338
|
+
dataset = dataset.prefetch(buffer_size=self._prefetch_size)
|
|
339
|
+
return dataset
|
|
340
|
+
|
|
341
|
+
def _get_for_predictor(self, fea_dict):
|
|
342
|
+
out_dict = {
|
|
343
|
+
'feature': {
|
|
344
|
+
'ragged_ids': fea_dict['feature']['sparse_fea'][0],
|
|
345
|
+
'ragged_lens': fea_dict['feature']['sparse_fea'][1]
|
|
346
|
+
}
|
|
347
|
+
}
|
|
348
|
+
if self._is_predictor and self._reserve_fields is not None:
|
|
349
|
+
out_dict['reserve'] = fea_dict['reserve']
|
|
350
|
+
return out_dict
|
|
351
|
+
|
|
352
|
+
def create_input(self, export_config=None):
|
|
353
|
+
|
|
354
|
+
def _input_fn(mode=None, params=None, config=None):
|
|
355
|
+
"""Build input_fn for estimator.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
mode: tf.estimator.ModeKeys.(TRAIN, EVAL, PREDICT)
|
|
359
|
+
params: `dict` of hyper parameters, from Estimator
|
|
360
|
+
config: tf.estimator.RunConfig instance
|
|
361
|
+
|
|
362
|
+
Return:
|
|
363
|
+
if mode is not None, return:
|
|
364
|
+
features: inputs to the model.
|
|
365
|
+
labels: groundtruth
|
|
366
|
+
else, return:
|
|
367
|
+
tf.estimator.export.ServingInputReceiver instance
|
|
368
|
+
"""
|
|
369
|
+
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
|
|
370
|
+
tf.estimator.ModeKeys.PREDICT):
|
|
371
|
+
# build dataset from self._config.input_path
|
|
372
|
+
self._mode = mode
|
|
373
|
+
dataset = self._build(mode, params)
|
|
374
|
+
return dataset
|
|
375
|
+
elif mode is None: # serving_input_receiver_fn for export SavedModel
|
|
376
|
+
inputs, features = {}, {}
|
|
377
|
+
if len(self._sparse_fea_names) > 0:
|
|
378
|
+
ragged_ids = array_ops.placeholder(
|
|
379
|
+
tf.int64, [None], name='ragged_ids')
|
|
380
|
+
ragged_lens = array_ops.placeholder(
|
|
381
|
+
tf.int32, [None], name='ragged_lens')
|
|
382
|
+
inputs = {'ragged_ids': ragged_ids, 'ragged_lens': ragged_lens}
|
|
383
|
+
if self._has_ev:
|
|
384
|
+
features = {'ragged_ids': ragged_ids, 'ragged_lens': ragged_lens}
|
|
385
|
+
else:
|
|
386
|
+
features = {
|
|
387
|
+
'ragged_ids': ragged_ids % self._feature_configs[0].num_buckets,
|
|
388
|
+
'ragged_lens': ragged_lens
|
|
389
|
+
}
|
|
390
|
+
if len(self._dense_fea_names) > 0:
|
|
391
|
+
inputs['dense_fea'] = array_ops.placeholder(
|
|
392
|
+
tf.float32, [None, self._total_dense_fea_dim], name='dense_fea')
|
|
393
|
+
features['dense_fea'] = inputs['dense_fea']
|
|
394
|
+
return tf.estimator.export.ServingInputReceiver(features, inputs)
|
|
395
|
+
|
|
396
|
+
_input_fn.input_creator = self
|
|
397
|
+
return _input_fn
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
# import logging
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
# import numpy as np
|
|
7
|
+
# import pandas as pd
|
|
8
|
+
import tensorflow as tf
|
|
9
|
+
from tensorflow.python.framework import dtypes
|
|
10
|
+
from tensorflow.python.framework import ops
|
|
11
|
+
# from tensorflow.python.ops import math_ops
|
|
12
|
+
# from tensorflow.python.ops import logging_ops
|
|
13
|
+
from tensorflow.python.ops import array_ops
|
|
14
|
+
from tensorflow.python.ops import string_ops
|
|
15
|
+
|
|
16
|
+
from easy_rec.python.input.parquet_input import ParquetInput
|
|
17
|
+
from easy_rec.python.utils import conditional
|
|
18
|
+
|
|
19
|
+
# from easy_rec.python.utils.tf_utils import get_tf_type
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ParquetInputV2(ParquetInput):
|
|
23
|
+
|
|
24
|
+
def __init__(self,
|
|
25
|
+
data_config,
|
|
26
|
+
feature_config,
|
|
27
|
+
input_path,
|
|
28
|
+
task_index=0,
|
|
29
|
+
task_num=1,
|
|
30
|
+
check_mode=False,
|
|
31
|
+
pipeline_config=None,
|
|
32
|
+
**kwargs):
|
|
33
|
+
super(ParquetInputV2,
|
|
34
|
+
self).__init__(data_config, feature_config, input_path, task_index,
|
|
35
|
+
task_num, check_mode, pipeline_config, **kwargs)
|
|
36
|
+
self._need_pack = False
|
|
37
|
+
|
|
38
|
+
def _predictor_preprocess(self, input_dict):
|
|
39
|
+
# when the ParquetInputV2 is build from ParquetPredictorV2
|
|
40
|
+
# the feature preprocess stage will be skipped.
|
|
41
|
+
fea_dict = {}
|
|
42
|
+
for k in input_dict:
|
|
43
|
+
vals = input_dict[k]
|
|
44
|
+
if isinstance(vals, tuple) and len(vals) == 2 and k != 'reserve':
|
|
45
|
+
fea_dict[k + '/lens'] = vals[0]
|
|
46
|
+
fea_dict[k + '/ids'] = vals[1]
|
|
47
|
+
else:
|
|
48
|
+
fea_dict[k] = vals
|
|
49
|
+
return fea_dict
|
|
50
|
+
|
|
51
|
+
def _to_fea_dict(self, input_dict):
|
|
52
|
+
if self._is_predictor:
|
|
53
|
+
fea_dict = self._predictor_preprocess(input_dict)
|
|
54
|
+
else:
|
|
55
|
+
fea_dict = self._preprocess(input_dict)
|
|
56
|
+
|
|
57
|
+
output_dict = {'feature': fea_dict}
|
|
58
|
+
|
|
59
|
+
lbl_dict = {}
|
|
60
|
+
for lbl_name in self._label_fields:
|
|
61
|
+
if lbl_name in input_dict:
|
|
62
|
+
lbl_dict[lbl_name] = input_dict[lbl_name]
|
|
63
|
+
|
|
64
|
+
if len(lbl_dict) > 0:
|
|
65
|
+
output_dict['label'] = lbl_dict
|
|
66
|
+
|
|
67
|
+
if self._reserve_fields is not None:
|
|
68
|
+
output_dict['reserve'] = input_dict['reserve']
|
|
69
|
+
|
|
70
|
+
return output_dict
|
|
71
|
+
|
|
72
|
+
def add_fea_type_and_shape(self, out_types, out_shapes):
|
|
73
|
+
# overload ParquetInput.build_type_and_shape
|
|
74
|
+
for k in self._sparse_fea_names:
|
|
75
|
+
out_types[k] = (tf.int32, tf.int64)
|
|
76
|
+
out_shapes[k] = (tf.TensorShape([None]), tf.TensorShape([None]))
|
|
77
|
+
for fc in self._dense_fea_cfgs:
|
|
78
|
+
k = fc.input_names[0]
|
|
79
|
+
out_types[k] = tf.float32
|
|
80
|
+
out_shapes[k] = tf.TensorShape([None, fc.raw_input_dim])
|
|
81
|
+
|
|
82
|
+
def _preprocess(self, inputs=None):
|
|
83
|
+
features = {}
|
|
84
|
+
placeholders = {}
|
|
85
|
+
for fc in self._feature_configs:
|
|
86
|
+
feature_name = fc.feature_name if fc.feature_name != '' else fc.input_names[
|
|
87
|
+
0]
|
|
88
|
+
feature_type = fc.feature_type
|
|
89
|
+
if feature_type in [fc.IdFeature, fc.TagFeature]:
|
|
90
|
+
input_name0 = fc.input_names[0]
|
|
91
|
+
if inputs is not None:
|
|
92
|
+
input_lens, input_vals = inputs[input_name0]
|
|
93
|
+
else:
|
|
94
|
+
if input_name0 in placeholders:
|
|
95
|
+
input_lens, input_vals = placeholders[input_name0]
|
|
96
|
+
else:
|
|
97
|
+
input_vals = array_ops.placeholder(
|
|
98
|
+
dtypes.int64, [None], name=input_name0 + '/ids')
|
|
99
|
+
input_lens = array_ops.placeholder(
|
|
100
|
+
dtypes.int64, [None], name=input_name0 + '/lens')
|
|
101
|
+
placeholders[input_name0] = (input_lens, input_vals)
|
|
102
|
+
if not self._has_ev:
|
|
103
|
+
if fc.num_buckets > 0:
|
|
104
|
+
input_vals = input_vals % fc.num_buckets
|
|
105
|
+
else:
|
|
106
|
+
input_vals = string_ops.as_string(input_vals)
|
|
107
|
+
features[feature_name] = tf.RaggedTensor.from_row_lengths(
|
|
108
|
+
values=input_vals, row_lengths=input_lens)
|
|
109
|
+
elif feature_type in [fc.RawFeature]:
|
|
110
|
+
input_name0 = fc.input_names[0]
|
|
111
|
+
if inputs is not None:
|
|
112
|
+
input_vals = inputs[input_name0]
|
|
113
|
+
else:
|
|
114
|
+
if input_name0 in placeholders:
|
|
115
|
+
input_vals = placeholders[input_name0]
|
|
116
|
+
else:
|
|
117
|
+
if fc.raw_input_dim > 1:
|
|
118
|
+
input_vals = array_ops.placeholder(
|
|
119
|
+
dtypes.float32, [None, fc.raw_input_dim], name=input_name0)
|
|
120
|
+
else:
|
|
121
|
+
input_vals = array_ops.placeholder(
|
|
122
|
+
dtypes.float32, [None], name=input_name0)
|
|
123
|
+
placeholders[input_name0] = input_vals
|
|
124
|
+
features[feature_name] = input_vals
|
|
125
|
+
else:
|
|
126
|
+
assert False, 'feature_type[%s] not supported' % str(feature_type)
|
|
127
|
+
|
|
128
|
+
if inputs is not None:
|
|
129
|
+
return features
|
|
130
|
+
else:
|
|
131
|
+
inputs = {}
|
|
132
|
+
for key in placeholders:
|
|
133
|
+
vals = placeholders[key]
|
|
134
|
+
if isinstance(vals, tuple):
|
|
135
|
+
inputs[key + '/lens'] = vals[0]
|
|
136
|
+
inputs[key + '/ids'] = vals[1]
|
|
137
|
+
else:
|
|
138
|
+
inputs[key] = vals
|
|
139
|
+
return features, inputs
|
|
140
|
+
|
|
141
|
+
def _get_for_predictor(self, fea_dict):
|
|
142
|
+
# called by ParquetInputV2._build, format:
|
|
143
|
+
# {
|
|
144
|
+
# "feature": {"user_id/ids":..., "user_id/lens":..., ... },
|
|
145
|
+
# "reserve": {"sample_id":..., ...}
|
|
146
|
+
# }
|
|
147
|
+
return fea_dict
|
|
148
|
+
|
|
149
|
+
def create_input(self, export_config=None):
|
|
150
|
+
|
|
151
|
+
def _input_fn(mode=None, params=None, config=None):
|
|
152
|
+
"""Build input_fn for estimator.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
mode: tf.estimator.ModeKeys.(TRAIN, EVAL, PREDICT)
|
|
156
|
+
params: `dict` of hyper parameters, from Estimator
|
|
157
|
+
config: tf.estimator.RunConfig instance
|
|
158
|
+
|
|
159
|
+
Return:
|
|
160
|
+
if mode is not None, return:
|
|
161
|
+
features: inputs to the model.
|
|
162
|
+
labels: groundtruth
|
|
163
|
+
else, return:
|
|
164
|
+
tf.estimator.export.ServingInputReceiver instance
|
|
165
|
+
"""
|
|
166
|
+
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
|
|
167
|
+
tf.estimator.ModeKeys.PREDICT):
|
|
168
|
+
# build dataset from self._config.input_path
|
|
169
|
+
self._mode = mode
|
|
170
|
+
dataset = self._build(mode, params)
|
|
171
|
+
return dataset
|
|
172
|
+
elif mode is None: # serving_input_receiver_fn for export SavedModel
|
|
173
|
+
place_on_cpu = os.getenv('place_embedding_on_cpu')
|
|
174
|
+
place_on_cpu = bool(place_on_cpu) if place_on_cpu else False
|
|
175
|
+
with conditional(place_on_cpu, ops.device('/CPU:0')):
|
|
176
|
+
features, inputs = self._preprocess()
|
|
177
|
+
return tf.estimator.export.ServingInputReceiver(features, inputs)
|
|
178
|
+
|
|
179
|
+
_input_fn.input_creator = self
|
|
180
|
+
return _input_fn
|