easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import argparse
|
|
4
|
+
import logging
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
# from kafka import KafkaConsumer
|
|
8
|
+
from kafka import KafkaAdminClient
|
|
9
|
+
from kafka import KafkaProducer
|
|
10
|
+
from kafka.admin import NewTopic
|
|
11
|
+
|
|
12
|
+
# from kafka.structs import TopicPartition
|
|
13
|
+
|
|
14
|
+
logging.basicConfig(
|
|
15
|
+
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
|
|
16
|
+
|
|
17
|
+
if __name__ == '__main__':
|
|
18
|
+
parser = argparse.ArgumentParser()
|
|
19
|
+
parser.add_argument('--servers', type=str, default='localhost:9092')
|
|
20
|
+
parser.add_argument('--topic', type=str, default=None)
|
|
21
|
+
parser.add_argument('--group', type=str, default='consumer')
|
|
22
|
+
parser.add_argument('--partitions', type=str, default=None)
|
|
23
|
+
parser.add_argument('--timeout', type=float, default=float('inf'))
|
|
24
|
+
# file to send
|
|
25
|
+
parser.add_argument('--input_path', type=str, default=None)
|
|
26
|
+
args = parser.parse_args()
|
|
27
|
+
|
|
28
|
+
if args.input_path is None:
|
|
29
|
+
logging.error('input_path is not set')
|
|
30
|
+
sys.exit(1)
|
|
31
|
+
|
|
32
|
+
if args.topic is None:
|
|
33
|
+
logging.error('topic is not set')
|
|
34
|
+
sys.exit(1)
|
|
35
|
+
|
|
36
|
+
servers = args.servers.split(',')
|
|
37
|
+
|
|
38
|
+
admin_clt = KafkaAdminClient(bootstrap_servers=servers)
|
|
39
|
+
if args.topic not in admin_clt.list_topics():
|
|
40
|
+
admin_clt.create_topics(
|
|
41
|
+
new_topics=[
|
|
42
|
+
NewTopic(
|
|
43
|
+
name=args.topic,
|
|
44
|
+
num_partitions=1,
|
|
45
|
+
replication_factor=1,
|
|
46
|
+
topic_configs={'max.message.bytes': 1024 * 1024 * 1024})
|
|
47
|
+
],
|
|
48
|
+
validate_only=False)
|
|
49
|
+
logging.info('create increment save topic: %s' % args.topic)
|
|
50
|
+
admin_clt.close()
|
|
51
|
+
|
|
52
|
+
producer = KafkaProducer(
|
|
53
|
+
bootstrap_servers=servers,
|
|
54
|
+
request_timeout_ms=args.timeout * 1000,
|
|
55
|
+
api_version=(0, 10, 1))
|
|
56
|
+
|
|
57
|
+
i = 1
|
|
58
|
+
with open(args.input_path, 'r') as fin:
|
|
59
|
+
for line_str in fin:
|
|
60
|
+
producer.send(args.topic, line_str.encode('utf-8'))
|
|
61
|
+
i += 1
|
|
62
|
+
break
|
|
63
|
+
if i % 100 == 0:
|
|
64
|
+
logging.info('progress: %d' % i)
|
|
65
|
+
producer.close()
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import argparse
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
py_root_dir_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
|
8
|
+
print(f"py_root_dir_path:{py_root_dir_path}")
|
|
9
|
+
import sys
|
|
10
|
+
sys.path.append(py_root_dir_path)
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
import tensorflow as tf
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from easy_rec.python.main import _train_and_evaluate_impl
|
|
17
|
+
from easy_rec.python.protos.pipeline_pb2 import EasyRecConfig
|
|
18
|
+
from easy_rec.python.protos.train_pb2 import DistributionStrategy
|
|
19
|
+
from easy_rec.python.utils import config_util
|
|
20
|
+
from easy_rec.python.utils import ds_util
|
|
21
|
+
from easy_rec.python.utils import estimator_utils
|
|
22
|
+
from easy_rec.python.utils import fg_util
|
|
23
|
+
from easy_rec.python.utils import hpo_util
|
|
24
|
+
from easy_rec.python.utils.config_util import process_neg_sampler_data_path
|
|
25
|
+
from easy_rec.python.utils.config_util import set_eval_input_path
|
|
26
|
+
from easy_rec.python.utils.config_util import set_train_input_path
|
|
27
|
+
|
|
28
|
+
logging.basicConfig(level=logging.INFO)
|
|
29
|
+
warnings.filterwarnings('ignore')
|
|
30
|
+
|
|
31
|
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
32
|
+
|
|
33
|
+
if tf.__version__.startswith('1.'):
|
|
34
|
+
from tensorflow.python.platform import gfile
|
|
35
|
+
else:
|
|
36
|
+
import tensorflow.io.gfile as gfile
|
|
37
|
+
|
|
38
|
+
from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_train_worker_num_on_ds # NOQA
|
|
39
|
+
|
|
40
|
+
if tf.__version__ >= '2.0':
|
|
41
|
+
tf = tf.compat.v1
|
|
42
|
+
|
|
43
|
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
|
|
44
|
+
|
|
45
|
+
logging.basicConfig(
|
|
46
|
+
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
|
|
47
|
+
level=logging.INFO)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _get_file_path(root_path, file_list):
|
|
51
|
+
# 获取该目录下所有的文件名称和目录名称
|
|
52
|
+
dir_or_files = os.listdir(root_path)
|
|
53
|
+
for dir_file in dir_or_files:
|
|
54
|
+
# 获取目录或者文件的路径
|
|
55
|
+
dir_file_path = os.path.join(root_path, dir_file)
|
|
56
|
+
# 判断该路径为文件还是路径
|
|
57
|
+
if os.path.isdir(dir_file_path):
|
|
58
|
+
# 递归获取所有文件和目录的路径
|
|
59
|
+
_get_file_path(dir_file_path, file_list)
|
|
60
|
+
else:
|
|
61
|
+
if not str(dir_file_path).__contains__('_SUCCESS'):
|
|
62
|
+
file_list.append(dir_file_path)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_vocab_list(vocab_path):
|
|
66
|
+
with gfile.GFile(vocab_path, 'r') as fin:
|
|
67
|
+
vocabulary_list = [str(line).strip() for line in fin]
|
|
68
|
+
return vocabulary_list
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_file_path_list(root_path):
|
|
72
|
+
file_list = []
|
|
73
|
+
_get_file_path(root_path, file_list)
|
|
74
|
+
return file_list
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def change_pipeline_config(pipeline_config: EasyRecConfig):
|
|
78
|
+
for data in pipeline_config.feature_config.features:
|
|
79
|
+
# print("****"*10)
|
|
80
|
+
vocab_file = data.vocab_list
|
|
81
|
+
if vocab_file:
|
|
82
|
+
vocab_file_new = get_file_path_list(f"{data_root_path}/{vocab_file.pop()}")[0]
|
|
83
|
+
# print(vocab_file_new)
|
|
84
|
+
vocab_list = get_vocab_list(vocab_file_new)
|
|
85
|
+
for vocab in vocab_list:
|
|
86
|
+
data.vocab_list.append(vocab)
|
|
87
|
+
|
|
88
|
+
model_dir = pipeline_config.model_dir
|
|
89
|
+
pipeline_config.model_dir = f"{data_root_path}/{model_dir}"
|
|
90
|
+
|
|
91
|
+
train_input_path = f"{data_root_path}/{pipeline_config.train_input_path}"
|
|
92
|
+
train_input_path_new = get_file_path_list(train_input_path)
|
|
93
|
+
pipeline_config.train_input_path = ','.join(train_input_path_new)
|
|
94
|
+
|
|
95
|
+
eval_input_path = f"{data_root_path}/{pipeline_config.eval_input_path}"
|
|
96
|
+
eval_input_path_new = get_file_path_list(eval_input_path)
|
|
97
|
+
pipeline_config.eval_input_path = ','.join(eval_input_path_new)
|
|
98
|
+
|
|
99
|
+
pipeline_config.data_config.batch_size = batch_size
|
|
100
|
+
pipeline_config.data_config.num_epochs = num_epochs
|
|
101
|
+
|
|
102
|
+
pipeline_config.train_config.log_step_count_steps = int(train_sample_cnt /
|
|
103
|
+
batch_size)
|
|
104
|
+
pipeline_config.train_config.save_checkpoints_steps = int(train_sample_cnt /
|
|
105
|
+
batch_size)
|
|
106
|
+
|
|
107
|
+
pipeline_config.train_config.optimizer_config[
|
|
108
|
+
0].adam_optimizer.learning_rate.exponential_decay_learning_rate.initial_learning_rate = initial_learning_rate
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
if __name__ == '__main__':
|
|
112
|
+
parser = argparse.ArgumentParser()
|
|
113
|
+
parser.add_argument(
|
|
114
|
+
'--pipeline_config_path',
|
|
115
|
+
type=str,
|
|
116
|
+
# default="/Users/chensheng/PycharmProjects/EasyRec/samples/model_config/deepfm_on_criteo_tfrecord.config",
|
|
117
|
+
default='/Users/chensheng/PycharmProjects/EasyRec/samples/model_config/custom_model.config',
|
|
118
|
+
help='Path to pipeline config file.')
|
|
119
|
+
|
|
120
|
+
parser.add_argument(
|
|
121
|
+
'--data_root_path',
|
|
122
|
+
type = str,
|
|
123
|
+
default= '/Users/chensheng/PycharmProjects/EasyRec/data/test/cs_data'
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
parser.add_argument(
|
|
127
|
+
'--train_sample_cnt',
|
|
128
|
+
type=int,
|
|
129
|
+
default=27000,
|
|
130
|
+
help='训练集合的样本数,该数与save_checkpoints_steps 数值相等')
|
|
131
|
+
|
|
132
|
+
parser.add_argument(
|
|
133
|
+
'--batch_size',
|
|
134
|
+
type=int,
|
|
135
|
+
default=3000,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
parser.add_argument(
|
|
139
|
+
'--num_epochs',
|
|
140
|
+
type=int,
|
|
141
|
+
default=10,
|
|
142
|
+
)
|
|
143
|
+
parser.add_argument(
|
|
144
|
+
'--initial_learning_rate',
|
|
145
|
+
type=float,
|
|
146
|
+
default=0.001,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
parser.add_argument(
|
|
150
|
+
'--continue_train',
|
|
151
|
+
action='store_true',
|
|
152
|
+
default=False,
|
|
153
|
+
help='continue train using existing model_dir')
|
|
154
|
+
parser.add_argument(
|
|
155
|
+
'--hpo_param_path',
|
|
156
|
+
type=str,
|
|
157
|
+
default=None,
|
|
158
|
+
help='hyperparam tuning param path')
|
|
159
|
+
parser.add_argument(
|
|
160
|
+
'--hpo_metric_save_path',
|
|
161
|
+
type=str,
|
|
162
|
+
default=None,
|
|
163
|
+
help='hyperparameter save metric path')
|
|
164
|
+
parser.add_argument(
|
|
165
|
+
'--model_dir',
|
|
166
|
+
type=str,
|
|
167
|
+
default=None,
|
|
168
|
+
help='will update the model_dir in pipeline_config')
|
|
169
|
+
parser.add_argument(
|
|
170
|
+
'--train_input_path',
|
|
171
|
+
type=str,
|
|
172
|
+
nargs='*',
|
|
173
|
+
default=None,
|
|
174
|
+
help='train data input path')
|
|
175
|
+
parser.add_argument(
|
|
176
|
+
'--eval_input_path',
|
|
177
|
+
type=str,
|
|
178
|
+
nargs='*',
|
|
179
|
+
default=None,
|
|
180
|
+
help='eval data input path')
|
|
181
|
+
parser.add_argument(
|
|
182
|
+
'--fit_on_eval',
|
|
183
|
+
action='store_true',
|
|
184
|
+
default=False,
|
|
185
|
+
help='Fit evaluation data after fitting and evaluating train data')
|
|
186
|
+
parser.add_argument(
|
|
187
|
+
'--fit_on_eval_steps',
|
|
188
|
+
type=int,
|
|
189
|
+
default=None,
|
|
190
|
+
help='Fit evaluation data steps')
|
|
191
|
+
parser.add_argument(
|
|
192
|
+
'--fine_tune_checkpoint',
|
|
193
|
+
type=str,
|
|
194
|
+
default=None,
|
|
195
|
+
help='will update the train_config.fine_tune_checkpoint in pipeline_config'
|
|
196
|
+
)
|
|
197
|
+
parser.add_argument(
|
|
198
|
+
'--edit_config_json',
|
|
199
|
+
type=str,
|
|
200
|
+
default=None,
|
|
201
|
+
help='edit pipeline config str, example: {"model_dir":"experiments/",'
|
|
202
|
+
'"feature_config.feature[0].boundaries":[4,5,6,7]}')
|
|
203
|
+
parser.add_argument(
|
|
204
|
+
'--ignore_finetune_ckpt_error',
|
|
205
|
+
action='store_true',
|
|
206
|
+
default=False,
|
|
207
|
+
help='During incremental training, ignore the problem of missing fine_tune_checkpoint files'
|
|
208
|
+
)
|
|
209
|
+
parser.add_argument(
|
|
210
|
+
'--odps_config', type=str, default=None, help='odps config path')
|
|
211
|
+
parser.add_argument(
|
|
212
|
+
'--is_on_ds', action='store_true', default=False, help='is on ds')
|
|
213
|
+
parser.add_argument(
|
|
214
|
+
'--check_mode',
|
|
215
|
+
action='store_true',
|
|
216
|
+
default=False,
|
|
217
|
+
help='is use check mode')
|
|
218
|
+
parser.add_argument(
|
|
219
|
+
'--selected_cols', type=str, default=None, help='select input columns')
|
|
220
|
+
parser.add_argument('--gpu', type=str, default=None, help='gpu id')
|
|
221
|
+
args, extra_args = parser.parse_known_args()
|
|
222
|
+
|
|
223
|
+
data_root_path = args.data_root_path
|
|
224
|
+
train_sample_cnt = args.train_sample_cnt
|
|
225
|
+
batch_size = args.batch_size
|
|
226
|
+
num_epochs = args.num_epochs
|
|
227
|
+
initial_learning_rate = args.initial_learning_rate
|
|
228
|
+
|
|
229
|
+
if args.gpu is not None:
|
|
230
|
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
|
231
|
+
|
|
232
|
+
edit_config_json = {}
|
|
233
|
+
if args.edit_config_json:
|
|
234
|
+
edit_config_json = json.loads(args.edit_config_json)
|
|
235
|
+
|
|
236
|
+
if extra_args is not None and len(extra_args) > 0:
|
|
237
|
+
config_util.parse_extra_config_param(extra_args, edit_config_json)
|
|
238
|
+
|
|
239
|
+
if args.pipeline_config_path is not None:
|
|
240
|
+
pipeline_config = config_util.get_configs_from_pipeline_file(
|
|
241
|
+
args.pipeline_config_path, False)
|
|
242
|
+
if args.selected_cols:
|
|
243
|
+
pipeline_config.data_config.selected_cols = args.selected_cols
|
|
244
|
+
if args.model_dir:
|
|
245
|
+
pipeline_config.model_dir = args.model_dir
|
|
246
|
+
logging.info('update model_dir to %s' % pipeline_config.model_dir)
|
|
247
|
+
if args.train_input_path:
|
|
248
|
+
set_train_input_path(pipeline_config, args.train_input_path)
|
|
249
|
+
if args.eval_input_path:
|
|
250
|
+
set_eval_input_path(pipeline_config, args.eval_input_path)
|
|
251
|
+
|
|
252
|
+
if args.fine_tune_checkpoint:
|
|
253
|
+
ckpt_path = estimator_utils.get_latest_checkpoint_from_checkpoint_path(
|
|
254
|
+
args.fine_tune_checkpoint, args.ignore_finetune_ckpt_error)
|
|
255
|
+
|
|
256
|
+
if ckpt_path:
|
|
257
|
+
pipeline_config.train_config.fine_tune_checkpoint = ckpt_path
|
|
258
|
+
|
|
259
|
+
if pipeline_config.fg_json_path:
|
|
260
|
+
fg_util.load_fg_json_to_config(pipeline_config)
|
|
261
|
+
|
|
262
|
+
if args.odps_config:
|
|
263
|
+
os.environ['ODPS_CONFIG_FILE_PATH'] = args.odps_config
|
|
264
|
+
|
|
265
|
+
if len(edit_config_json) > 0:
|
|
266
|
+
fine_tune_checkpoint = edit_config_json.get('train_config', {}).get(
|
|
267
|
+
'fine_tune_checkpoint', None)
|
|
268
|
+
if fine_tune_checkpoint:
|
|
269
|
+
ckpt_path = estimator_utils.get_latest_checkpoint_from_checkpoint_path(
|
|
270
|
+
args.fine_tune_checkpoint, args.ignore_finetune_ckpt_error)
|
|
271
|
+
edit_config_json['train_config']['fine_tune_checkpoint'] = ckpt_path
|
|
272
|
+
config_util.edit_config(pipeline_config, edit_config_json)
|
|
273
|
+
|
|
274
|
+
process_neg_sampler_data_path(pipeline_config)
|
|
275
|
+
|
|
276
|
+
if args.is_on_ds:
|
|
277
|
+
ds_util.set_on_ds()
|
|
278
|
+
set_tf_config_and_get_train_worker_num_on_ds()
|
|
279
|
+
if pipeline_config.train_config.fine_tune_checkpoint:
|
|
280
|
+
ds_util.cache_ckpt(pipeline_config)
|
|
281
|
+
|
|
282
|
+
if pipeline_config.train_config.train_distribute in [
|
|
283
|
+
DistributionStrategy.HorovodStrategy,
|
|
284
|
+
]:
|
|
285
|
+
estimator_utils.init_hvd()
|
|
286
|
+
elif pipeline_config.train_config.train_distribute in [
|
|
287
|
+
DistributionStrategy.EmbeddingParallelStrategy,
|
|
288
|
+
DistributionStrategy.SokStrategy
|
|
289
|
+
]:
|
|
290
|
+
estimator_utils.init_hvd()
|
|
291
|
+
estimator_utils.init_sok()
|
|
292
|
+
|
|
293
|
+
if args.hpo_param_path:
|
|
294
|
+
with gfile.GFile(args.hpo_param_path, 'r') as fin:
|
|
295
|
+
hpo_config = json.load(fin)
|
|
296
|
+
hpo_params = hpo_config['param']
|
|
297
|
+
config_util.edit_config(pipeline_config, hpo_params)
|
|
298
|
+
config_util.auto_expand_share_feature_configs(pipeline_config)
|
|
299
|
+
_train_and_evaluate_impl(pipeline_config, args.continue_train,
|
|
300
|
+
args.check_mode)
|
|
301
|
+
hpo_util.save_eval_metrics(
|
|
302
|
+
pipeline_config.model_dir,
|
|
303
|
+
metric_save_path=args.hpo_metric_save_path,
|
|
304
|
+
has_evaluator=False)
|
|
305
|
+
else:
|
|
306
|
+
|
|
307
|
+
change_pipeline_config(pipeline_config)
|
|
308
|
+
|
|
309
|
+
if args.continue_train:
|
|
310
|
+
pass
|
|
311
|
+
else:
|
|
312
|
+
model_dir = pipeline_config.model_dir
|
|
313
|
+
print(f'model_dir:{model_dir}')
|
|
314
|
+
os.system(f'rm -rf {model_dir}')
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
config_util.auto_expand_share_feature_configs(pipeline_config)
|
|
318
|
+
_train_and_evaluate_impl(
|
|
319
|
+
pipeline_config,
|
|
320
|
+
args.continue_train,
|
|
321
|
+
args.check_mode,
|
|
322
|
+
fit_on_eval=args.fit_on_eval,
|
|
323
|
+
fit_on_eval_steps=args.fit_on_eval_steps)
|
|
324
|
+
else:
|
|
325
|
+
raise ValueError('pipeline_config_path should not be empty when training!')
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
class conditional(object):
|
|
2
|
+
"""Wrap another context manager and enter it only if condition is true."""
|
|
3
|
+
|
|
4
|
+
def __init__(self, condition, contextmanager):
|
|
5
|
+
self.condition = condition
|
|
6
|
+
self.contextmanager = contextmanager
|
|
7
|
+
|
|
8
|
+
def __enter__(self):
|
|
9
|
+
"""Conditionally enter a context manager."""
|
|
10
|
+
if self.condition:
|
|
11
|
+
return self.contextmanager.__enter__()
|
|
12
|
+
|
|
13
|
+
def __exit__(self, *args):
|
|
14
|
+
if self.condition:
|
|
15
|
+
return self.contextmanager.__exit__(*args)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# -*- encoding: utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import six
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
|
|
8
|
+
from easy_rec.python.utils.load_class import load_by_path
|
|
9
|
+
|
|
10
|
+
if tf.__version__ >= '2.0':
|
|
11
|
+
tf = tf.compat.v1
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def dice(_x, axis=-1, epsilon=1e-9, name='dice', training=True):
|
|
15
|
+
"""The Data Adaptive Activation Function in DIN.
|
|
16
|
+
|
|
17
|
+
Which can be viewed as a generalization of PReLu,
|
|
18
|
+
and can adaptively adjust the rectified point according to distribution of input data.
|
|
19
|
+
|
|
20
|
+
Arguments
|
|
21
|
+
- **axis** : Integer, the axis that should be used to compute data distribution (typically the features axis).
|
|
22
|
+
- **epsilon** : Small float added to variance to avoid dividing by zero.
|
|
23
|
+
|
|
24
|
+
References
|
|
25
|
+
- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]
|
|
26
|
+
Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining.
|
|
27
|
+
ACM, 2018: 1059-1068.] (https://arxiv.org/pdf/1706.06978.pdf)
|
|
28
|
+
"""
|
|
29
|
+
alphas = tf.get_variable(
|
|
30
|
+
'alpha_' + name,
|
|
31
|
+
_x.get_shape()[-1],
|
|
32
|
+
initializer=tf.constant_initializer(0.0),
|
|
33
|
+
dtype=tf.float32)
|
|
34
|
+
inputs_normed = tf.layers.batch_normalization(
|
|
35
|
+
inputs=_x,
|
|
36
|
+
axis=axis,
|
|
37
|
+
epsilon=epsilon,
|
|
38
|
+
center=False,
|
|
39
|
+
scale=False,
|
|
40
|
+
training=training)
|
|
41
|
+
x_p = tf.sigmoid(inputs_normed)
|
|
42
|
+
return alphas * (1.0 - x_p) * _x + x_p * _x
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def gelu(x, name='gelu'):
|
|
46
|
+
"""Gaussian Error Linear Unit.
|
|
47
|
+
|
|
48
|
+
This is a smoother version of the RELU.
|
|
49
|
+
Original paper: https://arxiv.org/abs/1606.08415
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
x: float Tensor to perform activation.
|
|
53
|
+
name: name for this activation
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
`x` with the GELU activation applied.
|
|
57
|
+
"""
|
|
58
|
+
with tf.name_scope(name):
|
|
59
|
+
cdf = 0.5 * (1.0 + tf.tanh(
|
|
60
|
+
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
|
61
|
+
return x * cdf
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def swish(x, name='swish'):
|
|
65
|
+
with tf.name_scope(name):
|
|
66
|
+
return x * tf.sigmoid(x)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def get_activation(activation_string, **kwargs):
|
|
70
|
+
"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
activation_string: String name of the activation function.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
A Python function corresponding to the activation function. If
|
|
77
|
+
`activation_string` is None, empty, or "linear", this will return None.
|
|
78
|
+
If `activation_string` is not a string, it will return `activation_string`.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
ValueError: The `activation_string` does not correspond to a known
|
|
82
|
+
activation.
|
|
83
|
+
"""
|
|
84
|
+
# We assume that anything that's not a string is already an activation
|
|
85
|
+
# function, so we just return it.
|
|
86
|
+
if not isinstance(activation_string, six.string_types):
|
|
87
|
+
return activation_string
|
|
88
|
+
|
|
89
|
+
if not activation_string:
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
act = activation_string.lower()
|
|
93
|
+
if act == 'linear':
|
|
94
|
+
return None
|
|
95
|
+
elif act == 'relu':
|
|
96
|
+
return tf.nn.relu
|
|
97
|
+
elif act == 'gelu':
|
|
98
|
+
return gelu
|
|
99
|
+
elif act == 'leaky_relu':
|
|
100
|
+
return tf.nn.leaky_relu
|
|
101
|
+
elif act == 'prelu':
|
|
102
|
+
if len(kwargs) == 0:
|
|
103
|
+
return tf.nn.leaky_relu
|
|
104
|
+
return tf.keras.layers.PReLU(**kwargs)
|
|
105
|
+
elif act == 'dice':
|
|
106
|
+
return lambda x, name='dice': dice(x, name=name, **kwargs)
|
|
107
|
+
elif act == 'elu':
|
|
108
|
+
return tf.nn.elu
|
|
109
|
+
elif act == 'selu':
|
|
110
|
+
return tf.nn.selu
|
|
111
|
+
elif act == 'tanh':
|
|
112
|
+
return tf.tanh
|
|
113
|
+
elif act == 'swish':
|
|
114
|
+
if tf.__version__ < '1.13.0':
|
|
115
|
+
return swish
|
|
116
|
+
return tf.nn.swish
|
|
117
|
+
elif act == 'sigmoid':
|
|
118
|
+
return tf.nn.sigmoid
|
|
119
|
+
else:
|
|
120
|
+
return load_by_path(activation_string)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
|
|
4
|
+
import tensorflow as tf
|
|
5
|
+
|
|
6
|
+
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
|
|
7
|
+
|
|
8
|
+
if tf.__version__ >= '2.0':
|
|
9
|
+
tf = tf.compat.v1
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def check_split(line, sep, requried_field_num, field_name=''):
|
|
13
|
+
assert sep, 'must have separator.' + (' field: %s.' %
|
|
14
|
+
field_name) if field_name else ''
|
|
15
|
+
|
|
16
|
+
for one_line in line:
|
|
17
|
+
field_num = len(one_line.split(sep))
|
|
18
|
+
if field_name:
|
|
19
|
+
assert_info = 'sep[%s] maybe invalid. field_num=%d, required_num=%d, field: %s, value: %s, ' \
|
|
20
|
+
'please check separator and data.' % \
|
|
21
|
+
(sep, field_num, requried_field_num, field_name, one_line)
|
|
22
|
+
else:
|
|
23
|
+
assert_info = 'sep[%s] maybe invalid. field_num=%d, required_num=%d, current line is: %s, ' \
|
|
24
|
+
'please check separator and data.' % \
|
|
25
|
+
(sep, field_num, requried_field_num, one_line)
|
|
26
|
+
assert field_num == requried_field_num, assert_info
|
|
27
|
+
return True
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def check_string_to_number(field_vals, field_name):
|
|
31
|
+
for val in field_vals:
|
|
32
|
+
try:
|
|
33
|
+
float(val)
|
|
34
|
+
except: # noqa: E722
|
|
35
|
+
assert False, 'StringToNumber ERROR: cannot convert string_to_number, field: %s, value: %s. ' \
|
|
36
|
+
'please check data.' % (field_name, val)
|
|
37
|
+
return True
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def check_sequence(pipeline_config_path, features):
|
|
41
|
+
seq_att_groups = pipeline_config_path.model_config.seq_att_groups
|
|
42
|
+
if not seq_att_groups:
|
|
43
|
+
return
|
|
44
|
+
for seq_att_group in seq_att_groups:
|
|
45
|
+
seq_att_maps = seq_att_group.seq_att_map
|
|
46
|
+
if not seq_att_maps:
|
|
47
|
+
return
|
|
48
|
+
for seq_att_map in seq_att_maps:
|
|
49
|
+
assert len(seq_att_map.key) == len(seq_att_map.hist_seq), \
|
|
50
|
+
'The size of hist_seq must equal to the size of key in one seq_att_map.'
|
|
51
|
+
size_list = []
|
|
52
|
+
for hist_seq in seq_att_map.hist_seq:
|
|
53
|
+
cur_seq_size = len(features[hist_seq].values)
|
|
54
|
+
size_list.append(cur_seq_size)
|
|
55
|
+
hist_seqs = ' '.join(seq_att_map.hist_seq)
|
|
56
|
+
assert len(set(size_list)) == 1, \
|
|
57
|
+
'SequenceFeature Error: The size in [%s] should be consistent. Please check input: [%s].' % \
|
|
58
|
+
(hist_seqs, hist_seqs)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def check_env_and_input_path(pipeline_config, input_path):
|
|
62
|
+
input_type = pipeline_config.data_config.input_type
|
|
63
|
+
input_type_name = DatasetConfig.InputType.Name(input_type)
|
|
64
|
+
ignore_input_list = [
|
|
65
|
+
DatasetConfig.InputType.TFRecordInput,
|
|
66
|
+
DatasetConfig.InputType.BatchTFRecordInput,
|
|
67
|
+
DatasetConfig.InputType.KafkaInput,
|
|
68
|
+
DatasetConfig.InputType.DataHubInput,
|
|
69
|
+
DatasetConfig.InputType.HiveInput,
|
|
70
|
+
DatasetConfig.InputType.DummyInput,
|
|
71
|
+
]
|
|
72
|
+
if input_type in ignore_input_list:
|
|
73
|
+
return True
|
|
74
|
+
assert_info = 'Current InputType is %s, InputPath is %s. Please check InputType and InputPath.' % \
|
|
75
|
+
(input_type_name, input_path)
|
|
76
|
+
if input_type_name.startswith('Odps'):
|
|
77
|
+
# is on pai
|
|
78
|
+
for path in input_path.split(','):
|
|
79
|
+
if not path.startswith('odps://'):
|
|
80
|
+
assert False, assert_info
|
|
81
|
+
return True
|
|
82
|
+
else:
|
|
83
|
+
# local or ds
|
|
84
|
+
for path in input_path.split(','):
|
|
85
|
+
if path.startswith('odps://'):
|
|
86
|
+
assert False, assert_info
|
|
87
|
+
return True
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
# Date: 2019-10-12
|
|
4
|
+
# util to hanlde python2 python3 compatibility
|
|
5
|
+
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def in_python2():
|
|
10
|
+
return sys.version_info[0] == 2
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def in_python3():
|
|
14
|
+
return sys.version_info[0] == 3
|