easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,844 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
from __future__ import division
|
|
4
|
+
from __future__ import print_function
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import math
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
import threading
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import six
|
|
15
|
+
import tensorflow as tf
|
|
16
|
+
|
|
17
|
+
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
|
|
18
|
+
from easy_rec.python.utils import ds_util
|
|
19
|
+
from easy_rec.python.utils.config_util import process_multi_file_input_path
|
|
20
|
+
from easy_rec.python.utils.tf_utils import get_tf_type
|
|
21
|
+
|
|
22
|
+
if tf.__version__.startswith('1.'):
|
|
23
|
+
from tensorflow.python.platform import gfile
|
|
24
|
+
else:
|
|
25
|
+
import tensorflow.io.gfile as gfile
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# patch graph-learn string_attrs for utf-8
|
|
29
|
+
@property
|
|
30
|
+
def string_attrs(self): # NOQA
|
|
31
|
+
self._init()
|
|
32
|
+
return self._string_attrs
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# pyre-ignore [56]
|
|
36
|
+
@string_attrs.setter
|
|
37
|
+
# pyre-ignore [2, 3]
|
|
38
|
+
def string_attrs(self, string_attrs): # NOQA
|
|
39
|
+
self._string_attrs = self._reshape(string_attrs, expand_shape=True)
|
|
40
|
+
self._inited = True
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
import graphlearn as gl
|
|
45
|
+
from graphlearn.python.data.values import Values
|
|
46
|
+
Values.string_attrs = string_attrs
|
|
47
|
+
except Exception:
|
|
48
|
+
logging.info(
|
|
49
|
+
'GraphLearn is not installed. You can install it by "pip install https://easyrec.oss-cn-beijing.aliyuncs.com/3rdparty/graphlearn-0.7-cp27-cp27mu-linux_x86_64.whl"' # noqa: E501
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if tf.__version__ >= '2.0':
|
|
53
|
+
tf = tf.compat.v1
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _get_gl_type(field_type):
|
|
57
|
+
type_map = {
|
|
58
|
+
DatasetConfig.INT32: 'int',
|
|
59
|
+
DatasetConfig.INT64: 'int',
|
|
60
|
+
DatasetConfig.STRING: 'string',
|
|
61
|
+
DatasetConfig.BOOL: 'int',
|
|
62
|
+
DatasetConfig.FLOAT: 'float',
|
|
63
|
+
DatasetConfig.DOUBLE: 'float'
|
|
64
|
+
}
|
|
65
|
+
assert field_type in type_map, 'invalid type: %s' % field_type
|
|
66
|
+
return type_map[field_type]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _get_np_type(field_type):
|
|
70
|
+
type_map = {
|
|
71
|
+
DatasetConfig.INT32: np.int32,
|
|
72
|
+
DatasetConfig.INT64: np.int64,
|
|
73
|
+
DatasetConfig.STRING: str,
|
|
74
|
+
DatasetConfig.BOOL: bool,
|
|
75
|
+
DatasetConfig.FLOAT: np.float32,
|
|
76
|
+
DatasetConfig.DOUBLE: np.double
|
|
77
|
+
}
|
|
78
|
+
assert field_type in type_map, 'invalid type: %s' % field_type
|
|
79
|
+
return type_map[field_type]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class BaseSampler(object):
|
|
83
|
+
_instance_lock = threading.Lock()
|
|
84
|
+
|
|
85
|
+
def __init__(self, fields, num_sample, num_eval_sample=None):
|
|
86
|
+
self._g = None
|
|
87
|
+
self._sampler = None
|
|
88
|
+
self._num_sample = num_sample
|
|
89
|
+
self._num_eval_sample = num_eval_sample if num_eval_sample is not None else num_sample
|
|
90
|
+
self._build_field_types(fields)
|
|
91
|
+
self._log_first_n = 5
|
|
92
|
+
self._is_on_ds = ds_util.is_on_ds()
|
|
93
|
+
|
|
94
|
+
def set_eval_num_sample(self):
|
|
95
|
+
print('set_eval_num_sample: %d %d' %
|
|
96
|
+
(self._num_sample, self._num_eval_sample))
|
|
97
|
+
self._num_sample = self._num_eval_sample
|
|
98
|
+
|
|
99
|
+
def _init_graph(self):
|
|
100
|
+
if 'TF_CONFIG' in os.environ:
|
|
101
|
+
tf_config = json.loads(os.environ['TF_CONFIG'])
|
|
102
|
+
if 'ps' in tf_config['cluster']:
|
|
103
|
+
# ps mode
|
|
104
|
+
tf_config = json.loads(os.environ['TF_CONFIG'])
|
|
105
|
+
if 'worker' in tf_config['cluster']:
|
|
106
|
+
task_count = len(tf_config['cluster']['worker']) + 2
|
|
107
|
+
else:
|
|
108
|
+
task_count = 2
|
|
109
|
+
if self._is_on_ds:
|
|
110
|
+
gl.set_tracker_mode(0)
|
|
111
|
+
server_hosts = [
|
|
112
|
+
host.split(':')[0] + ':888' + str(i)
|
|
113
|
+
for i, host in enumerate(tf_config['cluster']['ps'])
|
|
114
|
+
]
|
|
115
|
+
cluster = {
|
|
116
|
+
'server': ','.join(server_hosts),
|
|
117
|
+
'client_count': task_count
|
|
118
|
+
}
|
|
119
|
+
else:
|
|
120
|
+
ps_count = len(tf_config['cluster']['ps'])
|
|
121
|
+
cluster = {'server_count': ps_count, 'client_count': task_count}
|
|
122
|
+
if tf_config['task']['type'] in ['chief', 'master']:
|
|
123
|
+
self._g.init(cluster=cluster, job_name='client', task_index=0)
|
|
124
|
+
elif tf_config['task']['type'] == 'worker':
|
|
125
|
+
self._g.init(
|
|
126
|
+
cluster=cluster,
|
|
127
|
+
job_name='client',
|
|
128
|
+
task_index=tf_config['task']['index'] + 2)
|
|
129
|
+
# TODO(hongsheng.jhs): check cluster has evaluator or not?
|
|
130
|
+
elif tf_config['task']['type'] == 'evaluator':
|
|
131
|
+
self._g.init(
|
|
132
|
+
cluster=cluster,
|
|
133
|
+
job_name='client',
|
|
134
|
+
task_index=tf_config['task']['index'] + 1)
|
|
135
|
+
if self._num_eval_sample is not None and self._num_eval_sample > 0:
|
|
136
|
+
self._num_sample = self._num_eval_sample
|
|
137
|
+
elif tf_config['task']['type'] == 'ps':
|
|
138
|
+
self._g.init(
|
|
139
|
+
cluster=cluster,
|
|
140
|
+
job_name='server',
|
|
141
|
+
task_index=tf_config['task']['index'])
|
|
142
|
+
else:
|
|
143
|
+
# worker mode
|
|
144
|
+
task_count = len(tf_config['cluster']['worker']) + 1
|
|
145
|
+
if not self._is_on_ds:
|
|
146
|
+
if tf_config['task']['type'] in ['chief', 'master']:
|
|
147
|
+
self._g.init(task_index=0, task_count=task_count)
|
|
148
|
+
elif tf_config['task']['type'] == 'worker':
|
|
149
|
+
self._g.init(
|
|
150
|
+
task_index=tf_config['task']['index'] + 1,
|
|
151
|
+
task_count=task_count)
|
|
152
|
+
else:
|
|
153
|
+
gl.set_tracker_mode(0)
|
|
154
|
+
if tf_config['cluster'].get('chief', ''):
|
|
155
|
+
chief_host = tf_config['cluster']['chief'][0].split(
|
|
156
|
+
':')[0] + ':8880'
|
|
157
|
+
else:
|
|
158
|
+
chief_host = tf_config['cluster']['master'][0].split(
|
|
159
|
+
':')[0] + ':8880'
|
|
160
|
+
worker_hosts = chief_host + [
|
|
161
|
+
host.split(':')[0] + ':888' + str(i)
|
|
162
|
+
for i, host in enumerate(tf_config['cluster']['worker'])
|
|
163
|
+
]
|
|
164
|
+
|
|
165
|
+
if tf_config['task']['type'] in ['chief', 'master']:
|
|
166
|
+
self._g.init(
|
|
167
|
+
task_index=0,
|
|
168
|
+
task_count=task_count,
|
|
169
|
+
hosts=','.join(worker_hosts))
|
|
170
|
+
elif tf_config['task']['type'] == 'worker':
|
|
171
|
+
self._g.init(
|
|
172
|
+
task_index=tf_config['task']['index'] + 1,
|
|
173
|
+
task_count=task_count,
|
|
174
|
+
hosts=worker_hosts)
|
|
175
|
+
|
|
176
|
+
# TODO(hongsheng.jhs): check cluster has evaluator or not?
|
|
177
|
+
else:
|
|
178
|
+
# local mode
|
|
179
|
+
self._g.init()
|
|
180
|
+
|
|
181
|
+
def _build_field_types(self, fields):
|
|
182
|
+
self._attr_names = []
|
|
183
|
+
self._attr_types = []
|
|
184
|
+
self._attr_gl_types = []
|
|
185
|
+
self._attr_np_types = []
|
|
186
|
+
self._attr_tf_types = []
|
|
187
|
+
for i, field in enumerate(fields):
|
|
188
|
+
self._attr_names.append(field.input_name)
|
|
189
|
+
self._attr_types.append(field.input_type)
|
|
190
|
+
self._attr_gl_types.append(_get_gl_type(field.input_type))
|
|
191
|
+
self._attr_np_types.append(_get_np_type(field.input_type))
|
|
192
|
+
self._attr_tf_types.append(get_tf_type(field.input_type))
|
|
193
|
+
|
|
194
|
+
@classmethod
|
|
195
|
+
def instance(cls, *args, **kwargs):
|
|
196
|
+
with cls._instance_lock:
|
|
197
|
+
if not hasattr(cls, '_instance'):
|
|
198
|
+
cls._instance = cls(*args, **kwargs)
|
|
199
|
+
return cls._instance
|
|
200
|
+
|
|
201
|
+
def __del__(self):
|
|
202
|
+
if self._g is not None:
|
|
203
|
+
self._g.close()
|
|
204
|
+
|
|
205
|
+
def _parse_nodes(self, nodes):
|
|
206
|
+
if self._log_first_n > 0:
|
|
207
|
+
logging.info('num_example=%d num_eval_example=%d node_num=%d' %
|
|
208
|
+
(self._num_sample, self._num_eval_sample, len(nodes.ids)))
|
|
209
|
+
self._log_first_n -= 1
|
|
210
|
+
features = []
|
|
211
|
+
int_idx = 0
|
|
212
|
+
float_idx = 0
|
|
213
|
+
string_idx = 0
|
|
214
|
+
for attr_gl_type, attr_np_type in zip(self._attr_gl_types,
|
|
215
|
+
self._attr_np_types):
|
|
216
|
+
if attr_gl_type == 'int':
|
|
217
|
+
feature = nodes.int_attrs[:, :, int_idx]
|
|
218
|
+
int_idx += 1
|
|
219
|
+
elif attr_gl_type == 'float':
|
|
220
|
+
feature = nodes.float_attrs[:, :, float_idx]
|
|
221
|
+
float_idx += 1
|
|
222
|
+
elif attr_gl_type == 'string':
|
|
223
|
+
feature = nodes.string_attrs[:, :, string_idx]
|
|
224
|
+
if int(sys.version_info[0]) == 3:
|
|
225
|
+
feature = np.char.decode(feature.astype(np.string_), 'utf-8')
|
|
226
|
+
string_idx += 1
|
|
227
|
+
else:
|
|
228
|
+
raise ValueError('Unknown attr type %s' % attr_gl_type)
|
|
229
|
+
feature = np.reshape(feature,
|
|
230
|
+
[-1])[:self._num_sample].astype(attr_np_type)
|
|
231
|
+
if attr_gl_type == 'string':
|
|
232
|
+
feature = feature.tolist()
|
|
233
|
+
features.append(feature)
|
|
234
|
+
return features
|
|
235
|
+
|
|
236
|
+
def _parse_sparse_nodes(self, nodes):
|
|
237
|
+
features = []
|
|
238
|
+
int_idx = 0
|
|
239
|
+
float_idx = 0
|
|
240
|
+
string_idx = 0
|
|
241
|
+
for attr_gl_type, attr_np_type in zip(self._attr_gl_types,
|
|
242
|
+
self._attr_np_types):
|
|
243
|
+
if attr_gl_type == 'int':
|
|
244
|
+
feature = nodes.int_attrs[:, int_idx]
|
|
245
|
+
int_idx += 1
|
|
246
|
+
elif attr_gl_type == 'float':
|
|
247
|
+
feature = nodes.float_attrs[:, float_idx]
|
|
248
|
+
float_idx += 1
|
|
249
|
+
elif attr_gl_type == 'string':
|
|
250
|
+
feature = nodes.string_attrs[:, string_idx]
|
|
251
|
+
string_idx += 1
|
|
252
|
+
else:
|
|
253
|
+
raise ValueError('Unknown attr type %s' % attr_gl_type)
|
|
254
|
+
feature = feature.astype(attr_np_type)
|
|
255
|
+
if attr_gl_type == 'string':
|
|
256
|
+
feature = feature.tolist()
|
|
257
|
+
features.append(feature)
|
|
258
|
+
return features, nodes.indices
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class NegativeSampler(BaseSampler):
|
|
262
|
+
"""Negative Sampler.
|
|
263
|
+
|
|
264
|
+
Weighted random sampling items not in batch.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
data_path: item feature data path. id:int64 | weight:float | attrs:string.
|
|
268
|
+
fields: item input fields.
|
|
269
|
+
num_sample: number of negative samples.
|
|
270
|
+
batch_size: mini-batch size.
|
|
271
|
+
attr_delimiter: delimiter of feature string.
|
|
272
|
+
num_eval_sample: number of negative samples for evaluator.
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
def __init__(self,
|
|
276
|
+
data_path,
|
|
277
|
+
fields,
|
|
278
|
+
num_sample,
|
|
279
|
+
batch_size,
|
|
280
|
+
attr_delimiter=':',
|
|
281
|
+
num_eval_sample=None):
|
|
282
|
+
super(NegativeSampler, self).__init__(fields, num_sample, num_eval_sample)
|
|
283
|
+
self._batch_size = batch_size
|
|
284
|
+
self._g = gl.Graph().node(
|
|
285
|
+
tf.compat.as_str(data_path),
|
|
286
|
+
node_type='item',
|
|
287
|
+
decoder=gl.Decoder(
|
|
288
|
+
attr_types=self._attr_gl_types,
|
|
289
|
+
weighted=True,
|
|
290
|
+
attr_delimiter=attr_delimiter))
|
|
291
|
+
self._init_graph()
|
|
292
|
+
|
|
293
|
+
expand_factor = int(math.ceil(self._num_sample / batch_size))
|
|
294
|
+
self._sampler = self._g.negative_sampler(
|
|
295
|
+
'item', expand_factor, strategy='node_weight')
|
|
296
|
+
|
|
297
|
+
def _get_impl(self, ids):
|
|
298
|
+
ids = np.array(ids, dtype=np.int64)
|
|
299
|
+
ids = np.pad(ids, (0, self._batch_size - len(ids)), 'edge')
|
|
300
|
+
nodes = self._sampler.get(ids)
|
|
301
|
+
features = self._parse_nodes(nodes)
|
|
302
|
+
return features
|
|
303
|
+
|
|
304
|
+
def get(self, ids):
|
|
305
|
+
"""Sampling method.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
ids: item id tensor.
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
Negative sampled feature dict.
|
|
312
|
+
"""
|
|
313
|
+
sampled_values = tf.py_func(self._get_impl, [ids], self._attr_tf_types)
|
|
314
|
+
result_dict = {}
|
|
315
|
+
for k, t, v in zip(self._attr_names, self._attr_tf_types, sampled_values):
|
|
316
|
+
v.set_shape([self._num_sample])
|
|
317
|
+
result_dict[k] = v
|
|
318
|
+
return result_dict
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class NegativeSamplerInMemory(BaseSampler):
|
|
322
|
+
"""Negative Sampler.
|
|
323
|
+
|
|
324
|
+
Weighted random sampling items not in batch.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
data_path: item feature data path. id:int64 | weight:float | attrs:string.
|
|
328
|
+
fields: item input fields.
|
|
329
|
+
num_sample: number of negative samples.
|
|
330
|
+
batch_size: mini-batch size.
|
|
331
|
+
attr_delimiter: delimiter of feature string.
|
|
332
|
+
num_eval_sample: number of negative samples for evaluator.
|
|
333
|
+
"""
|
|
334
|
+
|
|
335
|
+
def __init__(self,
|
|
336
|
+
data_path,
|
|
337
|
+
fields,
|
|
338
|
+
num_sample,
|
|
339
|
+
batch_size,
|
|
340
|
+
attr_delimiter=':',
|
|
341
|
+
num_eval_sample=None):
|
|
342
|
+
super(NegativeSamplerInMemory, self).__init__(fields, num_sample,
|
|
343
|
+
num_eval_sample)
|
|
344
|
+
self._batch_size = batch_size
|
|
345
|
+
|
|
346
|
+
self._item_ids = []
|
|
347
|
+
self._cols = [[] for x in fields]
|
|
348
|
+
|
|
349
|
+
if six.PY2 and isinstance(attr_delimiter, type(u'')):
|
|
350
|
+
attr_delimiter = attr_delimiter.encode('utf-8')
|
|
351
|
+
if data_path.startswith('odps://'):
|
|
352
|
+
self._load_table(data_path, attr_delimiter)
|
|
353
|
+
else:
|
|
354
|
+
self._load_data(data_path, attr_delimiter)
|
|
355
|
+
|
|
356
|
+
print('NegativeSamplerInMemory: total_row_num = %d' % len(self._cols[0]))
|
|
357
|
+
for col_id in range(len(self._attr_np_types)):
|
|
358
|
+
np_type = self._attr_np_types[col_id]
|
|
359
|
+
print('\tcol_id[%d], dtype=%s' % (col_id, self._attr_gl_types[col_id]))
|
|
360
|
+
if np_type != str:
|
|
361
|
+
self._cols[col_id] = np.array(self._cols[col_id], dtype=np_type)
|
|
362
|
+
else:
|
|
363
|
+
self._cols[col_id] = np.asarray(
|
|
364
|
+
self._cols[col_id], order='C', dtype=object)
|
|
365
|
+
|
|
366
|
+
def _load_table(self, data_path, attr_delimiter):
|
|
367
|
+
import common_io
|
|
368
|
+
reader = common_io.table.TableReader(data_path)
|
|
369
|
+
schema = reader.get_schema()
|
|
370
|
+
item_id_col = 0
|
|
371
|
+
fea_id_col = 2
|
|
372
|
+
for tid in range(len(schema)):
|
|
373
|
+
if schema[tid][0].startswith('feature'):
|
|
374
|
+
fea_id_col = tid
|
|
375
|
+
break
|
|
376
|
+
for tid in range(len(schema)):
|
|
377
|
+
if schema[tid][0].startswith('id'):
|
|
378
|
+
item_id_col = tid
|
|
379
|
+
break
|
|
380
|
+
print('NegativeSamplerInMemory: feature_id_col = %d, item_id_col = %d' %
|
|
381
|
+
(fea_id_col, item_id_col))
|
|
382
|
+
while True:
|
|
383
|
+
try:
|
|
384
|
+
row_arr = reader.read(num_records=1024, allow_smaller_final_batch=True)
|
|
385
|
+
for row in row_arr:
|
|
386
|
+
# item_id, weight, feature
|
|
387
|
+
self._item_ids.append(int(row[item_id_col]))
|
|
388
|
+
col_vals = row[fea_id_col].split(attr_delimiter)
|
|
389
|
+
assert len(col_vals) == len(
|
|
390
|
+
self._cols), 'invalid row[%d %d]: %s %s' % (len(
|
|
391
|
+
col_vals), len(self._cols), row[item_id_col], row[fea_id_col])
|
|
392
|
+
for col_id in range(len(col_vals)):
|
|
393
|
+
self._cols[col_id].append(col_vals[col_id])
|
|
394
|
+
except common_io.exception.OutOfRangeException:
|
|
395
|
+
reader.close()
|
|
396
|
+
break
|
|
397
|
+
|
|
398
|
+
def _load_data(self, data_path, attr_delimiter):
|
|
399
|
+
item_id_col = 0
|
|
400
|
+
fea_id_col = 2
|
|
401
|
+
print('NegativeSamplerInMemory: load sample feature from %s' % data_path)
|
|
402
|
+
with gfile.GFile(data_path, 'r') as fin:
|
|
403
|
+
for line_id, line_str in enumerate(fin):
|
|
404
|
+
line_str = line_str.strip()
|
|
405
|
+
cols = line_str.split('\t')
|
|
406
|
+
if line_id == 0:
|
|
407
|
+
schema = [x.split(':') for x in cols]
|
|
408
|
+
for tid in range(len(schema)):
|
|
409
|
+
if schema[tid][0].startswith('id'):
|
|
410
|
+
item_id_col = tid
|
|
411
|
+
if schema[tid][0].startswith('feature'):
|
|
412
|
+
fea_id_col = tid
|
|
413
|
+
print('feature_id_col = %d, item_id_col = %d' %
|
|
414
|
+
(fea_id_col, item_id_col))
|
|
415
|
+
else:
|
|
416
|
+
self._item_ids.append(int(cols[item_id_col]))
|
|
417
|
+
fea_vals = cols[fea_id_col].split(attr_delimiter)
|
|
418
|
+
assert len(fea_vals) == len(
|
|
419
|
+
self._cols), 'invalid row[%d][%d %d]:%s %s' % (
|
|
420
|
+
line_id, len(fea_vals), len(
|
|
421
|
+
self._cols), cols[item_id_col], cols[fea_id_col])
|
|
422
|
+
for col_id in range(len(fea_vals)):
|
|
423
|
+
self._cols[col_id].append(fea_vals[col_id])
|
|
424
|
+
|
|
425
|
+
def _get_impl(self, ids):
|
|
426
|
+
features = []
|
|
427
|
+
if type(ids[0]) != int:
|
|
428
|
+
ids = [int(x) for x in ids]
|
|
429
|
+
assert self._num_sample > 0, 'invalid num_sample: %d' % self._num_sample
|
|
430
|
+
|
|
431
|
+
indices = np.random.choice(
|
|
432
|
+
len(self._item_ids),
|
|
433
|
+
size=self._num_sample + self._batch_size,
|
|
434
|
+
replace=False)
|
|
435
|
+
|
|
436
|
+
sel_ids = []
|
|
437
|
+
for tid in indices:
|
|
438
|
+
rid = self._item_ids[tid]
|
|
439
|
+
if rid not in ids:
|
|
440
|
+
sel_ids.append(tid)
|
|
441
|
+
if len(sel_ids) >= self._num_sample and self._num_sample > 0:
|
|
442
|
+
break
|
|
443
|
+
|
|
444
|
+
features = []
|
|
445
|
+
for col_id in range(len(self._cols)):
|
|
446
|
+
tmp_col = self._cols[col_id]
|
|
447
|
+
np_type = self._attr_np_types[col_id]
|
|
448
|
+
if np_type != str:
|
|
449
|
+
sel_feas = tmp_col[sel_ids]
|
|
450
|
+
features.append(sel_feas)
|
|
451
|
+
else:
|
|
452
|
+
features.append(
|
|
453
|
+
np.asarray([tmp_col[x] for x in sel_ids], order='C', dtype=object))
|
|
454
|
+
return features
|
|
455
|
+
|
|
456
|
+
def get(self, ids):
|
|
457
|
+
"""Sampling method.
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
ids: item id tensor.
|
|
461
|
+
|
|
462
|
+
Returns:
|
|
463
|
+
Negative sampled feature dict.
|
|
464
|
+
"""
|
|
465
|
+
all_attr_types = list(self._attr_tf_types)
|
|
466
|
+
if self._num_sample <= 0:
|
|
467
|
+
all_attr_types.append(tf.float32)
|
|
468
|
+
sampled_values = tf.py_func(self._get_impl, [ids], all_attr_types)
|
|
469
|
+
result_dict = {}
|
|
470
|
+
for k, v in zip(self._attr_names, sampled_values):
|
|
471
|
+
result_dict[k] = v
|
|
472
|
+
return result_dict
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
class NegativeSamplerV2(BaseSampler):
|
|
476
|
+
"""Negative Sampler V2.
|
|
477
|
+
|
|
478
|
+
Weighted random sampling items which do not have positive edge with the user.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
user_data_path: user node data path. id:int64 | weight:float.
|
|
482
|
+
item_data_path: item feature data path. id:int64 | weight:float | attrs:string.
|
|
483
|
+
edge_data_path: positive edge data path. userid:int64 | itemid:int64 | weight:float
|
|
484
|
+
fields: item input fields.
|
|
485
|
+
num_sample: number of negative samples.
|
|
486
|
+
batch_size: mini-batch size.
|
|
487
|
+
attr_delimiter: delimiter of feature string.
|
|
488
|
+
num_eval_sample: number of negative samples for evaluator.
|
|
489
|
+
"""
|
|
490
|
+
|
|
491
|
+
def __init__(self,
|
|
492
|
+
user_data_path,
|
|
493
|
+
item_data_path,
|
|
494
|
+
edge_data_path,
|
|
495
|
+
fields,
|
|
496
|
+
num_sample,
|
|
497
|
+
batch_size,
|
|
498
|
+
attr_delimiter=':',
|
|
499
|
+
num_eval_sample=None):
|
|
500
|
+
super(NegativeSamplerV2, self).__init__(fields, num_sample, num_eval_sample)
|
|
501
|
+
self._batch_size = batch_size
|
|
502
|
+
self._g = gl.Graph() \
|
|
503
|
+
.node(tf.compat.as_str(user_data_path),
|
|
504
|
+
node_type='user',
|
|
505
|
+
decoder=gl.Decoder(weighted=True)) \
|
|
506
|
+
.node(tf.compat.as_str(item_data_path),
|
|
507
|
+
node_type='item',
|
|
508
|
+
decoder=gl.Decoder(
|
|
509
|
+
attr_types=self._attr_gl_types,
|
|
510
|
+
weighted=True,
|
|
511
|
+
attr_delimiter=attr_delimiter)) \
|
|
512
|
+
.edge(tf.compat.as_str(edge_data_path),
|
|
513
|
+
edge_type=('user', 'item', 'edge'),
|
|
514
|
+
decoder=gl.Decoder(weighted=True))
|
|
515
|
+
self._init_graph()
|
|
516
|
+
|
|
517
|
+
expand_factor = int(math.ceil(self._num_sample / batch_size))
|
|
518
|
+
self._sampler = self._g.negative_sampler(
|
|
519
|
+
'edge', expand_factor, strategy='random', conditional=True)
|
|
520
|
+
|
|
521
|
+
def _get_impl(self, src_ids, dst_ids):
|
|
522
|
+
src_ids = np.array(src_ids, dtype=np.int64)
|
|
523
|
+
src_ids = np.pad(src_ids, (0, self._batch_size - len(src_ids)), 'edge')
|
|
524
|
+
dst_ids = np.array(dst_ids, dtype=np.int64)
|
|
525
|
+
dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), 'edge')
|
|
526
|
+
nodes = self._sampler.get(src_ids, dst_ids)
|
|
527
|
+
features = self._parse_nodes(nodes)
|
|
528
|
+
return features
|
|
529
|
+
|
|
530
|
+
def get(self, src_ids, dst_ids):
|
|
531
|
+
"""Sampling method.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
src_ids: user id tensor.
|
|
535
|
+
dst_ids: item id tensor.
|
|
536
|
+
|
|
537
|
+
Returns:
|
|
538
|
+
Negative sampled feature dict.
|
|
539
|
+
"""
|
|
540
|
+
sampled_values = tf.py_func(self._get_impl, [src_ids, dst_ids],
|
|
541
|
+
self._attr_tf_types)
|
|
542
|
+
result_dict = {}
|
|
543
|
+
for k, t, v in zip(self._attr_names, self._attr_tf_types, sampled_values):
|
|
544
|
+
v.set_shape([self._num_sample])
|
|
545
|
+
result_dict[k] = v
|
|
546
|
+
return result_dict
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
class HardNegativeSampler(BaseSampler):
|
|
550
|
+
"""HardNegativeSampler.
|
|
551
|
+
|
|
552
|
+
Weighted random sampling items not in batch as negative samples, and sampling
|
|
553
|
+
destination nodes in hard_neg_edge as hard negative samples
|
|
554
|
+
|
|
555
|
+
Args:
|
|
556
|
+
user_data_path: user node data path. id:int64 | weight:float.
|
|
557
|
+
item_data_path: item feature data path. id:int64 | weight:float | attrs:string.
|
|
558
|
+
hard_neg_edge_data_path: hard negative edge data path. userid:int64 | itemid:int64 | weight:float
|
|
559
|
+
fields: item input fields.
|
|
560
|
+
num_sample: number of negative samples.
|
|
561
|
+
num_hard_sample: maximum number of hard negative samples.
|
|
562
|
+
batch_size: mini-batch size.
|
|
563
|
+
attr_delimiter: delimiter of feature string.
|
|
564
|
+
num_eval_sample: number of negative samples for evaluator.
|
|
565
|
+
"""
|
|
566
|
+
|
|
567
|
+
def __init__(self,
|
|
568
|
+
user_data_path,
|
|
569
|
+
item_data_path,
|
|
570
|
+
hard_neg_edge_data_path,
|
|
571
|
+
fields,
|
|
572
|
+
num_sample,
|
|
573
|
+
num_hard_sample,
|
|
574
|
+
batch_size,
|
|
575
|
+
attr_delimiter=':',
|
|
576
|
+
num_eval_sample=None):
|
|
577
|
+
super(HardNegativeSampler, self).__init__(fields, num_sample,
|
|
578
|
+
num_eval_sample)
|
|
579
|
+
self._batch_size = batch_size
|
|
580
|
+
self._g = gl.Graph() \
|
|
581
|
+
.node(tf.compat.as_str(user_data_path),
|
|
582
|
+
node_type='user',
|
|
583
|
+
decoder=gl.Decoder(weighted=True)) \
|
|
584
|
+
.node(tf.compat.as_str(item_data_path),
|
|
585
|
+
node_type='item',
|
|
586
|
+
decoder=gl.Decoder(
|
|
587
|
+
attr_types=self._attr_gl_types,
|
|
588
|
+
weighted=True,
|
|
589
|
+
attr_delimiter=attr_delimiter)) \
|
|
590
|
+
.edge(tf.compat.as_str(hard_neg_edge_data_path),
|
|
591
|
+
edge_type=('user', 'item', 'hard_neg_edge'),
|
|
592
|
+
decoder=gl.Decoder(weighted=True))
|
|
593
|
+
self._init_graph()
|
|
594
|
+
|
|
595
|
+
expand_factor = int(math.ceil(self._num_sample / batch_size))
|
|
596
|
+
self._neg_sampler = self._g.negative_sampler(
|
|
597
|
+
'item', expand_factor, strategy='node_weight')
|
|
598
|
+
self._hard_neg_sampler = self._g.neighbor_sampler(['hard_neg_edge'],
|
|
599
|
+
num_hard_sample,
|
|
600
|
+
strategy='full')
|
|
601
|
+
|
|
602
|
+
def _get_impl(self, src_ids, dst_ids):
|
|
603
|
+
src_ids = np.array(src_ids, dtype=np.int64)
|
|
604
|
+
dst_ids = np.array(dst_ids, dtype=np.int64)
|
|
605
|
+
dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), 'edge')
|
|
606
|
+
nodes = self._neg_sampler.get(dst_ids)
|
|
607
|
+
neg_features = self._parse_nodes(nodes)
|
|
608
|
+
sparse_nodes = self._hard_neg_sampler.get(src_ids).layer_nodes(1)
|
|
609
|
+
hard_neg_features, hard_neg_indices = self._parse_sparse_nodes(sparse_nodes)
|
|
610
|
+
|
|
611
|
+
results = []
|
|
612
|
+
for i, v in enumerate(hard_neg_features):
|
|
613
|
+
if type(v) == list:
|
|
614
|
+
results.append(np.asarray(neg_features[i] + v, order='C', dtype=object))
|
|
615
|
+
else:
|
|
616
|
+
results.append(np.concatenate([neg_features[i], v], axis=0))
|
|
617
|
+
results.append(hard_neg_indices)
|
|
618
|
+
return results
|
|
619
|
+
|
|
620
|
+
def get(self, src_ids, dst_ids):
|
|
621
|
+
"""Sampling method.
|
|
622
|
+
|
|
623
|
+
Args:
|
|
624
|
+
src_ids: user id tensor.
|
|
625
|
+
dst_ids: item id tensor.
|
|
626
|
+
|
|
627
|
+
Returns:
|
|
628
|
+
Sampled feature dict. The first batch_size is negative samples, remainder is hard negative samples
|
|
629
|
+
"""
|
|
630
|
+
output_types = self._attr_tf_types + [tf.int64]
|
|
631
|
+
output_values = tf.py_func(self._get_impl, [src_ids, dst_ids], output_types)
|
|
632
|
+
result_dict = {}
|
|
633
|
+
for k, t, v in zip(self._attr_names, self._attr_tf_types,
|
|
634
|
+
output_values[:-1]):
|
|
635
|
+
v.set_shape([None])
|
|
636
|
+
result_dict[k] = v
|
|
637
|
+
|
|
638
|
+
hard_neg_indices = output_values[-1]
|
|
639
|
+
hard_neg_indices.set_shape([None, 2])
|
|
640
|
+
result_dict['hard_neg_indices'] = hard_neg_indices
|
|
641
|
+
return result_dict
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
class HardNegativeSamplerV2(BaseSampler):
|
|
645
|
+
"""HardNegativeSampler.
|
|
646
|
+
|
|
647
|
+
Weighted random sampling items which do not have positive edge with the user., and sampling
|
|
648
|
+
destination nodes in hard_neg_edge as hard negative samples
|
|
649
|
+
|
|
650
|
+
Args:
|
|
651
|
+
user_data_path: user node data path. id:int64 | weight:float.
|
|
652
|
+
item_data_path: item feature data path. id:int64 | weight:float | attrs:string.
|
|
653
|
+
edge_data_path: positive edge data path. userid:int64 | itemid:int64 | weight:float
|
|
654
|
+
hard_neg_edge_data_path: hard negative edge data path. userid:int64 | itemid:int64 | weight:float
|
|
655
|
+
fields: item input fields.
|
|
656
|
+
num_sample: number of negative samples.
|
|
657
|
+
num_hard_sample: maximum number of hard negative samples.
|
|
658
|
+
batch_size: mini-batch size.
|
|
659
|
+
attr_delimiter: delimiter of feature string.
|
|
660
|
+
num_eval_sample: number of negative samples for evaluator.
|
|
661
|
+
"""
|
|
662
|
+
|
|
663
|
+
def __init__(self,
|
|
664
|
+
user_data_path,
|
|
665
|
+
item_data_path,
|
|
666
|
+
edge_data_path,
|
|
667
|
+
hard_neg_edge_data_path,
|
|
668
|
+
fields,
|
|
669
|
+
num_sample,
|
|
670
|
+
num_hard_sample,
|
|
671
|
+
batch_size,
|
|
672
|
+
attr_delimiter=':',
|
|
673
|
+
num_eval_sample=None):
|
|
674
|
+
super(HardNegativeSamplerV2, self).__init__(fields, num_sample,
|
|
675
|
+
num_eval_sample)
|
|
676
|
+
self._batch_size = batch_size
|
|
677
|
+
self._g = gl.Graph() \
|
|
678
|
+
.node(tf.compat.as_str(user_data_path),
|
|
679
|
+
node_type='user',
|
|
680
|
+
decoder=gl.Decoder(weighted=True)) \
|
|
681
|
+
.node(tf.compat.as_str(item_data_path),
|
|
682
|
+
node_type='item',
|
|
683
|
+
decoder=gl.Decoder(
|
|
684
|
+
attr_types=self._attr_gl_types,
|
|
685
|
+
weighted=True,
|
|
686
|
+
attr_delimiter=attr_delimiter)) \
|
|
687
|
+
.edge(tf.compat.as_str(edge_data_path),
|
|
688
|
+
edge_type=('user', 'item', 'edge'),
|
|
689
|
+
decoder=gl.Decoder(weighted=True)) \
|
|
690
|
+
.edge(tf.compat.as_str(hard_neg_edge_data_path),
|
|
691
|
+
edge_type=('user', 'item', 'hard_neg_edge'),
|
|
692
|
+
decoder=gl.Decoder(weighted=True))
|
|
693
|
+
self._init_graph()
|
|
694
|
+
|
|
695
|
+
expand_factor = int(math.ceil(self._num_sample / batch_size))
|
|
696
|
+
self._neg_sampler = self._g.negative_sampler(
|
|
697
|
+
'edge', expand_factor, strategy='random', conditional=True)
|
|
698
|
+
self._hard_neg_sampler = self._g.neighbor_sampler(['hard_neg_edge'],
|
|
699
|
+
num_hard_sample,
|
|
700
|
+
strategy='full')
|
|
701
|
+
|
|
702
|
+
def _get_impl(self, src_ids, dst_ids):
|
|
703
|
+
src_ids = np.array(src_ids, dtype=np.int64)
|
|
704
|
+
src_ids_padded = np.pad(src_ids, (0, self._batch_size - len(src_ids)),
|
|
705
|
+
'edge')
|
|
706
|
+
dst_ids = np.array(dst_ids, dtype=np.int64)
|
|
707
|
+
dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), 'edge')
|
|
708
|
+
nodes = self._neg_sampler.get(src_ids_padded, dst_ids)
|
|
709
|
+
neg_features = self._parse_nodes(nodes)
|
|
710
|
+
sparse_nodes = self._hard_neg_sampler.get(src_ids).layer_nodes(1)
|
|
711
|
+
hard_neg_features, hard_neg_indices = self._parse_sparse_nodes(sparse_nodes)
|
|
712
|
+
|
|
713
|
+
results = []
|
|
714
|
+
for i, v in enumerate(hard_neg_features):
|
|
715
|
+
if type(v) == list:
|
|
716
|
+
results.append(np.asarray(neg_features[i] + v, order='C', dtype=object))
|
|
717
|
+
else:
|
|
718
|
+
results.append(np.concatenate([neg_features[i], v], axis=0))
|
|
719
|
+
results.append(hard_neg_indices)
|
|
720
|
+
return results
|
|
721
|
+
|
|
722
|
+
def get(self, src_ids, dst_ids):
|
|
723
|
+
"""Sampling method.
|
|
724
|
+
|
|
725
|
+
Args:
|
|
726
|
+
src_ids: user id tensor.
|
|
727
|
+
dst_ids: item id tensor.
|
|
728
|
+
|
|
729
|
+
Returns:
|
|
730
|
+
Sampled feature dict. The first batch_size is negative samples, remainder is hard negative samples
|
|
731
|
+
"""
|
|
732
|
+
output_types = self._attr_tf_types + [tf.int64]
|
|
733
|
+
output_values = tf.py_func(self._get_impl, [src_ids, dst_ids], output_types)
|
|
734
|
+
result_dict = {}
|
|
735
|
+
for k, t, v in zip(self._attr_names, self._attr_tf_types,
|
|
736
|
+
output_values[:-1]):
|
|
737
|
+
v.set_shape([None])
|
|
738
|
+
result_dict[k] = v
|
|
739
|
+
|
|
740
|
+
hard_neg_indices = output_values[-1]
|
|
741
|
+
hard_neg_indices.set_shape([None, 2])
|
|
742
|
+
result_dict['hard_neg_indices'] = hard_neg_indices
|
|
743
|
+
return result_dict
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
def build(data_config):
|
|
747
|
+
|
|
748
|
+
if not data_config.HasField('sampler'):
|
|
749
|
+
return None
|
|
750
|
+
sampler_type = data_config.WhichOneof('sampler')
|
|
751
|
+
print('sampler_type = %s' % sampler_type)
|
|
752
|
+
sampler_config = getattr(data_config, sampler_type)
|
|
753
|
+
|
|
754
|
+
if ds_util.is_on_ds():
|
|
755
|
+
gl.set_field_delimiter(sampler_config.field_delimiter)
|
|
756
|
+
|
|
757
|
+
if sampler_type == 'negative_sampler':
|
|
758
|
+
input_fields = {f.input_name: f for f in data_config.input_fields}
|
|
759
|
+
attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
|
|
760
|
+
|
|
761
|
+
input_path = process_multi_file_input_path(sampler_config.input_path)
|
|
762
|
+
return NegativeSampler.instance(
|
|
763
|
+
data_path=input_path,
|
|
764
|
+
fields=attr_fields,
|
|
765
|
+
num_sample=sampler_config.num_sample,
|
|
766
|
+
batch_size=data_config.batch_size,
|
|
767
|
+
attr_delimiter=sampler_config.attr_delimiter,
|
|
768
|
+
num_eval_sample=sampler_config.num_eval_sample)
|
|
769
|
+
elif sampler_type == 'negative_sampler_in_memory':
|
|
770
|
+
input_fields = {f.input_name: f for f in data_config.input_fields}
|
|
771
|
+
attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
|
|
772
|
+
|
|
773
|
+
input_path = process_multi_file_input_path(sampler_config.input_path)
|
|
774
|
+
return NegativeSamplerInMemory.instance(
|
|
775
|
+
data_path=input_path,
|
|
776
|
+
fields=attr_fields,
|
|
777
|
+
num_sample=sampler_config.num_sample,
|
|
778
|
+
batch_size=data_config.batch_size,
|
|
779
|
+
attr_delimiter=sampler_config.attr_delimiter,
|
|
780
|
+
num_eval_sample=sampler_config.num_eval_sample)
|
|
781
|
+
elif sampler_type == 'negative_sampler_v2':
|
|
782
|
+
input_fields = {f.input_name: f for f in data_config.input_fields}
|
|
783
|
+
attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
|
|
784
|
+
|
|
785
|
+
user_input_path = process_multi_file_input_path(
|
|
786
|
+
sampler_config.user_input_path)
|
|
787
|
+
item_input_path = process_multi_file_input_path(
|
|
788
|
+
sampler_config.item_input_path)
|
|
789
|
+
pos_edge_input_path = process_multi_file_input_path(
|
|
790
|
+
sampler_config.pos_edge_input_path)
|
|
791
|
+
return NegativeSamplerV2.instance(
|
|
792
|
+
user_data_path=user_input_path,
|
|
793
|
+
item_data_path=item_input_path,
|
|
794
|
+
edge_data_path=pos_edge_input_path,
|
|
795
|
+
fields=attr_fields,
|
|
796
|
+
num_sample=sampler_config.num_sample,
|
|
797
|
+
batch_size=data_config.batch_size,
|
|
798
|
+
attr_delimiter=sampler_config.attr_delimiter,
|
|
799
|
+
num_eval_sample=sampler_config.num_eval_sample)
|
|
800
|
+
elif sampler_type == 'hard_negative_sampler':
|
|
801
|
+
input_fields = {f.input_name: f for f in data_config.input_fields}
|
|
802
|
+
attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
|
|
803
|
+
|
|
804
|
+
user_input_path = process_multi_file_input_path(
|
|
805
|
+
sampler_config.user_input_path)
|
|
806
|
+
item_input_path = process_multi_file_input_path(
|
|
807
|
+
sampler_config.item_input_path)
|
|
808
|
+
hard_neg_edge_input_path = process_multi_file_input_path(
|
|
809
|
+
sampler_config.hard_neg_edge_input_path)
|
|
810
|
+
return HardNegativeSampler.instance(
|
|
811
|
+
user_data_path=user_input_path,
|
|
812
|
+
item_data_path=item_input_path,
|
|
813
|
+
hard_neg_edge_data_path=hard_neg_edge_input_path,
|
|
814
|
+
fields=attr_fields,
|
|
815
|
+
num_sample=sampler_config.num_sample,
|
|
816
|
+
num_hard_sample=sampler_config.num_hard_sample,
|
|
817
|
+
batch_size=data_config.batch_size,
|
|
818
|
+
attr_delimiter=sampler_config.attr_delimiter,
|
|
819
|
+
num_eval_sample=sampler_config.num_eval_sample)
|
|
820
|
+
elif sampler_type == 'hard_negative_sampler_v2':
|
|
821
|
+
input_fields = {f.input_name: f for f in data_config.input_fields}
|
|
822
|
+
attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
|
|
823
|
+
|
|
824
|
+
user_input_path = process_multi_file_input_path(
|
|
825
|
+
sampler_config.user_input_path)
|
|
826
|
+
item_input_path = process_multi_file_input_path(
|
|
827
|
+
sampler_config.item_input_path)
|
|
828
|
+
pos_edge_input_path = process_multi_file_input_path(
|
|
829
|
+
sampler_config.pos_edge_input_path)
|
|
830
|
+
hard_neg_edge_input_path = process_multi_file_input_path(
|
|
831
|
+
sampler_config.hard_neg_edge_input_path)
|
|
832
|
+
return HardNegativeSamplerV2.instance(
|
|
833
|
+
user_data_path=user_input_path,
|
|
834
|
+
item_data_path=item_input_path,
|
|
835
|
+
edge_data_path=pos_edge_input_path,
|
|
836
|
+
hard_neg_edge_data_path=hard_neg_edge_input_path,
|
|
837
|
+
fields=attr_fields,
|
|
838
|
+
num_sample=sampler_config.num_sample,
|
|
839
|
+
num_hard_sample=sampler_config.num_hard_sample,
|
|
840
|
+
batch_size=data_config.batch_size,
|
|
841
|
+
attr_delimiter=sampler_config.attr_delimiter,
|
|
842
|
+
num_eval_sample=sampler_config.num_eval_sample)
|
|
843
|
+
else:
|
|
844
|
+
raise ValueError('Unknown sampler %s' % sampler_type)
|