easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import threading
|
|
8
|
+
import time
|
|
9
|
+
import traceback
|
|
10
|
+
import unittest
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import six
|
|
14
|
+
import tensorflow as tf
|
|
15
|
+
from tensorflow.python.data.ops import iterator_ops
|
|
16
|
+
from tensorflow.python.platform import gfile
|
|
17
|
+
|
|
18
|
+
from easy_rec.python.inference.predictor import Predictor
|
|
19
|
+
from easy_rec.python.input.kafka_dataset import KafkaDataset
|
|
20
|
+
from easy_rec.python.utils import numpy_utils
|
|
21
|
+
from easy_rec.python.utils import test_utils
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
import kafka
|
|
25
|
+
from kafka import KafkaProducer, KafkaAdminClient
|
|
26
|
+
from kafka.admin import NewTopic
|
|
27
|
+
except ImportError:
|
|
28
|
+
logging.warning('kafka-python is not installed: %s' % traceback.format_exc())
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class KafkaTest(tf.test.TestCase):
|
|
32
|
+
|
|
33
|
+
def setUp(self):
|
|
34
|
+
self._success = True
|
|
35
|
+
self._test_dir = test_utils.get_tmp_dir()
|
|
36
|
+
if self._testMethodName == 'test_session':
|
|
37
|
+
self._kafka_server_proc = None
|
|
38
|
+
self._zookeeper_proc = None
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
logging.info('Testing %s.%s, test_dir=%s' %
|
|
42
|
+
(type(self).__name__, self._testMethodName, self._test_dir))
|
|
43
|
+
self._log_dir = os.path.join(self._test_dir, 'logs')
|
|
44
|
+
if not gfile.IsDirectory(self._log_dir):
|
|
45
|
+
gfile.MakeDirs(self._log_dir)
|
|
46
|
+
|
|
47
|
+
self._kafka_servers = ['127.0.0.1:9092']
|
|
48
|
+
self._test_topic = 'kafka_op_test_topic'
|
|
49
|
+
|
|
50
|
+
if 'kafka_install_dir' in os.environ:
|
|
51
|
+
kafka_install_dir = os.environ.get('kafka_install_dir', None)
|
|
52
|
+
|
|
53
|
+
zookeeper_config_raw = '%s/config/zookeeper.properties' % kafka_install_dir
|
|
54
|
+
zookeeper_config = os.path.join(self._test_dir, 'zookeeper.properties')
|
|
55
|
+
with open(zookeeper_config, 'w') as fout:
|
|
56
|
+
with open(zookeeper_config_raw, 'r') as fin:
|
|
57
|
+
for line_str in fin:
|
|
58
|
+
if line_str.startswith('dataDir='):
|
|
59
|
+
fout.write('dataDir=%s/zookeeper\n' % self._test_dir)
|
|
60
|
+
else:
|
|
61
|
+
fout.write(line_str)
|
|
62
|
+
cmd = 'bash %s/bin/zookeeper-server-start.sh %s' % (kafka_install_dir,
|
|
63
|
+
zookeeper_config)
|
|
64
|
+
log_file = os.path.join(self._log_dir, 'zookeeper.log')
|
|
65
|
+
self._zookeeper_proc = test_utils.run_cmd(cmd, log_file)
|
|
66
|
+
|
|
67
|
+
kafka_config_raw = '%s/config/server.properties' % kafka_install_dir
|
|
68
|
+
kafka_config = os.path.join(self._test_dir, 'server.properties')
|
|
69
|
+
with open(kafka_config, 'w') as fout:
|
|
70
|
+
with open(kafka_config_raw, 'r') as fin:
|
|
71
|
+
for line_str in fin:
|
|
72
|
+
if line_str.startswith('log.dirs='):
|
|
73
|
+
fout.write('log.dirs=%s/kafka\n' % self._test_dir)
|
|
74
|
+
else:
|
|
75
|
+
fout.write(line_str)
|
|
76
|
+
cmd = 'bash %s/bin/kafka-server-start.sh %s' % (kafka_install_dir,
|
|
77
|
+
kafka_config)
|
|
78
|
+
log_file = os.path.join(self._log_dir, 'kafka_server.log')
|
|
79
|
+
self._kafka_server_proc = test_utils.run_cmd(cmd, log_file)
|
|
80
|
+
|
|
81
|
+
started = False
|
|
82
|
+
while not started:
|
|
83
|
+
if self._kafka_server_proc.poll(
|
|
84
|
+
) and self._kafka_server_proc.returncode:
|
|
85
|
+
logging.warning('start kafka server failed, will retry.')
|
|
86
|
+
os.system('cat %s' % log_file)
|
|
87
|
+
self._kafka_server_proc = test_utils.run_cmd(cmd, log_file)
|
|
88
|
+
time.sleep(5)
|
|
89
|
+
else:
|
|
90
|
+
try:
|
|
91
|
+
admin_clt = KafkaAdminClient(bootstrap_servers=self._kafka_servers)
|
|
92
|
+
logging.info('old topics: %s' % (','.join(admin_clt.list_topics())))
|
|
93
|
+
admin_clt.close()
|
|
94
|
+
started = True
|
|
95
|
+
except kafka.errors.NoBrokersAvailable:
|
|
96
|
+
time.sleep(2)
|
|
97
|
+
self._create_topic()
|
|
98
|
+
else:
|
|
99
|
+
self._zookeeper_proc = None
|
|
100
|
+
self._kafka_server_proc = None
|
|
101
|
+
self._should_stop = False
|
|
102
|
+
self._producer = None
|
|
103
|
+
|
|
104
|
+
def _create_topic(self, num_partitions=2):
|
|
105
|
+
admin_clt = KafkaAdminClient(bootstrap_servers=self._kafka_servers)
|
|
106
|
+
|
|
107
|
+
logging.info('create topic: %s' % self._test_topic)
|
|
108
|
+
topic_list = [
|
|
109
|
+
NewTopic(
|
|
110
|
+
name=self._test_topic,
|
|
111
|
+
num_partitions=num_partitions,
|
|
112
|
+
replication_factor=1)
|
|
113
|
+
]
|
|
114
|
+
|
|
115
|
+
admin_clt.create_topics(new_topics=topic_list, validate_only=False)
|
|
116
|
+
logging.info('all topics: %s' % (','.join(admin_clt.list_topics())))
|
|
117
|
+
admin_clt.close()
|
|
118
|
+
|
|
119
|
+
def _create_producer(self, generate_func):
|
|
120
|
+
# start produce thread
|
|
121
|
+
|
|
122
|
+
prod = threading.Thread(target=generate_func)
|
|
123
|
+
prod.start()
|
|
124
|
+
return prod
|
|
125
|
+
|
|
126
|
+
def _stop_producer(self):
|
|
127
|
+
if self._producer is not None:
|
|
128
|
+
self._should_stop = True
|
|
129
|
+
self._producer.join()
|
|
130
|
+
|
|
131
|
+
def tearDown(self):
|
|
132
|
+
try:
|
|
133
|
+
self._stop_producer()
|
|
134
|
+
if self._kafka_server_proc is not None:
|
|
135
|
+
self._kafka_server_proc.terminate()
|
|
136
|
+
except Exception as ex:
|
|
137
|
+
logging.warning('exception terminate kafka proc: %s' % str(ex))
|
|
138
|
+
|
|
139
|
+
try:
|
|
140
|
+
if self._zookeeper_proc is not None:
|
|
141
|
+
self._zookeeper_proc.terminate()
|
|
142
|
+
except Exception as ex:
|
|
143
|
+
logging.warning('exception terminate zookeeper proc: %s' % str(ex))
|
|
144
|
+
|
|
145
|
+
test_utils.set_gpu_id(None)
|
|
146
|
+
if self._success:
|
|
147
|
+
test_utils.clean_up(self._test_dir)
|
|
148
|
+
|
|
149
|
+
@unittest.skipIf('kafka_install_dir' not in os.environ,
|
|
150
|
+
'Only execute when kafka is available')
|
|
151
|
+
def test_kafka_ops(self):
|
|
152
|
+
try:
|
|
153
|
+
test_utils.set_gpu_id(None)
|
|
154
|
+
|
|
155
|
+
def _generate():
|
|
156
|
+
producer = KafkaProducer(
|
|
157
|
+
bootstrap_servers=self._kafka_servers, api_version=(0, 10, 1))
|
|
158
|
+
i = 0
|
|
159
|
+
while not self._should_stop:
|
|
160
|
+
msg = 'user_id_%d' % i
|
|
161
|
+
producer.send(self._test_topic, msg)
|
|
162
|
+
producer.close()
|
|
163
|
+
|
|
164
|
+
self._producer = self._create_producer(_generate)
|
|
165
|
+
|
|
166
|
+
group = 'dataset_consumer'
|
|
167
|
+
k = KafkaDataset(
|
|
168
|
+
servers=self._kafka_servers[0],
|
|
169
|
+
topics=[self._test_topic + ':0', self._test_topic + ':1'],
|
|
170
|
+
group=group,
|
|
171
|
+
eof=True,
|
|
172
|
+
# control the maximal read of each partition
|
|
173
|
+
config_global=['max.partition.fetch.bytes=1048576'],
|
|
174
|
+
message_key=True,
|
|
175
|
+
message_offset=True)
|
|
176
|
+
|
|
177
|
+
batch_dataset = k.batch(5)
|
|
178
|
+
|
|
179
|
+
iterator = iterator_ops.Iterator.from_structure(
|
|
180
|
+
batch_dataset.output_types)
|
|
181
|
+
init_batch_op = iterator.make_initializer(batch_dataset)
|
|
182
|
+
get_next = iterator.get_next()
|
|
183
|
+
|
|
184
|
+
sess = tf.Session()
|
|
185
|
+
sess.run(init_batch_op)
|
|
186
|
+
|
|
187
|
+
p = sess.run(get_next)
|
|
188
|
+
|
|
189
|
+
self.assertEquals(len(p), 3)
|
|
190
|
+
offset = p[2]
|
|
191
|
+
self.assertEquals(offset[0], '0:0')
|
|
192
|
+
self.assertEquals(offset[1], '0:1')
|
|
193
|
+
|
|
194
|
+
p = sess.run(get_next)
|
|
195
|
+
offset = p[2]
|
|
196
|
+
self.assertEquals(offset[0], '0:5')
|
|
197
|
+
self.assertEquals(offset[1], '0:6')
|
|
198
|
+
|
|
199
|
+
max_iter = 300
|
|
200
|
+
while max_iter > 0:
|
|
201
|
+
sess.run(get_next)
|
|
202
|
+
max_iter -= 1
|
|
203
|
+
except tf.errors.OutOfRangeError:
|
|
204
|
+
pass
|
|
205
|
+
except Exception as ex:
|
|
206
|
+
self._success = False
|
|
207
|
+
raise ex
|
|
208
|
+
|
|
209
|
+
@unittest.skipIf('kafka_install_dir' not in os.environ,
|
|
210
|
+
'Only execute when kafka is available')
|
|
211
|
+
def test_kafka_train(self):
|
|
212
|
+
try:
|
|
213
|
+
# start produce thread
|
|
214
|
+
self._producer = self._create_producer(self._generate)
|
|
215
|
+
|
|
216
|
+
test_utils.set_gpu_id(None)
|
|
217
|
+
|
|
218
|
+
self._success = test_utils.test_single_train_eval(
|
|
219
|
+
'samples/model_config/deepfm_combo_avazu_kafka.config',
|
|
220
|
+
self._test_dir)
|
|
221
|
+
self.assertTrue(self._success)
|
|
222
|
+
except Exception as ex:
|
|
223
|
+
self._success = False
|
|
224
|
+
raise ex
|
|
225
|
+
|
|
226
|
+
def _generate(self):
|
|
227
|
+
producer = KafkaProducer(
|
|
228
|
+
bootstrap_servers=self._kafka_servers, api_version=(0, 10, 1))
|
|
229
|
+
while not self._should_stop:
|
|
230
|
+
with open('data/test/dwd_avazu_ctr_deepmodel_10w.csv', 'r') as fin:
|
|
231
|
+
for line_str in fin:
|
|
232
|
+
line_str = line_str.strip()
|
|
233
|
+
if self._should_stop:
|
|
234
|
+
break
|
|
235
|
+
if six.PY3:
|
|
236
|
+
line_str = line_str.encode('utf-8')
|
|
237
|
+
producer.send(self._test_topic, line_str)
|
|
238
|
+
producer.close()
|
|
239
|
+
logging.info('data generation thread done.')
|
|
240
|
+
|
|
241
|
+
@unittest.skipIf('kafka_install_dir' not in os.environ,
|
|
242
|
+
'Only execute when kafka is available')
|
|
243
|
+
def test_kafka_train_chief_redundant(self):
|
|
244
|
+
try:
|
|
245
|
+
# start produce thread
|
|
246
|
+
self._producer = self._create_producer(self._generate)
|
|
247
|
+
|
|
248
|
+
test_utils.set_gpu_id(None)
|
|
249
|
+
|
|
250
|
+
self._success = test_utils.test_distributed_train_eval(
|
|
251
|
+
'samples/model_config/deepfm_combo_avazu_kafka_chief_redundant.config',
|
|
252
|
+
self._test_dir,
|
|
253
|
+
num_evaluator=1)
|
|
254
|
+
self.assertTrue(self._success)
|
|
255
|
+
except Exception as ex:
|
|
256
|
+
self._success = False
|
|
257
|
+
raise ex
|
|
258
|
+
|
|
259
|
+
@unittest.skipIf('kafka_install_dir' not in os.environ,
|
|
260
|
+
'Only execute when kafka is available')
|
|
261
|
+
def test_kafka_train_v2(self):
|
|
262
|
+
try:
|
|
263
|
+
# start produce thread
|
|
264
|
+
self._producer = self._create_producer(self._generate)
|
|
265
|
+
|
|
266
|
+
test_utils.set_gpu_id(None)
|
|
267
|
+
|
|
268
|
+
self._success = test_utils.test_single_train_eval(
|
|
269
|
+
'samples/model_config/deepfm_combo_avazu_kafka_time_offset.config',
|
|
270
|
+
self._test_dir)
|
|
271
|
+
|
|
272
|
+
self.assertTrue(self._success)
|
|
273
|
+
except Exception as ex:
|
|
274
|
+
self._success = False
|
|
275
|
+
raise ex
|
|
276
|
+
|
|
277
|
+
@unittest.skipIf(
|
|
278
|
+
'kafka_install_dir' not in os.environ or 'oss_path' not in os.environ or
|
|
279
|
+
'oss_endpoint' not in os.environ and 'oss_ak' not in os.environ or
|
|
280
|
+
'oss_sk' not in os.environ, 'Only execute when kafka is available')
|
|
281
|
+
def test_kafka_processor(self):
|
|
282
|
+
self._test_kafka_processor(
|
|
283
|
+
'samples/model_config/taobao_fg_incr_save.config')
|
|
284
|
+
|
|
285
|
+
@unittest.skipIf(
|
|
286
|
+
'kafka_install_dir' not in os.environ or 'oss_path' not in os.environ or
|
|
287
|
+
'oss_endpoint' not in os.environ and 'oss_ak' not in os.environ or
|
|
288
|
+
'oss_sk' not in os.environ, 'Only execute when kafka is available')
|
|
289
|
+
def test_kafka_processor_ev(self):
|
|
290
|
+
self._test_kafka_processor(
|
|
291
|
+
'samples/model_config/taobao_fg_incr_save_ev.config')
|
|
292
|
+
|
|
293
|
+
def _test_kafka_processor(self, config_path):
|
|
294
|
+
self._success = False
|
|
295
|
+
success = test_utils.test_distributed_train_eval(
|
|
296
|
+
config_path, self._test_dir, total_steps=500)
|
|
297
|
+
self.assertTrue(success)
|
|
298
|
+
export_cmd = """
|
|
299
|
+
python -m easy_rec.python.export --pipeline_config_path %s/pipeline.config
|
|
300
|
+
--export_dir %s/export/sep/ --oss_path=%s --oss_ak=%s --oss_sk=%s --oss_endpoint=%s
|
|
301
|
+
--asset_files ./samples/rtp_fg/fg.json
|
|
302
|
+
--checkpoint_path %s/train/model.ckpt-0
|
|
303
|
+
""" % (self._test_dir, self._test_dir, os.environ['oss_path'],
|
|
304
|
+
os.environ['oss_ak'], os.environ['oss_sk'],
|
|
305
|
+
os.environ['oss_endpoint'], self._test_dir)
|
|
306
|
+
proc = test_utils.run_cmd(export_cmd,
|
|
307
|
+
'%s/log_export_sep.txt' % self._test_dir)
|
|
308
|
+
proc.wait()
|
|
309
|
+
self.assertTrue(proc.returncode == 0)
|
|
310
|
+
files = gfile.Glob(os.path.join(self._test_dir, 'export/sep/[1-9][0-9]*'))
|
|
311
|
+
export_sep_dir = files[0]
|
|
312
|
+
|
|
313
|
+
predict_cmd = """
|
|
314
|
+
python -m easy_rec.python.inference.processor.test --saved_model_dir %s
|
|
315
|
+
--input_path data/test/rtp/taobao_test_feature.txt
|
|
316
|
+
--output_path %s/processor.out --test_dir %s
|
|
317
|
+
""" % (export_sep_dir, self._test_dir, self._test_dir)
|
|
318
|
+
envs = dict(os.environ)
|
|
319
|
+
envs['PROCESSOR_TEST'] = '1'
|
|
320
|
+
proc = test_utils.run_cmd(
|
|
321
|
+
predict_cmd, '%s/log_processor.txt' % self._test_dir, env=envs)
|
|
322
|
+
proc.wait()
|
|
323
|
+
self.assertTrue(proc.returncode == 0)
|
|
324
|
+
|
|
325
|
+
with open('%s/processor.out' % self._test_dir, 'r') as fin:
|
|
326
|
+
processor_out = []
|
|
327
|
+
for line_str in fin:
|
|
328
|
+
line_str = line_str.strip()
|
|
329
|
+
processor_out.append(json.loads(line_str))
|
|
330
|
+
|
|
331
|
+
predictor = Predictor(os.path.join(self._test_dir, 'train/export/final/'))
|
|
332
|
+
with open('data/test/rtp/taobao_test_feature.txt', 'r') as fin:
|
|
333
|
+
inputs = []
|
|
334
|
+
for line_str in fin:
|
|
335
|
+
line_str = line_str.strip()
|
|
336
|
+
line_tok = line_str.split(';')[-1]
|
|
337
|
+
line_tok = line_tok.split(chr(2))
|
|
338
|
+
inputs.append(line_tok)
|
|
339
|
+
output_res = predictor.predict(inputs, batch_size=1024)
|
|
340
|
+
|
|
341
|
+
with open('%s/predictor.out' % self._test_dir, 'w') as fout:
|
|
342
|
+
for i in range(len(output_res)):
|
|
343
|
+
fout.write(
|
|
344
|
+
json.dumps(output_res[i], cls=numpy_utils.NumpyEncoder) + '\n')
|
|
345
|
+
|
|
346
|
+
for i in range(len(output_res)):
|
|
347
|
+
val0 = output_res[i]['probs']
|
|
348
|
+
val1 = processor_out[i]['probs']
|
|
349
|
+
diff = np.abs(val0 - val1)
|
|
350
|
+
assert diff < 1e-4, 'too much difference[%.6f] >= 1e-4' % diff
|
|
351
|
+
self._success = True
|
|
352
|
+
|
|
353
|
+
@unittest.skipIf('kafka_install_dir' not in os.environ,
|
|
354
|
+
'Only execute when kafka is available')
|
|
355
|
+
def test_kafka_train_v3(self):
|
|
356
|
+
try:
|
|
357
|
+
# start produce thread
|
|
358
|
+
self._producer = self._create_producer(self._generate)
|
|
359
|
+
|
|
360
|
+
test_utils.set_gpu_id(None)
|
|
361
|
+
|
|
362
|
+
self._success = test_utils.test_single_train_eval(
|
|
363
|
+
'samples/model_config/deepfm_combo_avazu_kafka_time_offset2.config',
|
|
364
|
+
self._test_dir)
|
|
365
|
+
|
|
366
|
+
self.assertTrue(self._success)
|
|
367
|
+
except Exception as ex:
|
|
368
|
+
self._success = False
|
|
369
|
+
raise ex
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
if __name__ == '__main__':
|
|
373
|
+
tf.test.main()
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import unittest
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import tensorflow as tf
|
|
11
|
+
from tensorflow.python.platform import gfile
|
|
12
|
+
|
|
13
|
+
from easy_rec.python.inference.predictor import Predictor
|
|
14
|
+
from easy_rec.python.utils import numpy_utils
|
|
15
|
+
from easy_rec.python.utils import test_utils
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LocalIncrTest(tf.test.TestCase):
|
|
19
|
+
|
|
20
|
+
def setUp(self):
|
|
21
|
+
self._success = True
|
|
22
|
+
self._test_dir = test_utils.get_tmp_dir()
|
|
23
|
+
|
|
24
|
+
logging.info('Testing %s.%s, test_dir=%s' %
|
|
25
|
+
(type(self).__name__, self._testMethodName, self._test_dir))
|
|
26
|
+
self._log_dir = os.path.join(self._test_dir, 'logs')
|
|
27
|
+
if not gfile.IsDirectory(self._log_dir):
|
|
28
|
+
gfile.MakeDirs(self._log_dir)
|
|
29
|
+
|
|
30
|
+
@unittest.skipIf(
|
|
31
|
+
'oss_path' not in os.environ or
|
|
32
|
+
'oss_endpoint' not in os.environ and 'oss_ak' not in os.environ or
|
|
33
|
+
'oss_sk' not in os.environ, 'Only execute when kafka is available')
|
|
34
|
+
def test_incr_save(self):
|
|
35
|
+
self._test_incr_save(
|
|
36
|
+
'samples/model_config/taobao_fg_incr_save_local.config')
|
|
37
|
+
|
|
38
|
+
@unittest.skipIf(
|
|
39
|
+
'oss_path' not in os.environ or
|
|
40
|
+
'oss_endpoint' not in os.environ and 'oss_ak' not in os.environ or
|
|
41
|
+
'oss_sk' not in os.environ, 'Only execute when kafka is available')
|
|
42
|
+
def test_incr_save_ev(self):
|
|
43
|
+
self._test_incr_save(
|
|
44
|
+
'samples/model_config/taobao_fg_incr_save_ev_local.config')
|
|
45
|
+
|
|
46
|
+
@unittest.skipIf(
|
|
47
|
+
'oss_path' not in os.environ or
|
|
48
|
+
'oss_endpoint' not in os.environ and 'oss_ak' not in os.environ or
|
|
49
|
+
'oss_sk' not in os.environ, 'Only execute when kafka is available')
|
|
50
|
+
def test_incr_save_share_ev(self):
|
|
51
|
+
self._test_incr_save(
|
|
52
|
+
'samples/model_config/taobao_fg_incr_save_share_ev_local.config')
|
|
53
|
+
|
|
54
|
+
def _test_incr_save(self, config_path):
|
|
55
|
+
self._success = False
|
|
56
|
+
success = test_utils.test_distributed_train_eval(
|
|
57
|
+
config_path,
|
|
58
|
+
self._test_dir,
|
|
59
|
+
total_steps=100,
|
|
60
|
+
edit_config_json={
|
|
61
|
+
'train_config.incr_save_config.fs.mount_path':
|
|
62
|
+
os.path.join(self._test_dir, 'train/incr_save/')
|
|
63
|
+
})
|
|
64
|
+
self.assertTrue(success)
|
|
65
|
+
export_cmd = """
|
|
66
|
+
python -m easy_rec.python.export --pipeline_config_path %s/pipeline.config
|
|
67
|
+
--export_dir %s/export/sep/ --oss_path=%s --oss_ak=%s --oss_sk=%s --oss_endpoint=%s
|
|
68
|
+
--asset_files ./samples/rtp_fg/fg.json
|
|
69
|
+
--checkpoint_path %s/train/model.ckpt-0
|
|
70
|
+
""" % (self._test_dir, self._test_dir, os.environ['oss_path'],
|
|
71
|
+
os.environ['oss_ak'], os.environ['oss_sk'],
|
|
72
|
+
os.environ['oss_endpoint'], self._test_dir)
|
|
73
|
+
proc = test_utils.run_cmd(export_cmd,
|
|
74
|
+
'%s/log_export_sep.txt' % self._test_dir)
|
|
75
|
+
proc.wait()
|
|
76
|
+
self.assertTrue(proc.returncode == 0)
|
|
77
|
+
files = gfile.Glob(os.path.join(self._test_dir, 'export/sep/[1-9][0-9]*'))
|
|
78
|
+
export_sep_dir = files[0]
|
|
79
|
+
|
|
80
|
+
predict_cmd = """
|
|
81
|
+
python -m easy_rec.python.inference.processor.test --saved_model_dir %s
|
|
82
|
+
--input_path data/test/rtp/taobao_test_feature.txt
|
|
83
|
+
--output_path %s/processor.out --test_dir %s
|
|
84
|
+
""" % (export_sep_dir, self._test_dir, self._test_dir)
|
|
85
|
+
envs = dict(os.environ)
|
|
86
|
+
envs['PROCESSOR_TEST'] = '1'
|
|
87
|
+
proc = test_utils.run_cmd(
|
|
88
|
+
predict_cmd, '%s/log_processor.txt' % self._test_dir, env=envs)
|
|
89
|
+
proc.wait()
|
|
90
|
+
self.assertTrue(proc.returncode == 0)
|
|
91
|
+
|
|
92
|
+
with open('%s/processor.out' % self._test_dir, 'r') as fin:
|
|
93
|
+
processor_out = []
|
|
94
|
+
for line_str in fin:
|
|
95
|
+
line_str = line_str.strip()
|
|
96
|
+
processor_out.append(json.loads(line_str))
|
|
97
|
+
|
|
98
|
+
predictor = Predictor(os.path.join(self._test_dir, 'train/export/final/'))
|
|
99
|
+
with open('data/test/rtp/taobao_test_feature.txt', 'r') as fin:
|
|
100
|
+
inputs = []
|
|
101
|
+
for line_str in fin:
|
|
102
|
+
line_str = line_str.strip()
|
|
103
|
+
line_tok = line_str.split(';')[-1]
|
|
104
|
+
line_tok = line_tok.split(chr(2))
|
|
105
|
+
inputs.append(line_tok)
|
|
106
|
+
output_res = predictor.predict(inputs, batch_size=1024)
|
|
107
|
+
|
|
108
|
+
with open('%s/predictor.out' % self._test_dir, 'w') as fout:
|
|
109
|
+
for i in range(len(output_res)):
|
|
110
|
+
fout.write(
|
|
111
|
+
json.dumps(output_res[i], cls=numpy_utils.NumpyEncoder) + '\n')
|
|
112
|
+
|
|
113
|
+
for i in range(len(output_res)):
|
|
114
|
+
val0 = output_res[i]['probs']
|
|
115
|
+
val1 = processor_out[i]['probs']
|
|
116
|
+
diff = np.abs(val0 - val1)
|
|
117
|
+
assert diff < 1e-4, 'too much difference[%.6f] >= 1e-4' % diff
|
|
118
|
+
self._success = True
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
if __name__ == '__main__':
|
|
122
|
+
tf.test.main()
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
+
import tensorflow as tf
|
|
3
|
+
|
|
4
|
+
from easy_rec.python.loss.circle_loss import circle_loss
|
|
5
|
+
from easy_rec.python.loss.circle_loss import get_anchor_positive_triplet_mask
|
|
6
|
+
|
|
7
|
+
from easy_rec.python.loss.f1_reweight_loss import f1_reweight_sigmoid_cross_entropy # NOQA
|
|
8
|
+
|
|
9
|
+
from easy_rec.python.loss.softmax_loss_with_negative_mining import softmax_loss_with_negative_mining # NOQA
|
|
10
|
+
|
|
11
|
+
if tf.__version__ >= '2.0':
|
|
12
|
+
tf = tf.compat.v1
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LossTest(tf.test.TestCase):
|
|
16
|
+
|
|
17
|
+
def test_f1_reweighted_loss(self):
|
|
18
|
+
print('test_f1_reweighted_loss')
|
|
19
|
+
logits = tf.constant([0.1, 0.5, 0.3, 0.8, -0.1, 0.3])
|
|
20
|
+
labels = tf.constant([1, 1, 0, 0, 1, 1])
|
|
21
|
+
loss = f1_reweight_sigmoid_cross_entropy(
|
|
22
|
+
labels=labels, logits=logits, beta_square=4)
|
|
23
|
+
with self.test_session() as sess:
|
|
24
|
+
loss_val = sess.run(loss)
|
|
25
|
+
self.assertAlmostEqual(loss_val, 0.47844395, delta=1e-5)
|
|
26
|
+
|
|
27
|
+
def test_softmax_loss_with_negative_mining(self):
|
|
28
|
+
print('test_softmax_loss_with_negative_mining')
|
|
29
|
+
user_emb = tf.constant([[0.1, 0.5, 0.3], [0.8, -0.1, 0.3], [0.28, 0.3, 0.9],
|
|
30
|
+
[0.37, 0.45, 0.93], [-0.7, 0.15, 0.03],
|
|
31
|
+
[0.18, 0.9, -0.3]])
|
|
32
|
+
item_emb = tf.constant([[0.1, -0.5, 0.3], [0.8, -0.31, 0.3],
|
|
33
|
+
[0.7, -0.45, 0.15], [0.08, -0.31, -0.9],
|
|
34
|
+
[-0.7, 0.85, 0.03], [0.18, 0.89, -0.3]])
|
|
35
|
+
|
|
36
|
+
label = tf.constant([1, 1, 0, 0, 1, 1])
|
|
37
|
+
loss = softmax_loss_with_negative_mining(
|
|
38
|
+
user_emb, item_emb, label, num_negative_samples=2, seed=1)
|
|
39
|
+
with self.test_session() as sess:
|
|
40
|
+
loss_val = sess.run(loss)
|
|
41
|
+
self.assertAlmostEqual(loss_val, 0.48577175, delta=1e-5)
|
|
42
|
+
|
|
43
|
+
def test_circle_loss(self):
|
|
44
|
+
print('test_circle_loss')
|
|
45
|
+
emb = tf.constant([[0.1, 0.2, 0.15, 0.1], [0.3, 0.6, 0.45, 0.3],
|
|
46
|
+
[0.13, 0.6, 0.45, 0.3], [0.3, 0.26, 0.45, 0.3],
|
|
47
|
+
[0.3, 0.6, 0.5, 0.13], [0.08, 0.43, 0.21, 0.6]],
|
|
48
|
+
dtype=tf.float32)
|
|
49
|
+
label = tf.constant([1, 1, 2, 2, 3, 3])
|
|
50
|
+
loss = circle_loss(emb, label, label, margin=0.25, gamma=64)
|
|
51
|
+
with self.test_session() as sess:
|
|
52
|
+
loss_val = sess.run(loss)
|
|
53
|
+
self.assertAlmostEqual(loss_val, 52.75707, delta=1e-5)
|
|
54
|
+
|
|
55
|
+
def test_triplet_mask(self):
|
|
56
|
+
print('test_triplet_mask')
|
|
57
|
+
label = tf.constant([1, 1, 2, 2, 3, 3, 4, 5])
|
|
58
|
+
positive_mask = tf.constant(
|
|
59
|
+
[[0., 1., 0., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0., 0., 0.],
|
|
60
|
+
[0., 0., 0., 1., 0., 0., 0., 0.], [0., 0., 1., 0., 0., 0., 0., 0.],
|
|
61
|
+
[0., 0., 0., 0., 0., 1., 0., 0.], [0., 0., 0., 0., 1., 0., 0., 0.],
|
|
62
|
+
[0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.]],
|
|
63
|
+
dtype=tf.float32)
|
|
64
|
+
negative_mask = tf.constant(
|
|
65
|
+
[[0., 0., 1., 1., 1., 1., 1., 1.], [0., 0., 1., 1., 1., 1., 1., 1.],
|
|
66
|
+
[1., 1., 0., 0., 1., 1., 1., 1.], [1., 1., 0., 0., 1., 1., 1., 1.],
|
|
67
|
+
[1., 1., 1., 1., 0., 0., 1., 1.], [1., 1., 1., 1., 0., 0., 1., 1.],
|
|
68
|
+
[1., 1., 1., 1., 1., 1., 0., 1.], [1., 1., 1., 1., 1., 1., 1., 0.]],
|
|
69
|
+
dtype=tf.float32)
|
|
70
|
+
with self.test_session():
|
|
71
|
+
pos_mask = get_anchor_positive_triplet_mask(label, label)
|
|
72
|
+
self.assertAllEqual(positive_mask, pos_mask)
|
|
73
|
+
|
|
74
|
+
neg_mask = _get_anchor_negative_triplet_mask(label, label)
|
|
75
|
+
self.assertAllEqual(negative_mask, neg_mask)
|
|
76
|
+
|
|
77
|
+
batch_size = label.shape.as_list()[0]
|
|
78
|
+
neg_mask2 = 1 - pos_mask - tf.eye(batch_size)
|
|
79
|
+
self.assertAllEqual(neg_mask, neg_mask2)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _get_anchor_negative_triplet_mask(labels, sessions):
|
|
83
|
+
"""Return a 2D mask where mask[a, n] is 1.0 iff a and n have distinct session or label.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
sessions: a `Tensor` with shape [batch_size]
|
|
87
|
+
labels: a `Tensor` with shape [batch_size]
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
|
|
91
|
+
"""
|
|
92
|
+
# Check if sessions[i] != sessions[k]
|
|
93
|
+
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
|
94
|
+
session_not_equal = tf.not_equal(
|
|
95
|
+
tf.expand_dims(sessions, 0), tf.expand_dims(sessions, 1))
|
|
96
|
+
|
|
97
|
+
if labels is sessions:
|
|
98
|
+
return tf.cast(session_not_equal, tf.float32)
|
|
99
|
+
|
|
100
|
+
# Check if labels[i] != labels[k]
|
|
101
|
+
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
|
102
|
+
label_not_equal = tf.not_equal(
|
|
103
|
+
tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
|
|
104
|
+
|
|
105
|
+
mask = tf.logical_or(session_not_equal, label_not_equal)
|
|
106
|
+
return tf.cast(mask, tf.float32)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
if __name__ == '__main__':
|
|
110
|
+
tf.test.main()
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import subprocess
|
|
7
|
+
|
|
8
|
+
from easy_rec.python.test.odps_test_util import get_oss_bucket
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OdpsCommand:
|
|
12
|
+
|
|
13
|
+
def __init__(self, odps_oss_config):
|
|
14
|
+
"""Wrapper for running odps command.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
odps_oss_config: instance of easy_rec.python.utils.odps_test_util.OdpsOSSConfig
|
|
18
|
+
"""
|
|
19
|
+
self.bucket = get_oss_bucket(odps_oss_config.oss_key,
|
|
20
|
+
odps_oss_config.oss_secret,
|
|
21
|
+
odps_oss_config.endpoint,
|
|
22
|
+
odps_oss_config.bucket_name)
|
|
23
|
+
self.bucket_name = odps_oss_config.bucket_name
|
|
24
|
+
self.temp_dir = odps_oss_config.temp_dir
|
|
25
|
+
self.log_path = odps_oss_config.log_dir
|
|
26
|
+
self.odpscmd = odps_oss_config.odpscmd_path
|
|
27
|
+
self.odps_config_path = odps_oss_config.odps_config_path
|
|
28
|
+
self.algo_project = odps_oss_config.algo_project
|
|
29
|
+
self.algo_res_project = odps_oss_config.algo_res_project
|
|
30
|
+
self.algo_version = odps_oss_config.algo_version
|
|
31
|
+
|
|
32
|
+
def run_odps_cmd(self, script_file):
|
|
33
|
+
"""Run sql use odpscmd.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
script_file: xxx.sql file, to be runned by odpscmd
|
|
37
|
+
Raise:
|
|
38
|
+
ValueError if failed
|
|
39
|
+
"""
|
|
40
|
+
exec_file_path = os.path.join(self.temp_dir, script_file)
|
|
41
|
+
file_name = os.path.split(script_file)[1]
|
|
42
|
+
log_file = os.path.join(self.log_path, file_name)
|
|
43
|
+
|
|
44
|
+
if self.odps_config_path is None:
|
|
45
|
+
cmd = 'nohup %s -f %s > %s.log 2>&1' % (self.odpscmd, exec_file_path,
|
|
46
|
+
log_file)
|
|
47
|
+
else:
|
|
48
|
+
cmd = 'nohup %s --config=%s -f %s > %s.log 2>&1' % (
|
|
49
|
+
self.odpscmd, self.odps_config_path, exec_file_path, log_file)
|
|
50
|
+
logging.info('will run cmd: %s' % (cmd))
|
|
51
|
+
proc = subprocess.Popen(cmd, shell=True)
|
|
52
|
+
proc.wait()
|
|
53
|
+
if (proc.returncode == 0):
|
|
54
|
+
logging.info('%s run succeed' % script_file)
|
|
55
|
+
else:
|
|
56
|
+
raise ValueError('%s run FAILED: please check log file:%s.log' %
|
|
57
|
+
(exec_file_path, log_file))
|
|
58
|
+
|
|
59
|
+
def run_list(self, files):
|
|
60
|
+
for f in files:
|
|
61
|
+
self.run_odps_cmd(f)
|