easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
# Date: 2018-09-13
|
|
4
|
+
import tensorflow as tf
|
|
5
|
+
from tensorflow.python.training import optimizer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MultiOptimizer(optimizer.Optimizer):
|
|
9
|
+
|
|
10
|
+
def __init__(self, opts, grouped_vars, use_locking=False):
|
|
11
|
+
"""Combine multiple optimizers for optimization, such as WideAndDeep.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
opts: list of optimizer instance.
|
|
15
|
+
grouped_vars: list of list of vars, each list of vars are
|
|
16
|
+
optimized by each of the optimizers.
|
|
17
|
+
use_locking: be compatible, currently not used.
|
|
18
|
+
"""
|
|
19
|
+
super(MultiOptimizer, self).__init__(use_locking, 'MultiOptimizer')
|
|
20
|
+
self._opts = opts
|
|
21
|
+
self._grouped_vars = grouped_vars
|
|
22
|
+
|
|
23
|
+
def compute_gradients(self, loss, variables, **kwargs):
|
|
24
|
+
grad_and_vars = []
|
|
25
|
+
for gid, opt in enumerate(self._opts):
|
|
26
|
+
grad_and_vars.extend(
|
|
27
|
+
opt.compute_gradients(loss, self._grouped_vars[gid], **kwargs))
|
|
28
|
+
return grad_and_vars
|
|
29
|
+
|
|
30
|
+
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
|
31
|
+
update_ops = []
|
|
32
|
+
grads_and_vars = [x for x in grads_and_vars]
|
|
33
|
+
for gid, opt in enumerate(self._opts):
|
|
34
|
+
tmp = [x for x in grads_and_vars if x[1] in self._grouped_vars[gid]]
|
|
35
|
+
if gid == 0:
|
|
36
|
+
update_ops.append(opt.apply_gradients(tmp, global_step))
|
|
37
|
+
else:
|
|
38
|
+
update_ops.append(opt.apply_gradients(tmp, None))
|
|
39
|
+
return tf.group(update_ops)
|
|
40
|
+
|
|
41
|
+
def open_auto_record(self, flag=True):
|
|
42
|
+
super(MultiOptimizer, self).open_auto_record(flag)
|
|
43
|
+
|
|
44
|
+
def get_slot(self, var, name):
|
|
45
|
+
raise NotImplementedError('not implemented')
|
|
46
|
+
# for opt in self._opts:
|
|
47
|
+
# tmp = opt.get_slot(var, name)
|
|
48
|
+
# if tmp is not None:
|
|
49
|
+
# return tmp
|
|
50
|
+
# return None
|
|
51
|
+
|
|
52
|
+
def variables(self):
|
|
53
|
+
all_vars = []
|
|
54
|
+
for opt in self._opts:
|
|
55
|
+
all_vars.extend(opt.variables())
|
|
56
|
+
return all_vars
|
|
57
|
+
|
|
58
|
+
def get_slot_names(self):
|
|
59
|
+
slot_names = []
|
|
60
|
+
for opt in self._opts:
|
|
61
|
+
slot_names.extend(opt.get_slot_names())
|
|
62
|
+
return slot_names
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class NumpyEncoder(json.JSONEncoder):
|
|
9
|
+
"""For encode numpy arrays."""
|
|
10
|
+
|
|
11
|
+
def default(self, obj):
|
|
12
|
+
if isinstance(obj, np.integer):
|
|
13
|
+
return int(obj)
|
|
14
|
+
elif isinstance(obj, np.floating):
|
|
15
|
+
return float(obj)
|
|
16
|
+
elif isinstance(obj, np.ndarray):
|
|
17
|
+
return obj.tolist()
|
|
18
|
+
return json.JSONEncoder.default(self, obj)
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
"""Common functions used for odps input."""
|
|
4
|
+
from tensorflow.python.framework import dtypes
|
|
5
|
+
|
|
6
|
+
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def is_type_compatiable(odps_type, input_type):
|
|
10
|
+
"""Check that odps_type are compatiable with input_type."""
|
|
11
|
+
type_map = {
|
|
12
|
+
'bigint': DatasetConfig.INT64,
|
|
13
|
+
'string': DatasetConfig.STRING,
|
|
14
|
+
'double': DatasetConfig.DOUBLE
|
|
15
|
+
}
|
|
16
|
+
tmp_type = type_map[odps_type]
|
|
17
|
+
if tmp_type == input_type:
|
|
18
|
+
return True
|
|
19
|
+
else:
|
|
20
|
+
float_types = [DatasetConfig.FLOAT, DatasetConfig.DOUBLE]
|
|
21
|
+
int_types = [DatasetConfig.INT32, DatasetConfig.INT64]
|
|
22
|
+
if tmp_type in float_types and input_type in float_types:
|
|
23
|
+
return True
|
|
24
|
+
elif tmp_type in int_types and input_type in int_types:
|
|
25
|
+
return True
|
|
26
|
+
else:
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def odps_type_to_input_type(odps_type):
|
|
31
|
+
"""Check that odps_type are compatiable with input_type."""
|
|
32
|
+
odps_type_map = {
|
|
33
|
+
'bigint': DatasetConfig.INT64,
|
|
34
|
+
'string': DatasetConfig.STRING,
|
|
35
|
+
'double': DatasetConfig.DOUBLE
|
|
36
|
+
}
|
|
37
|
+
assert odps_type in odps_type_map, 'only support [bigint, string, double]'
|
|
38
|
+
input_type = odps_type_map[odps_type]
|
|
39
|
+
return input_type
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def check_input_field_and_types(data_config):
|
|
43
|
+
"""Check compatibility of input in data_config.
|
|
44
|
+
|
|
45
|
+
check that data_config.input_fields are compatible with
|
|
46
|
+
data_config.selected_cols and data_config.selected_types.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
data_config: instance of DatasetConfig
|
|
50
|
+
"""
|
|
51
|
+
input_fields = [x.input_name for x in data_config.input_fields]
|
|
52
|
+
input_field_types = [x.input_type for x in data_config.input_fields]
|
|
53
|
+
selected_cols = data_config.selected_cols if data_config.selected_cols else None
|
|
54
|
+
selected_col_types = data_config.selected_col_types if data_config.selected_col_types else None
|
|
55
|
+
if not selected_cols:
|
|
56
|
+
return
|
|
57
|
+
|
|
58
|
+
selected_cols = selected_cols.split(',')
|
|
59
|
+
for x in input_fields:
|
|
60
|
+
assert x in selected_cols, 'column %s is not in table' % x
|
|
61
|
+
if selected_col_types:
|
|
62
|
+
selected_types = selected_col_types.split(',')
|
|
63
|
+
type_map = {x: y for x, y in zip(selected_cols, selected_types)}
|
|
64
|
+
for x, y in zip(input_fields, input_field_types):
|
|
65
|
+
tmp_type = type_map[x]
|
|
66
|
+
assert is_type_compatiable(tmp_type, y), \
|
|
67
|
+
'feature[%s] type error: odps %s is not compatible with input_type %s' % (
|
|
68
|
+
x, tmp_type, DatasetConfig.FieldType.Name(y))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def odps_type_2_tf_type(odps_type):
|
|
72
|
+
if odps_type == 'string':
|
|
73
|
+
return dtypes.string
|
|
74
|
+
elif odps_type == 'bigint':
|
|
75
|
+
return dtypes.int64
|
|
76
|
+
elif odps_type in ['double', 'float']:
|
|
77
|
+
return dtypes.float32
|
|
78
|
+
else:
|
|
79
|
+
return dtypes.string
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
import traceback
|
|
8
|
+
|
|
9
|
+
import tensorflow as tf
|
|
10
|
+
|
|
11
|
+
if sys.version_info.major == 2:
|
|
12
|
+
from urllib2 import urlopen, Request, HTTPError
|
|
13
|
+
else:
|
|
14
|
+
from urllib.request import urlopen, Request
|
|
15
|
+
from urllib.error import HTTPError
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def is_on_pai():
|
|
19
|
+
# IS_ON_PAI is set in pai_jobs/run.py
|
|
20
|
+
# which is the entry on pai platform
|
|
21
|
+
return 'IS_ON_PAI' in os.environ
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def set_on_pai():
|
|
25
|
+
logging.info('set on pai environment variable: IS_ON_PAI')
|
|
26
|
+
os.environ['IS_ON_PAI'] = '1'
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def download(url):
|
|
30
|
+
_, fname = os.path.split(url)
|
|
31
|
+
request = Request(url=url)
|
|
32
|
+
try:
|
|
33
|
+
response = urlopen(request, timeout=10)
|
|
34
|
+
with open(fname, 'w') as ofile:
|
|
35
|
+
ofile.write(response.read())
|
|
36
|
+
return fname
|
|
37
|
+
except HTTPError as e:
|
|
38
|
+
tf.logging.error('http error: ', e.code)
|
|
39
|
+
tf.logging.error('body:', e.read())
|
|
40
|
+
return None
|
|
41
|
+
except Exception as e:
|
|
42
|
+
tf.logging.error(e)
|
|
43
|
+
tf.logging.error(traceback.format_exc())
|
|
44
|
+
return None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def process_config(configs, task_index=0, worker_num=1):
|
|
48
|
+
"""Download config and select config for the worker.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
configs: config paths, separated by ','
|
|
52
|
+
task_index: worker index
|
|
53
|
+
worker_num: total number of workers
|
|
54
|
+
"""
|
|
55
|
+
configs = configs.split(',')
|
|
56
|
+
if len(configs) > 1:
|
|
57
|
+
assert len(configs) == worker_num, \
|
|
58
|
+
'number of configs must be equal to number of workers,' + \
|
|
59
|
+
' when number of configs > 1'
|
|
60
|
+
config = configs[task_index]
|
|
61
|
+
else:
|
|
62
|
+
config = configs[0]
|
|
63
|
+
|
|
64
|
+
if config[:4] == 'http':
|
|
65
|
+
return download(config)
|
|
66
|
+
elif config[:3] == 'oss':
|
|
67
|
+
if '/##/' in config:
|
|
68
|
+
config = config.replace('/##/', '\x02')
|
|
69
|
+
if '/#/' in config:
|
|
70
|
+
config = config.replace('/#/', '\x01')
|
|
71
|
+
return config
|
|
72
|
+
else:
|
|
73
|
+
# allow to use this entry file to run experiments from local env
|
|
74
|
+
# to avoid uploading sample file
|
|
75
|
+
return config
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def test():
|
|
79
|
+
f = download(
|
|
80
|
+
'https://easy-rec.oss-cn-hangzhou.aliyuncs.com/config/MultiTower/dwd_avazu_ctr_deepmodel.config'
|
|
81
|
+
)
|
|
82
|
+
assert f == 'dwd_avazu_ctr_deepmodel.config'
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
if __name__ == '__main__':
|
|
86
|
+
test()
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def copy_obj(proto_obj):
|
|
7
|
+
"""Make a copy of proto_obj so that later modifications of tmp_obj will have no impact on proto_obj.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
proto_obj: a protobuf message
|
|
11
|
+
Return:
|
|
12
|
+
a copy of proto_obj
|
|
13
|
+
"""
|
|
14
|
+
tmp_obj = type(proto_obj)()
|
|
15
|
+
tmp_obj.CopyFrom(proto_obj)
|
|
16
|
+
return tmp_obj
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_norm_embed_name(name, verbose=False):
|
|
20
|
+
"""For embedding export to redis.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
name: variable name
|
|
24
|
+
verbose: whether to dump the embed_names
|
|
25
|
+
Return:
|
|
26
|
+
embedding_name: normalized embedding_name
|
|
27
|
+
embedding_part_id: normalized embedding part_id
|
|
28
|
+
if embedding_weights not in name, return None, None
|
|
29
|
+
"""
|
|
30
|
+
name_toks = name.split('/')
|
|
31
|
+
for i in range(0, len(name_toks) - 1):
|
|
32
|
+
if name_toks[i + 1].startswith('embedding_weights:'):
|
|
33
|
+
var_id = name_toks[i + 1].replace('embedding_weights:', '')
|
|
34
|
+
tmp_name = '/'.join(name_toks[:i + 1])
|
|
35
|
+
if var_id != '0':
|
|
36
|
+
tmp_name = tmp_name + '_' + var_id
|
|
37
|
+
if verbose:
|
|
38
|
+
logging.info('norm %s to %s' % (name, tmp_name))
|
|
39
|
+
return tmp_name, 0
|
|
40
|
+
if i > 1 and name_toks[i + 1].startswith('part_') and \
|
|
41
|
+
name_toks[i] == 'embedding_weights':
|
|
42
|
+
tmp_name = '/'.join(name_toks[:i])
|
|
43
|
+
part_id = name_toks[i + 1].replace('part_', '')
|
|
44
|
+
part_toks = part_id.split(':')
|
|
45
|
+
if len(part_toks) >= 2 and part_toks[1] != '0':
|
|
46
|
+
tmp_name = tmp_name + '_' + part_toks[1]
|
|
47
|
+
if verbose:
|
|
48
|
+
logging.info('norm %s to %s' % (name, tmp_name))
|
|
49
|
+
return tmp_name, int(part_toks[0])
|
|
50
|
+
|
|
51
|
+
# input_layer/app_category_embedding/app_category_embedding_weights/SparseReshape
|
|
52
|
+
# => input_layer/app_category_embedding
|
|
53
|
+
for i in range(0, len(name_toks) - 1):
|
|
54
|
+
if name_toks[i + 1].endswith('_embedding_weights') or \
|
|
55
|
+
'_embedding_weights_' in name_toks[i + 1]:
|
|
56
|
+
tmp_name = '/'.join(name_toks[:i + 1])
|
|
57
|
+
if verbose:
|
|
58
|
+
logging.info('norm %s to %s' % (name, tmp_name))
|
|
59
|
+
return tmp_name, 0
|
|
60
|
+
# input_layer/app_category_embedding/embedding_weights
|
|
61
|
+
# => input_layer/app_category_embedding
|
|
62
|
+
for i in range(0, len(name_toks) - 1):
|
|
63
|
+
if name_toks[i + 1] == 'embedding_weights':
|
|
64
|
+
tmp_name = '/'.join(name_toks[:i + 1])
|
|
65
|
+
if verbose:
|
|
66
|
+
logging.info('norm %s to %s' % (name, tmp_name))
|
|
67
|
+
return tmp_name, 0
|
|
68
|
+
logging.warning('Failed to norm: %s' % name)
|
|
69
|
+
return None, None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def is_cache_from_redis(name, redis_cache_names):
|
|
73
|
+
"""Check whether name should be cached.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
name: string, the variable name to be checked
|
|
77
|
+
redis_cache_names: list of string, names which should be cached.
|
|
78
|
+
|
|
79
|
+
Return:
|
|
80
|
+
True if need to be cached
|
|
81
|
+
"""
|
|
82
|
+
tok = name.split('/')
|
|
83
|
+
if tok[0].startswith('input_layer'):
|
|
84
|
+
tok = tok[1:]
|
|
85
|
+
for y in redis_cache_names:
|
|
86
|
+
for k in tok:
|
|
87
|
+
if k.startswith(y):
|
|
88
|
+
logging.info('embedding %s will be cached[specified by %s]' % (name, y))
|
|
89
|
+
return True
|
|
90
|
+
return False
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
"""Define filters for restore."""
|
|
4
|
+
|
|
5
|
+
from abc import ABCMeta
|
|
6
|
+
from abc import abstractmethod
|
|
7
|
+
from enum import Enum
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Logical(Enum):
|
|
11
|
+
AND = 1
|
|
12
|
+
OR = 2
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Filter:
|
|
16
|
+
__metaclass__ = ABCMeta
|
|
17
|
+
|
|
18
|
+
def __init__(self):
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def keep(self, var_name):
|
|
23
|
+
"""Keep the var or not.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
var_name: input name of the var
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
True if the var will be kept, else False
|
|
30
|
+
"""
|
|
31
|
+
return True
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class KeywordFilter(Filter):
|
|
35
|
+
|
|
36
|
+
def __init__(self, pattern, exclusive=False):
|
|
37
|
+
"""Init KeywordFilter.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
pattern: keyword to be matched
|
|
41
|
+
exclusive: if True, var_name should include the pattern
|
|
42
|
+
else, var_name should not include the pattern
|
|
43
|
+
"""
|
|
44
|
+
self._pattern = pattern
|
|
45
|
+
self._exclusive = exclusive
|
|
46
|
+
|
|
47
|
+
def keep(self, var_name):
|
|
48
|
+
if not self._exclusive:
|
|
49
|
+
return self._pattern in var_name
|
|
50
|
+
else:
|
|
51
|
+
return self._pattern not in var_name
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class CombineFilter(Filter):
|
|
55
|
+
|
|
56
|
+
def __init__(self, filters, logical=Logical.AND):
|
|
57
|
+
"""Init CombineFilter.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
filters: a set of filters to be combined
|
|
61
|
+
logical: logical and/or combination of the filters
|
|
62
|
+
"""
|
|
63
|
+
self._filters = filters
|
|
64
|
+
self._logical = logical
|
|
65
|
+
|
|
66
|
+
def keep(self, var_name):
|
|
67
|
+
if self._logical == Logical.AND:
|
|
68
|
+
for one_filter in self._filters:
|
|
69
|
+
if not one_filter.keep(var_name):
|
|
70
|
+
return False
|
|
71
|
+
return True
|
|
72
|
+
elif self._logical == Logical.OR:
|
|
73
|
+
for one_filter in self._filters:
|
|
74
|
+
if one_filter.keep(var_name):
|
|
75
|
+
return True
|
|
76
|
+
return False
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class ScopeDrop:
|
|
80
|
+
"""For drop out scope prefix when restore variables from checkpoint."""
|
|
81
|
+
|
|
82
|
+
def __init__(self, scope_name):
|
|
83
|
+
self._scope_name = scope_name
|
|
84
|
+
if len(self._scope_name) >= 0:
|
|
85
|
+
if self._scope_name[-1] != '/':
|
|
86
|
+
self._scope_name += '/'
|
|
87
|
+
|
|
88
|
+
def update(self, var_name):
|
|
89
|
+
return var_name.replace(self._scope_name, '')
|