easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
from __future__ import absolute_import
|
|
4
|
+
from __future__ import division
|
|
5
|
+
from __future__ import print_function
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
import tensorflow as tf
|
|
11
|
+
from tensorflow.python.platform import gfile
|
|
12
|
+
|
|
13
|
+
from easy_rec.python.inference.predictor import SINGLE_PLACEHOLDER_FEATURE_KEY
|
|
14
|
+
from easy_rec.python.inference.predictor import Predictor
|
|
15
|
+
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
|
|
16
|
+
from easy_rec.python.utils.check_utils import check_split
|
|
17
|
+
|
|
18
|
+
if tf.__version__ >= '2.0':
|
|
19
|
+
tf = tf.compat.v1
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CSVPredictor(Predictor):
|
|
23
|
+
|
|
24
|
+
def __init__(self,
|
|
25
|
+
model_path,
|
|
26
|
+
data_config,
|
|
27
|
+
with_header=False,
|
|
28
|
+
ds_vector_recall=False,
|
|
29
|
+
fg_json_path=None,
|
|
30
|
+
profiling_file=None,
|
|
31
|
+
selected_cols=None,
|
|
32
|
+
output_sep=chr(1)):
|
|
33
|
+
super(CSVPredictor, self).__init__(model_path, profiling_file, fg_json_path)
|
|
34
|
+
self._output_sep = output_sep
|
|
35
|
+
self._ds_vector_recall = ds_vector_recall
|
|
36
|
+
input_type = DatasetConfig.InputType.Name(data_config.input_type).lower()
|
|
37
|
+
self._with_header = with_header
|
|
38
|
+
|
|
39
|
+
if 'rtp' in input_type:
|
|
40
|
+
self._is_rtp = True
|
|
41
|
+
self._input_sep = data_config.rtp_separator
|
|
42
|
+
else:
|
|
43
|
+
self._is_rtp = False
|
|
44
|
+
self._input_sep = data_config.separator
|
|
45
|
+
|
|
46
|
+
if selected_cols and not ds_vector_recall:
|
|
47
|
+
self._selected_cols = [int(x) for x in selected_cols.split(',')]
|
|
48
|
+
elif ds_vector_recall:
|
|
49
|
+
self._selected_cols = selected_cols.split(',')
|
|
50
|
+
else:
|
|
51
|
+
self._selected_cols = None
|
|
52
|
+
|
|
53
|
+
def _get_reserved_cols(self, reserved_cols):
|
|
54
|
+
if reserved_cols == 'ALL_COLUMNS':
|
|
55
|
+
if self._is_rtp:
|
|
56
|
+
if self._with_header:
|
|
57
|
+
reserved_cols = self._all_fields
|
|
58
|
+
else:
|
|
59
|
+
idx = 0
|
|
60
|
+
reserved_cols = []
|
|
61
|
+
for x in range(len(self._record_defaults) - 1):
|
|
62
|
+
if not self._selected_cols or x in self._selected_cols[:-1]:
|
|
63
|
+
reserved_cols.append(self._input_fields[idx])
|
|
64
|
+
idx += 1
|
|
65
|
+
else:
|
|
66
|
+
reserved_cols.append('no_used_%d' % x)
|
|
67
|
+
reserved_cols.append(SINGLE_PLACEHOLDER_FEATURE_KEY)
|
|
68
|
+
else:
|
|
69
|
+
reserved_cols = self._all_fields
|
|
70
|
+
else:
|
|
71
|
+
reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
|
|
72
|
+
return reserved_cols
|
|
73
|
+
|
|
74
|
+
def _parse_line(self, line):
|
|
75
|
+
check_list = [
|
|
76
|
+
tf.py_func(
|
|
77
|
+
check_split, [line, self._input_sep,
|
|
78
|
+
len(self._record_defaults)],
|
|
79
|
+
Tout=tf.bool)
|
|
80
|
+
]
|
|
81
|
+
with tf.control_dependencies(check_list):
|
|
82
|
+
fields = tf.decode_csv(
|
|
83
|
+
line,
|
|
84
|
+
field_delim=self._input_sep,
|
|
85
|
+
record_defaults=self._record_defaults,
|
|
86
|
+
name='decode_csv')
|
|
87
|
+
if self._is_rtp:
|
|
88
|
+
if self._with_header:
|
|
89
|
+
inputs = dict(zip(self._all_fields, fields))
|
|
90
|
+
else:
|
|
91
|
+
inputs = {}
|
|
92
|
+
idx = 0
|
|
93
|
+
for x in range(len(self._record_defaults) - 1):
|
|
94
|
+
if not self._selected_cols or x in self._selected_cols[:-1]:
|
|
95
|
+
inputs[self._input_fields[idx]] = fields[x]
|
|
96
|
+
idx += 1
|
|
97
|
+
else:
|
|
98
|
+
inputs['no_used_%d' % x] = fields[x]
|
|
99
|
+
inputs[SINGLE_PLACEHOLDER_FEATURE_KEY] = fields[-1]
|
|
100
|
+
else:
|
|
101
|
+
inputs = {self._all_fields[x]: fields[x] for x in range(len(fields))}
|
|
102
|
+
return inputs
|
|
103
|
+
|
|
104
|
+
def _get_num_cols(self, file_paths):
|
|
105
|
+
# try to figure out number of fields from one file
|
|
106
|
+
num_cols = -1
|
|
107
|
+
with gfile.GFile(file_paths[0], 'r') as fin:
|
|
108
|
+
num_lines = 0
|
|
109
|
+
for line_str in fin:
|
|
110
|
+
line_tok = line_str.strip().split(self._input_sep)
|
|
111
|
+
if num_cols != -1:
|
|
112
|
+
assert num_cols == len(line_tok), (
|
|
113
|
+
'num selected cols is %d, not equal to %d, current line is: %s, please check input_sep and data.'
|
|
114
|
+
% (num_cols, len(line_tok), line_str))
|
|
115
|
+
num_cols = len(line_tok)
|
|
116
|
+
num_lines += 1
|
|
117
|
+
if num_lines > 10:
|
|
118
|
+
break
|
|
119
|
+
logging.info('num selected cols = %d' % num_cols)
|
|
120
|
+
return num_cols
|
|
121
|
+
|
|
122
|
+
def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
|
|
123
|
+
slice_id):
|
|
124
|
+
file_paths = []
|
|
125
|
+
for path in input_path.split(','):
|
|
126
|
+
for x in gfile.Glob(path):
|
|
127
|
+
if not x.endswith('_SUCCESS'):
|
|
128
|
+
file_paths.append(x)
|
|
129
|
+
assert len(file_paths) > 0, 'match no files with %s' % input_path
|
|
130
|
+
|
|
131
|
+
if self._with_header:
|
|
132
|
+
with gfile.GFile(file_paths[0], 'r') as fin:
|
|
133
|
+
for line_str in fin:
|
|
134
|
+
line_str = line_str.strip()
|
|
135
|
+
self._field_names = line_str.split(self._input_sep)
|
|
136
|
+
break
|
|
137
|
+
print('field_names: %s' % ','.join(self._field_names))
|
|
138
|
+
self._all_fields = self._field_names
|
|
139
|
+
elif self._ds_vector_recall:
|
|
140
|
+
self._all_fields = self._selected_cols
|
|
141
|
+
else:
|
|
142
|
+
self._all_fields = self._input_fields
|
|
143
|
+
if self._is_rtp:
|
|
144
|
+
num_cols = self._get_num_cols(file_paths)
|
|
145
|
+
self._record_defaults = ['' for _ in range(num_cols)]
|
|
146
|
+
if not self._selected_cols:
|
|
147
|
+
self._selected_cols = list(range(num_cols))
|
|
148
|
+
for col_idx in self._selected_cols[:-1]:
|
|
149
|
+
col_name = self._input_fields[col_idx]
|
|
150
|
+
default_val = self._get_defaults(col_name)
|
|
151
|
+
self._record_defaults[col_idx] = default_val
|
|
152
|
+
else:
|
|
153
|
+
self._record_defaults = [
|
|
154
|
+
self._get_defaults(col_name) for col_name in self._all_fields
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
|
|
158
|
+
parallel_num = min(num_parallel_calls, len(file_paths))
|
|
159
|
+
dataset = dataset.interleave(
|
|
160
|
+
lambda x: tf.data.TextLineDataset(x).skip(int(self._with_header)),
|
|
161
|
+
cycle_length=parallel_num,
|
|
162
|
+
num_parallel_calls=parallel_num)
|
|
163
|
+
dataset = dataset.shard(slice_num, slice_id)
|
|
164
|
+
dataset = dataset.batch(batch_size)
|
|
165
|
+
dataset = dataset.prefetch(buffer_size=64)
|
|
166
|
+
return dataset
|
|
167
|
+
|
|
168
|
+
def _get_writer(self, output_path, slice_id):
|
|
169
|
+
if not gfile.Exists(output_path):
|
|
170
|
+
gfile.MakeDirs(output_path)
|
|
171
|
+
res_path = os.path.join(output_path, 'part-%d.csv' % slice_id)
|
|
172
|
+
table_writer = gfile.GFile(res_path, 'w')
|
|
173
|
+
table_writer.write(
|
|
174
|
+
self._output_sep.join(self._output_cols + self._reserved_cols) + '\n')
|
|
175
|
+
return table_writer
|
|
176
|
+
|
|
177
|
+
def _write_lines(self, table_writer, outputs):
|
|
178
|
+
outputs = '\n'.join(
|
|
179
|
+
[self._output_sep.join([str(i) for i in output]) for output in outputs])
|
|
180
|
+
table_writer.write(outputs + '\n')
|
|
181
|
+
|
|
182
|
+
def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
|
|
183
|
+
reserve_vals = [outputs[x] for x in output_cols] + \
|
|
184
|
+
[all_vals[k] for k in reserved_cols]
|
|
185
|
+
return reserve_vals
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def out_of_range_exception(self):
|
|
189
|
+
return (tf.errors.OutOfRangeError)
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
from __future__ import absolute_import
|
|
4
|
+
from __future__ import division
|
|
5
|
+
from __future__ import print_function
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import time
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import tensorflow as tf
|
|
13
|
+
from tensorflow.python.platform import gfile
|
|
14
|
+
|
|
15
|
+
from easy_rec.python.inference.predictor import Predictor
|
|
16
|
+
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
|
|
17
|
+
from easy_rec.python.utils import tf_utils
|
|
18
|
+
from easy_rec.python.utils.hive_utils import HiveUtils
|
|
19
|
+
from easy_rec.python.utils.tf_utils import get_tf_type
|
|
20
|
+
|
|
21
|
+
if tf.__version__ >= '2.0':
|
|
22
|
+
tf = tf.compat.v1
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class HiveParquetPredictor(Predictor):
|
|
26
|
+
|
|
27
|
+
def __init__(self,
|
|
28
|
+
model_path,
|
|
29
|
+
data_config,
|
|
30
|
+
hive_config,
|
|
31
|
+
fg_json_path=None,
|
|
32
|
+
profiling_file=None,
|
|
33
|
+
output_sep=chr(1),
|
|
34
|
+
all_cols=None,
|
|
35
|
+
all_col_types=None):
|
|
36
|
+
super(HiveParquetPredictor, self).__init__(model_path, profiling_file,
|
|
37
|
+
fg_json_path)
|
|
38
|
+
|
|
39
|
+
self._data_config = data_config
|
|
40
|
+
self._hive_config = hive_config
|
|
41
|
+
self._output_sep = output_sep
|
|
42
|
+
input_type = DatasetConfig.InputType.Name(data_config.input_type).lower()
|
|
43
|
+
if 'rtp' in input_type:
|
|
44
|
+
self._is_rtp = True
|
|
45
|
+
else:
|
|
46
|
+
self._is_rtp = False
|
|
47
|
+
self._all_cols = [x.strip() for x in all_cols if x != '']
|
|
48
|
+
self._all_col_types = [x.strip() for x in all_col_types if x != '']
|
|
49
|
+
self._record_defaults = [
|
|
50
|
+
self._get_defaults(col_name, col_type)
|
|
51
|
+
for col_name, col_type in zip(self._all_cols, self._all_col_types)
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
def _get_reserved_cols(self, reserved_cols):
|
|
55
|
+
if reserved_cols == 'ALL_COLUMNS':
|
|
56
|
+
reserved_cols = self._all_cols
|
|
57
|
+
else:
|
|
58
|
+
reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
|
|
59
|
+
return reserved_cols
|
|
60
|
+
|
|
61
|
+
def _parse_line(self, *fields):
|
|
62
|
+
fields = list(fields)
|
|
63
|
+
field_dict = {self._all_cols[i]: fields[i] for i in range(len(fields))}
|
|
64
|
+
return field_dict
|
|
65
|
+
|
|
66
|
+
def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
|
|
67
|
+
slice_id):
|
|
68
|
+
self._hive_util = HiveUtils(
|
|
69
|
+
data_config=self._data_config, hive_config=self._hive_config)
|
|
70
|
+
hdfs_path = self._hive_util.get_table_location(input_path)
|
|
71
|
+
self._input_hdfs_path = gfile.Glob(os.path.join(hdfs_path, '*'))
|
|
72
|
+
assert len(self._input_hdfs_path) > 0, 'match no files with %s' % input_path
|
|
73
|
+
|
|
74
|
+
list_type = []
|
|
75
|
+
input_field_type_map = {
|
|
76
|
+
x.input_name: x.input_type for x in self._data_config.input_fields
|
|
77
|
+
}
|
|
78
|
+
type_2_tftype = {
|
|
79
|
+
'string': tf.string,
|
|
80
|
+
'double': tf.double,
|
|
81
|
+
'float': tf.float32,
|
|
82
|
+
'bigint': tf.int32,
|
|
83
|
+
'boolean': tf.bool
|
|
84
|
+
}
|
|
85
|
+
for col_name, col_type in zip(self._all_cols, self._all_col_types):
|
|
86
|
+
if col_name in input_field_type_map:
|
|
87
|
+
list_type.append(get_tf_type(input_field_type_map[col_name]))
|
|
88
|
+
else:
|
|
89
|
+
list_type.append(type_2_tftype[col_type.lower()])
|
|
90
|
+
list_type = tuple(list_type)
|
|
91
|
+
list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))]
|
|
92
|
+
list_shapes = tuple(list_shapes)
|
|
93
|
+
|
|
94
|
+
def parquet_read():
|
|
95
|
+
for input_path in self._input_hdfs_path:
|
|
96
|
+
if input_path.endswith('SUCCESS'):
|
|
97
|
+
continue
|
|
98
|
+
df = pd.read_parquet(input_path, engine='pyarrow')
|
|
99
|
+
|
|
100
|
+
df.replace('', np.nan, inplace=True)
|
|
101
|
+
df.replace('NULL', np.nan, inplace=True)
|
|
102
|
+
total_records_num = len(df)
|
|
103
|
+
|
|
104
|
+
for k, v in zip(self._all_cols, self._record_defaults):
|
|
105
|
+
df[k].fillna(v, inplace=True)
|
|
106
|
+
|
|
107
|
+
for start_idx in range(0, total_records_num, batch_size):
|
|
108
|
+
end_idx = min(total_records_num, start_idx + batch_size)
|
|
109
|
+
batch_data = df[start_idx:end_idx]
|
|
110
|
+
inputs = []
|
|
111
|
+
for k in self._all_cols:
|
|
112
|
+
inputs.append(batch_data[k].to_numpy())
|
|
113
|
+
yield tuple(inputs)
|
|
114
|
+
|
|
115
|
+
dataset = tf.data.Dataset.from_generator(
|
|
116
|
+
parquet_read, output_types=list_type, output_shapes=list_shapes)
|
|
117
|
+
dataset = dataset.shard(slice_num, slice_id)
|
|
118
|
+
dataset = dataset.prefetch(buffer_size=64)
|
|
119
|
+
return dataset
|
|
120
|
+
|
|
121
|
+
def get_table_info(self, output_path):
|
|
122
|
+
partition_name, partition_val = None, None
|
|
123
|
+
if len(output_path.split('/')) == 2:
|
|
124
|
+
table_name, partition = output_path.split('/')
|
|
125
|
+
partition_name, partition_val = partition.split('=')
|
|
126
|
+
else:
|
|
127
|
+
table_name = output_path
|
|
128
|
+
return table_name, partition_name, partition_val
|
|
129
|
+
|
|
130
|
+
def _get_writer(self, output_path, slice_id):
|
|
131
|
+
table_name, partition_name, partition_val = self.get_table_info(output_path)
|
|
132
|
+
is_exist = self._hive_util.is_table_or_partition_exist(
|
|
133
|
+
table_name, partition_name, partition_val)
|
|
134
|
+
assert not is_exist, '%s is already exists. Please drop it.' % output_path
|
|
135
|
+
|
|
136
|
+
output_path = output_path.replace('.', '/')
|
|
137
|
+
self._hdfs_path = 'hdfs://%s:9000/user/easy_rec/%s_tmp' % (
|
|
138
|
+
self._hive_config.host, output_path)
|
|
139
|
+
if not gfile.Exists(self._hdfs_path):
|
|
140
|
+
gfile.MakeDirs(self._hdfs_path)
|
|
141
|
+
res_path = os.path.join(self._hdfs_path, 'part-%d.csv' % slice_id)
|
|
142
|
+
table_writer = gfile.GFile(res_path, 'w')
|
|
143
|
+
return table_writer
|
|
144
|
+
|
|
145
|
+
def _write_lines(self, table_writer, outputs):
|
|
146
|
+
outputs = '\n'.join(
|
|
147
|
+
[self._output_sep.join([str(i) for i in output]) for output in outputs])
|
|
148
|
+
table_writer.write(outputs + '\n')
|
|
149
|
+
|
|
150
|
+
def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
|
|
151
|
+
reserve_vals = [outputs[x] for x in output_cols] + \
|
|
152
|
+
[all_vals[k] for k in reserved_cols]
|
|
153
|
+
return reserve_vals
|
|
154
|
+
|
|
155
|
+
def load_to_table(self, output_path, slice_num, slice_id):
|
|
156
|
+
res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % slice_id)
|
|
157
|
+
success_writer = gfile.GFile(res_path, 'w')
|
|
158
|
+
success_writer.write('')
|
|
159
|
+
success_writer.close()
|
|
160
|
+
|
|
161
|
+
if slice_id != 0:
|
|
162
|
+
return
|
|
163
|
+
|
|
164
|
+
for id in range(slice_num):
|
|
165
|
+
res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % id)
|
|
166
|
+
while not gfile.Exists(res_path):
|
|
167
|
+
time.sleep(10)
|
|
168
|
+
|
|
169
|
+
table_name, partition_name, partition_val = self.get_table_info(output_path)
|
|
170
|
+
schema = ''
|
|
171
|
+
for output_col_name in self._output_cols:
|
|
172
|
+
tf_type = self._predictor_impl._outputs_map[output_col_name].dtype
|
|
173
|
+
col_type = tf_utils.get_col_type(tf_type)
|
|
174
|
+
schema += output_col_name + ' ' + col_type + ','
|
|
175
|
+
|
|
176
|
+
for output_col_name in self._reserved_cols:
|
|
177
|
+
assert output_col_name in self._all_cols, 'Column: %s not exists.' % output_col_name
|
|
178
|
+
idx = self._all_cols.index(output_col_name)
|
|
179
|
+
output_col_types = self._all_col_types[idx]
|
|
180
|
+
schema += output_col_name + ' ' + output_col_types + ','
|
|
181
|
+
schema = schema.rstrip(',')
|
|
182
|
+
|
|
183
|
+
if partition_name and partition_val:
|
|
184
|
+
sql = 'create table if not exists %s (%s) PARTITIONED BY (%s string)' % \
|
|
185
|
+
(table_name, schema, partition_name)
|
|
186
|
+
self._hive_util.run_sql(sql)
|
|
187
|
+
sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s PARTITION (%s=%s)" % \
|
|
188
|
+
(self._hdfs_path, table_name, partition_name, partition_val)
|
|
189
|
+
self._hive_util.run_sql(sql)
|
|
190
|
+
else:
|
|
191
|
+
sql = 'create table if not exists %s (%s)' % \
|
|
192
|
+
(table_name, schema)
|
|
193
|
+
self._hive_util.run_sql(sql)
|
|
194
|
+
sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s" % \
|
|
195
|
+
(self._hdfs_path, table_name)
|
|
196
|
+
self._hive_util.run_sql(sql)
|
|
197
|
+
|
|
198
|
+
@property
|
|
199
|
+
def out_of_range_exception(self):
|
|
200
|
+
return (tf.errors.OutOfRangeError)
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
from __future__ import absolute_import
|
|
4
|
+
from __future__ import division
|
|
5
|
+
from __future__ import print_function
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import time
|
|
9
|
+
|
|
10
|
+
import tensorflow as tf
|
|
11
|
+
from tensorflow.python.platform import gfile
|
|
12
|
+
|
|
13
|
+
from easy_rec.python.inference.predictor import Predictor
|
|
14
|
+
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
|
|
15
|
+
from easy_rec.python.utils import tf_utils
|
|
16
|
+
from easy_rec.python.utils.hive_utils import HiveUtils
|
|
17
|
+
|
|
18
|
+
if tf.__version__ >= '2.0':
|
|
19
|
+
tf = tf.compat.v1
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class HivePredictor(Predictor):
|
|
23
|
+
|
|
24
|
+
def __init__(self,
|
|
25
|
+
model_path,
|
|
26
|
+
data_config,
|
|
27
|
+
hive_config,
|
|
28
|
+
fg_json_path=None,
|
|
29
|
+
profiling_file=None,
|
|
30
|
+
output_sep=chr(1),
|
|
31
|
+
all_cols=None,
|
|
32
|
+
all_col_types=None):
|
|
33
|
+
super(HivePredictor, self).__init__(model_path, profiling_file,
|
|
34
|
+
fg_json_path)
|
|
35
|
+
|
|
36
|
+
self._data_config = data_config
|
|
37
|
+
self._hive_config = hive_config
|
|
38
|
+
self._output_sep = output_sep
|
|
39
|
+
input_type = DatasetConfig.InputType.Name(data_config.input_type).lower()
|
|
40
|
+
if 'rtp' in input_type:
|
|
41
|
+
self._is_rtp = True
|
|
42
|
+
else:
|
|
43
|
+
self._is_rtp = False
|
|
44
|
+
self._all_cols = [x.strip() for x in all_cols if x != '']
|
|
45
|
+
self._all_col_types = [x.strip() for x in all_col_types if x != '']
|
|
46
|
+
self._record_defaults = [
|
|
47
|
+
self._get_defaults(col_name, col_type)
|
|
48
|
+
for col_name, col_type in zip(self._all_cols, self._all_col_types)
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
def _get_reserved_cols(self, reserved_cols):
|
|
52
|
+
if reserved_cols == 'ALL_COLUMNS':
|
|
53
|
+
reserved_cols = self._all_cols
|
|
54
|
+
else:
|
|
55
|
+
reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
|
|
56
|
+
return reserved_cols
|
|
57
|
+
|
|
58
|
+
def _parse_line(self, line):
|
|
59
|
+
field_delim = self._data_config.rtp_separator if self._is_rtp else self._data_config.separator
|
|
60
|
+
fields = tf.decode_csv(
|
|
61
|
+
line,
|
|
62
|
+
field_delim=field_delim,
|
|
63
|
+
record_defaults=self._record_defaults,
|
|
64
|
+
name='decode_csv')
|
|
65
|
+
inputs = {self._all_cols[x]: fields[x] for x in range(len(fields))}
|
|
66
|
+
return inputs
|
|
67
|
+
|
|
68
|
+
def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
|
|
69
|
+
slice_id):
|
|
70
|
+
self._hive_util = HiveUtils(
|
|
71
|
+
data_config=self._data_config, hive_config=self._hive_config)
|
|
72
|
+
self._input_hdfs_path = self._hive_util.get_table_location(input_path)
|
|
73
|
+
file_paths = tf.gfile.Glob(os.path.join(self._input_hdfs_path, '*'))
|
|
74
|
+
assert len(file_paths) > 0, 'match no files with %s' % input_path
|
|
75
|
+
|
|
76
|
+
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
|
|
77
|
+
parallel_num = min(num_parallel_calls, len(file_paths))
|
|
78
|
+
dataset = dataset.interleave(
|
|
79
|
+
tf.data.TextLineDataset,
|
|
80
|
+
cycle_length=parallel_num,
|
|
81
|
+
num_parallel_calls=parallel_num)
|
|
82
|
+
dataset = dataset.shard(slice_num, slice_id)
|
|
83
|
+
dataset = dataset.batch(batch_size)
|
|
84
|
+
dataset = dataset.prefetch(buffer_size=64)
|
|
85
|
+
return dataset
|
|
86
|
+
|
|
87
|
+
def get_table_info(self, output_path):
|
|
88
|
+
partition_name, partition_val = None, None
|
|
89
|
+
if len(output_path.split('/')) == 2:
|
|
90
|
+
table_name, partition = output_path.split('/')
|
|
91
|
+
partition_name, partition_val = partition.split('=')
|
|
92
|
+
else:
|
|
93
|
+
table_name = output_path
|
|
94
|
+
return table_name, partition_name, partition_val
|
|
95
|
+
|
|
96
|
+
def _get_writer(self, output_path, slice_id):
|
|
97
|
+
table_name, partition_name, partition_val = self.get_table_info(output_path)
|
|
98
|
+
is_exist = self._hive_util.is_table_or_partition_exist(
|
|
99
|
+
table_name, partition_name, partition_val)
|
|
100
|
+
assert not is_exist, '%s is already exists. Please drop it.' % output_path
|
|
101
|
+
|
|
102
|
+
output_path = output_path.replace('.', '/')
|
|
103
|
+
self._hdfs_path = 'hdfs://%s:9000/user/easy_rec/%s_tmp' % (
|
|
104
|
+
self._hive_config.host, output_path)
|
|
105
|
+
if not gfile.Exists(self._hdfs_path):
|
|
106
|
+
gfile.MakeDirs(self._hdfs_path)
|
|
107
|
+
res_path = os.path.join(self._hdfs_path, 'part-%d.csv' % slice_id)
|
|
108
|
+
table_writer = gfile.GFile(res_path, 'w')
|
|
109
|
+
return table_writer
|
|
110
|
+
|
|
111
|
+
def _write_lines(self, table_writer, outputs):
|
|
112
|
+
outputs = '\n'.join(
|
|
113
|
+
[self._output_sep.join([str(i) for i in output]) for output in outputs])
|
|
114
|
+
table_writer.write(outputs + '\n')
|
|
115
|
+
|
|
116
|
+
def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
|
|
117
|
+
reserve_vals = [outputs[x] for x in output_cols] + \
|
|
118
|
+
[all_vals[k] for k in reserved_cols]
|
|
119
|
+
return reserve_vals
|
|
120
|
+
|
|
121
|
+
def load_to_table(self, output_path, slice_num, slice_id):
|
|
122
|
+
res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % slice_id)
|
|
123
|
+
success_writer = gfile.GFile(res_path, 'w')
|
|
124
|
+
success_writer.write('')
|
|
125
|
+
success_writer.close()
|
|
126
|
+
|
|
127
|
+
if slice_id != 0:
|
|
128
|
+
return
|
|
129
|
+
|
|
130
|
+
for id in range(slice_num):
|
|
131
|
+
res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % id)
|
|
132
|
+
while not gfile.Exists(res_path):
|
|
133
|
+
time.sleep(10)
|
|
134
|
+
|
|
135
|
+
table_name, partition_name, partition_val = self.get_table_info(output_path)
|
|
136
|
+
schema = ''
|
|
137
|
+
for output_col_name in self._output_cols:
|
|
138
|
+
tf_type = self._predictor_impl._outputs_map[output_col_name].dtype
|
|
139
|
+
col_type = tf_utils.get_col_type(tf_type)
|
|
140
|
+
schema += output_col_name + ' ' + col_type + ','
|
|
141
|
+
|
|
142
|
+
for output_col_name in self._reserved_cols:
|
|
143
|
+
assert output_col_name in self._all_cols, 'Column: %s not exists.' % output_col_name
|
|
144
|
+
idx = self._all_cols.index(output_col_name)
|
|
145
|
+
output_col_types = self._all_col_types[idx]
|
|
146
|
+
schema += output_col_name + ' ' + output_col_types + ','
|
|
147
|
+
schema = schema.rstrip(',')
|
|
148
|
+
|
|
149
|
+
if partition_name and partition_val:
|
|
150
|
+
sql = 'create table if not exists %s (%s) PARTITIONED BY (%s string)' % \
|
|
151
|
+
(table_name, schema, partition_name)
|
|
152
|
+
self._hive_util.run_sql(sql)
|
|
153
|
+
sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s PARTITION (%s=%s)" % \
|
|
154
|
+
(self._hdfs_path, table_name, partition_name, partition_val)
|
|
155
|
+
self._hive_util.run_sql(sql)
|
|
156
|
+
else:
|
|
157
|
+
sql = 'create table if not exists %s (%s)' % \
|
|
158
|
+
(table_name, schema)
|
|
159
|
+
self._hive_util.run_sql(sql)
|
|
160
|
+
sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s" % \
|
|
161
|
+
(self._hdfs_path, table_name)
|
|
162
|
+
self._hive_util.run_sql(sql)
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def out_of_range_exception(self):
|
|
166
|
+
return (tf.errors.OutOfRangeError)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
from __future__ import absolute_import
|
|
4
|
+
from __future__ import division
|
|
5
|
+
from __future__ import print_function
|
|
6
|
+
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
|
|
9
|
+
from easy_rec.python.inference.predictor import Predictor
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ODPSPredictor(Predictor):
|
|
13
|
+
|
|
14
|
+
def __init__(self,
|
|
15
|
+
model_path,
|
|
16
|
+
fg_json_path=None,
|
|
17
|
+
profiling_file=None,
|
|
18
|
+
all_cols='',
|
|
19
|
+
all_col_types=''):
|
|
20
|
+
super(ODPSPredictor, self).__init__(model_path, profiling_file,
|
|
21
|
+
fg_json_path)
|
|
22
|
+
self._all_cols = [x.strip() for x in all_cols.split(',') if x != '']
|
|
23
|
+
self._all_col_types = [
|
|
24
|
+
x.strip() for x in all_col_types.split(',') if x != ''
|
|
25
|
+
]
|
|
26
|
+
self._record_defaults = [
|
|
27
|
+
self._get_defaults(col_name, col_type)
|
|
28
|
+
for col_name, col_type in zip(self._all_cols, self._all_col_types)
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
def _get_reserved_cols(self, reserved_cols):
|
|
32
|
+
reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
|
|
33
|
+
return reserved_cols
|
|
34
|
+
|
|
35
|
+
def _parse_line(self, *fields):
|
|
36
|
+
fields = list(fields)
|
|
37
|
+
field_dict = {self._all_cols[i]: fields[i] for i in range(len(fields))}
|
|
38
|
+
return field_dict
|
|
39
|
+
|
|
40
|
+
def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
|
|
41
|
+
slice_id):
|
|
42
|
+
input_list = input_path.split(',')
|
|
43
|
+
dataset = tf.data.TableRecordDataset(
|
|
44
|
+
input_list,
|
|
45
|
+
record_defaults=self._record_defaults,
|
|
46
|
+
slice_id=slice_id,
|
|
47
|
+
slice_count=slice_num,
|
|
48
|
+
selected_cols=','.join(self._all_cols))
|
|
49
|
+
dataset = dataset.batch(batch_size)
|
|
50
|
+
dataset = dataset.prefetch(buffer_size=64)
|
|
51
|
+
return dataset
|
|
52
|
+
|
|
53
|
+
def _get_writer(self, output_path, slice_id):
|
|
54
|
+
import common_io
|
|
55
|
+
table_writer = common_io.table.TableWriter(output_path, slice_id=slice_id)
|
|
56
|
+
return table_writer
|
|
57
|
+
|
|
58
|
+
def _write_lines(self, table_writer, outputs):
|
|
59
|
+
assert len(outputs) > 0
|
|
60
|
+
indices = list(range(0, len(outputs[0])))
|
|
61
|
+
table_writer.write(outputs, indices, allow_type_cast=False)
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def out_of_range_exception(self):
|
|
65
|
+
return (tf.python_io.OutOfRangeException, tf.errors.OutOfRangeError)
|
|
66
|
+
|
|
67
|
+
def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
|
|
68
|
+
reserve_vals = [all_vals[k] for k in reserved_cols] + \
|
|
69
|
+
[outputs[x] for x in output_cols]
|
|
70
|
+
return reserve_vals
|