easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Kafka Dataset."""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
import traceback
|
|
19
|
+
|
|
20
|
+
from tensorflow.python.data.ops import dataset_ops
|
|
21
|
+
from tensorflow.python.framework import dtypes
|
|
22
|
+
from tensorflow.python.framework import ops
|
|
23
|
+
from tensorflow.python.framework import tensor_shape
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
from easy_rec.python.ops import gen_kafka_ops
|
|
27
|
+
except ImportError:
|
|
28
|
+
logging.warning('failed to import gen_kafka_ops: %s' % traceback.format_exc())
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class KafkaDataset(dataset_ops.Dataset):
|
|
32
|
+
"""A Kafka Dataset that consumes the message."""
|
|
33
|
+
|
|
34
|
+
def __init__(self,
|
|
35
|
+
topics,
|
|
36
|
+
servers='localhost',
|
|
37
|
+
group='',
|
|
38
|
+
eof=False,
|
|
39
|
+
timeout=1000,
|
|
40
|
+
config_global=None,
|
|
41
|
+
config_topic=None,
|
|
42
|
+
message_key=False,
|
|
43
|
+
message_offset=False):
|
|
44
|
+
"""Create a KafkaReader.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
topics: A `tf.string` tensor containing one or more subscriptions,
|
|
48
|
+
in the format of [topic:partition:offset:length],
|
|
49
|
+
by default length is -1 for unlimited.
|
|
50
|
+
servers: A list of bootstrap servers.
|
|
51
|
+
group: The consumer group id.
|
|
52
|
+
eof: If True, the kafka reader will stop on EOF.
|
|
53
|
+
timeout: The timeout value for the Kafka Consumer to wait
|
|
54
|
+
(in millisecond).
|
|
55
|
+
config_global: A `tf.string` tensor containing global configuration
|
|
56
|
+
properties in [Key=Value] format,
|
|
57
|
+
eg. ["enable.auto.commit=false",
|
|
58
|
+
"heartbeat.interval.ms=2000"],
|
|
59
|
+
please refer to 'Global configuration properties'
|
|
60
|
+
in librdkafka doc.
|
|
61
|
+
config_topic: A `tf.string` tensor containing topic configuration
|
|
62
|
+
properties in [Key=Value] format,
|
|
63
|
+
eg. ["auto.offset.reset=earliest"],
|
|
64
|
+
please refer to 'Topic configuration properties'
|
|
65
|
+
in librdkafka doc.
|
|
66
|
+
message_key: If True, the kafka will output both message value and key.
|
|
67
|
+
message_offset: If True, the kafka will output both message value and offset.
|
|
68
|
+
"""
|
|
69
|
+
self._topics = ops.convert_to_tensor(
|
|
70
|
+
topics, dtype=dtypes.string, name='topics')
|
|
71
|
+
self._servers = ops.convert_to_tensor(
|
|
72
|
+
servers, dtype=dtypes.string, name='servers')
|
|
73
|
+
self._group = ops.convert_to_tensor(
|
|
74
|
+
group, dtype=dtypes.string, name='group')
|
|
75
|
+
self._eof = ops.convert_to_tensor(eof, dtype=dtypes.bool, name='eof')
|
|
76
|
+
self._timeout = ops.convert_to_tensor(
|
|
77
|
+
timeout, dtype=dtypes.int64, name='timeout')
|
|
78
|
+
config_global = config_global if config_global else []
|
|
79
|
+
self._config_global = ops.convert_to_tensor(
|
|
80
|
+
config_global, dtype=dtypes.string, name='config_global')
|
|
81
|
+
config_topic = config_topic if config_topic else []
|
|
82
|
+
self._config_topic = ops.convert_to_tensor(
|
|
83
|
+
config_topic, dtype=dtypes.string, name='config_topic')
|
|
84
|
+
self._message_key = message_key
|
|
85
|
+
self._message_offset = message_offset
|
|
86
|
+
super(KafkaDataset, self).__init__()
|
|
87
|
+
|
|
88
|
+
def _inputs(self):
|
|
89
|
+
return []
|
|
90
|
+
|
|
91
|
+
def _as_variant_tensor(self):
|
|
92
|
+
return gen_kafka_ops.io_kafka_dataset_v2(
|
|
93
|
+
self._topics,
|
|
94
|
+
self._servers,
|
|
95
|
+
self._group,
|
|
96
|
+
self._eof,
|
|
97
|
+
self._timeout,
|
|
98
|
+
self._config_global,
|
|
99
|
+
self._config_topic,
|
|
100
|
+
self._message_key,
|
|
101
|
+
self._message_offset,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def output_classes(self):
|
|
106
|
+
if self._message_key ^ self._message_offset:
|
|
107
|
+
return (ops.Tensor, ops.Tensor)
|
|
108
|
+
elif self._message_key and self._message_offset:
|
|
109
|
+
return (ops.Tensor, ops.Tensor, ops.Tensor)
|
|
110
|
+
return (ops.Tensor)
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def output_shapes(self):
|
|
114
|
+
if self._message_key ^ self._message_offset:
|
|
115
|
+
return ((tensor_shape.TensorShape([]), tensor_shape.TensorShape([])))
|
|
116
|
+
elif self._message_key and self._message_offset:
|
|
117
|
+
return ((tensor_shape.TensorShape([]), tensor_shape.TensorShape([]),
|
|
118
|
+
tensor_shape.TensorShape([])))
|
|
119
|
+
return ((tensor_shape.TensorShape([])))
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def output_types(self):
|
|
123
|
+
if self._message_key ^ self._message_offset:
|
|
124
|
+
return ((dtypes.string, dtypes.string))
|
|
125
|
+
elif self._message_key and self._message_offset:
|
|
126
|
+
return ((dtypes.string, dtypes.string, dtypes.string))
|
|
127
|
+
return ((dtypes.string))
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def write_kafka_v2(message, topic, servers='localhost', name=None):
|
|
131
|
+
"""Write kafka.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
message: A `Tensor` of type `string`. 0-D.
|
|
135
|
+
topic: A `tf.string` tensor containing one subscription,
|
|
136
|
+
in the format of topic:partition.
|
|
137
|
+
servers: A list of bootstrap servers.
|
|
138
|
+
name: A name for the operation (optional).
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
A `Tensor` of type `string`. 0-D.
|
|
142
|
+
"""
|
|
143
|
+
return gen_kafka_ops.io_write_kafka_v2(
|
|
144
|
+
message=message, topic=topic, servers=servers, name=name)
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import traceback
|
|
6
|
+
|
|
7
|
+
import six
|
|
8
|
+
import tensorflow as tf
|
|
9
|
+
|
|
10
|
+
from easy_rec.python.input.input import Input
|
|
11
|
+
from easy_rec.python.input.kafka_dataset import KafkaDataset
|
|
12
|
+
from easy_rec.python.utils.config_util import parse_time
|
|
13
|
+
|
|
14
|
+
if tf.__version__.startswith('1.'):
|
|
15
|
+
from tensorflow.python.platform import gfile
|
|
16
|
+
else:
|
|
17
|
+
import tensorflow.io.gfile as gfile
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
from kafka import KafkaConsumer, TopicPartition
|
|
21
|
+
except ImportError:
|
|
22
|
+
logging.warning(
|
|
23
|
+
'kafka-python is not installed[%s]. You can install it by: pip install kafka-python'
|
|
24
|
+
% traceback.format_exc())
|
|
25
|
+
|
|
26
|
+
if tf.__version__ >= '2.0':
|
|
27
|
+
ignore_errors = tf.data.experimental.ignore_errors()
|
|
28
|
+
tf = tf.compat.v1
|
|
29
|
+
else:
|
|
30
|
+
ignore_errors = tf.contrib.data.ignore_errors()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class KafkaInput(Input):
|
|
34
|
+
|
|
35
|
+
DATA_OFFSET = 'DATA_OFFSET'
|
|
36
|
+
|
|
37
|
+
def __init__(self,
|
|
38
|
+
data_config,
|
|
39
|
+
feature_config,
|
|
40
|
+
kafka_config,
|
|
41
|
+
task_index=0,
|
|
42
|
+
task_num=1,
|
|
43
|
+
check_mode=False,
|
|
44
|
+
pipeline_config=None):
|
|
45
|
+
super(KafkaInput,
|
|
46
|
+
self).__init__(data_config, feature_config, '', task_index, task_num,
|
|
47
|
+
check_mode, pipeline_config)
|
|
48
|
+
self._kafka = kafka_config
|
|
49
|
+
self._offset_dict = {}
|
|
50
|
+
if self._kafka is not None:
|
|
51
|
+
consumer = KafkaConsumer(
|
|
52
|
+
group_id='kafka_dataset_consumer',
|
|
53
|
+
bootstrap_servers=[self._kafka.server],
|
|
54
|
+
api_version_auto_timeout_ms=60000) # in miliseconds
|
|
55
|
+
partitions = consumer.partitions_for_topic(self._kafka.topic)
|
|
56
|
+
self._num_partition = len(partitions)
|
|
57
|
+
logging.info('all partitions[%d]: %s' % (self._num_partition, partitions))
|
|
58
|
+
|
|
59
|
+
# determine kafka offsets for each partition
|
|
60
|
+
offset_type = self._kafka.WhichOneof('offset')
|
|
61
|
+
if offset_type is not None:
|
|
62
|
+
if offset_type == 'offset_time':
|
|
63
|
+
ts = parse_time(self._kafka.offset_time)
|
|
64
|
+
input_map = {
|
|
65
|
+
TopicPartition(partition=part_id, topic=self._kafka.topic):
|
|
66
|
+
ts * 1000 for part_id in partitions
|
|
67
|
+
}
|
|
68
|
+
part_offsets = consumer.offsets_for_times(input_map)
|
|
69
|
+
# part_offsets is a dictionary:
|
|
70
|
+
# {
|
|
71
|
+
# TopicPartition(topic=u'kafka_data_20220408', partition=0):
|
|
72
|
+
# OffsetAndTimestamp(offset=2, timestamp=1650611437895)
|
|
73
|
+
# }
|
|
74
|
+
for part in part_offsets:
|
|
75
|
+
self._offset_dict[part.partition] = part_offsets[part].offset
|
|
76
|
+
logging.info(
|
|
77
|
+
'Find offset by time, topic[%s], partition[%d], timestamp[%ss], offset[%d], offset_timestamp[%dms]'
|
|
78
|
+
% (self._kafka.topic, part.partition, ts,
|
|
79
|
+
part_offsets[part].offset, part_offsets[part].timestamp))
|
|
80
|
+
elif offset_type == 'offset_info':
|
|
81
|
+
offset_dict = json.loads(self._kafka.offset_info)
|
|
82
|
+
for part in offset_dict:
|
|
83
|
+
part_id = int(part)
|
|
84
|
+
self._offset_dict[part_id] = offset_dict[part]
|
|
85
|
+
else:
|
|
86
|
+
assert 'invalid offset_type: %s' % offset_type
|
|
87
|
+
self._task_offset_dict = {}
|
|
88
|
+
|
|
89
|
+
def _preprocess(self, field_dict):
|
|
90
|
+
output_dict = super(KafkaInput, self)._preprocess(field_dict)
|
|
91
|
+
|
|
92
|
+
# append offset fields
|
|
93
|
+
if Input.DATA_OFFSET in field_dict:
|
|
94
|
+
output_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]
|
|
95
|
+
|
|
96
|
+
# for _get_features to include DATA_OFFSET
|
|
97
|
+
if Input.DATA_OFFSET not in self._appended_fields:
|
|
98
|
+
self._appended_fields.append(Input.DATA_OFFSET)
|
|
99
|
+
|
|
100
|
+
return output_dict
|
|
101
|
+
|
|
102
|
+
def _parse_csv(self, line, message_key, message_offset):
|
|
103
|
+
record_defaults = [
|
|
104
|
+
self.get_type_defaults(t, v)
|
|
105
|
+
for t, v in zip(self._input_field_types, self._input_field_defaults)
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
fields = tf.decode_csv(
|
|
109
|
+
line,
|
|
110
|
+
use_quote_delim=False,
|
|
111
|
+
field_delim=self._data_config.separator,
|
|
112
|
+
record_defaults=record_defaults,
|
|
113
|
+
name='decode_csv')
|
|
114
|
+
|
|
115
|
+
inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
|
|
116
|
+
|
|
117
|
+
for x in self._label_fids:
|
|
118
|
+
inputs[self._input_fields[x]] = fields[x]
|
|
119
|
+
|
|
120
|
+
# record current offset
|
|
121
|
+
def _parse_offset(message_offset):
|
|
122
|
+
for kv in message_offset:
|
|
123
|
+
if six.PY3:
|
|
124
|
+
kv = kv.decode('utf-8')
|
|
125
|
+
k, v = kv.split(':')
|
|
126
|
+
k = int(k)
|
|
127
|
+
v = int(v)
|
|
128
|
+
if k not in self._task_offset_dict or v > self._task_offset_dict[k]:
|
|
129
|
+
self._task_offset_dict[k] = v
|
|
130
|
+
return json.dumps(self._task_offset_dict)
|
|
131
|
+
|
|
132
|
+
inputs[Input.DATA_OFFSET] = tf.py_func(_parse_offset, [message_offset],
|
|
133
|
+
tf.string)
|
|
134
|
+
return inputs
|
|
135
|
+
|
|
136
|
+
def restore(self, checkpoint_path):
|
|
137
|
+
if checkpoint_path is None:
|
|
138
|
+
return
|
|
139
|
+
|
|
140
|
+
offset_path = checkpoint_path + '.offset'
|
|
141
|
+
if not gfile.Exists(offset_path):
|
|
142
|
+
return
|
|
143
|
+
|
|
144
|
+
logging.info('will restore kafka offset from %s' % offset_path)
|
|
145
|
+
with gfile.GFile(offset_path, 'r') as fin:
|
|
146
|
+
offset_dict = json.load(fin)
|
|
147
|
+
self._offset_dict = {}
|
|
148
|
+
for k in offset_dict:
|
|
149
|
+
v = offset_dict[k]
|
|
150
|
+
k = int(k)
|
|
151
|
+
if k not in self._offset_dict or v > self._offset_dict[k]:
|
|
152
|
+
self._offset_dict[k] = v
|
|
153
|
+
|
|
154
|
+
def _get_topics(self):
|
|
155
|
+
task_num = self._task_num
|
|
156
|
+
task_index = self._task_index
|
|
157
|
+
if self._data_config.chief_redundant and self._mode == tf.estimator.ModeKeys.TRAIN:
|
|
158
|
+
task_index = max(task_index - 1, 0)
|
|
159
|
+
task_num = max(task_num - 1, 1)
|
|
160
|
+
|
|
161
|
+
topics = []
|
|
162
|
+
self._task_offset_dict = {}
|
|
163
|
+
for part_id in range(self._num_partition):
|
|
164
|
+
if (part_id % task_num) == task_index:
|
|
165
|
+
offset = self._offset_dict.get(part_id, 0)
|
|
166
|
+
topics.append('%s:%d:%d' % (self._kafka.topic, part_id, offset))
|
|
167
|
+
self._task_offset_dict[part_id] = offset
|
|
168
|
+
logging.info('assigned topic partitions: %s' % (','.join(topics)))
|
|
169
|
+
assert len(
|
|
170
|
+
topics) > 0, 'no partitions are assigned for this task(%d/%d)' % (
|
|
171
|
+
self._task_index, self._task_num)
|
|
172
|
+
return topics
|
|
173
|
+
|
|
174
|
+
def _build(self, mode, params):
|
|
175
|
+
num_parallel_calls = self._data_config.num_parallel_calls
|
|
176
|
+
task_topics = self._get_topics()
|
|
177
|
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
|
178
|
+
assert self._kafka is not None, 'kafka_train_input is not set.'
|
|
179
|
+
train_kafka = self._kafka
|
|
180
|
+
logging.info(
|
|
181
|
+
'train kafka server: %s topic: %s task_num: %d task_index: %d topics: %s'
|
|
182
|
+
% (train_kafka.server, train_kafka.topic, self._task_num,
|
|
183
|
+
self._task_index, task_topics))
|
|
184
|
+
|
|
185
|
+
dataset = KafkaDataset(
|
|
186
|
+
task_topics,
|
|
187
|
+
servers=train_kafka.server,
|
|
188
|
+
group=train_kafka.group,
|
|
189
|
+
eof=False,
|
|
190
|
+
config_global=list(self._kafka.config_global),
|
|
191
|
+
config_topic=list(self._kafka.config_topic),
|
|
192
|
+
message_key=True,
|
|
193
|
+
message_offset=True)
|
|
194
|
+
|
|
195
|
+
if self._data_config.shuffle:
|
|
196
|
+
dataset = dataset.shuffle(
|
|
197
|
+
self._data_config.shuffle_buffer_size,
|
|
198
|
+
seed=2020,
|
|
199
|
+
reshuffle_each_iteration=True)
|
|
200
|
+
else:
|
|
201
|
+
eval_kafka = self._kafka
|
|
202
|
+
assert self._kafka is not None, 'kafka_eval_input is not set.'
|
|
203
|
+
|
|
204
|
+
logging.info(
|
|
205
|
+
'eval kafka server: %s topic: %s task_num: %d task_index: %d topics: %s'
|
|
206
|
+
% (eval_kafka.server, eval_kafka.topic, self._task_num,
|
|
207
|
+
self._task_index, task_topics))
|
|
208
|
+
|
|
209
|
+
dataset = KafkaDataset(
|
|
210
|
+
task_topics,
|
|
211
|
+
servers=self._kafka.server,
|
|
212
|
+
group=eval_kafka.group,
|
|
213
|
+
eof=False,
|
|
214
|
+
config_global=list(self._kafka.config_global),
|
|
215
|
+
config_topic=list(self._kafka.config_topic),
|
|
216
|
+
message_key=True,
|
|
217
|
+
message_offset=True)
|
|
218
|
+
|
|
219
|
+
dataset = dataset.batch(self._data_config.batch_size)
|
|
220
|
+
dataset = dataset.map(
|
|
221
|
+
self._parse_csv, num_parallel_calls=num_parallel_calls)
|
|
222
|
+
if self._data_config.ignore_error:
|
|
223
|
+
dataset = dataset.apply(ignore_errors)
|
|
224
|
+
dataset = dataset.prefetch(buffer_size=self._prefetch_size)
|
|
225
|
+
dataset = dataset.map(
|
|
226
|
+
map_func=self._preprocess, num_parallel_calls=num_parallel_calls)
|
|
227
|
+
|
|
228
|
+
dataset = dataset.prefetch(buffer_size=self._prefetch_size)
|
|
229
|
+
|
|
230
|
+
if mode != tf.estimator.ModeKeys.PREDICT:
|
|
231
|
+
dataset = dataset.map(lambda x:
|
|
232
|
+
(self._get_features(x), self._get_labels(x)))
|
|
233
|
+
else:
|
|
234
|
+
dataset = dataset.map(lambda x: (self._get_features(x)))
|
|
235
|
+
return dataset
|
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import multiprocessing
|
|
3
|
+
import queue
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def start_data_proc(task_index,
|
|
10
|
+
task_num,
|
|
11
|
+
num_proc,
|
|
12
|
+
file_que,
|
|
13
|
+
data_que,
|
|
14
|
+
proc_start_que,
|
|
15
|
+
proc_stop_que,
|
|
16
|
+
batch_size,
|
|
17
|
+
label_fields,
|
|
18
|
+
sparse_fea_names,
|
|
19
|
+
dense_fea_names,
|
|
20
|
+
dense_fea_cfgs,
|
|
21
|
+
reserve_fields,
|
|
22
|
+
drop_remainder,
|
|
23
|
+
need_pack=True):
|
|
24
|
+
mp_ctxt = multiprocessing.get_context('spawn')
|
|
25
|
+
proc_arr = []
|
|
26
|
+
for proc_id in range(num_proc):
|
|
27
|
+
proc = mp_ctxt.Process(
|
|
28
|
+
target=load_data_proc,
|
|
29
|
+
args=(proc_id, file_que, data_que, proc_start_que, proc_stop_que,
|
|
30
|
+
batch_size, label_fields, sparse_fea_names, dense_fea_names,
|
|
31
|
+
dense_fea_cfgs, reserve_fields, drop_remainder, task_index,
|
|
32
|
+
task_num, need_pack),
|
|
33
|
+
name='task_%d_data_proc_%d' % (task_index, proc_id))
|
|
34
|
+
proc.daemon = True
|
|
35
|
+
proc.start()
|
|
36
|
+
proc_arr.append(proc)
|
|
37
|
+
return proc_arr
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _should_stop(proc_stop_que):
|
|
41
|
+
try:
|
|
42
|
+
proc_stop_que.get(block=False)
|
|
43
|
+
logging.info('data_proc stop signal received')
|
|
44
|
+
proc_stop_que.close()
|
|
45
|
+
return True
|
|
46
|
+
except queue.Empty:
|
|
47
|
+
return False
|
|
48
|
+
except ValueError:
|
|
49
|
+
return True
|
|
50
|
+
except AssertionError:
|
|
51
|
+
return True
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _add_to_que(data_dict, data_que, proc_stop_que):
|
|
55
|
+
while True:
|
|
56
|
+
try:
|
|
57
|
+
data_que.put(data_dict, timeout=5)
|
|
58
|
+
return True
|
|
59
|
+
except queue.Full:
|
|
60
|
+
logging.warning('data_que is full')
|
|
61
|
+
if _should_stop(proc_stop_que):
|
|
62
|
+
return False
|
|
63
|
+
except ValueError:
|
|
64
|
+
logging.warning('data_que is closed')
|
|
65
|
+
return False
|
|
66
|
+
except AssertionError:
|
|
67
|
+
logging.warning('data_que is closed')
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _get_one_file(file_que, proc_stop_que):
|
|
72
|
+
while True:
|
|
73
|
+
try:
|
|
74
|
+
input_file = file_que.get(timeout=1)
|
|
75
|
+
return input_file
|
|
76
|
+
except queue.Empty:
|
|
77
|
+
pass
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _pack_sparse_feas(data_dict, sparse_fea_names):
|
|
82
|
+
fea_val_arr = []
|
|
83
|
+
fea_len_arr = []
|
|
84
|
+
for fea_name in sparse_fea_names:
|
|
85
|
+
fea_len_arr.append(data_dict[fea_name][0])
|
|
86
|
+
fea_val_arr.append(data_dict[fea_name][1])
|
|
87
|
+
del data_dict[fea_name]
|
|
88
|
+
fea_lens = np.concatenate(fea_len_arr, axis=0)
|
|
89
|
+
fea_vals = np.concatenate(fea_val_arr, axis=0)
|
|
90
|
+
data_dict['sparse_fea'] = (fea_lens, fea_vals)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _pack_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs):
|
|
94
|
+
fea_val_arr = []
|
|
95
|
+
for fea_name, fea_cfg in zip(dense_fea_names, dense_fea_cfgs):
|
|
96
|
+
fea_val_arr.append(data_dict[fea_name].reshape([-1, fea_cfg.raw_input_dim]))
|
|
97
|
+
del data_dict[fea_name]
|
|
98
|
+
fea_vals = np.concatenate(fea_val_arr, axis=1)
|
|
99
|
+
data_dict['dense_fea'] = fea_vals
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _reshape_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs):
|
|
103
|
+
for fea_name, fea_cfg in zip(dense_fea_names, dense_fea_cfgs):
|
|
104
|
+
data_dict[fea_name] = data_dict[fea_name].reshape(
|
|
105
|
+
[-1, fea_cfg.raw_input_dim])
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _load_dense(input_data, field_names, sid, eid, dense_dict):
|
|
109
|
+
for k in field_names:
|
|
110
|
+
if isinstance(input_data[k][0], np.ndarray):
|
|
111
|
+
np_dtype = type(input_data[k][sid][0])
|
|
112
|
+
dense_dict[k] = np.array([x[0] for x in input_data[k][sid:eid]],
|
|
113
|
+
dtype=np_dtype)
|
|
114
|
+
else:
|
|
115
|
+
dense_dict[k] = input_data[k][sid:eid].to_numpy()
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _load_and_pad_dense(input_data, field_names, sid, dense_dict,
|
|
119
|
+
part_dense_dict, part_dense_dict_n, batch_size):
|
|
120
|
+
for k in field_names:
|
|
121
|
+
if isinstance(input_data[k][0], np.ndarray):
|
|
122
|
+
np_dtype = type(input_data[k][sid][0])
|
|
123
|
+
tmp_lbls = np.array([x[0] for x in input_data[k][sid:]], dtype=np_dtype)
|
|
124
|
+
else:
|
|
125
|
+
tmp_lbls = input_data[k][sid:].to_numpy()
|
|
126
|
+
if part_dense_dict is not None and k in part_dense_dict:
|
|
127
|
+
tmp_lbls = np.concatenate([part_dense_dict[k], tmp_lbls], axis=0)
|
|
128
|
+
if len(tmp_lbls) > batch_size:
|
|
129
|
+
dense_dict[k] = tmp_lbls[:batch_size]
|
|
130
|
+
part_dense_dict_n[k] = tmp_lbls[batch_size:]
|
|
131
|
+
elif len(tmp_lbls) == batch_size:
|
|
132
|
+
dense_dict[k] = tmp_lbls
|
|
133
|
+
else:
|
|
134
|
+
part_dense_dict_n[k] = tmp_lbls
|
|
135
|
+
else:
|
|
136
|
+
part_dense_dict_n[k] = tmp_lbls
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def load_data_proc(proc_id, file_que, data_que, proc_start_que, proc_stop_que,
|
|
140
|
+
batch_size, label_fields, sparse_fea_names, dense_fea_names,
|
|
141
|
+
dense_fea_cfgs, reserve_fields, drop_remainder, task_index,
|
|
142
|
+
task_num, need_pack):
|
|
143
|
+
logging.info('data proc %d start, proc_start_que=%s' %
|
|
144
|
+
(proc_id, proc_start_que.qsize()))
|
|
145
|
+
proc_start_que.get()
|
|
146
|
+
effective_fields = sparse_fea_names + dense_fea_names
|
|
147
|
+
all_fields = effective_fields
|
|
148
|
+
if label_fields is not None:
|
|
149
|
+
all_fields = all_fields + label_fields
|
|
150
|
+
if reserve_fields is not None:
|
|
151
|
+
for tmp in reserve_fields:
|
|
152
|
+
if tmp not in all_fields:
|
|
153
|
+
all_fields.append(tmp)
|
|
154
|
+
logging.info('data proc %d start, file_que.qsize=%d' %
|
|
155
|
+
(proc_id, file_que.qsize()))
|
|
156
|
+
num_files = 0
|
|
157
|
+
part_data_dict = {}
|
|
158
|
+
|
|
159
|
+
is_good = True
|
|
160
|
+
total_batch_cnt = 0
|
|
161
|
+
total_sample_cnt = 0
|
|
162
|
+
while is_good:
|
|
163
|
+
if _should_stop(proc_stop_que):
|
|
164
|
+
is_good = False
|
|
165
|
+
break
|
|
166
|
+
input_file = _get_one_file(file_que, proc_stop_que)
|
|
167
|
+
if input_file is None:
|
|
168
|
+
break
|
|
169
|
+
num_files += 1
|
|
170
|
+
input_data = pd.read_parquet(input_file, columns=all_fields)
|
|
171
|
+
data_len = len(input_data[all_fields[0]])
|
|
172
|
+
total_sample_cnt += data_len
|
|
173
|
+
batch_num = int(data_len / batch_size)
|
|
174
|
+
res_num = data_len % batch_size
|
|
175
|
+
|
|
176
|
+
sid = 0
|
|
177
|
+
for batch_id in range(batch_num):
|
|
178
|
+
eid = sid + batch_size
|
|
179
|
+
data_dict = {}
|
|
180
|
+
|
|
181
|
+
if label_fields is not None and len(label_fields) > 0:
|
|
182
|
+
_load_dense(input_data, label_fields, sid, eid, data_dict)
|
|
183
|
+
|
|
184
|
+
if reserve_fields is not None and len(reserve_fields) > 0:
|
|
185
|
+
data_dict['reserve'] = {}
|
|
186
|
+
_load_dense(input_data, reserve_fields, sid, eid, data_dict['reserve'])
|
|
187
|
+
|
|
188
|
+
if len(sparse_fea_names) > 0:
|
|
189
|
+
for k in sparse_fea_names:
|
|
190
|
+
val = input_data[k][sid:eid]
|
|
191
|
+
if isinstance(input_data[k][sid], np.ndarray):
|
|
192
|
+
all_lens = np.array([len(x) for x in val], dtype=np.int32)
|
|
193
|
+
all_vals = np.concatenate(val.to_numpy())
|
|
194
|
+
else:
|
|
195
|
+
all_lens = np.ones([len(val)], dtype=np.int32)
|
|
196
|
+
all_vals = val.to_numpy()
|
|
197
|
+
assert np.sum(all_lens) == len(
|
|
198
|
+
all_vals), 'len(all_vals)=%d np.sum(all_lens)=%d' % (
|
|
199
|
+
len(all_vals), np.sum(all_lens))
|
|
200
|
+
data_dict[k] = (all_lens, all_vals)
|
|
201
|
+
|
|
202
|
+
if len(dense_fea_names) > 0:
|
|
203
|
+
_load_dense(input_data, dense_fea_names, sid, eid, data_dict)
|
|
204
|
+
|
|
205
|
+
if need_pack:
|
|
206
|
+
if len(sparse_fea_names) > 0:
|
|
207
|
+
_pack_sparse_feas(data_dict, sparse_fea_names)
|
|
208
|
+
if len(dense_fea_names) > 0:
|
|
209
|
+
_pack_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs)
|
|
210
|
+
else:
|
|
211
|
+
if len(dense_fea_names) > 0:
|
|
212
|
+
_reshape_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs)
|
|
213
|
+
# logging.info('task_index=%d sid=%d eid=%d total_len=%d' % (task_index, sid, eid,
|
|
214
|
+
# len(data_dict['sparse_fea'][1])))
|
|
215
|
+
if not _add_to_que(data_dict, data_que, proc_stop_que):
|
|
216
|
+
logging.info('add to que failed')
|
|
217
|
+
is_good = False
|
|
218
|
+
break
|
|
219
|
+
total_batch_cnt += 1
|
|
220
|
+
sid += batch_size
|
|
221
|
+
|
|
222
|
+
if res_num > 0 and is_good:
|
|
223
|
+
data_dict = {}
|
|
224
|
+
part_data_dict_n = {}
|
|
225
|
+
|
|
226
|
+
if label_fields is not None and len(label_fields) > 0:
|
|
227
|
+
_load_and_pad_dense(input_data, label_fields, sid, data_dict,
|
|
228
|
+
part_data_dict, part_data_dict_n, batch_size)
|
|
229
|
+
|
|
230
|
+
if reserve_fields is not None and len(reserve_fields) > 0:
|
|
231
|
+
data_dict['reserve'] = {}
|
|
232
|
+
part_data_dict_n['reserve'] = {}
|
|
233
|
+
_load_and_pad_dense(input_data, label_fields, sid, data_dict['reserve'],
|
|
234
|
+
part_data_dict['reserve'],
|
|
235
|
+
part_data_dict_n['reserve'], batch_size)
|
|
236
|
+
|
|
237
|
+
if len(dense_fea_names) > 0:
|
|
238
|
+
_load_and_pad_dense(input_data, dense_fea_names, sid, data_dict,
|
|
239
|
+
part_data_dict, part_data_dict_n, batch_size)
|
|
240
|
+
|
|
241
|
+
if len(sparse_fea_names) > 0:
|
|
242
|
+
for k in sparse_fea_names:
|
|
243
|
+
val = input_data[k][sid:]
|
|
244
|
+
|
|
245
|
+
if isinstance(input_data[k][sid], np.ndarray):
|
|
246
|
+
all_lens = np.array([len(x) for x in val], dtype=np.int32)
|
|
247
|
+
all_vals = np.concatenate(val.to_numpy())
|
|
248
|
+
else:
|
|
249
|
+
all_lens = np.ones([len(val)], dtype=np.int32)
|
|
250
|
+
all_vals = val.to_numpy()
|
|
251
|
+
|
|
252
|
+
if part_data_dict is not None and k in part_data_dict:
|
|
253
|
+
tmp_lens = np.concatenate([part_data_dict[k][0], all_lens], axis=0)
|
|
254
|
+
tmp_vals = np.concatenate([part_data_dict[k][1], all_vals], axis=0)
|
|
255
|
+
if len(tmp_lens) > batch_size:
|
|
256
|
+
tmp_res_lens = tmp_lens[batch_size:]
|
|
257
|
+
tmp_lens = tmp_lens[:batch_size]
|
|
258
|
+
tmp_num_elems = np.sum(tmp_lens)
|
|
259
|
+
tmp_res_vals = tmp_vals[tmp_num_elems:]
|
|
260
|
+
tmp_vals = tmp_vals[:tmp_num_elems]
|
|
261
|
+
part_data_dict_n[k] = (tmp_res_lens, tmp_res_vals)
|
|
262
|
+
data_dict[k] = (tmp_lens, tmp_vals)
|
|
263
|
+
elif len(tmp_lens) == batch_size:
|
|
264
|
+
data_dict[k] = (tmp_lens, tmp_vals)
|
|
265
|
+
else:
|
|
266
|
+
part_data_dict_n[k] = (tmp_lens, tmp_vals)
|
|
267
|
+
else:
|
|
268
|
+
part_data_dict_n[k] = (all_lens, all_vals)
|
|
269
|
+
|
|
270
|
+
if effective_fields[0] in data_dict:
|
|
271
|
+
if need_pack:
|
|
272
|
+
if len(sparse_fea_names) > 0:
|
|
273
|
+
_pack_sparse_feas(data_dict, sparse_fea_names)
|
|
274
|
+
if len(dense_fea_names) > 0:
|
|
275
|
+
_pack_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs)
|
|
276
|
+
else:
|
|
277
|
+
if len(dense_fea_names) > 0:
|
|
278
|
+
_reshape_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs)
|
|
279
|
+
if not _add_to_que(data_dict, data_que, proc_stop_que):
|
|
280
|
+
logging.info('add to que failed')
|
|
281
|
+
is_good = False
|
|
282
|
+
break
|
|
283
|
+
total_batch_cnt += 1
|
|
284
|
+
part_data_dict = part_data_dict_n
|
|
285
|
+
if len(part_data_dict) > 0 and is_good:
|
|
286
|
+
batch_len = len(part_data_dict[effective_fields[0]][0])
|
|
287
|
+
if not drop_remainder:
|
|
288
|
+
if need_pack:
|
|
289
|
+
if len(sparse_fea_names) > 0:
|
|
290
|
+
_pack_sparse_feas(part_data_dict, sparse_fea_names)
|
|
291
|
+
if len(dense_fea_names) > 0:
|
|
292
|
+
_pack_dense_feas(part_data_dict, dense_fea_names, dense_fea_cfgs)
|
|
293
|
+
else:
|
|
294
|
+
if len(dense_fea_names) > 0:
|
|
295
|
+
_reshape_dense_feas(part_data_dict, dense_fea_names, dense_fea_cfgs)
|
|
296
|
+
logging.info('remainder batch: %s sample_num=%d' %
|
|
297
|
+
(','.join(part_data_dict.keys()), batch_len))
|
|
298
|
+
_add_to_que(part_data_dict, data_que, proc_stop_que)
|
|
299
|
+
total_batch_cnt += 1
|
|
300
|
+
else:
|
|
301
|
+
logging.warning('drop remain %d samples as drop_remainder is set' %
|
|
302
|
+
batch_len)
|
|
303
|
+
if is_good:
|
|
304
|
+
is_good = _add_to_que(None, data_que, proc_stop_que)
|
|
305
|
+
logging.info(
|
|
306
|
+
'data_proc_id[%d]: is_good = %s, total_batch_cnt=%d, total_sample_cnt=%d'
|
|
307
|
+
% (proc_id, is_good, total_batch_cnt, total_sample_cnt))
|
|
308
|
+
data_que.close(wait_send_finish=is_good)
|
|
309
|
+
|
|
310
|
+
while not is_good:
|
|
311
|
+
try:
|
|
312
|
+
if file_que.get(timeout=1) is None:
|
|
313
|
+
break
|
|
314
|
+
except queue.Empty:
|
|
315
|
+
pass
|
|
316
|
+
file_que.close()
|
|
317
|
+
logging.info('data proc %d done, file_num=%d' % (proc_id, num_files))
|