easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,571 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import six
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
from google.protobuf import struct_pb2
|
|
8
|
+
|
|
9
|
+
from easy_rec.python.layers.common_layers import EnhancedInputLayer
|
|
10
|
+
from easy_rec.python.layers.keras import MLP
|
|
11
|
+
from easy_rec.python.layers.keras import EmbeddingLayer
|
|
12
|
+
from easy_rec.python.layers.utils import Parameter
|
|
13
|
+
from easy_rec.python.protos import backbone_pb2
|
|
14
|
+
from easy_rec.python.utils.dag import DAG
|
|
15
|
+
from easy_rec.python.utils.load_class import load_keras_layer
|
|
16
|
+
from easy_rec.python.utils.tf_utils import add_elements_to_collection
|
|
17
|
+
|
|
18
|
+
if tf.__version__ >= '2.0':
|
|
19
|
+
tf = tf.compat.v1
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Package(object):
|
|
23
|
+
"""A sub DAG of tf ops for reuse."""
|
|
24
|
+
__packages = {}
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def has_backbone_block(name):
|
|
28
|
+
if 'backbone' not in Package.__packages:
|
|
29
|
+
return False
|
|
30
|
+
backbone = Package.__packages['backbone']
|
|
31
|
+
return backbone.has_block(name)
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def backbone_block_outputs(name):
|
|
35
|
+
if 'backbone' not in Package.__packages:
|
|
36
|
+
return None
|
|
37
|
+
backbone = Package.__packages['backbone']
|
|
38
|
+
return backbone.block_outputs(name)
|
|
39
|
+
|
|
40
|
+
def __init__(self, config, features, input_layer, l2_reg=None):
|
|
41
|
+
self._config = config
|
|
42
|
+
self._features = features
|
|
43
|
+
self._input_layer = input_layer
|
|
44
|
+
self._l2_reg = l2_reg
|
|
45
|
+
self._dag = DAG()
|
|
46
|
+
self._name_to_blocks = {}
|
|
47
|
+
self._name_to_layer = {}
|
|
48
|
+
self.reset_input_config(None)
|
|
49
|
+
self._block_outputs = {}
|
|
50
|
+
self._package_input = None
|
|
51
|
+
self._feature_group_inputs = {}
|
|
52
|
+
reuse = None if config.name == 'backbone' else tf.AUTO_REUSE
|
|
53
|
+
input_feature_groups = self._feature_group_inputs
|
|
54
|
+
|
|
55
|
+
for block in config.blocks:
|
|
56
|
+
if len(block.inputs) == 0:
|
|
57
|
+
raise ValueError('block takes at least one input: %s' % block.name)
|
|
58
|
+
self._dag.add_node(block.name)
|
|
59
|
+
self._name_to_blocks[block.name] = block
|
|
60
|
+
layer = block.WhichOneof('layer')
|
|
61
|
+
if layer in {'input_layer', 'raw_input', 'embedding_layer'}:
|
|
62
|
+
if len(block.inputs) != 1:
|
|
63
|
+
raise ValueError('input layer `%s` takes only one input' % block.name)
|
|
64
|
+
one_input = block.inputs[0]
|
|
65
|
+
name = one_input.WhichOneof('name')
|
|
66
|
+
if name != 'feature_group_name':
|
|
67
|
+
raise KeyError(
|
|
68
|
+
'`feature_group_name` should be set for input layer: ' +
|
|
69
|
+
block.name)
|
|
70
|
+
group = one_input.feature_group_name
|
|
71
|
+
if not input_layer.has_group(group):
|
|
72
|
+
raise KeyError('invalid feature group name: ' + group)
|
|
73
|
+
if group in input_feature_groups:
|
|
74
|
+
if layer == input_layer:
|
|
75
|
+
logging.warning('input `%s` already exists in other block' % group)
|
|
76
|
+
elif layer == 'raw_input':
|
|
77
|
+
input_fn = input_feature_groups[group]
|
|
78
|
+
self._name_to_layer[block.name] = input_fn
|
|
79
|
+
elif layer == 'embedding_layer':
|
|
80
|
+
inputs, vocab, weights = input_feature_groups[group]
|
|
81
|
+
block.embedding_layer.vocab_size = vocab
|
|
82
|
+
params = Parameter.make_from_pb(block.embedding_layer)
|
|
83
|
+
input_fn = EmbeddingLayer(params, block.name)
|
|
84
|
+
self._name_to_layer[block.name] = input_fn
|
|
85
|
+
else:
|
|
86
|
+
if layer == 'input_layer':
|
|
87
|
+
input_fn = EnhancedInputLayer(self._input_layer, self._features,
|
|
88
|
+
group, reuse)
|
|
89
|
+
input_feature_groups[group] = input_fn
|
|
90
|
+
elif layer == 'raw_input':
|
|
91
|
+
input_fn = self._input_layer.get_raw_features(self._features, group)
|
|
92
|
+
input_feature_groups[group] = input_fn
|
|
93
|
+
else: # embedding_layer
|
|
94
|
+
inputs, vocab, weights = self._input_layer.get_bucketized_features(
|
|
95
|
+
self._features, group)
|
|
96
|
+
block.embedding_layer.vocab_size = vocab
|
|
97
|
+
params = Parameter.make_from_pb(block.embedding_layer)
|
|
98
|
+
input_fn = EmbeddingLayer(params, block.name)
|
|
99
|
+
input_feature_groups[group] = (inputs, vocab, weights)
|
|
100
|
+
logging.info('add an embedding layer %s with vocab size %d',
|
|
101
|
+
block.name, vocab)
|
|
102
|
+
self._name_to_layer[block.name] = input_fn
|
|
103
|
+
else:
|
|
104
|
+
self.define_layers(layer, block, block.name, reuse)
|
|
105
|
+
|
|
106
|
+
# sequential layers
|
|
107
|
+
for i, layer_cnf in enumerate(block.layers):
|
|
108
|
+
layer = layer_cnf.WhichOneof('layer')
|
|
109
|
+
name_i = '%s_l%d' % (block.name, i)
|
|
110
|
+
self.define_layers(layer, layer_cnf, name_i, reuse)
|
|
111
|
+
|
|
112
|
+
num_groups = len(input_feature_groups)
|
|
113
|
+
num_blocks = len(self._name_to_blocks) - num_groups
|
|
114
|
+
assert num_blocks > 0, 'there must be at least one block in backbone'
|
|
115
|
+
|
|
116
|
+
num_pkg_input = 0
|
|
117
|
+
for block in config.blocks:
|
|
118
|
+
layer = block.WhichOneof('layer')
|
|
119
|
+
if layer in {'input_layer', 'raw_input', 'embedding_layer'}:
|
|
120
|
+
continue
|
|
121
|
+
name = block.name
|
|
122
|
+
if name in input_feature_groups:
|
|
123
|
+
raise KeyError('block name can not be one of feature groups:' + name)
|
|
124
|
+
for input_node in block.inputs:
|
|
125
|
+
input_type = input_node.WhichOneof('name')
|
|
126
|
+
input_name = getattr(input_node, input_type)
|
|
127
|
+
if input_type == 'use_package_input':
|
|
128
|
+
assert input_name, 'use_package_input can not set false'
|
|
129
|
+
num_pkg_input += 1
|
|
130
|
+
continue
|
|
131
|
+
if input_type == 'package_name':
|
|
132
|
+
num_pkg_input += 1
|
|
133
|
+
self._dag.add_node_if_not_exists(input_name)
|
|
134
|
+
self._dag.add_edge(input_name, name)
|
|
135
|
+
if input_node.HasField('package_input'):
|
|
136
|
+
pkg_input_name = input_node.package_input
|
|
137
|
+
self._dag.add_node_if_not_exists(pkg_input_name)
|
|
138
|
+
self._dag.add_edge(pkg_input_name, input_name)
|
|
139
|
+
continue
|
|
140
|
+
iname = input_name
|
|
141
|
+
if iname in self._name_to_blocks:
|
|
142
|
+
assert iname != name, 'input name can not equal to block name:' + iname
|
|
143
|
+
self._dag.add_edge(iname, name)
|
|
144
|
+
else:
|
|
145
|
+
is_fea_group = input_type == 'feature_group_name'
|
|
146
|
+
if is_fea_group and input_layer.has_group(iname):
|
|
147
|
+
logging.info('adding an input_layer block: ' + iname)
|
|
148
|
+
new_block = backbone_pb2.Block()
|
|
149
|
+
new_block.name = iname
|
|
150
|
+
input_cfg = backbone_pb2.Input()
|
|
151
|
+
input_cfg.feature_group_name = iname
|
|
152
|
+
new_block.inputs.append(input_cfg)
|
|
153
|
+
new_block.input_layer.CopyFrom(backbone_pb2.InputLayer())
|
|
154
|
+
self._name_to_blocks[iname] = new_block
|
|
155
|
+
self._dag.add_node(iname)
|
|
156
|
+
self._dag.add_edge(iname, name)
|
|
157
|
+
if iname in input_feature_groups:
|
|
158
|
+
fn = input_feature_groups[iname]
|
|
159
|
+
else:
|
|
160
|
+
fn = EnhancedInputLayer(self._input_layer, self._features, iname)
|
|
161
|
+
input_feature_groups[iname] = fn
|
|
162
|
+
self._name_to_layer[iname] = fn
|
|
163
|
+
elif Package.has_backbone_block(iname):
|
|
164
|
+
backbone = Package.__packages['backbone']
|
|
165
|
+
backbone._dag.add_node_if_not_exists(self._config.name)
|
|
166
|
+
backbone._dag.add_edge(iname, self._config.name)
|
|
167
|
+
num_pkg_input += 1
|
|
168
|
+
else:
|
|
169
|
+
raise KeyError(
|
|
170
|
+
'invalid input name `%s`, must be the name of either a feature group or an another block'
|
|
171
|
+
% iname)
|
|
172
|
+
num_groups = len(input_feature_groups)
|
|
173
|
+
assert num_pkg_input > 0 or num_groups > 0, 'there must be at least one input layer/feature group'
|
|
174
|
+
|
|
175
|
+
if len(config.concat_blocks) == 0 and len(config.output_blocks) == 0:
|
|
176
|
+
leaf = self._dag.all_leaves()
|
|
177
|
+
logging.warning(
|
|
178
|
+
'%s has no `concat_blocks` or `output_blocks`, try to concat all leaf blocks: %s'
|
|
179
|
+
% (config.name, ','.join(leaf)))
|
|
180
|
+
self._config.concat_blocks.extend(leaf)
|
|
181
|
+
|
|
182
|
+
Package.__packages[self._config.name] = self
|
|
183
|
+
logging.info('%s layers: %s' %
|
|
184
|
+
(config.name, ','.join(self._name_to_layer.keys())))
|
|
185
|
+
|
|
186
|
+
def define_layers(self, layer, layer_cnf, name, reuse):
|
|
187
|
+
if layer == 'keras_layer':
|
|
188
|
+
layer_obj = self.load_keras_layer(layer_cnf.keras_layer, name, reuse)
|
|
189
|
+
self._name_to_layer[name] = layer_obj
|
|
190
|
+
elif layer == 'recurrent':
|
|
191
|
+
keras_layer = layer_cnf.recurrent.keras_layer
|
|
192
|
+
for i in range(layer_cnf.recurrent.num_steps):
|
|
193
|
+
name_i = '%s_%d' % (name, i)
|
|
194
|
+
layer_obj = self.load_keras_layer(keras_layer, name_i, reuse)
|
|
195
|
+
self._name_to_layer[name_i] = layer_obj
|
|
196
|
+
elif layer == 'repeat':
|
|
197
|
+
keras_layer = layer_cnf.repeat.keras_layer
|
|
198
|
+
for i in range(layer_cnf.repeat.num_repeat):
|
|
199
|
+
name_i = '%s_%d' % (name, i)
|
|
200
|
+
layer_obj = self.load_keras_layer(keras_layer, name_i, reuse)
|
|
201
|
+
self._name_to_layer[name_i] = layer_obj
|
|
202
|
+
|
|
203
|
+
def reset_input_config(self, config):
|
|
204
|
+
self.input_config = config
|
|
205
|
+
|
|
206
|
+
def set_package_input(self, pkg_input):
|
|
207
|
+
self._package_input = pkg_input
|
|
208
|
+
|
|
209
|
+
def has_block(self, name):
|
|
210
|
+
return name in self._name_to_blocks
|
|
211
|
+
|
|
212
|
+
def block_outputs(self, name):
|
|
213
|
+
return self._block_outputs.get(name, None)
|
|
214
|
+
|
|
215
|
+
def block_input(self, config, block_outputs, training=None, **kwargs):
|
|
216
|
+
inputs = []
|
|
217
|
+
for input_node in config.inputs:
|
|
218
|
+
input_type = input_node.WhichOneof('name')
|
|
219
|
+
input_name = getattr(input_node, input_type)
|
|
220
|
+
if input_type == 'use_package_input':
|
|
221
|
+
input_feature = self._package_input
|
|
222
|
+
input_name = 'package_input'
|
|
223
|
+
elif input_type == 'package_name':
|
|
224
|
+
if input_name not in Package.__packages:
|
|
225
|
+
raise KeyError('package name `%s` does not exists' % input_name)
|
|
226
|
+
package = Package.__packages[input_name]
|
|
227
|
+
if input_node.HasField('reset_input'):
|
|
228
|
+
package.reset_input_config(input_node.reset_input)
|
|
229
|
+
if input_node.HasField('package_input'):
|
|
230
|
+
pkg_input_name = input_node.package_input
|
|
231
|
+
if pkg_input_name in block_outputs:
|
|
232
|
+
pkg_input = block_outputs[pkg_input_name]
|
|
233
|
+
else:
|
|
234
|
+
if pkg_input_name not in Package.__packages:
|
|
235
|
+
raise KeyError('package name `%s` does not exists' %
|
|
236
|
+
pkg_input_name)
|
|
237
|
+
inner_package = Package.__packages[pkg_input_name]
|
|
238
|
+
pkg_input = inner_package(training)
|
|
239
|
+
if input_node.HasField('package_input_fn'):
|
|
240
|
+
fn = eval(input_node.package_input_fn)
|
|
241
|
+
pkg_input = fn(pkg_input)
|
|
242
|
+
package.set_package_input(pkg_input)
|
|
243
|
+
input_feature = package(training, **kwargs)
|
|
244
|
+
elif input_name in block_outputs:
|
|
245
|
+
input_feature = block_outputs[input_name]
|
|
246
|
+
else:
|
|
247
|
+
input_feature = Package.backbone_block_outputs(input_name)
|
|
248
|
+
|
|
249
|
+
if input_feature is None:
|
|
250
|
+
raise KeyError('input name `%s` does not exists' % input_name)
|
|
251
|
+
|
|
252
|
+
if input_node.ignore_input:
|
|
253
|
+
continue
|
|
254
|
+
if input_node.HasField('input_slice'):
|
|
255
|
+
fn = eval('lambda x: x' + input_node.input_slice.strip())
|
|
256
|
+
input_feature = fn(input_feature)
|
|
257
|
+
if input_node.HasField('input_fn'):
|
|
258
|
+
with tf.name_scope(config.name):
|
|
259
|
+
fn = eval(input_node.input_fn)
|
|
260
|
+
input_feature = fn(input_feature)
|
|
261
|
+
inputs.append(input_feature)
|
|
262
|
+
|
|
263
|
+
if config.merge_inputs_into_list:
|
|
264
|
+
output = inputs
|
|
265
|
+
else:
|
|
266
|
+
try:
|
|
267
|
+
output = merge_inputs(inputs, config.input_concat_axis, config.name)
|
|
268
|
+
except ValueError as e:
|
|
269
|
+
msg = getattr(e, 'message', str(e))
|
|
270
|
+
logging.error('merge inputs of block %s failed: %s', config.name, msg)
|
|
271
|
+
raise e
|
|
272
|
+
|
|
273
|
+
if config.HasField('extra_input_fn'):
|
|
274
|
+
fn = eval(config.extra_input_fn)
|
|
275
|
+
output = fn(output)
|
|
276
|
+
return output
|
|
277
|
+
|
|
278
|
+
def __call__(self, is_training, **kwargs):
|
|
279
|
+
with tf.name_scope(self._config.name):
|
|
280
|
+
return self.call(is_training, **kwargs)
|
|
281
|
+
|
|
282
|
+
def call(self, is_training, **kwargs):
|
|
283
|
+
block_outputs = {}
|
|
284
|
+
self._block_outputs = block_outputs # reset
|
|
285
|
+
blocks = self._dag.topological_sort()
|
|
286
|
+
logging.info(self._config.name + ' topological order: ' + ','.join(blocks))
|
|
287
|
+
for block in blocks:
|
|
288
|
+
if block not in self._name_to_blocks:
|
|
289
|
+
assert block in Package.__packages, 'invalid block: ' + block
|
|
290
|
+
continue
|
|
291
|
+
config = self._name_to_blocks[block]
|
|
292
|
+
if config.layers: # sequential layers
|
|
293
|
+
logging.info('call sequential %d layers' % len(config.layers))
|
|
294
|
+
output = self.block_input(config, block_outputs, is_training, **kwargs)
|
|
295
|
+
for i, layer in enumerate(config.layers):
|
|
296
|
+
name_i = '%s_l%d' % (block, i)
|
|
297
|
+
output = self.call_layer(output, layer, name_i, is_training, **kwargs)
|
|
298
|
+
block_outputs[block] = output
|
|
299
|
+
continue
|
|
300
|
+
# just one of layer
|
|
301
|
+
layer = config.WhichOneof('layer')
|
|
302
|
+
if layer is None: # identity layer
|
|
303
|
+
output = self.block_input(config, block_outputs, is_training, **kwargs)
|
|
304
|
+
block_outputs[block] = output
|
|
305
|
+
elif layer == 'raw_input':
|
|
306
|
+
block_outputs[block] = self._name_to_layer[block]
|
|
307
|
+
elif layer == 'input_layer':
|
|
308
|
+
input_fn = self._name_to_layer[block]
|
|
309
|
+
input_config = config.input_layer
|
|
310
|
+
if self.input_config is not None:
|
|
311
|
+
input_config = self.input_config
|
|
312
|
+
input_fn.reset(input_config, is_training)
|
|
313
|
+
block_outputs[block] = input_fn(input_config, is_training)
|
|
314
|
+
elif layer == 'embedding_layer':
|
|
315
|
+
input_fn = self._name_to_layer[block]
|
|
316
|
+
feature_group = config.inputs[0].feature_group_name
|
|
317
|
+
inputs, _, weights = self._feature_group_inputs[feature_group]
|
|
318
|
+
block_outputs[block] = input_fn([inputs, weights], is_training)
|
|
319
|
+
else:
|
|
320
|
+
with tf.name_scope(block + '_input'):
|
|
321
|
+
inputs = self.block_input(config, block_outputs, is_training,
|
|
322
|
+
**kwargs)
|
|
323
|
+
output = self.call_layer(inputs, config, block, is_training, **kwargs)
|
|
324
|
+
block_outputs[block] = output
|
|
325
|
+
|
|
326
|
+
outputs = []
|
|
327
|
+
for output in self._config.output_blocks:
|
|
328
|
+
if output in block_outputs:
|
|
329
|
+
temp = block_outputs[output]
|
|
330
|
+
outputs.append(temp)
|
|
331
|
+
else:
|
|
332
|
+
raise ValueError('No output `%s` of backbone to be concat' % output)
|
|
333
|
+
if outputs:
|
|
334
|
+
return outputs
|
|
335
|
+
|
|
336
|
+
for output in self._config.concat_blocks:
|
|
337
|
+
if output in block_outputs:
|
|
338
|
+
temp = block_outputs[output]
|
|
339
|
+
outputs.append(temp)
|
|
340
|
+
else:
|
|
341
|
+
raise ValueError('No output `%s` of backbone to be concat' % output)
|
|
342
|
+
try:
|
|
343
|
+
output = merge_inputs(outputs, msg='backbone')
|
|
344
|
+
except ValueError as e:
|
|
345
|
+
msg = getattr(e, 'message', str(e))
|
|
346
|
+
logging.error("merge backbone's output failed: %s", msg)
|
|
347
|
+
raise e
|
|
348
|
+
return output
|
|
349
|
+
|
|
350
|
+
def load_keras_layer(self, layer_conf, name, reuse=None):
|
|
351
|
+
layer_cls, customize = load_keras_layer(layer_conf.class_name)
|
|
352
|
+
if layer_cls is None:
|
|
353
|
+
raise ValueError('Invalid keras layer class name: ' +
|
|
354
|
+
layer_conf.class_name)
|
|
355
|
+
|
|
356
|
+
param_type = layer_conf.WhichOneof('params')
|
|
357
|
+
if customize:
|
|
358
|
+
if param_type is None or param_type == 'st_params':
|
|
359
|
+
params = Parameter(layer_conf.st_params, True, l2_reg=self._l2_reg)
|
|
360
|
+
else:
|
|
361
|
+
pb_params = getattr(layer_conf, param_type)
|
|
362
|
+
params = Parameter(pb_params, False, l2_reg=self._l2_reg)
|
|
363
|
+
|
|
364
|
+
has_reuse = True
|
|
365
|
+
try:
|
|
366
|
+
from funcsigs import signature
|
|
367
|
+
sig = signature(layer_cls.__init__)
|
|
368
|
+
has_reuse = 'reuse' in sig.parameters.keys()
|
|
369
|
+
except ImportError:
|
|
370
|
+
try:
|
|
371
|
+
from sklearn.externals.funcsigs import signature
|
|
372
|
+
sig = signature(layer_cls.__init__)
|
|
373
|
+
has_reuse = 'reuse' in sig.parameters.keys()
|
|
374
|
+
except ImportError:
|
|
375
|
+
logging.warning('import funcsigs failed')
|
|
376
|
+
|
|
377
|
+
if has_reuse:
|
|
378
|
+
layer = layer_cls(params, name=name, reuse=reuse)
|
|
379
|
+
else:
|
|
380
|
+
layer = layer_cls(params, name=name)
|
|
381
|
+
return layer, customize
|
|
382
|
+
elif param_type is None: # internal keras layer
|
|
383
|
+
layer = layer_cls(name=name)
|
|
384
|
+
return layer, customize
|
|
385
|
+
else:
|
|
386
|
+
assert param_type == 'st_params', 'internal keras layer only support st_params'
|
|
387
|
+
try:
|
|
388
|
+
kwargs = convert_to_dict(layer_conf.st_params)
|
|
389
|
+
logging.info('call %s layer with params %r' %
|
|
390
|
+
(layer_conf.class_name, kwargs))
|
|
391
|
+
layer = layer_cls(name=name, **kwargs)
|
|
392
|
+
except TypeError as e:
|
|
393
|
+
logging.warning(e)
|
|
394
|
+
args = map(format_value, layer_conf.st_params.values())
|
|
395
|
+
logging.info('try to call %s layer with params %r' %
|
|
396
|
+
(layer_conf.class_name, args))
|
|
397
|
+
layer = layer_cls(*args, name=name)
|
|
398
|
+
return layer, customize
|
|
399
|
+
|
|
400
|
+
def call_keras_layer(self, inputs, name, training, **kwargs):
|
|
401
|
+
"""Call predefined Keras Layer, which can be reused."""
|
|
402
|
+
layer, customize = self._name_to_layer[name]
|
|
403
|
+
cls = layer.__class__.__name__
|
|
404
|
+
if customize:
|
|
405
|
+
try:
|
|
406
|
+
output = layer(inputs, training=training, **kwargs)
|
|
407
|
+
except Exception as e:
|
|
408
|
+
msg = getattr(e, 'message', str(e))
|
|
409
|
+
logging.error('call keras layer %s (%s) failed: %s' % (name, cls, msg))
|
|
410
|
+
raise e
|
|
411
|
+
else:
|
|
412
|
+
try:
|
|
413
|
+
output = layer(inputs, training=training)
|
|
414
|
+
if cls == 'BatchNormalization':
|
|
415
|
+
add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS)
|
|
416
|
+
except TypeError:
|
|
417
|
+
output = layer(inputs)
|
|
418
|
+
return output
|
|
419
|
+
|
|
420
|
+
def call_layer(self, inputs, config, name, training, **kwargs):
|
|
421
|
+
layer_name = config.WhichOneof('layer')
|
|
422
|
+
if layer_name == 'keras_layer':
|
|
423
|
+
return self.call_keras_layer(inputs, name, training, **kwargs)
|
|
424
|
+
if layer_name == 'lambda':
|
|
425
|
+
conf = getattr(config, 'lambda')
|
|
426
|
+
fn = eval(conf.expression)
|
|
427
|
+
return fn(inputs)
|
|
428
|
+
if layer_name == 'repeat':
|
|
429
|
+
conf = config.repeat
|
|
430
|
+
n_loop = conf.num_repeat
|
|
431
|
+
outputs = []
|
|
432
|
+
for i in range(n_loop):
|
|
433
|
+
name_i = '%s_%d' % (name, i)
|
|
434
|
+
ly_inputs = inputs
|
|
435
|
+
if conf.HasField('input_slice'):
|
|
436
|
+
fn = eval('lambda x, i: x' + conf.input_slice.strip())
|
|
437
|
+
ly_inputs = fn(ly_inputs, i)
|
|
438
|
+
if conf.HasField('input_fn'):
|
|
439
|
+
with tf.name_scope(config.name):
|
|
440
|
+
fn = eval(conf.input_fn)
|
|
441
|
+
ly_inputs = fn(ly_inputs, i)
|
|
442
|
+
output = self.call_keras_layer(ly_inputs, name_i, training, **kwargs)
|
|
443
|
+
outputs.append(output)
|
|
444
|
+
if len(outputs) == 1:
|
|
445
|
+
return outputs[0]
|
|
446
|
+
if conf.HasField('output_concat_axis'):
|
|
447
|
+
return tf.concat(outputs, conf.output_concat_axis)
|
|
448
|
+
return outputs
|
|
449
|
+
if layer_name == 'recurrent':
|
|
450
|
+
conf = config.recurrent
|
|
451
|
+
fixed_input_index = -1
|
|
452
|
+
if conf.HasField('fixed_input_index'):
|
|
453
|
+
fixed_input_index = conf.fixed_input_index
|
|
454
|
+
if fixed_input_index >= 0:
|
|
455
|
+
assert type(inputs) in (tuple, list), '%s inputs must be a list'
|
|
456
|
+
output = inputs
|
|
457
|
+
for i in range(conf.num_steps):
|
|
458
|
+
name_i = '%s_%d' % (name, i)
|
|
459
|
+
output_i = self.call_keras_layer(output, name_i, training, **kwargs)
|
|
460
|
+
if fixed_input_index >= 0:
|
|
461
|
+
j = 0
|
|
462
|
+
for idx in range(len(output)):
|
|
463
|
+
if idx == fixed_input_index:
|
|
464
|
+
continue
|
|
465
|
+
if type(output_i) in (tuple, list):
|
|
466
|
+
output[idx] = output_i[j]
|
|
467
|
+
else:
|
|
468
|
+
output[idx] = output_i
|
|
469
|
+
j += 1
|
|
470
|
+
else:
|
|
471
|
+
output = output_i
|
|
472
|
+
if fixed_input_index >= 0:
|
|
473
|
+
del output[fixed_input_index]
|
|
474
|
+
if len(output) == 1:
|
|
475
|
+
return output[0]
|
|
476
|
+
return output
|
|
477
|
+
return output
|
|
478
|
+
|
|
479
|
+
raise NotImplementedError('Unsupported backbone layer:' + layer_name)
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
class Backbone(object):
|
|
483
|
+
"""Configurable Backbone Network."""
|
|
484
|
+
|
|
485
|
+
def __init__(self, config, features, input_layer, l2_reg=None):
|
|
486
|
+
self._config = config
|
|
487
|
+
self._l2_reg = l2_reg
|
|
488
|
+
main_pkg = backbone_pb2.BlockPackage()
|
|
489
|
+
main_pkg.name = 'backbone'
|
|
490
|
+
main_pkg.blocks.MergeFrom(config.blocks)
|
|
491
|
+
if config.concat_blocks:
|
|
492
|
+
main_pkg.concat_blocks.extend(config.concat_blocks)
|
|
493
|
+
if config.output_blocks:
|
|
494
|
+
main_pkg.output_blocks.extend(config.output_blocks)
|
|
495
|
+
self._main_pkg = Package(main_pkg, features, input_layer, l2_reg)
|
|
496
|
+
for pkg in config.packages:
|
|
497
|
+
Package(pkg, features, input_layer, l2_reg)
|
|
498
|
+
|
|
499
|
+
def __call__(self, is_training, **kwargs):
|
|
500
|
+
output = self._main_pkg(is_training, **kwargs)
|
|
501
|
+
|
|
502
|
+
if self._config.HasField('top_mlp'):
|
|
503
|
+
params = Parameter.make_from_pb(self._config.top_mlp)
|
|
504
|
+
params.l2_regularizer = self._l2_reg
|
|
505
|
+
final_mlp = MLP(params, name='backbone_top_mlp')
|
|
506
|
+
if type(output) in (list, tuple):
|
|
507
|
+
output = tf.concat(output, axis=-1)
|
|
508
|
+
output = final_mlp(output, training=is_training, **kwargs)
|
|
509
|
+
return output
|
|
510
|
+
|
|
511
|
+
@classmethod
|
|
512
|
+
def wide_embed_dim(cls, config):
|
|
513
|
+
wide_embed_dim = None
|
|
514
|
+
for pkg in config.packages:
|
|
515
|
+
wide_embed_dim = get_wide_embed_dim(pkg.blocks, wide_embed_dim)
|
|
516
|
+
return get_wide_embed_dim(config.blocks, wide_embed_dim)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
def get_wide_embed_dim(blocks, wide_embed_dim=None):
|
|
520
|
+
for block in blocks:
|
|
521
|
+
layer = block.WhichOneof('layer')
|
|
522
|
+
if layer == 'input_layer':
|
|
523
|
+
if block.input_layer.HasField('wide_output_dim'):
|
|
524
|
+
wide_dim = block.input_layer.wide_output_dim
|
|
525
|
+
if wide_embed_dim:
|
|
526
|
+
assert wide_embed_dim == wide_dim, 'wide_output_dim must be consistent'
|
|
527
|
+
else:
|
|
528
|
+
wide_embed_dim = wide_dim
|
|
529
|
+
return wide_embed_dim
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def merge_inputs(inputs, axis=-1, msg=''):
|
|
533
|
+
if len(inputs) == 0:
|
|
534
|
+
raise ValueError('no inputs to be concat:' + msg)
|
|
535
|
+
if len(inputs) == 1:
|
|
536
|
+
return inputs[0]
|
|
537
|
+
|
|
538
|
+
from functools import reduce
|
|
539
|
+
if all(map(lambda x: type(x) == list, inputs)):
|
|
540
|
+
# merge multiple lists into a list
|
|
541
|
+
return reduce(lambda x, y: x + y, inputs)
|
|
542
|
+
|
|
543
|
+
if any(map(lambda x: type(x) == list, inputs)):
|
|
544
|
+
logging.warning('%s: try to merge inputs into list' % msg)
|
|
545
|
+
return reduce(lambda x, y: x + y,
|
|
546
|
+
[e if type(e) == list else [e] for e in inputs])
|
|
547
|
+
|
|
548
|
+
if axis != -1:
|
|
549
|
+
logging.info('concat inputs %s axis=%d' % (msg, axis))
|
|
550
|
+
return tf.concat(inputs, axis=axis)
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
def format_value(value):
|
|
554
|
+
value_type = type(value)
|
|
555
|
+
if value_type == six.text_type:
|
|
556
|
+
return str(value)
|
|
557
|
+
if value_type == float:
|
|
558
|
+
int_v = int(value)
|
|
559
|
+
return int_v if int_v == value else value
|
|
560
|
+
if value_type == struct_pb2.ListValue:
|
|
561
|
+
return map(format_value, value)
|
|
562
|
+
if value_type == struct_pb2.Struct:
|
|
563
|
+
return convert_to_dict(value)
|
|
564
|
+
return value
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
def convert_to_dict(struct):
|
|
568
|
+
kwargs = {}
|
|
569
|
+
for key, value in struct.items():
|
|
570
|
+
kwargs[str(key)] = format_value(value)
|
|
571
|
+
return kwargs
|