easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import tensorflow as tf
|
|
3
|
+
from tensorflow.python.framework import constant_op
|
|
4
|
+
from tensorflow.python.framework import ops
|
|
5
|
+
from tensorflow.python.framework import sparse_tensor
|
|
6
|
+
from tensorflow.python.ops import gen_math_ops
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def convert_to_int_tensor(tensor, name, dtype=tf.int32):
|
|
10
|
+
"""Converts the given value to an integer Tensor."""
|
|
11
|
+
tensor = ops.convert_to_tensor(tensor, name=name, preferred_dtype=dtype)
|
|
12
|
+
if tensor.dtype.is_integer:
|
|
13
|
+
tensor = gen_math_ops.cast(tensor, dtype)
|
|
14
|
+
else:
|
|
15
|
+
raise TypeError('%s must be an integer tensor; dtype=%s' %
|
|
16
|
+
(name, tensor.dtype))
|
|
17
|
+
return tensor
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _with_nonzero_rank(data):
|
|
21
|
+
"""If `data` is scalar, then add a dimension; otherwise return as-is."""
|
|
22
|
+
if data.shape.ndims is not None:
|
|
23
|
+
if data.shape.ndims == 0:
|
|
24
|
+
return tf.stack([data])
|
|
25
|
+
else:
|
|
26
|
+
return data
|
|
27
|
+
else:
|
|
28
|
+
data_shape = tf.shape(data)
|
|
29
|
+
data_ndims = tf.rank(data)
|
|
30
|
+
return tf.reshape(data, tf.concat([[1], data_shape], axis=0)[-data_ndims:])
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_positive_axis(axis, ndims):
|
|
34
|
+
"""Validate an `axis` parameter, and normalize it to be positive.
|
|
35
|
+
|
|
36
|
+
If `ndims` is known (i.e., not `None`), then check that `axis` is in the
|
|
37
|
+
range `-ndims <= axis < ndims`, and return `axis` (if `axis >= 0`) or
|
|
38
|
+
`axis + ndims` (otherwise).
|
|
39
|
+
If `ndims` is not known, and `axis` is positive, then return it as-is.
|
|
40
|
+
If `ndims` is not known, and `axis` is negative, then report an error.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
axis: An integer constant
|
|
44
|
+
ndims: An integer constant, or `None`
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
The normalized `axis` value.
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
ValueError: If `axis` is out-of-bounds, or if `axis` is negative and
|
|
51
|
+
`ndims is None`.
|
|
52
|
+
"""
|
|
53
|
+
if not isinstance(axis, int):
|
|
54
|
+
raise TypeError('axis must be an int; got %s' % type(axis).__name__)
|
|
55
|
+
if ndims is not None:
|
|
56
|
+
if 0 <= axis < ndims:
|
|
57
|
+
return axis
|
|
58
|
+
elif -ndims <= axis < 0:
|
|
59
|
+
return axis + ndims
|
|
60
|
+
else:
|
|
61
|
+
raise ValueError('axis=%s out of bounds: expected %s<=axis<%s' %
|
|
62
|
+
(axis, -ndims, ndims))
|
|
63
|
+
elif axis < 0:
|
|
64
|
+
raise ValueError('axis may only be negative if ndims is statically known.')
|
|
65
|
+
return axis
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def tile_one_dimension(data, axis, multiple):
|
|
69
|
+
"""Tiles a single dimension of a tensor."""
|
|
70
|
+
# Assumes axis is a nonnegative int.
|
|
71
|
+
if data.shape.ndims is not None:
|
|
72
|
+
multiples = [1] * data.shape.ndims
|
|
73
|
+
multiples[axis] = multiple
|
|
74
|
+
else:
|
|
75
|
+
ones_value = tf.ones(tf.rank(data), tf.int32)
|
|
76
|
+
multiples = tf.concat(
|
|
77
|
+
[ones_value[:axis], [multiple], ones_value[axis + 1:]], axis=0)
|
|
78
|
+
return tf.tile(data, multiples)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _all_dimensions(x):
|
|
82
|
+
"""Returns a 1D-tensor listing all dimensions in x."""
|
|
83
|
+
# Fast path: avoid creating Rank and Range ops if ndims is known.
|
|
84
|
+
if isinstance(x, ops.Tensor) and x.get_shape().ndims is not None:
|
|
85
|
+
return constant_op.constant(np.arange(x.get_shape().ndims), dtype=tf.int32)
|
|
86
|
+
if (isinstance(x, sparse_tensor.SparseTensor) and
|
|
87
|
+
x.dense_shape.get_shape().is_fully_defined()):
|
|
88
|
+
r = x.dense_shape.get_shape().dims[0].value # sparse.dense_shape is 1-D.
|
|
89
|
+
return constant_op.constant(np.arange(r), dtype=tf.int32)
|
|
90
|
+
|
|
91
|
+
# Otherwise, we rely on `range` and `rank` to do the right thing at runtime.
|
|
92
|
+
return gen_math_ops._range(0, tf.rank(x), 1)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# This op is intended to exactly match the semantics of numpy.repeat, with
|
|
96
|
+
# one exception: numpy.repeat has special (and somewhat non-intuitive) behavior
|
|
97
|
+
# when axis is not specified. Rather than implement that special behavior, we
|
|
98
|
+
# simply make `axis` be a required argument.
|
|
99
|
+
#
|
|
100
|
+
# External (OSS) `tf.repeat` feature request:
|
|
101
|
+
# https://github.com/tensorflow/tensorflow/issues/8246
|
|
102
|
+
def repeat_with_axis(data, repeats, axis, name=None):
|
|
103
|
+
"""Repeats elements of `data`.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
data: An `N`-dimensional tensor.
|
|
107
|
+
repeats: A 1-D integer tensor specifying how many times each element in
|
|
108
|
+
`axis` should be repeated. `len(repeats)` must equal `data.shape[axis]`.
|
|
109
|
+
Supports broadcasting from a scalar value.
|
|
110
|
+
axis: `int`. The axis along which to repeat values. Must be less than
|
|
111
|
+
`max(N, 1)`.
|
|
112
|
+
name: A name for the operation.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
A tensor with `max(N, 1)` dimensions. Has the same shape as `data`,
|
|
116
|
+
except that dimension `axis` has size `sum(repeats)`.
|
|
117
|
+
#### Examples:
|
|
118
|
+
```python
|
|
119
|
+
>>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
|
|
120
|
+
['a', 'a', 'a', 'c', 'c']
|
|
121
|
+
>>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
|
|
122
|
+
[[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]]
|
|
123
|
+
>>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
|
|
124
|
+
[[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]]
|
|
125
|
+
```
|
|
126
|
+
"""
|
|
127
|
+
if not isinstance(axis, int):
|
|
128
|
+
raise TypeError('axis must be an int; got %s' % type(axis).__name__)
|
|
129
|
+
|
|
130
|
+
with ops.name_scope(name, 'Repeat', [data, repeats]):
|
|
131
|
+
data = ops.convert_to_tensor(data, name='data')
|
|
132
|
+
repeats = convert_to_int_tensor(repeats, name='repeats')
|
|
133
|
+
repeats.shape.with_rank_at_most(1)
|
|
134
|
+
|
|
135
|
+
# If `data` is a scalar, then upgrade it to a vector.
|
|
136
|
+
data = _with_nonzero_rank(data)
|
|
137
|
+
data_shape = tf.shape(data)
|
|
138
|
+
|
|
139
|
+
# If `axis` is negative, then convert it to a positive value.
|
|
140
|
+
axis = get_positive_axis(axis, data.shape.ndims)
|
|
141
|
+
|
|
142
|
+
# Check data Tensor shapes.
|
|
143
|
+
if repeats.shape.ndims == 1:
|
|
144
|
+
data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0])
|
|
145
|
+
|
|
146
|
+
# If we know that `repeats` is a scalar, then we can just tile & reshape.
|
|
147
|
+
if repeats.shape.ndims == 0:
|
|
148
|
+
expanded = tf.expand_dims(data, axis + 1)
|
|
149
|
+
tiled = tile_one_dimension(expanded, axis + 1, repeats)
|
|
150
|
+
result_shape = tf.concat([data_shape[:axis], [-1], data_shape[axis + 1:]],
|
|
151
|
+
axis=0)
|
|
152
|
+
return tf.reshape(tiled, result_shape)
|
|
153
|
+
|
|
154
|
+
# Broadcast the `repeats` tensor so rank(repeats) == axis + 1.
|
|
155
|
+
if repeats.shape.ndims != axis + 1:
|
|
156
|
+
repeats_shape = tf.shape(repeats)
|
|
157
|
+
repeats_ndims = tf.rank(repeats)
|
|
158
|
+
broadcast_shape = tf.concat(
|
|
159
|
+
[data_shape[:axis + 1 - repeats_ndims], repeats_shape], axis=0)
|
|
160
|
+
repeats = tf.broadcast_to(repeats, broadcast_shape)
|
|
161
|
+
repeats.set_shape([None] * (axis + 1))
|
|
162
|
+
|
|
163
|
+
# Create a "sequence mask" based on `repeats`, where slices across `axis`
|
|
164
|
+
# contain one `True` value for each repetition. E.g., if
|
|
165
|
+
# `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`.
|
|
166
|
+
max_repeat = gen_math_ops.maximum(
|
|
167
|
+
0, gen_math_ops._max(repeats, _all_dimensions(repeats)))
|
|
168
|
+
mask = tf.sequence_mask(repeats, max_repeat)
|
|
169
|
+
|
|
170
|
+
# Add a new dimension around each value that needs to be repeated, and
|
|
171
|
+
# then tile that new dimension to match the maximum number of repetitions.
|
|
172
|
+
expanded = tf.expand_dims(data, axis + 1)
|
|
173
|
+
tiled = tile_one_dimension(expanded, axis + 1, max_repeat)
|
|
174
|
+
|
|
175
|
+
# Use `boolean_mask` to discard the extra repeated values. This also
|
|
176
|
+
# flattens all dimensions up through `axis`.
|
|
177
|
+
masked = tf.boolean_mask(tiled, mask)
|
|
178
|
+
|
|
179
|
+
# Reshape the output tensor to add the outer dimensions back.
|
|
180
|
+
if axis == 0:
|
|
181
|
+
result = masked
|
|
182
|
+
else:
|
|
183
|
+
result_shape = tf.concat([data_shape[:axis], [-1], data_shape[axis + 1:]],
|
|
184
|
+
axis=0)
|
|
185
|
+
result = tf.reshape(masked, result_shape)
|
|
186
|
+
|
|
187
|
+
# Preserve shape information.
|
|
188
|
+
if data.shape.ndims is not None:
|
|
189
|
+
new_axis_size = 0 if repeats.shape[0] == 0 else None
|
|
190
|
+
result.set_shape(data.shape[:axis].concatenate(
|
|
191
|
+
[new_axis_size]).concatenate(data.shape[axis + 1:]))
|
|
192
|
+
|
|
193
|
+
return result
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def repeat(input, repeats, axis=None, name=None): # pylint: disable=redefined-builtin
|
|
197
|
+
"""Repeat elements of `input`.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
input: An `N`-dimensional Tensor.
|
|
201
|
+
repeats: An 1-D `int` Tensor. The number of repetitions for each element.
|
|
202
|
+
repeats is broadcasted to fit the shape of the given axis. `len(repeats)`
|
|
203
|
+
must equal `input.shape[axis]` if axis is not None.
|
|
204
|
+
axis: An int. The axis along which to repeat values. By default (axis=None),
|
|
205
|
+
use the flattened input array, and return a flat output array.
|
|
206
|
+
name: A name for the operation.
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
A Tensor which has the same shape as `input`, except along the given axis.
|
|
210
|
+
If axis is None then the output array is flattened to match the flattened
|
|
211
|
+
input array.
|
|
212
|
+
#### Examples:
|
|
213
|
+
```python
|
|
214
|
+
>>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
|
|
215
|
+
['a', 'a', 'a', 'c', 'c']
|
|
216
|
+
>>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
|
|
217
|
+
[[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]]
|
|
218
|
+
>>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
|
|
219
|
+
[[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]]
|
|
220
|
+
>>> repeat(3, repeats=4)
|
|
221
|
+
[3, 3, 3, 3]
|
|
222
|
+
>>> repeat([[1,2], [3,4]], repeats=2)
|
|
223
|
+
[1, 1, 2, 2, 3, 3, 4, 4]
|
|
224
|
+
```
|
|
225
|
+
"""
|
|
226
|
+
if axis is None:
|
|
227
|
+
input = tf.reshape(input, [-1])
|
|
228
|
+
axis = 0
|
|
229
|
+
return repeat_with_axis(input, repeats, axis, name)
|