easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
|
|
4
|
+
import tensorflow as tf
|
|
5
|
+
|
|
6
|
+
from easy_rec.python.utils import estimator_utils
|
|
7
|
+
from easy_rec.python.utils.dag import DAG
|
|
8
|
+
from easy_rec.python.utils.expr_util import get_expression
|
|
9
|
+
|
|
10
|
+
if tf.__version__ >= '2.0':
|
|
11
|
+
tf = tf.compat.v1
|
|
12
|
+
gfile = tf.gfile
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class UtilTest(tf.test.TestCase):
|
|
16
|
+
|
|
17
|
+
def test_get_ckpt_version(self):
|
|
18
|
+
ver = estimator_utils.get_ckpt_version(
|
|
19
|
+
'oss://easyrec/ckpts/model.ckpt-6500.meta')
|
|
20
|
+
assert ver == 6500, 'invalid version: %s' % str(ver)
|
|
21
|
+
ver = estimator_utils.get_ckpt_version(
|
|
22
|
+
'oss://easyrec/ckpts/model.ckpt-6500')
|
|
23
|
+
assert ver == 6500, 'invalid version: %s' % str(ver)
|
|
24
|
+
|
|
25
|
+
def test_get_expression_greater(self):
|
|
26
|
+
result = get_expression('age_level>item_age_level',
|
|
27
|
+
['age_level', 'item_age_level'])
|
|
28
|
+
assert result == "tf.greater(parsed_dict['age_level'], parsed_dict['item_age_level'])"
|
|
29
|
+
|
|
30
|
+
def test_get_expression_greater_equal(self):
|
|
31
|
+
result = get_expression('age_level>=item_age_level',
|
|
32
|
+
['age_level', 'item_age_level'])
|
|
33
|
+
assert result == "tf.greater_equal(parsed_dict['age_level'], parsed_dict['item_age_level'])"
|
|
34
|
+
|
|
35
|
+
def test_get_expression_less(self):
|
|
36
|
+
result = get_expression('age_level<item_age_level',
|
|
37
|
+
['age_level', 'item_age_level'])
|
|
38
|
+
assert result == "tf.less(parsed_dict['age_level'], parsed_dict['item_age_level'])"
|
|
39
|
+
|
|
40
|
+
def test_get_expression_less_equal(self):
|
|
41
|
+
result = get_expression('age_level<=item_age_level',
|
|
42
|
+
['age_level', 'item_age_level'])
|
|
43
|
+
assert result == "tf.less_equal(parsed_dict['age_level'], parsed_dict['item_age_level'])"
|
|
44
|
+
|
|
45
|
+
def test_get_expression_and(self):
|
|
46
|
+
result = get_expression('(age_level>3)&(item_age_level<1)',
|
|
47
|
+
['age_level', 'item_age_level'])
|
|
48
|
+
assert result == "tf.greater(parsed_dict['age_level'], 3) & tf.less(parsed_dict['item_age_level'], 1)"
|
|
49
|
+
|
|
50
|
+
result = get_expression(
|
|
51
|
+
'(age_level>item_age_level) & (age_level<item_age_level*3)',
|
|
52
|
+
['age_level', 'item_age_level'])
|
|
53
|
+
assert result == "tf.greater(parsed_dict['age_level'], parsed_dict['item_age_level']) &" \
|
|
54
|
+
" tf.less(parsed_dict['age_level'], parsed_dict['item_age_level']*3)"
|
|
55
|
+
|
|
56
|
+
def test_get_expression_or(self):
|
|
57
|
+
result = get_expression('(age_level>3)|(item_age_level<1)',
|
|
58
|
+
['age_level', 'item_age_level'])
|
|
59
|
+
assert result == "tf.greater(parsed_dict['age_level'], 3) | tf.less(parsed_dict['item_age_level'], 1)"
|
|
60
|
+
|
|
61
|
+
def test_dag(self):
|
|
62
|
+
dag = DAG()
|
|
63
|
+
dag.add_node('a')
|
|
64
|
+
dag.add_node('b')
|
|
65
|
+
dag.add_node('c')
|
|
66
|
+
dag.add_node('d')
|
|
67
|
+
dag.add_edge('a', 'b')
|
|
68
|
+
dag.add_edge('a', 'd')
|
|
69
|
+
dag.add_edge('b', 'c')
|
|
70
|
+
order = dag.topological_sort()
|
|
71
|
+
idx_a = order.index('a')
|
|
72
|
+
idx_b = order.index('b')
|
|
73
|
+
idx_c = order.index('c')
|
|
74
|
+
idx_d = order.index('d')
|
|
75
|
+
assert idx_a < idx_b
|
|
76
|
+
assert idx_a < idx_d
|
|
77
|
+
assert idx_b < idx_c
|
|
78
|
+
c = dag.all_downstreams('b')
|
|
79
|
+
assert c == ['c']
|
|
80
|
+
leaf = dag.all_leaves()
|
|
81
|
+
assert leaf == ['c', 'd']
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
if __name__ == '__main__':
|
|
85
|
+
tf.test.main()
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import numpy as np
|
|
4
|
+
import tensorflow as tf
|
|
5
|
+
from scipy import stats
|
|
6
|
+
|
|
7
|
+
from easy_rec.python.loss.zero_inflated_lognormal import zero_inflated_lognormal_loss # NOQA
|
|
8
|
+
|
|
9
|
+
if tf.__version__ >= '2.0':
|
|
10
|
+
tf = tf.compat.v1
|
|
11
|
+
|
|
12
|
+
# Absolute error tolerance in asserting array near.
|
|
13
|
+
_ERR_TOL = 1e-6
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# softplus function that calculates log(1+exp(x))
|
|
17
|
+
def _softplus(x):
|
|
18
|
+
return np.log(1.0 + np.exp(x))
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# sigmoid function that calculates 1/(1+exp(-x))
|
|
22
|
+
def _sigmoid(x):
|
|
23
|
+
return 1 / (1 + np.exp(-x))
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ZeroInflatedLognormalLossTest(tf.test.TestCase):
|
|
27
|
+
|
|
28
|
+
def setUp(self):
|
|
29
|
+
super(ZeroInflatedLognormalLossTest, self).setUp()
|
|
30
|
+
self.logits = np.array([[.1, .2, .3], [.4, .5, .6]])
|
|
31
|
+
self.labels = np.array([[0.], [1.5]])
|
|
32
|
+
|
|
33
|
+
def zero_inflated_lognormal(self, labels, logits):
|
|
34
|
+
positive_logits = logits[..., :1]
|
|
35
|
+
loss_zero = _softplus(positive_logits)
|
|
36
|
+
loc = logits[..., 1:2]
|
|
37
|
+
scale = np.maximum(
|
|
38
|
+
_softplus(logits[..., 2:]), np.sqrt(tf.keras.backend.epsilon()))
|
|
39
|
+
log_prob_non_zero = stats.lognorm.logpdf(
|
|
40
|
+
x=labels, s=scale, loc=0, scale=np.exp(loc))
|
|
41
|
+
loss_non_zero = _softplus(-positive_logits) - log_prob_non_zero
|
|
42
|
+
return np.mean(np.where(labels == 0., loss_zero, loss_non_zero), axis=-1)
|
|
43
|
+
|
|
44
|
+
def test_loss_value(self):
|
|
45
|
+
expected_loss = self.zero_inflated_lognormal(self.labels, self.logits)
|
|
46
|
+
expected_loss = np.average(expected_loss)
|
|
47
|
+
loss = zero_inflated_lognormal_loss(self.labels, self.logits)
|
|
48
|
+
self.assertNear(self.evaluate(loss), expected_loss, _ERR_TOL)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
if __name__ == '__main__':
|
|
52
|
+
tf.enable_eager_execution()
|
|
53
|
+
tf.test.main()
|
|
File without changes
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
import common_io
|
|
9
|
+
import tensorflow as tf
|
|
10
|
+
|
|
11
|
+
from easy_rec.python.utils import config_util
|
|
12
|
+
from easy_rec.python.utils import io_util
|
|
13
|
+
|
|
14
|
+
if tf.__version__ >= '2.0':
|
|
15
|
+
tf = tf.compat.v1
|
|
16
|
+
|
|
17
|
+
logging.basicConfig(
|
|
18
|
+
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
|
|
19
|
+
level=logging.INFO)
|
|
20
|
+
tf.app.flags.DEFINE_string('template_config_path', None,
|
|
21
|
+
'Path to template pipeline config '
|
|
22
|
+
'file.')
|
|
23
|
+
tf.app.flags.DEFINE_string('output_config_path', None,
|
|
24
|
+
'Path to output pipeline config '
|
|
25
|
+
'file.')
|
|
26
|
+
tf.app.flags.DEFINE_string('tables', '', 'quantile binning table')
|
|
27
|
+
|
|
28
|
+
FLAGS = tf.app.flags.FLAGS
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def main(argv):
|
|
32
|
+
pipeline_config = config_util.get_configs_from_pipeline_file(
|
|
33
|
+
FLAGS.template_config_path)
|
|
34
|
+
|
|
35
|
+
feature_boundaries_info = {}
|
|
36
|
+
reader = common_io.table.TableReader(
|
|
37
|
+
FLAGS.tables, selected_cols='feature,json')
|
|
38
|
+
while True:
|
|
39
|
+
try:
|
|
40
|
+
record = reader.read()
|
|
41
|
+
raw_info = json.loads(record[0][1])
|
|
42
|
+
bin_info = []
|
|
43
|
+
for info in raw_info['bin']['norm'][:-1]:
|
|
44
|
+
split_point = float(info['value'].split(',')[1][:-1])
|
|
45
|
+
bin_info.append(split_point)
|
|
46
|
+
feature_boundaries_info[record[0][0]] = bin_info
|
|
47
|
+
except common_io.exception.OutOfRangeException:
|
|
48
|
+
reader.close()
|
|
49
|
+
break
|
|
50
|
+
|
|
51
|
+
logging.info('feature boundaries: %s' % feature_boundaries_info)
|
|
52
|
+
|
|
53
|
+
for feature_config in pipeline_config.feature_configs:
|
|
54
|
+
feature_name = feature_config.input_names[0]
|
|
55
|
+
if feature_name in feature_boundaries_info:
|
|
56
|
+
feature_config.feature_type = feature_config.RawFeature
|
|
57
|
+
feature_config.hash_bucket_size = 0
|
|
58
|
+
feature_config.boundaries.extend(feature_boundaries_info[feature_name])
|
|
59
|
+
logging.info('edited %s' % feature_name)
|
|
60
|
+
|
|
61
|
+
config_dir, config_name = os.path.split(FLAGS.output_config_path)
|
|
62
|
+
config_util.save_pipeline_config(pipeline_config, config_dir, config_name)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
if __name__ == '__main__':
|
|
66
|
+
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
|
|
67
|
+
tf.app.run()
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
import tensorflow as tf
|
|
9
|
+
|
|
10
|
+
from easy_rec.python.utils import config_util
|
|
11
|
+
from easy_rec.python.utils import io_util
|
|
12
|
+
from easy_rec.python.utils.hive_utils import HiveUtils
|
|
13
|
+
|
|
14
|
+
if tf.__version__ >= '2.0':
|
|
15
|
+
tf = tf.compat.v1
|
|
16
|
+
|
|
17
|
+
logging.basicConfig(
|
|
18
|
+
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
|
|
19
|
+
level=logging.INFO)
|
|
20
|
+
tf.app.flags.DEFINE_string('template_config_path', None,
|
|
21
|
+
'Path to template pipeline config '
|
|
22
|
+
'file.')
|
|
23
|
+
tf.app.flags.DEFINE_string('output_config_path', None,
|
|
24
|
+
'Path to output pipeline config '
|
|
25
|
+
'file.')
|
|
26
|
+
tf.app.flags.DEFINE_string('config_table', '', 'config table')
|
|
27
|
+
|
|
28
|
+
FLAGS = tf.app.flags.FLAGS
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def main(argv):
|
|
32
|
+
pipeline_config = config_util.get_configs_from_pipeline_file(
|
|
33
|
+
FLAGS.template_config_path)
|
|
34
|
+
sels = 'feature,feature_info,message'
|
|
35
|
+
feature_info_map = {}
|
|
36
|
+
drop_feature_names = []
|
|
37
|
+
|
|
38
|
+
if pipeline_config.WhichOneof('train_path') == 'hive_train_input':
|
|
39
|
+
hive_util = HiveUtils(
|
|
40
|
+
data_config=pipeline_config.data_config,
|
|
41
|
+
hive_config=pipeline_config.hive_train_input,
|
|
42
|
+
selected_cols=sels,
|
|
43
|
+
record_defaults=['', '', ''])
|
|
44
|
+
reader = hive_util.hive_read_line(FLAGS.config_table)
|
|
45
|
+
for record in reader:
|
|
46
|
+
feature_name = record[0][0]
|
|
47
|
+
feature_info_map[feature_name] = json.loads(record[0][1])
|
|
48
|
+
if 'DROP IT' in record[0][2]:
|
|
49
|
+
drop_feature_names.append(feature_name)
|
|
50
|
+
|
|
51
|
+
else:
|
|
52
|
+
import common_io
|
|
53
|
+
reader = common_io.table.TableReader(FLAGS.config_table, selected_cols=sels)
|
|
54
|
+
while True:
|
|
55
|
+
try:
|
|
56
|
+
record = reader.read()
|
|
57
|
+
feature_name = record[0][0]
|
|
58
|
+
feature_info_map[feature_name] = json.loads(record[0][1])
|
|
59
|
+
if 'DROP IT' in record[0][2]:
|
|
60
|
+
drop_feature_names.append(feature_name)
|
|
61
|
+
except common_io.exception.OutOfRangeException:
|
|
62
|
+
reader.close()
|
|
63
|
+
break
|
|
64
|
+
|
|
65
|
+
feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
|
|
66
|
+
if drop_feature_names:
|
|
67
|
+
tmp_feature_configs = feature_configs[:]
|
|
68
|
+
for fea_cfg in tmp_feature_configs:
|
|
69
|
+
fea_name = fea_cfg.input_names[0]
|
|
70
|
+
if fea_name in drop_feature_names:
|
|
71
|
+
feature_configs.remove(fea_cfg)
|
|
72
|
+
for feature_config in feature_configs:
|
|
73
|
+
feature_name = feature_config.input_names[0]
|
|
74
|
+
if feature_name in feature_info_map:
|
|
75
|
+
logging.info('edited %s' % feature_name)
|
|
76
|
+
feature_config.embedding_dim = int(
|
|
77
|
+
feature_info_map[feature_name]['embedding_dim'])
|
|
78
|
+
logging.info('modify embedding_dim to %s' % feature_config.embedding_dim)
|
|
79
|
+
if 'boundary' in feature_info_map[feature_name]:
|
|
80
|
+
feature_config.ClearField('boundaries')
|
|
81
|
+
feature_config.boundaries.extend(
|
|
82
|
+
[float(i) for i in feature_info_map[feature_name]['boundary']])
|
|
83
|
+
logging.info('modify boundaries to %s' % feature_config.boundaries)
|
|
84
|
+
elif 'hash_bucket_size' in feature_info_map[feature_name]:
|
|
85
|
+
feature_config.hash_bucket_size = int(
|
|
86
|
+
feature_info_map[feature_name]['hash_bucket_size'])
|
|
87
|
+
logging.info('modify hash_bucket_size to %s' %
|
|
88
|
+
feature_config.hash_bucket_size)
|
|
89
|
+
# modify num_steps
|
|
90
|
+
pipeline_config.train_config.num_steps = feature_info_map['__NUM_STEPS__'][
|
|
91
|
+
'num_steps']
|
|
92
|
+
logging.info('modify num_steps to %s' %
|
|
93
|
+
pipeline_config.train_config.num_steps)
|
|
94
|
+
# modify decay_steps
|
|
95
|
+
optimizer_configs = pipeline_config.train_config.optimizer_config
|
|
96
|
+
for optimizer_config in optimizer_configs:
|
|
97
|
+
optimizer = optimizer_config.WhichOneof('optimizer')
|
|
98
|
+
optimizer = getattr(optimizer_config, optimizer)
|
|
99
|
+
learning_rate = optimizer.learning_rate.WhichOneof('learning_rate')
|
|
100
|
+
learning_rate = getattr(optimizer.learning_rate, learning_rate)
|
|
101
|
+
if hasattr(learning_rate, 'decay_steps'):
|
|
102
|
+
learning_rate.decay_steps = feature_info_map['__DECAY_STEPS__'][
|
|
103
|
+
'decay_steps']
|
|
104
|
+
logging.info('modify decay_steps to %s' % learning_rate.decay_steps)
|
|
105
|
+
|
|
106
|
+
for feature_group in pipeline_config.model_config.feature_groups:
|
|
107
|
+
feature_names = feature_group.feature_names
|
|
108
|
+
reserved_features = []
|
|
109
|
+
for feature_name in feature_names:
|
|
110
|
+
if feature_name not in drop_feature_names:
|
|
111
|
+
reserved_features.append(feature_name)
|
|
112
|
+
else:
|
|
113
|
+
logging.info('drop feature: %s' % feature_name)
|
|
114
|
+
feature_group.ClearField('feature_names')
|
|
115
|
+
feature_group.feature_names.extend(reserved_features)
|
|
116
|
+
for sequence_feature in feature_group.sequence_features:
|
|
117
|
+
seq_att_maps = sequence_feature.seq_att_map
|
|
118
|
+
for seq_att in seq_att_maps:
|
|
119
|
+
keys = seq_att.key
|
|
120
|
+
reserved_keys = []
|
|
121
|
+
for key in keys:
|
|
122
|
+
if key not in drop_feature_names:
|
|
123
|
+
reserved_keys.append(key)
|
|
124
|
+
else:
|
|
125
|
+
logging.info('drop sequence feature key: %s' % key)
|
|
126
|
+
seq_att.ClearField('key')
|
|
127
|
+
seq_att.key.extend(reserved_keys)
|
|
128
|
+
|
|
129
|
+
hist_seqs = seq_att.hist_seq
|
|
130
|
+
reserved_hist_seqs = []
|
|
131
|
+
for hist_seq in hist_seqs:
|
|
132
|
+
if hist_seq not in drop_feature_names:
|
|
133
|
+
reserved_hist_seqs.append(hist_seq)
|
|
134
|
+
else:
|
|
135
|
+
logging.info('drop sequence feature hist_seq: %s' % hist_seq)
|
|
136
|
+
seq_att.ClearField('hist_seq')
|
|
137
|
+
seq_att.hist_seq.extend(reserved_hist_seqs)
|
|
138
|
+
|
|
139
|
+
config_dir, config_name = os.path.split(FLAGS.output_config_path)
|
|
140
|
+
config_util.save_pipeline_config(pipeline_config, config_dir, config_name)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
if __name__ == '__main__':
|
|
144
|
+
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
|
|
145
|
+
tf.app.run()
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from google.protobuf import json_format
|
|
6
|
+
from google.protobuf import text_format
|
|
7
|
+
|
|
8
|
+
from easy_rec.python.protos.pipeline_pb2 import EasyRecConfig
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_config(input_config):
|
|
12
|
+
pipeline_config = EasyRecConfig()
|
|
13
|
+
with open(input_config, 'r') as fin:
|
|
14
|
+
tmp_str = fin.read()
|
|
15
|
+
if input_config.endswith('.config'):
|
|
16
|
+
text_format.Merge(tmp_str, pipeline_config)
|
|
17
|
+
elif input_config.endswith('.json'):
|
|
18
|
+
json_format.Parse(tmp_str, pipeline_config)
|
|
19
|
+
else:
|
|
20
|
+
assert False, 'only .config/.json are supported(%s)' % input_config
|
|
21
|
+
return pipeline_config
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def save_config(pipeline_config, save_path):
|
|
25
|
+
with open(save_path, 'w') as fout:
|
|
26
|
+
if save_path.endswith('.config'):
|
|
27
|
+
fout.write(text_format.MessageToString(pipeline_config, as_utf8=True))
|
|
28
|
+
elif save_path.endswith('.json'):
|
|
29
|
+
fout.write(
|
|
30
|
+
json_format.MessageToJson(
|
|
31
|
+
pipeline_config, preserving_proto_field_name=True))
|
|
32
|
+
else:
|
|
33
|
+
assert False, 'only .config/.json are supported(%s)' % save_path
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
if __name__ == '__main__':
|
|
37
|
+
import argparse
|
|
38
|
+
|
|
39
|
+
parser = argparse.ArgumentParser()
|
|
40
|
+
parser.add_argument(
|
|
41
|
+
'--input_config', type=str, help='input_config path', default=None)
|
|
42
|
+
parser.add_argument(
|
|
43
|
+
'--output_config', type=str, help='output_config path', default=None)
|
|
44
|
+
args = parser.parse_args()
|
|
45
|
+
|
|
46
|
+
assert os.path.exists(args.input_config)
|
|
47
|
+
pipeline_config = load_config(args.input_config)
|
|
48
|
+
save_config(pipeline_config, args.output_config)
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
"""Convert the original rtp data format to csv format.
|
|
4
|
+
|
|
5
|
+
The original data format is not suggested to use with EasyRec.
|
|
6
|
+
In the original format: features are in kv format, if a feature has
|
|
7
|
+
more than one value, there will be multiple kvs, such as:
|
|
8
|
+
...tagbeautytagsmart...
|
|
9
|
+
In our new format:
|
|
10
|
+
...beautysmart...
|
|
11
|
+
"""
|
|
12
|
+
import argparse
|
|
13
|
+
import csv
|
|
14
|
+
import json
|
|
15
|
+
import logging
|
|
16
|
+
import sys
|
|
17
|
+
|
|
18
|
+
import tensorflow as tf
|
|
19
|
+
|
|
20
|
+
logging.basicConfig(
|
|
21
|
+
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
|
|
22
|
+
level=logging.INFO)
|
|
23
|
+
|
|
24
|
+
if tf.__version__ >= '2.0':
|
|
25
|
+
tf = tf.compat.v1
|
|
26
|
+
|
|
27
|
+
if __name__ == '__main__':
|
|
28
|
+
parser = argparse.ArgumentParser()
|
|
29
|
+
parser.add_argument(
|
|
30
|
+
'--rtp_fg', type=str, default='', help='rtp fg path(.json)')
|
|
31
|
+
parser.add_argument('--input_path', type=str, default='', help='input path')
|
|
32
|
+
parser.add_argument('--output_path', type=str, default='', help='output path')
|
|
33
|
+
parser.add_argument('--label', type=str, default='', help='label for train')
|
|
34
|
+
args = parser.parse_args()
|
|
35
|
+
|
|
36
|
+
if not args.rtp_fg:
|
|
37
|
+
logging.error('rtp_fg is not set')
|
|
38
|
+
sys.exit(1)
|
|
39
|
+
|
|
40
|
+
if not args.input_path:
|
|
41
|
+
logging.error('input_path is not set')
|
|
42
|
+
sys.exit(1)
|
|
43
|
+
|
|
44
|
+
if not args.output_path:
|
|
45
|
+
logging.error('output_path is not set')
|
|
46
|
+
sys.exit(1)
|
|
47
|
+
|
|
48
|
+
if not args.label:
|
|
49
|
+
logging.error('label is not set')
|
|
50
|
+
sys.exit(1)
|
|
51
|
+
|
|
52
|
+
with open(args.rtp_fg, 'r') as fin:
|
|
53
|
+
rtp_fg = json.load(fin)
|
|
54
|
+
|
|
55
|
+
feature_names = [args.label]
|
|
56
|
+
for feature in rtp_fg['features']:
|
|
57
|
+
feature_name = feature['feature_name']
|
|
58
|
+
feature_names.append(feature_name)
|
|
59
|
+
|
|
60
|
+
with open(args.input_path, 'r') as fin:
|
|
61
|
+
with open(args.output_path, 'w') as fout:
|
|
62
|
+
writer = csv.writer(fout)
|
|
63
|
+
for line_str in fin:
|
|
64
|
+
line_str = line_str.strip()
|
|
65
|
+
line_toks = line_str.split('\002')
|
|
66
|
+
temp_dict = {}
|
|
67
|
+
for line_tok in line_toks:
|
|
68
|
+
k, v = line_tok.split('\003')
|
|
69
|
+
if k not in temp_dict:
|
|
70
|
+
temp_dict[k] = [v]
|
|
71
|
+
else:
|
|
72
|
+
temp_dict[k].append(v)
|
|
73
|
+
temp_vs = []
|
|
74
|
+
for feature_name in feature_names:
|
|
75
|
+
if feature_name in temp_dict:
|
|
76
|
+
temp_vs.append('|'.join(temp_dict[feature_name]))
|
|
77
|
+
else:
|
|
78
|
+
temp_vs.append('')
|
|
79
|
+
writer.writerow(temp_vs)
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
"""Convert rtp fg feature config to easy_rec data_config and feature_config."""
|
|
4
|
+
import argparse
|
|
5
|
+
import logging
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
import tensorflow as tf
|
|
9
|
+
|
|
10
|
+
from easy_rec.python.utils.config_util import save_message
|
|
11
|
+
from easy_rec.python.utils.convert_rtp_fg import convert_rtp_fg
|
|
12
|
+
|
|
13
|
+
logging.basicConfig(
|
|
14
|
+
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
|
|
15
|
+
level=logging.INFO)
|
|
16
|
+
|
|
17
|
+
if tf.__version__ >= '2.0':
|
|
18
|
+
tf = tf.compat.v1
|
|
19
|
+
|
|
20
|
+
model_types = ['deepfm', 'multi_tower', 'wide_and_deep', 'esmm', 'dbmtl', '']
|
|
21
|
+
if __name__ == '__main__':
|
|
22
|
+
parser = argparse.ArgumentParser()
|
|
23
|
+
parser.add_argument(
|
|
24
|
+
'--model_type',
|
|
25
|
+
type=str,
|
|
26
|
+
choices=model_types,
|
|
27
|
+
default='',
|
|
28
|
+
help='model type, currently support: %s' % ','.join(model_types))
|
|
29
|
+
parser.add_argument('--rtp_fg', type=str, help='rtp fg path')
|
|
30
|
+
parser.add_argument(
|
|
31
|
+
'--embedding_dim', type=int, default=16, help='embedding_dimension')
|
|
32
|
+
parser.add_argument(
|
|
33
|
+
'--batch_size', type=int, default=1024, help='batch_size for train')
|
|
34
|
+
parser.add_argument(
|
|
35
|
+
'--label',
|
|
36
|
+
type=str,
|
|
37
|
+
default='',
|
|
38
|
+
nargs='+',
|
|
39
|
+
required=True,
|
|
40
|
+
help='label fields')
|
|
41
|
+
parser.add_argument(
|
|
42
|
+
'--num_steps',
|
|
43
|
+
type=int,
|
|
44
|
+
default=1000,
|
|
45
|
+
help='number of train steps = num_samples * num_epochs / batch_size / num_workers'
|
|
46
|
+
)
|
|
47
|
+
parser.add_argument('--output_path', type=str, help='generated config path')
|
|
48
|
+
parser.add_argument(
|
|
49
|
+
'--incol_separator',
|
|
50
|
+
type=str,
|
|
51
|
+
default='\003',
|
|
52
|
+
help='separator for multi_value features')
|
|
53
|
+
parser.add_argument(
|
|
54
|
+
'--separator',
|
|
55
|
+
type=str,
|
|
56
|
+
default='\002',
|
|
57
|
+
help='separator between different features')
|
|
58
|
+
parser.add_argument(
|
|
59
|
+
'--train_input_path', type=str, default=None, help='train data path')
|
|
60
|
+
parser.add_argument(
|
|
61
|
+
'--eval_input_path', type=str, default=None, help='eval data path')
|
|
62
|
+
parser.add_argument(
|
|
63
|
+
'--selected_cols',
|
|
64
|
+
type=str,
|
|
65
|
+
default=None,
|
|
66
|
+
help='selected cols, for csv input, it is in the format of: label_col_id0,...,lable_cold_idn,feature_col_id '
|
|
67
|
+
'for odps table input, it is in the format of: label_col_name0,...,label_col_namen,feature_col_name '
|
|
68
|
+
)
|
|
69
|
+
parser.add_argument(
|
|
70
|
+
'--rtp_separator', type=str, default=';', help='separator')
|
|
71
|
+
parser.add_argument(
|
|
72
|
+
'--input_type',
|
|
73
|
+
type=str,
|
|
74
|
+
default='OdpsRTPInput',
|
|
75
|
+
help='default to OdpsRTPInput, if test local, change it to RTPInput')
|
|
76
|
+
parser.add_argument(
|
|
77
|
+
'--is_async', action='store_true', help='async mode, debug to false')
|
|
78
|
+
|
|
79
|
+
args = parser.parse_args()
|
|
80
|
+
|
|
81
|
+
if not args.rtp_fg:
|
|
82
|
+
logging.error('rtp_fg is not set')
|
|
83
|
+
sys.exit(1)
|
|
84
|
+
|
|
85
|
+
if not args.output_path:
|
|
86
|
+
logging.error('output_path is not set')
|
|
87
|
+
sys.exit(1)
|
|
88
|
+
|
|
89
|
+
pipeline_config = convert_rtp_fg(args.rtp_fg, args.embedding_dim,
|
|
90
|
+
args.batch_size, args.label, args.num_steps,
|
|
91
|
+
args.model_type, args.separator,
|
|
92
|
+
args.incol_separator, args.train_input_path,
|
|
93
|
+
args.eval_input_path, args.selected_cols,
|
|
94
|
+
args.input_type, args.is_async)
|
|
95
|
+
save_message(pipeline_config, args.output_path)
|
|
96
|
+
logging.info('Conversion done.')
|
|
97
|
+
logging.info('Tips:')
|
|
98
|
+
logging.info(
|
|
99
|
+
'if run on local, please change data_config.input_type to RTPInput, '
|
|
100
|
+
'and model_dir/train_input_path/eval_input_path must also be set, ')
|
|
101
|
+
logging.info(
|
|
102
|
+
'if run local, please set data_config.selected_cols in the format '
|
|
103
|
+
'label_col_id0,label_col_id1,...,label_col_idn,feature_col_id')
|
|
104
|
+
logging.info(
|
|
105
|
+
'if run on odps, selected_cols must be set, which are label0_col,'
|
|
106
|
+
'label1_col, ..., feature_col_name')
|