easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
|
|
5
|
+
from easy_rec.python.input.input import Input
|
|
6
|
+
from easy_rec.python.utils import odps_util
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import pai
|
|
10
|
+
except Exception:
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OdpsInput(Input):
|
|
15
|
+
|
|
16
|
+
def __init__(self,
|
|
17
|
+
data_config,
|
|
18
|
+
feature_config,
|
|
19
|
+
input_path,
|
|
20
|
+
task_index=0,
|
|
21
|
+
task_num=1,
|
|
22
|
+
check_mode=False,
|
|
23
|
+
pipeline_config=None):
|
|
24
|
+
super(OdpsInput,
|
|
25
|
+
self).__init__(data_config, feature_config, input_path, task_index,
|
|
26
|
+
task_num, check_mode, pipeline_config)
|
|
27
|
+
|
|
28
|
+
def _build(self, mode, params):
|
|
29
|
+
# check data_config are consistent with odps tables
|
|
30
|
+
odps_util.check_input_field_and_types(self._data_config)
|
|
31
|
+
|
|
32
|
+
selected_cols = ','.join(self._input_fields)
|
|
33
|
+
if self._data_config.chief_redundant and \
|
|
34
|
+
mode == tf.estimator.ModeKeys.TRAIN:
|
|
35
|
+
reader = tf.TableRecordReader(
|
|
36
|
+
csv_delimiter=self._data_config.separator,
|
|
37
|
+
selected_cols=selected_cols,
|
|
38
|
+
slice_count=max(self._task_num - 1, 1),
|
|
39
|
+
slice_id=max(self._task_index - 1, 0))
|
|
40
|
+
else:
|
|
41
|
+
reader = tf.TableRecordReader(
|
|
42
|
+
csv_delimiter=self._data_config.separator,
|
|
43
|
+
selected_cols=selected_cols,
|
|
44
|
+
slice_count=self._task_num,
|
|
45
|
+
slice_id=self._task_index)
|
|
46
|
+
|
|
47
|
+
if type(self._input_path) != list:
|
|
48
|
+
self._input_path = self._input_path.split(',')
|
|
49
|
+
assert len(
|
|
50
|
+
self._input_path) > 0, 'match no files with %s' % self._input_path
|
|
51
|
+
|
|
52
|
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
|
53
|
+
if self._data_config.pai_worker_queue:
|
|
54
|
+
work_queue = pai.data.WorkQueue(
|
|
55
|
+
self._input_path,
|
|
56
|
+
num_epochs=self.num_epochs,
|
|
57
|
+
shuffle=self._data_config.shuffle,
|
|
58
|
+
num_slices=self._data_config.pai_worker_slice_num * self._task_num)
|
|
59
|
+
work_queue.add_summary()
|
|
60
|
+
file_queue = work_queue.input_producer()
|
|
61
|
+
reader = tf.TableRecordReader()
|
|
62
|
+
else:
|
|
63
|
+
file_queue = tf.train.string_input_producer(
|
|
64
|
+
self._input_path,
|
|
65
|
+
num_epochs=self.num_epochs,
|
|
66
|
+
capacity=1000,
|
|
67
|
+
shuffle=self._data_config.shuffle)
|
|
68
|
+
else:
|
|
69
|
+
file_queue = tf.train.string_input_producer(
|
|
70
|
+
self._input_path, num_epochs=1, capacity=1000, shuffle=False)
|
|
71
|
+
key, value = reader.read_up_to(file_queue, self._batch_size)
|
|
72
|
+
|
|
73
|
+
record_defaults = [
|
|
74
|
+
self.get_type_defaults(t, v)
|
|
75
|
+
for t, v in zip(self._input_field_types, self._input_field_defaults)
|
|
76
|
+
]
|
|
77
|
+
fields = tf.decode_csv(
|
|
78
|
+
value,
|
|
79
|
+
record_defaults=record_defaults,
|
|
80
|
+
field_delim=self._data_config.separator,
|
|
81
|
+
name='decode_csv')
|
|
82
|
+
|
|
83
|
+
inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
|
|
84
|
+
for x in self._label_fids:
|
|
85
|
+
inputs[self._input_fields[x]] = fields[x]
|
|
86
|
+
|
|
87
|
+
fields = self._preprocess(inputs)
|
|
88
|
+
|
|
89
|
+
features = self._get_features(fields)
|
|
90
|
+
# import pai
|
|
91
|
+
if mode != tf.estimator.ModeKeys.PREDICT:
|
|
92
|
+
labels = self._get_labels(fields)
|
|
93
|
+
# features, labels = pai.data.prefetch(features=(features, labels),
|
|
94
|
+
# capacity=self._prefetch_size, num_threads=2,
|
|
95
|
+
# closed_exception_types=(tuple([tf.errors.InternalError])))
|
|
96
|
+
return features, labels
|
|
97
|
+
else:
|
|
98
|
+
# features = pai.data.prefetch(features=(features,),
|
|
99
|
+
# capacity=self._prefetch_size, num_threads=2,
|
|
100
|
+
# closed_exception_types=(tuple([tf.errors.InternalError])))
|
|
101
|
+
return features
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
|
|
7
|
+
from easy_rec.python.input.input import Input
|
|
8
|
+
from easy_rec.python.utils import odps_util
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import pai
|
|
12
|
+
except Exception:
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OdpsInputV2(Input):
|
|
17
|
+
|
|
18
|
+
def __init__(self,
|
|
19
|
+
data_config,
|
|
20
|
+
feature_config,
|
|
21
|
+
input_path,
|
|
22
|
+
task_index=0,
|
|
23
|
+
task_num=1,
|
|
24
|
+
check_mode=False,
|
|
25
|
+
pipeline_config=None):
|
|
26
|
+
super(OdpsInputV2,
|
|
27
|
+
self).__init__(data_config, feature_config, input_path, task_index,
|
|
28
|
+
task_num, check_mode, pipeline_config)
|
|
29
|
+
|
|
30
|
+
def _parse_table(self, *fields):
|
|
31
|
+
fields = list(fields)
|
|
32
|
+
inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
|
|
33
|
+
for x in self._label_fids:
|
|
34
|
+
inputs[self._input_fields[x]] = fields[x]
|
|
35
|
+
return inputs
|
|
36
|
+
|
|
37
|
+
def _build(self, mode, params):
|
|
38
|
+
if type(self._input_path) != list:
|
|
39
|
+
self._input_path = self._input_path.split(',')
|
|
40
|
+
assert len(
|
|
41
|
+
self._input_path) > 0, 'match no files with %s' % self._input_path
|
|
42
|
+
# check data_config are consistent with odps tables
|
|
43
|
+
odps_util.check_input_field_and_types(self._data_config)
|
|
44
|
+
|
|
45
|
+
selected_cols = ','.join(self._input_fields)
|
|
46
|
+
record_defaults = [
|
|
47
|
+
self.get_type_defaults(x, v)
|
|
48
|
+
for x, v in zip(self._input_field_types, self._input_field_defaults)
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
if self._data_config.pai_worker_queue and \
|
|
52
|
+
mode == tf.estimator.ModeKeys.TRAIN:
|
|
53
|
+
logging.info('pai_worker_slice_num = %d' %
|
|
54
|
+
self._data_config.pai_worker_slice_num)
|
|
55
|
+
work_queue = pai.data.WorkQueue(
|
|
56
|
+
self._input_path,
|
|
57
|
+
num_epochs=self.num_epochs,
|
|
58
|
+
shuffle=self._data_config.shuffle,
|
|
59
|
+
num_slices=self._data_config.pai_worker_slice_num * self._task_num)
|
|
60
|
+
que_paths = work_queue.input_dataset()
|
|
61
|
+
dataset = tf.data.TableRecordDataset(
|
|
62
|
+
que_paths,
|
|
63
|
+
record_defaults=record_defaults,
|
|
64
|
+
selected_cols=selected_cols)
|
|
65
|
+
elif self._data_config.chief_redundant and \
|
|
66
|
+
mode == tf.estimator.ModeKeys.TRAIN:
|
|
67
|
+
dataset = tf.data.TableRecordDataset(
|
|
68
|
+
self._input_path,
|
|
69
|
+
record_defaults=record_defaults,
|
|
70
|
+
selected_cols=selected_cols,
|
|
71
|
+
slice_id=max(self._task_index - 1, 0),
|
|
72
|
+
slice_count=max(self._task_num - 1, 1))
|
|
73
|
+
else:
|
|
74
|
+
dataset = tf.data.TableRecordDataset(
|
|
75
|
+
self._input_path,
|
|
76
|
+
record_defaults=record_defaults,
|
|
77
|
+
selected_cols=selected_cols,
|
|
78
|
+
slice_id=self._task_index,
|
|
79
|
+
slice_count=self._task_num)
|
|
80
|
+
|
|
81
|
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
|
82
|
+
if self._data_config.shuffle:
|
|
83
|
+
dataset = dataset.shuffle(
|
|
84
|
+
self._data_config.shuffle_buffer_size,
|
|
85
|
+
seed=2020,
|
|
86
|
+
reshuffle_each_iteration=True)
|
|
87
|
+
dataset = dataset.repeat(self.num_epochs)
|
|
88
|
+
else:
|
|
89
|
+
dataset = dataset.repeat(1)
|
|
90
|
+
|
|
91
|
+
dataset = dataset.batch(batch_size=self._data_config.batch_size)
|
|
92
|
+
|
|
93
|
+
dataset = dataset.map(
|
|
94
|
+
self._parse_table,
|
|
95
|
+
num_parallel_calls=self._data_config.num_parallel_calls)
|
|
96
|
+
|
|
97
|
+
# preprocess is necessary to transform data
|
|
98
|
+
# so that they could be feed into FeatureColumns
|
|
99
|
+
dataset = dataset.map(
|
|
100
|
+
map_func=self._preprocess,
|
|
101
|
+
num_parallel_calls=self._data_config.num_parallel_calls)
|
|
102
|
+
|
|
103
|
+
dataset = dataset.prefetch(buffer_size=self._prefetch_size)
|
|
104
|
+
|
|
105
|
+
if mode != tf.estimator.ModeKeys.PREDICT:
|
|
106
|
+
dataset = dataset.map(lambda x:
|
|
107
|
+
(self._get_features(x), self._get_labels(x)))
|
|
108
|
+
else:
|
|
109
|
+
dataset = dataset.map(lambda x: (self._get_features(x)))
|
|
110
|
+
return dataset
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
|
|
9
|
+
from easy_rec.python.input.input import Input
|
|
10
|
+
from easy_rec.python.utils import odps_util
|
|
11
|
+
from easy_rec.python.utils.tf_utils import get_tf_type
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import common_io
|
|
15
|
+
except Exception:
|
|
16
|
+
common_io = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OdpsInputV3(Input):
|
|
20
|
+
"""Common IO based interface, could run at local or on data science."""
|
|
21
|
+
|
|
22
|
+
def __init__(self,
|
|
23
|
+
data_config,
|
|
24
|
+
feature_config,
|
|
25
|
+
input_path,
|
|
26
|
+
task_index=0,
|
|
27
|
+
task_num=1,
|
|
28
|
+
check_mode=False,
|
|
29
|
+
pipeline_config=None):
|
|
30
|
+
super(OdpsInputV3,
|
|
31
|
+
self).__init__(data_config, feature_config, input_path, task_index,
|
|
32
|
+
task_num, check_mode, pipeline_config)
|
|
33
|
+
self._num_epoch = 0
|
|
34
|
+
if common_io is None:
|
|
35
|
+
logging.error('''
|
|
36
|
+
please install common_io pip install
|
|
37
|
+
https://easyrec.oss-cn-beijing.aliyuncs.com/3rdparty/common_io-0.4.2%2Btunnel-py2.py3-none-any.whl'''
|
|
38
|
+
)
|
|
39
|
+
sys.exit(1)
|
|
40
|
+
|
|
41
|
+
def _parse_table(self, *fields):
|
|
42
|
+
fields = list(fields)
|
|
43
|
+
inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
|
|
44
|
+
for x in self._label_fids:
|
|
45
|
+
inputs[self._input_fields[x]] = fields[x]
|
|
46
|
+
return inputs
|
|
47
|
+
|
|
48
|
+
def _odps_read(self):
|
|
49
|
+
logging.info('start epoch[%d]' % self._num_epoch)
|
|
50
|
+
self._num_epoch += 1
|
|
51
|
+
if type(self._input_path) != list:
|
|
52
|
+
self._input_path = self._input_path.split(',')
|
|
53
|
+
assert len(
|
|
54
|
+
self._input_path) > 0, 'match no files with %s' % self._input_path
|
|
55
|
+
|
|
56
|
+
# check data_config are consistent with odps tables
|
|
57
|
+
odps_util.check_input_field_and_types(self._data_config)
|
|
58
|
+
|
|
59
|
+
record_defaults = [
|
|
60
|
+
self.get_type_defaults(x, v)
|
|
61
|
+
for x, v in zip(self._input_field_types, self._input_field_defaults)
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
selected_cols = ','.join(self._input_fields)
|
|
65
|
+
for table_path in self._input_path:
|
|
66
|
+
reader = common_io.table.TableReader(
|
|
67
|
+
table_path,
|
|
68
|
+
selected_cols=selected_cols,
|
|
69
|
+
slice_id=self._task_index,
|
|
70
|
+
slice_count=self._task_num)
|
|
71
|
+
total_records_num = reader.get_row_count()
|
|
72
|
+
batch_num = int(total_records_num / self._data_config.batch_size)
|
|
73
|
+
res_num = total_records_num - batch_num * self._data_config.batch_size
|
|
74
|
+
batch_defaults = [
|
|
75
|
+
[x] * self._data_config.batch_size for x in record_defaults
|
|
76
|
+
]
|
|
77
|
+
for batch_id in range(batch_num):
|
|
78
|
+
batch_data_np = [x.copy() for x in batch_defaults]
|
|
79
|
+
for row_id, one_data in enumerate(
|
|
80
|
+
reader.read(self._data_config.batch_size)):
|
|
81
|
+
for col_id in range(len(record_defaults)):
|
|
82
|
+
if one_data[col_id] not in ['', 'NULL', None]:
|
|
83
|
+
batch_data_np[col_id][row_id] = one_data[col_id]
|
|
84
|
+
yield tuple(batch_data_np)
|
|
85
|
+
if res_num > 0:
|
|
86
|
+
batch_data_np = [x[:res_num] for x in batch_defaults]
|
|
87
|
+
for row_id, one_data in enumerate(reader.read(res_num)):
|
|
88
|
+
for col_id in range(len(record_defaults)):
|
|
89
|
+
if one_data[col_id] not in ['', 'NULL', None]:
|
|
90
|
+
batch_data_np[col_id][row_id] = one_data[col_id]
|
|
91
|
+
yield tuple(batch_data_np)
|
|
92
|
+
reader.close()
|
|
93
|
+
logging.info('finish epoch[%d]' % self._num_epoch)
|
|
94
|
+
|
|
95
|
+
def _build(self, mode, params):
|
|
96
|
+
# get input type
|
|
97
|
+
list_type = [get_tf_type(x) for x in self._input_field_types]
|
|
98
|
+
list_type = tuple(list_type)
|
|
99
|
+
list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))]
|
|
100
|
+
list_shapes = tuple(list_shapes)
|
|
101
|
+
|
|
102
|
+
# read odps tables
|
|
103
|
+
dataset = tf.data.Dataset.from_generator(
|
|
104
|
+
self._odps_read, output_types=list_type, output_shapes=list_shapes)
|
|
105
|
+
|
|
106
|
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
|
107
|
+
dataset = dataset.shuffle(
|
|
108
|
+
self._data_config.shuffle_buffer_size,
|
|
109
|
+
seed=2020,
|
|
110
|
+
reshuffle_each_iteration=True)
|
|
111
|
+
dataset = dataset.repeat(self.num_epochs)
|
|
112
|
+
else:
|
|
113
|
+
dataset = dataset.repeat(1)
|
|
114
|
+
|
|
115
|
+
dataset = dataset.map(
|
|
116
|
+
self._parse_table,
|
|
117
|
+
num_parallel_calls=self._data_config.num_parallel_calls)
|
|
118
|
+
|
|
119
|
+
# preprocess is necessary to transform data
|
|
120
|
+
# so that they could be feed into FeatureColumns
|
|
121
|
+
dataset = dataset.map(
|
|
122
|
+
map_func=self._preprocess,
|
|
123
|
+
num_parallel_calls=self._data_config.num_parallel_calls)
|
|
124
|
+
|
|
125
|
+
dataset = dataset.prefetch(buffer_size=self._prefetch_size)
|
|
126
|
+
|
|
127
|
+
if mode != tf.estimator.ModeKeys.PREDICT:
|
|
128
|
+
dataset = dataset.map(lambda x:
|
|
129
|
+
(self._get_features(x), self._get_labels(x)))
|
|
130
|
+
else:
|
|
131
|
+
dataset = dataset.map(lambda x: (self._get_features(x)))
|
|
132
|
+
return dataset
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
|
|
7
|
+
from easy_rec.python.input.input import Input
|
|
8
|
+
from easy_rec.python.ops.gen_str_avx_op import str_split_by_chr
|
|
9
|
+
from easy_rec.python.utils.check_utils import check_split
|
|
10
|
+
from easy_rec.python.utils.input_utils import string_to_number
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import pai
|
|
14
|
+
except Exception:
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OdpsRTPInput(Input):
|
|
19
|
+
"""RTPInput for parsing rtp fg new input format on odps.
|
|
20
|
+
|
|
21
|
+
Our new format(csv in table) of rtp output:
|
|
22
|
+
label0, item_id, ..., user_id, features
|
|
23
|
+
For the feature column, features are separated by ,
|
|
24
|
+
multiple values of one feature are separated by , such as:
|
|
25
|
+
...20beautysmartParis...
|
|
26
|
+
The features column and labels are specified by data_config.selected_cols,
|
|
27
|
+
columns are selected by names in the table
|
|
28
|
+
such as: clk,features, the last selected column is features, the first
|
|
29
|
+
selected columns are labels
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self,
|
|
33
|
+
data_config,
|
|
34
|
+
feature_config,
|
|
35
|
+
input_path,
|
|
36
|
+
task_index=0,
|
|
37
|
+
task_num=1,
|
|
38
|
+
check_mode=False,
|
|
39
|
+
pipeline_config=None):
|
|
40
|
+
super(OdpsRTPInput,
|
|
41
|
+
self).__init__(data_config, feature_config, input_path, task_index,
|
|
42
|
+
task_num, check_mode, pipeline_config)
|
|
43
|
+
logging.info('input_fields: %s label_fields: %s' %
|
|
44
|
+
(','.join(self._input_fields), ','.join(self._label_fields)))
|
|
45
|
+
|
|
46
|
+
def _parse_table(self, *fields):
|
|
47
|
+
fields = list(fields)
|
|
48
|
+
labels = fields[:-1]
|
|
49
|
+
|
|
50
|
+
selected_cols = self._data_config.selected_cols \
|
|
51
|
+
if self._data_config.selected_cols else None
|
|
52
|
+
non_feature_cols = self._label_fields
|
|
53
|
+
if selected_cols:
|
|
54
|
+
cols = [c.strip() for c in selected_cols.split(',')]
|
|
55
|
+
non_feature_cols = cols[:-1]
|
|
56
|
+
# only for features, labels and sample_weight excluded
|
|
57
|
+
record_types = [
|
|
58
|
+
t for x, t in zip(self._input_fields, self._input_field_types)
|
|
59
|
+
if x not in non_feature_cols
|
|
60
|
+
]
|
|
61
|
+
record_defaults = [
|
|
62
|
+
self.get_type_defaults(t, v)
|
|
63
|
+
for x, t, v in zip(self._input_fields, self._input_field_types,
|
|
64
|
+
self._input_field_defaults)
|
|
65
|
+
if x not in non_feature_cols
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
feature_num = len(record_types)
|
|
69
|
+
# assume that the last field is the generated feature column
|
|
70
|
+
print('field_delim = %s, feature_num = %d' %
|
|
71
|
+
(self._data_config.separator, feature_num))
|
|
72
|
+
logging.info('field_delim = %s, input_field_name = %d' %
|
|
73
|
+
(self._data_config.separator, len(record_types)))
|
|
74
|
+
|
|
75
|
+
check_list = [
|
|
76
|
+
tf.py_func(
|
|
77
|
+
check_split,
|
|
78
|
+
[fields[-1], self._data_config.separator,
|
|
79
|
+
len(record_types)],
|
|
80
|
+
Tout=tf.bool)
|
|
81
|
+
] if self._check_mode else []
|
|
82
|
+
with tf.control_dependencies(check_list):
|
|
83
|
+
fields = str_split_by_chr(
|
|
84
|
+
fields[-1], self._data_config.separator, skip_empty=False)
|
|
85
|
+
tmp_fields = tf.reshape(fields.values, [-1, feature_num])
|
|
86
|
+
fields = labels[len(self._label_fields):]
|
|
87
|
+
for i in range(feature_num):
|
|
88
|
+
field = string_to_number(tmp_fields[:, i], record_types[i],
|
|
89
|
+
record_defaults[i], i)
|
|
90
|
+
fields.append(field)
|
|
91
|
+
|
|
92
|
+
field_keys = [x for x in self._input_fields if x not in self._label_fields]
|
|
93
|
+
effective_fids = [field_keys.index(x) for x in self._effective_fields]
|
|
94
|
+
inputs = {field_keys[x]: fields[x] for x in effective_fids}
|
|
95
|
+
|
|
96
|
+
for x in range(len(self._label_fields)):
|
|
97
|
+
inputs[self._label_fields[x]] = labels[x]
|
|
98
|
+
print('effective field num = %d, input_num = %d' %
|
|
99
|
+
(len(fields), len(inputs)))
|
|
100
|
+
return inputs
|
|
101
|
+
|
|
102
|
+
def _build(self, mode, params):
|
|
103
|
+
if type(self._input_path) != list:
|
|
104
|
+
self._input_path = self._input_path.split(',')
|
|
105
|
+
assert len(
|
|
106
|
+
self._input_path) > 0, 'match no files with %s' % self._input_path
|
|
107
|
+
|
|
108
|
+
selected_cols = self._data_config.selected_cols \
|
|
109
|
+
if self._data_config.selected_cols else None
|
|
110
|
+
if selected_cols:
|
|
111
|
+
cols = [c.strip() for c in selected_cols.split(',')]
|
|
112
|
+
record_defaults = [
|
|
113
|
+
self.get_type_defaults(t, v)
|
|
114
|
+
for x, t, v in zip(self._input_fields, self._input_field_types,
|
|
115
|
+
self._input_field_defaults)
|
|
116
|
+
if x in cols[:-1]
|
|
117
|
+
]
|
|
118
|
+
print('selected_cols: %s; defaults num: %d' %
|
|
119
|
+
(','.join(cols), len(record_defaults)))
|
|
120
|
+
else:
|
|
121
|
+
record_defaults = [
|
|
122
|
+
self.get_type_defaults(t, v)
|
|
123
|
+
for x, t, v in zip(self._input_fields, self._input_field_types,
|
|
124
|
+
self._input_field_defaults)
|
|
125
|
+
if x in self._label_fields
|
|
126
|
+
]
|
|
127
|
+
# the actual features are in one single column
|
|
128
|
+
record_defaults.append(
|
|
129
|
+
self._data_config.separator.join([
|
|
130
|
+
str(self.get_type_defaults(t, v))
|
|
131
|
+
for x, t, v in zip(self._input_fields, self._input_field_types,
|
|
132
|
+
self._input_field_defaults)
|
|
133
|
+
if x not in self._label_fields
|
|
134
|
+
]))
|
|
135
|
+
|
|
136
|
+
if self._data_config.pai_worker_queue and \
|
|
137
|
+
mode == tf.estimator.ModeKeys.TRAIN:
|
|
138
|
+
logging.info('pai_worker_slice_num = %d' %
|
|
139
|
+
self._data_config.pai_worker_slice_num)
|
|
140
|
+
work_queue = pai.data.WorkQueue(
|
|
141
|
+
self._input_path,
|
|
142
|
+
num_epochs=self.num_epochs,
|
|
143
|
+
shuffle=self._data_config.shuffle,
|
|
144
|
+
num_slices=self._data_config.pai_worker_slice_num * self._task_num)
|
|
145
|
+
que_paths = work_queue.input_dataset()
|
|
146
|
+
dataset = tf.data.TableRecordDataset(
|
|
147
|
+
que_paths,
|
|
148
|
+
record_defaults=record_defaults,
|
|
149
|
+
selected_cols=selected_cols)
|
|
150
|
+
else:
|
|
151
|
+
dataset = tf.data.TableRecordDataset(
|
|
152
|
+
self._input_path,
|
|
153
|
+
record_defaults=record_defaults,
|
|
154
|
+
selected_cols=selected_cols,
|
|
155
|
+
slice_id=self._task_index,
|
|
156
|
+
slice_count=self._task_num)
|
|
157
|
+
|
|
158
|
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
|
159
|
+
if self._data_config.shuffle:
|
|
160
|
+
dataset = dataset.shuffle(
|
|
161
|
+
self._data_config.shuffle_buffer_size,
|
|
162
|
+
seed=2020,
|
|
163
|
+
reshuffle_each_iteration=True)
|
|
164
|
+
dataset = dataset.repeat(self.num_epochs)
|
|
165
|
+
else:
|
|
166
|
+
dataset = dataset.repeat(1)
|
|
167
|
+
|
|
168
|
+
dataset = dataset.batch(batch_size=self._data_config.batch_size)
|
|
169
|
+
|
|
170
|
+
dataset = dataset.map(
|
|
171
|
+
self._parse_table,
|
|
172
|
+
num_parallel_calls=self._data_config.num_parallel_calls)
|
|
173
|
+
|
|
174
|
+
# preprocess is necessary to transform data
|
|
175
|
+
# so that they could be feed into FeatureColumns
|
|
176
|
+
dataset = dataset.map(
|
|
177
|
+
map_func=self._preprocess,
|
|
178
|
+
num_parallel_calls=self._data_config.num_parallel_calls)
|
|
179
|
+
|
|
180
|
+
dataset = dataset.prefetch(buffer_size=self._prefetch_size)
|
|
181
|
+
|
|
182
|
+
if mode != tf.estimator.ModeKeys.PREDICT:
|
|
183
|
+
dataset = dataset.map(lambda x:
|
|
184
|
+
(self._get_features(x), self._get_labels(x)))
|
|
185
|
+
else:
|
|
186
|
+
dataset = dataset.map(lambda x: (self._get_features(x)))
|
|
187
|
+
return dataset
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
|
|
8
|
+
from easy_rec.python.input.odps_rtp_input import OdpsRTPInput
|
|
9
|
+
|
|
10
|
+
if tf.__version__.startswith('1.'):
|
|
11
|
+
from tensorflow.python.platform import gfile
|
|
12
|
+
else:
|
|
13
|
+
import tensorflow.io.gfile as gfile
|
|
14
|
+
try:
|
|
15
|
+
import pai
|
|
16
|
+
import rtp_fg
|
|
17
|
+
except Exception:
|
|
18
|
+
pai = None
|
|
19
|
+
rtp_fg = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class OdpsRTPInputV2(OdpsRTPInput):
|
|
23
|
+
"""RTPInput for parsing rtp fg new input format on odps.
|
|
24
|
+
|
|
25
|
+
Our new format(csv in table) of rtp output:
|
|
26
|
+
label0, item_id, ..., user_id, features
|
|
27
|
+
Where features is in default RTP-tensorflow format.
|
|
28
|
+
The features column and labels are specified by data_config.selected_cols,
|
|
29
|
+
columns are selected by names in the table
|
|
30
|
+
such as: clk,features, the last selected column is features, the first
|
|
31
|
+
selected columns are labels
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self,
|
|
35
|
+
data_config,
|
|
36
|
+
feature_config,
|
|
37
|
+
input_path,
|
|
38
|
+
task_index=0,
|
|
39
|
+
task_num=1,
|
|
40
|
+
check_mode=False,
|
|
41
|
+
fg_json_path=None,
|
|
42
|
+
pipeline_config=None):
|
|
43
|
+
super(OdpsRTPInputV2,
|
|
44
|
+
self).__init__(data_config, feature_config, input_path, task_index,
|
|
45
|
+
task_num, check_mode, pipeline_config)
|
|
46
|
+
if fg_json_path.startswith('!'):
|
|
47
|
+
fg_json_path = fg_json_path[1:]
|
|
48
|
+
self._fg_config_path = fg_json_path
|
|
49
|
+
logging.info('fg config path: {}'.format(self._fg_config_path))
|
|
50
|
+
if self._fg_config_path is None:
|
|
51
|
+
raise ValueError('fg_json_path is not set')
|
|
52
|
+
with gfile.GFile(self._fg_config_path, 'r') as f:
|
|
53
|
+
self._fg_config = json.load(f)
|
|
54
|
+
|
|
55
|
+
def _parse_table(self, *fields):
|
|
56
|
+
self.check_rtp()
|
|
57
|
+
|
|
58
|
+
fields = list(fields)
|
|
59
|
+
labels = fields[:-1]
|
|
60
|
+
|
|
61
|
+
# assume that the last field is the generated feature column
|
|
62
|
+
features = rtp_fg.parse_genreated_fg(self._fg_config, fields[-1])
|
|
63
|
+
|
|
64
|
+
field_keys = [x for x in self._input_fields if x not in self._label_fields]
|
|
65
|
+
for feature_key in features:
|
|
66
|
+
if feature_key not in field_keys or feature_key not in self._effective_fields:
|
|
67
|
+
del features[feature_key]
|
|
68
|
+
inputs = {x: features[x] for x in features.keys()}
|
|
69
|
+
|
|
70
|
+
for x in range(len(self._label_fields)):
|
|
71
|
+
inputs[self._label_fields[x]] = labels[x]
|
|
72
|
+
return inputs
|
|
73
|
+
|
|
74
|
+
def create_placeholders(self, *args, **kwargs):
|
|
75
|
+
"""Create serving placeholders with rtp_fg."""
|
|
76
|
+
self.check_rtp()
|
|
77
|
+
self._mode = tf.estimator.ModeKeys.PREDICT
|
|
78
|
+
inputs_placeholder = tf.placeholder(tf.string, [None], name='features')
|
|
79
|
+
print('[OdpsRTPInputV2] building placeholders.')
|
|
80
|
+
print('[OdpsRTPInputV2] fg_config: {}'.format(self._fg_config))
|
|
81
|
+
features = rtp_fg.parse_genreated_fg(self._fg_config, inputs_placeholder)
|
|
82
|
+
print('[OdpsRTPInputV2] built features: {}'.format(features.keys()))
|
|
83
|
+
features = self._preprocess(features)
|
|
84
|
+
print('[OdpsRTPInputV2] processed features: {}'.format(features.keys()))
|
|
85
|
+
return {'features': inputs_placeholder}, features['feature']
|
|
86
|
+
|
|
87
|
+
def create_multi_placeholders(self, *args, **kwargs):
|
|
88
|
+
"""Create serving multi-placeholders with rtp_fg."""
|
|
89
|
+
raise NotImplementedError(
|
|
90
|
+
'create_multi_placeholders is not supported for OdpsRTPInputV2')
|
|
91
|
+
|
|
92
|
+
def check_rtp(self):
|
|
93
|
+
if rtp_fg is None:
|
|
94
|
+
raise NotImplementedError(
|
|
95
|
+
'OdpsRTPInputV2 cannot run without rtp_fg, which is not installed')
|
|
96
|
+
|
|
97
|
+
def _pre_build(self, mode, params):
|
|
98
|
+
try:
|
|
99
|
+
# Prevent TF from replacing the shape tensor to a constant tensor. This will
|
|
100
|
+
# cause the batch size being fixed. And RTP will be not able to recognize
|
|
101
|
+
# the input shape.
|
|
102
|
+
tf.get_default_graph().set_shape_optimize(False)
|
|
103
|
+
except AttributeError as e:
|
|
104
|
+
logging.warning('failed to disable shape optimization:', e)
|