easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,648 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
"""This API defines FeatureColumn for sequential input.
|
|
17
|
+
|
|
18
|
+
NOTE: This API is a work in progress and will likely be changing frequently.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import absolute_import
|
|
22
|
+
from __future__ import division
|
|
23
|
+
from __future__ import print_function
|
|
24
|
+
|
|
25
|
+
import collections
|
|
26
|
+
|
|
27
|
+
from tensorflow.python.framework import dtypes
|
|
28
|
+
from tensorflow.python.framework import ops
|
|
29
|
+
from tensorflow.python.framework import tensor_shape
|
|
30
|
+
from tensorflow.python.ops import array_ops
|
|
31
|
+
from tensorflow.python.ops import check_ops
|
|
32
|
+
from tensorflow.python.ops import math_ops
|
|
33
|
+
from tensorflow.python.ops import parsing_ops
|
|
34
|
+
from tensorflow.python.ops import sparse_ops
|
|
35
|
+
|
|
36
|
+
from easy_rec.python.compat.feature_column import feature_column as fc_v1
|
|
37
|
+
from easy_rec.python.compat.feature_column import feature_column_v2 as fc
|
|
38
|
+
from easy_rec.python.compat.feature_column import utils as fc_utils
|
|
39
|
+
|
|
40
|
+
# pylint: disable=protected-access
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SequenceFeatures(fc._BaseFeaturesLayer):
|
|
44
|
+
"""A layer for sequence input.
|
|
45
|
+
|
|
46
|
+
All `feature_columns` must be sequence dense columns with the same
|
|
47
|
+
`sequence_length`. The output of this method can be fed into sequence
|
|
48
|
+
networks, such as RNN.
|
|
49
|
+
|
|
50
|
+
The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`.
|
|
51
|
+
`T` is the maximum sequence length for this batch, which could differ from
|
|
52
|
+
batch to batch.
|
|
53
|
+
|
|
54
|
+
If multiple `feature_columns` are given with `Di` `num_elements` each, their
|
|
55
|
+
outputs are concatenated. So, the final `Tensor` has shape
|
|
56
|
+
`[batch_size, T, D0 + D1 + ... + Dn]`.
|
|
57
|
+
|
|
58
|
+
Example:
|
|
59
|
+
```python
|
|
60
|
+
rating = sequence_numeric_column('rating')
|
|
61
|
+
watches = sequence_categorical_column_with_identity(
|
|
62
|
+
'watches', num_buckets=1000)
|
|
63
|
+
watches_embedding = embedding_column(watches, dimension=10)
|
|
64
|
+
columns = [rating, watches_embedding]
|
|
65
|
+
|
|
66
|
+
sequence_input_layer = SequenceFeatures(columns)
|
|
67
|
+
features = tf.io.parse_example(...,
|
|
68
|
+
features=make_parse_example_spec(columns))
|
|
69
|
+
sequence_input, sequence_length = sequence_input_layer(features)
|
|
70
|
+
sequence_length_mask = tf.sequence_mask(sequence_length)
|
|
71
|
+
|
|
72
|
+
rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
|
|
73
|
+
rnn_layer = tf.keras.layers.RNN(rnn_cell)
|
|
74
|
+
outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
|
|
75
|
+
```
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(self, feature_columns, trainable=True, name=None, **kwargs):
|
|
79
|
+
"""Constructs a SequenceFeatures layer.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
feature_columns: An iterable of dense sequence columns. Valid columns are
|
|
83
|
+
- `embedding_column` that wraps a `sequence_categorical_column_with_*`
|
|
84
|
+
- `sequence_numeric_column`.
|
|
85
|
+
trainable: Boolean, whether the layer's variables will be updated via
|
|
86
|
+
gradient descent during training.
|
|
87
|
+
name: Name to give to the SequenceFeatures.
|
|
88
|
+
**kwargs: Keyword arguments to construct a layer.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
ValueError: If any of the `feature_columns` is not a
|
|
92
|
+
`SequenceDenseColumn`.
|
|
93
|
+
"""
|
|
94
|
+
super(SequenceFeatures, self).__init__(
|
|
95
|
+
feature_columns=feature_columns,
|
|
96
|
+
trainable=trainable,
|
|
97
|
+
name=name,
|
|
98
|
+
expected_column_type=fc.SequenceDenseColumn,
|
|
99
|
+
**kwargs)
|
|
100
|
+
|
|
101
|
+
def _target_shape(self, input_shape, total_elements):
|
|
102
|
+
return (input_shape[0], input_shape[1], total_elements)
|
|
103
|
+
|
|
104
|
+
def call(self, features):
|
|
105
|
+
"""Returns sequence input corresponding to the `feature_columns`.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
features: A dict mapping keys to tensors.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
An `(input_layer, sequence_length)` tuple where:
|
|
112
|
+
- input_layer: A float `Tensor` of shape `[batch_size, T, D]`.
|
|
113
|
+
`T` is the maximum sequence length for this batch, which could differ
|
|
114
|
+
from batch to batch. `D` is the sum of `num_elements` for all
|
|
115
|
+
`feature_columns`.
|
|
116
|
+
- sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence
|
|
117
|
+
length for each example.
|
|
118
|
+
|
|
119
|
+
Raises:
|
|
120
|
+
ValueError: If features are not a dictionary.
|
|
121
|
+
"""
|
|
122
|
+
if not isinstance(features, dict):
|
|
123
|
+
raise ValueError('We expected a dictionary here. Instead we got: ',
|
|
124
|
+
features)
|
|
125
|
+
transformation_cache = fc.FeatureTransformationCache(features)
|
|
126
|
+
output_tensors = []
|
|
127
|
+
sequence_lengths = []
|
|
128
|
+
|
|
129
|
+
for column in self._feature_columns:
|
|
130
|
+
with ops.name_scope(column.name):
|
|
131
|
+
dense_tensor, sequence_length = column.get_sequence_dense_tensor(
|
|
132
|
+
transformation_cache, self._state_manager)
|
|
133
|
+
# Flattens the final dimension to produce a 3D Tensor.
|
|
134
|
+
output_tensors.append(self._process_dense_tensor(column, dense_tensor))
|
|
135
|
+
sequence_lengths.append(sequence_length)
|
|
136
|
+
|
|
137
|
+
# Check and process sequence lengths.
|
|
138
|
+
fc._verify_static_batch_size_equality(sequence_lengths,
|
|
139
|
+
self._feature_columns)
|
|
140
|
+
sequence_length = _assert_all_equal_and_return(sequence_lengths)
|
|
141
|
+
|
|
142
|
+
return self._verify_and_concat_tensors(output_tensors), sequence_length
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def concatenate_context_input(context_input, sequence_input):
|
|
146
|
+
"""Replicates `context_input` across all timesteps of `sequence_input`.
|
|
147
|
+
|
|
148
|
+
Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
|
|
149
|
+
This value is appended to `sequence_input` on dimension 2 and the result is
|
|
150
|
+
returned.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
|
|
154
|
+
sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
|
|
155
|
+
padded_length, d0]`.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
|
|
159
|
+
d0 + d1]`.
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
ValueError: If `sequence_input` does not have rank 3 or `context_input` does
|
|
163
|
+
not have rank 2.
|
|
164
|
+
"""
|
|
165
|
+
seq_rank_check = check_ops.assert_rank(
|
|
166
|
+
sequence_input,
|
|
167
|
+
3,
|
|
168
|
+
message='sequence_input must have rank 3',
|
|
169
|
+
data=[array_ops.shape(sequence_input)])
|
|
170
|
+
seq_type_check = check_ops.assert_type(
|
|
171
|
+
sequence_input,
|
|
172
|
+
dtypes.float32,
|
|
173
|
+
message='sequence_input must have dtype float32; got {}.'.format(
|
|
174
|
+
sequence_input.dtype))
|
|
175
|
+
ctx_rank_check = check_ops.assert_rank(
|
|
176
|
+
context_input,
|
|
177
|
+
2,
|
|
178
|
+
message='context_input must have rank 2',
|
|
179
|
+
data=[array_ops.shape(context_input)])
|
|
180
|
+
ctx_type_check = check_ops.assert_type(
|
|
181
|
+
context_input,
|
|
182
|
+
dtypes.float32,
|
|
183
|
+
message='context_input must have dtype float32; got {}.'.format(
|
|
184
|
+
context_input.dtype))
|
|
185
|
+
with ops.control_dependencies(
|
|
186
|
+
[seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
|
|
187
|
+
padded_length = array_ops.shape(sequence_input)[1]
|
|
188
|
+
tiled_context_input = array_ops.tile(
|
|
189
|
+
array_ops.expand_dims(context_input, 1),
|
|
190
|
+
array_ops.concat([[1], [padded_length], [1]], 0))
|
|
191
|
+
return array_ops.concat([sequence_input, tiled_context_input], 2)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def sequence_categorical_column_with_identity(key,
|
|
195
|
+
num_buckets,
|
|
196
|
+
default_value=None,
|
|
197
|
+
feature_name=None):
|
|
198
|
+
"""Returns a feature column that represents sequences of integers.
|
|
199
|
+
|
|
200
|
+
Pass this to `embedding_column` or `indicator_column` to convert sequence
|
|
201
|
+
categorical data into dense representation for input to sequence NN, such as
|
|
202
|
+
RNN.
|
|
203
|
+
|
|
204
|
+
Example:
|
|
205
|
+
|
|
206
|
+
```python
|
|
207
|
+
watches = sequence_categorical_column_with_identity(
|
|
208
|
+
'watches', num_buckets=1000)
|
|
209
|
+
watches_embedding = embedding_column(watches, dimension=10)
|
|
210
|
+
columns = [watches_embedding]
|
|
211
|
+
|
|
212
|
+
features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
|
|
213
|
+
sequence_feature_layer = SequenceFeatures(columns)
|
|
214
|
+
sequence_input, sequence_length = sequence_feature_layer(features)
|
|
215
|
+
sequence_length_mask = tf.sequence_mask(sequence_length)
|
|
216
|
+
|
|
217
|
+
rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
|
|
218
|
+
rnn_layer = tf.keras.layers.RNN(rnn_cell)
|
|
219
|
+
outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
|
|
220
|
+
```
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
key: A unique string identifying the input feature.
|
|
224
|
+
num_buckets: Range of inputs. Namely, inputs are expected to be in the
|
|
225
|
+
range `[0, num_buckets)`.
|
|
226
|
+
default_value: If `None`, this column's graph operations will fail for
|
|
227
|
+
out-of-range inputs. Otherwise, this value must be in the range
|
|
228
|
+
`[0, num_buckets)`, and will replace out-of-range inputs.
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
A `SequenceCategoricalColumn`.
|
|
232
|
+
|
|
233
|
+
Raises:
|
|
234
|
+
ValueError: if `num_buckets` is less than one.
|
|
235
|
+
ValueError: if `default_value` is not in range `[0, num_buckets)`.
|
|
236
|
+
"""
|
|
237
|
+
return fc.SequenceCategoricalColumn(
|
|
238
|
+
fc.categorical_column_with_identity(
|
|
239
|
+
feature_name=feature_name,
|
|
240
|
+
key=key,
|
|
241
|
+
num_buckets=num_buckets,
|
|
242
|
+
default_value=default_value))
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def sequence_numeric_column_with_bucketized_column(source_column, boundaries):
|
|
246
|
+
if not isinstance(source_column, (SequenceNumericColumn,)): # pylint: disable=protected-access
|
|
247
|
+
raise ValueError(
|
|
248
|
+
'source_column must be a column generated with sequence_numeric_column(). '
|
|
249
|
+
'Given: {}'.format(source_column))
|
|
250
|
+
if len(source_column.shape) > 1:
|
|
251
|
+
raise ValueError('source_column must be one-dimensional column. '
|
|
252
|
+
'Given: {}'.format(source_column))
|
|
253
|
+
if not boundaries:
|
|
254
|
+
raise ValueError('boundaries must not be empty.')
|
|
255
|
+
if not (isinstance(boundaries, list) or isinstance(boundaries, tuple)):
|
|
256
|
+
raise ValueError('boundaries must be a sorted list.')
|
|
257
|
+
for i in range(len(boundaries) - 1):
|
|
258
|
+
if boundaries[i] >= boundaries[i + 1]:
|
|
259
|
+
raise ValueError('boundaries must be a sorted list.')
|
|
260
|
+
return fc.SequenceBucketizedColumn(source_column, tuple(boundaries))
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def sequence_numeric_column_with_raw_column(source_column, sequence_length):
|
|
264
|
+
if not isinstance(source_column, (SequenceNumericColumn,)): # pylint: disable=protected-access
|
|
265
|
+
raise ValueError(
|
|
266
|
+
'source_column must be a column generated with sequence_numeric_column(). '
|
|
267
|
+
'Given: {}'.format(source_column))
|
|
268
|
+
if len(source_column.shape) > 1:
|
|
269
|
+
raise ValueError('source_column must be one-dimensional column. '
|
|
270
|
+
'Given: {}'.format(source_column))
|
|
271
|
+
|
|
272
|
+
return fc.SequenceNumericColumn(source_column, sequence_length)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def sequence_weighted_categorical_column(categorical_column,
|
|
276
|
+
weight_feature_key,
|
|
277
|
+
dtype=dtypes.float32):
|
|
278
|
+
if (dtype is None) or not (dtype.is_integer or dtype.is_floating):
|
|
279
|
+
raise ValueError('dtype {} is not convertible to float.'.format(dtype))
|
|
280
|
+
return fc.SequenceWeightedCategoricalColumn(
|
|
281
|
+
categorical_column=categorical_column,
|
|
282
|
+
weight_feature_key=weight_feature_key,
|
|
283
|
+
dtype=dtype)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def sequence_categorical_column_with_hash_bucket(key,
|
|
287
|
+
hash_bucket_size,
|
|
288
|
+
dtype=dtypes.string,
|
|
289
|
+
feature_name=None):
|
|
290
|
+
"""A sequence of categorical terms where ids are set by hashing.
|
|
291
|
+
|
|
292
|
+
Pass this to `embedding_column` or `indicator_column` to convert sequence
|
|
293
|
+
categorical data into dense representation for input to sequence NN, such as
|
|
294
|
+
RNN.
|
|
295
|
+
|
|
296
|
+
Example:
|
|
297
|
+
|
|
298
|
+
```python
|
|
299
|
+
tokens = sequence_categorical_column_with_hash_bucket(
|
|
300
|
+
'tokens', hash_bucket_size=1000)
|
|
301
|
+
tokens_embedding = embedding_column(tokens, dimension=10)
|
|
302
|
+
columns = [tokens_embedding]
|
|
303
|
+
|
|
304
|
+
features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
|
|
305
|
+
sequence_feature_layer = SequenceFeatures(columns)
|
|
306
|
+
sequence_input, sequence_length = sequence_feature_layer(features)
|
|
307
|
+
sequence_length_mask = tf.sequence_mask(sequence_length)
|
|
308
|
+
|
|
309
|
+
rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
|
|
310
|
+
rnn_layer = tf.keras.layers.RNN(rnn_cell)
|
|
311
|
+
outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
|
|
312
|
+
```
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
key: A unique string identifying the input feature.
|
|
316
|
+
hash_bucket_size: An int > 1. The number of buckets.
|
|
317
|
+
dtype: The type of features. Only string and integer types are supported.
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
A `SequenceCategoricalColumn`.
|
|
321
|
+
|
|
322
|
+
Raises:
|
|
323
|
+
ValueError: `hash_bucket_size` is not greater than 1.
|
|
324
|
+
ValueError: `dtype` is neither string nor integer.
|
|
325
|
+
"""
|
|
326
|
+
return fc.SequenceCategoricalColumn(
|
|
327
|
+
fc.categorical_column_with_hash_bucket(
|
|
328
|
+
feature_name=feature_name,
|
|
329
|
+
key=key,
|
|
330
|
+
hash_bucket_size=hash_bucket_size,
|
|
331
|
+
dtype=dtype))
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def sequence_categorical_column_with_vocabulary_file(key,
|
|
335
|
+
vocabulary_file,
|
|
336
|
+
vocabulary_size=None,
|
|
337
|
+
num_oov_buckets=0,
|
|
338
|
+
default_value=None,
|
|
339
|
+
dtype=dtypes.string,
|
|
340
|
+
feature_name=None):
|
|
341
|
+
"""A sequence of categorical terms where ids use a vocabulary file.
|
|
342
|
+
|
|
343
|
+
Pass this to `embedding_column` or `indicator_column` to convert sequence
|
|
344
|
+
categorical data into dense representation for input to sequence NN, such as
|
|
345
|
+
RNN.
|
|
346
|
+
|
|
347
|
+
Example:
|
|
348
|
+
|
|
349
|
+
```python
|
|
350
|
+
states = sequence_categorical_column_with_vocabulary_file(
|
|
351
|
+
key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
|
|
352
|
+
num_oov_buckets=5)
|
|
353
|
+
states_embedding = embedding_column(states, dimension=10)
|
|
354
|
+
columns = [states_embedding]
|
|
355
|
+
|
|
356
|
+
features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
|
|
357
|
+
sequence_feature_layer = SequenceFeatures(columns)
|
|
358
|
+
sequence_input, sequence_length = sequence_feature_layer(features)
|
|
359
|
+
sequence_length_mask = tf.sequence_mask(sequence_length)
|
|
360
|
+
|
|
361
|
+
rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
|
|
362
|
+
rnn_layer = tf.keras.layers.RNN(rnn_cell)
|
|
363
|
+
outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
|
|
364
|
+
```
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
key: A unique string identifying the input feature.
|
|
368
|
+
vocabulary_file: The vocabulary file name.
|
|
369
|
+
vocabulary_size: Number of the elements in the vocabulary. This must be no
|
|
370
|
+
greater than length of `vocabulary_file`, if less than length, later
|
|
371
|
+
values are ignored. If None, it is set to the length of `vocabulary_file`.
|
|
372
|
+
num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
|
|
373
|
+
buckets. All out-of-vocabulary inputs will be assigned IDs in the range
|
|
374
|
+
`[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
|
|
375
|
+
the input value. A positive `num_oov_buckets` can not be specified with
|
|
376
|
+
`default_value`.
|
|
377
|
+
default_value: The integer ID value to return for out-of-vocabulary feature
|
|
378
|
+
values, defaults to `-1`. This can not be specified with a positive
|
|
379
|
+
`num_oov_buckets`.
|
|
380
|
+
dtype: The type of features. Only string and integer types are supported.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
A `SequenceCategoricalColumn`.
|
|
384
|
+
|
|
385
|
+
Raises:
|
|
386
|
+
ValueError: `vocabulary_file` is missing or cannot be opened.
|
|
387
|
+
ValueError: `vocabulary_size` is missing or < 1.
|
|
388
|
+
ValueError: `num_oov_buckets` is a negative integer.
|
|
389
|
+
ValueError: `num_oov_buckets` and `default_value` are both specified.
|
|
390
|
+
ValueError: `dtype` is neither string nor integer.
|
|
391
|
+
"""
|
|
392
|
+
return fc.SequenceCategoricalColumn(
|
|
393
|
+
fc.categorical_column_with_vocabulary_file(
|
|
394
|
+
feature_name=feature_name,
|
|
395
|
+
key=key,
|
|
396
|
+
vocabulary_file=vocabulary_file,
|
|
397
|
+
vocabulary_size=vocabulary_size,
|
|
398
|
+
num_oov_buckets=num_oov_buckets,
|
|
399
|
+
default_value=default_value,
|
|
400
|
+
dtype=dtype))
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def sequence_categorical_column_with_vocabulary_list(key,
|
|
404
|
+
vocabulary_list,
|
|
405
|
+
dtype=None,
|
|
406
|
+
default_value=-1,
|
|
407
|
+
num_oov_buckets=0,
|
|
408
|
+
feature_name=None):
|
|
409
|
+
"""A sequence of categorical terms where ids use an in-memory list.
|
|
410
|
+
|
|
411
|
+
Pass this to `embedding_column` or `indicator_column` to convert sequence
|
|
412
|
+
categorical data into dense representation for input to sequence NN, such as
|
|
413
|
+
RNN.
|
|
414
|
+
|
|
415
|
+
Example:
|
|
416
|
+
|
|
417
|
+
```python
|
|
418
|
+
colors = sequence_categorical_column_with_vocabulary_list(
|
|
419
|
+
key='colors', vocabulary_list=('R', 'G', 'B', 'Y'),
|
|
420
|
+
num_oov_buckets=2)
|
|
421
|
+
colors_embedding = embedding_column(colors, dimension=3)
|
|
422
|
+
columns = [colors_embedding]
|
|
423
|
+
|
|
424
|
+
features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
|
|
425
|
+
sequence_feature_layer = SequenceFeatures(columns)
|
|
426
|
+
sequence_input, sequence_length = sequence_feature_layer(features)
|
|
427
|
+
sequence_length_mask = tf.sequence_mask(sequence_length)
|
|
428
|
+
|
|
429
|
+
rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
|
|
430
|
+
rnn_layer = tf.keras.layers.RNN(rnn_cell)
|
|
431
|
+
outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
|
|
432
|
+
```
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
key: A unique string identifying the input feature.
|
|
436
|
+
vocabulary_list: An ordered iterable defining the vocabulary. Each feature
|
|
437
|
+
is mapped to the index of its value (if present) in `vocabulary_list`.
|
|
438
|
+
Must be castable to `dtype`.
|
|
439
|
+
dtype: The type of features. Only string and integer types are supported.
|
|
440
|
+
If `None`, it will be inferred from `vocabulary_list`.
|
|
441
|
+
default_value: The integer ID value to return for out-of-vocabulary feature
|
|
442
|
+
values, defaults to `-1`. This can not be specified with a positive
|
|
443
|
+
`num_oov_buckets`.
|
|
444
|
+
num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
|
|
445
|
+
buckets. All out-of-vocabulary inputs will be assigned IDs in the range
|
|
446
|
+
`[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a
|
|
447
|
+
hash of the input value. A positive `num_oov_buckets` can not be specified
|
|
448
|
+
with `default_value`.
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
A `SequenceCategoricalColumn`.
|
|
452
|
+
|
|
453
|
+
Raises:
|
|
454
|
+
ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
|
|
455
|
+
ValueError: `num_oov_buckets` is a negative integer.
|
|
456
|
+
ValueError: `num_oov_buckets` and `default_value` are both specified.
|
|
457
|
+
ValueError: if `dtype` is not integer or string.
|
|
458
|
+
"""
|
|
459
|
+
return fc.SequenceCategoricalColumn(
|
|
460
|
+
fc.categorical_column_with_vocabulary_list(
|
|
461
|
+
feature_name=feature_name,
|
|
462
|
+
key=key,
|
|
463
|
+
vocabulary_list=vocabulary_list,
|
|
464
|
+
dtype=dtype,
|
|
465
|
+
default_value=default_value,
|
|
466
|
+
num_oov_buckets=num_oov_buckets))
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def sequence_numeric_column(key,
|
|
470
|
+
shape=(1,),
|
|
471
|
+
default_value=0.,
|
|
472
|
+
dtype=dtypes.float32,
|
|
473
|
+
normalizer_fn=None,
|
|
474
|
+
feature_name=None):
|
|
475
|
+
"""Returns a feature column that represents sequences of numeric data.
|
|
476
|
+
|
|
477
|
+
Example:
|
|
478
|
+
|
|
479
|
+
```python
|
|
480
|
+
temperature = sequence_numeric_column('temperature')
|
|
481
|
+
columns = [temperature]
|
|
482
|
+
|
|
483
|
+
features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
|
|
484
|
+
sequence_feature_layer = SequenceFeatures(columns)
|
|
485
|
+
sequence_input, sequence_length = sequence_feature_layer(features)
|
|
486
|
+
sequence_length_mask = tf.sequence_mask(sequence_length)
|
|
487
|
+
|
|
488
|
+
rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
|
|
489
|
+
rnn_layer = tf.keras.layers.RNN(rnn_cell)
|
|
490
|
+
outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
|
|
491
|
+
```
|
|
492
|
+
|
|
493
|
+
Args:
|
|
494
|
+
key: A unique string identifying the input features.
|
|
495
|
+
shape: The shape of the input data per sequence id. E.g. if `shape=(2,)`,
|
|
496
|
+
each example must contain `2 * sequence_length` values.
|
|
497
|
+
default_value: A single value compatible with `dtype` that is used for
|
|
498
|
+
padding the sparse data into a dense `Tensor`.
|
|
499
|
+
dtype: The type of values.
|
|
500
|
+
normalizer_fn: If not `None`, a function that can be used to normalize the
|
|
501
|
+
value of the tensor after `default_value` is applied for parsing.
|
|
502
|
+
Normalizer function takes the input `Tensor` as its argument, and returns
|
|
503
|
+
the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that
|
|
504
|
+
even though the most common use case of this function is normalization, it
|
|
505
|
+
can be used for any kind of Tensorflow transformations.
|
|
506
|
+
|
|
507
|
+
Returns:
|
|
508
|
+
A `SequenceNumericColumn`.
|
|
509
|
+
|
|
510
|
+
Raises:
|
|
511
|
+
TypeError: if any dimension in shape is not an int.
|
|
512
|
+
ValueError: if any dimension in shape is not a positive integer.
|
|
513
|
+
ValueError: if `dtype` is not convertible to `tf.float32`.
|
|
514
|
+
"""
|
|
515
|
+
shape = fc._check_shape(shape=shape, key=key)
|
|
516
|
+
if not (dtype.is_integer or dtype.is_floating):
|
|
517
|
+
raise ValueError('dtype must be convertible to float. '
|
|
518
|
+
'dtype: {}, key: {}'.format(dtype, key))
|
|
519
|
+
if normalizer_fn is not None and not callable(normalizer_fn):
|
|
520
|
+
raise TypeError(
|
|
521
|
+
'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
|
|
522
|
+
|
|
523
|
+
return SequenceNumericColumn(
|
|
524
|
+
feature_name=feature_name,
|
|
525
|
+
key=key,
|
|
526
|
+
shape=shape,
|
|
527
|
+
default_value=default_value,
|
|
528
|
+
dtype=dtype,
|
|
529
|
+
normalizer_fn=normalizer_fn)
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def _assert_all_equal_and_return(tensors, name=None):
|
|
533
|
+
"""Asserts that all tensors are equal and returns the first one."""
|
|
534
|
+
with ops.name_scope(name, 'assert_all_equal', values=tensors):
|
|
535
|
+
if len(tensors) == 1:
|
|
536
|
+
return tensors[0]
|
|
537
|
+
assert_equal_ops = []
|
|
538
|
+
for t in tensors[1:]:
|
|
539
|
+
assert_equal_ops.append(check_ops.assert_equal(tensors[0], t))
|
|
540
|
+
with ops.control_dependencies(assert_equal_ops):
|
|
541
|
+
return array_ops.identity(tensors[0])
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
class SequenceNumericColumn(
|
|
545
|
+
fc.SequenceDenseColumn, fc_v1._FeatureColumn,
|
|
546
|
+
collections.namedtuple('SequenceNumericColumn',
|
|
547
|
+
('feature_name', 'key', 'shape', 'default_value',
|
|
548
|
+
'dtype', 'normalizer_fn'))):
|
|
549
|
+
"""Represents sequences of numeric data."""
|
|
550
|
+
|
|
551
|
+
@property
|
|
552
|
+
def _is_v2_column(self):
|
|
553
|
+
return True
|
|
554
|
+
|
|
555
|
+
@property
|
|
556
|
+
def name(self):
|
|
557
|
+
"""See `FeatureColumn` base class."""
|
|
558
|
+
return self.feature_name if self.feature_name else self.key
|
|
559
|
+
|
|
560
|
+
@property
|
|
561
|
+
def raw_name(self):
|
|
562
|
+
"""See `FeatureColumn` base class."""
|
|
563
|
+
return self.key
|
|
564
|
+
|
|
565
|
+
@property
|
|
566
|
+
def parse_example_spec(self):
|
|
567
|
+
"""See `FeatureColumn` base class."""
|
|
568
|
+
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
|
|
569
|
+
|
|
570
|
+
def _transform_feature(self, inputs):
|
|
571
|
+
input_tensor = inputs.get(self.key)
|
|
572
|
+
return self._transform_input_tensor(input_tensor)
|
|
573
|
+
|
|
574
|
+
def _transform_input_tensor(self, input_tensor):
|
|
575
|
+
return math_ops.cast(input_tensor, dtypes.float32)
|
|
576
|
+
|
|
577
|
+
def transform_feature(self, transformation_cache, state_manager):
|
|
578
|
+
"""See `FeatureColumn` base class.
|
|
579
|
+
|
|
580
|
+
In this case, we apply the `normalizer_fn` to the input tensor.
|
|
581
|
+
|
|
582
|
+
Args:
|
|
583
|
+
transformation_cache: A `FeatureTransformationCache` object to access
|
|
584
|
+
features.
|
|
585
|
+
state_manager: A `StateManager` to create / access resources such as
|
|
586
|
+
lookup tables.
|
|
587
|
+
|
|
588
|
+
Returns:
|
|
589
|
+
Normalized input tensor.
|
|
590
|
+
"""
|
|
591
|
+
input_tensor = transformation_cache.get(self.key, state_manager)
|
|
592
|
+
if self.normalizer_fn is not None:
|
|
593
|
+
input_tensor = self.normalizer_fn(input_tensor)
|
|
594
|
+
return self._transform_input_tensor(input_tensor)
|
|
595
|
+
|
|
596
|
+
@property
|
|
597
|
+
def variable_shape(self):
|
|
598
|
+
"""Returns a `TensorShape` representing the shape of sequence input."""
|
|
599
|
+
return tensor_shape.TensorShape(self.shape)
|
|
600
|
+
|
|
601
|
+
def get_sequence_dense_tensor(self, transformation_cache, state_manager):
|
|
602
|
+
"""Returns a `TensorSequenceLengthPair`.
|
|
603
|
+
|
|
604
|
+
Args:
|
|
605
|
+
transformation_cache: A `FeatureTransformationCache` object to access
|
|
606
|
+
features.
|
|
607
|
+
state_manager: A `StateManager` to create / access resources such as
|
|
608
|
+
lookup tables.
|
|
609
|
+
"""
|
|
610
|
+
sp_tensor = transformation_cache.get(self, state_manager)
|
|
611
|
+
dense_tensor = sparse_ops.sparse_tensor_to_dense(
|
|
612
|
+
sp_tensor, default_value=self.default_value)
|
|
613
|
+
# Reshape into [batch_size, T, variable_shape].
|
|
614
|
+
dense_shape = array_ops.concat(
|
|
615
|
+
[array_ops.shape(dense_tensor)[:1], [-1], self.variable_shape], axis=0)
|
|
616
|
+
dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape)
|
|
617
|
+
|
|
618
|
+
# Get the number of timesteps per example
|
|
619
|
+
# For the 2D case, the raw values are grouped according to num_elements;
|
|
620
|
+
# for the 3D case, the grouping happens in the third dimension, and
|
|
621
|
+
# sequence length is not affected.
|
|
622
|
+
if sp_tensor.shape.ndims == 2:
|
|
623
|
+
num_elements = self.variable_shape.num_elements()
|
|
624
|
+
else:
|
|
625
|
+
num_elements = 1
|
|
626
|
+
seq_length = fc_utils.sequence_length_from_sparse_tensor(
|
|
627
|
+
sp_tensor, num_elements=num_elements)
|
|
628
|
+
|
|
629
|
+
return fc.SequenceDenseColumn.TensorSequenceLengthPair(
|
|
630
|
+
dense_tensor=dense_tensor, sequence_length=seq_length)
|
|
631
|
+
|
|
632
|
+
# TODO(b/119409767): Implement parents, _{get,from}_config.
|
|
633
|
+
@property
|
|
634
|
+
def parents(self):
|
|
635
|
+
"""See 'FeatureColumn` base class."""
|
|
636
|
+
raise NotImplementedError()
|
|
637
|
+
|
|
638
|
+
def _get_config(self):
|
|
639
|
+
"""See 'FeatureColumn` base class."""
|
|
640
|
+
raise NotImplementedError()
|
|
641
|
+
|
|
642
|
+
@classmethod
|
|
643
|
+
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
|
644
|
+
"""See 'FeatureColumn` base class."""
|
|
645
|
+
raise NotImplementedError()
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
# pylint: enable=protected-access
|