easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,229 @@
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from tensorflow.python.framework import constant_op
4
+ from tensorflow.python.framework import ops
5
+ from tensorflow.python.framework import sparse_tensor
6
+ from tensorflow.python.ops import gen_math_ops
7
+
8
+
9
+ def convert_to_int_tensor(tensor, name, dtype=tf.int32):
10
+ """Converts the given value to an integer Tensor."""
11
+ tensor = ops.convert_to_tensor(tensor, name=name, preferred_dtype=dtype)
12
+ if tensor.dtype.is_integer:
13
+ tensor = gen_math_ops.cast(tensor, dtype)
14
+ else:
15
+ raise TypeError('%s must be an integer tensor; dtype=%s' %
16
+ (name, tensor.dtype))
17
+ return tensor
18
+
19
+
20
+ def _with_nonzero_rank(data):
21
+ """If `data` is scalar, then add a dimension; otherwise return as-is."""
22
+ if data.shape.ndims is not None:
23
+ if data.shape.ndims == 0:
24
+ return tf.stack([data])
25
+ else:
26
+ return data
27
+ else:
28
+ data_shape = tf.shape(data)
29
+ data_ndims = tf.rank(data)
30
+ return tf.reshape(data, tf.concat([[1], data_shape], axis=0)[-data_ndims:])
31
+
32
+
33
+ def get_positive_axis(axis, ndims):
34
+ """Validate an `axis` parameter, and normalize it to be positive.
35
+
36
+ If `ndims` is known (i.e., not `None`), then check that `axis` is in the
37
+ range `-ndims <= axis < ndims`, and return `axis` (if `axis >= 0`) or
38
+ `axis + ndims` (otherwise).
39
+ If `ndims` is not known, and `axis` is positive, then return it as-is.
40
+ If `ndims` is not known, and `axis` is negative, then report an error.
41
+
42
+ Args:
43
+ axis: An integer constant
44
+ ndims: An integer constant, or `None`
45
+
46
+ Returns:
47
+ The normalized `axis` value.
48
+
49
+ Raises:
50
+ ValueError: If `axis` is out-of-bounds, or if `axis` is negative and
51
+ `ndims is None`.
52
+ """
53
+ if not isinstance(axis, int):
54
+ raise TypeError('axis must be an int; got %s' % type(axis).__name__)
55
+ if ndims is not None:
56
+ if 0 <= axis < ndims:
57
+ return axis
58
+ elif -ndims <= axis < 0:
59
+ return axis + ndims
60
+ else:
61
+ raise ValueError('axis=%s out of bounds: expected %s<=axis<%s' %
62
+ (axis, -ndims, ndims))
63
+ elif axis < 0:
64
+ raise ValueError('axis may only be negative if ndims is statically known.')
65
+ return axis
66
+
67
+
68
+ def tile_one_dimension(data, axis, multiple):
69
+ """Tiles a single dimension of a tensor."""
70
+ # Assumes axis is a nonnegative int.
71
+ if data.shape.ndims is not None:
72
+ multiples = [1] * data.shape.ndims
73
+ multiples[axis] = multiple
74
+ else:
75
+ ones_value = tf.ones(tf.rank(data), tf.int32)
76
+ multiples = tf.concat(
77
+ [ones_value[:axis], [multiple], ones_value[axis + 1:]], axis=0)
78
+ return tf.tile(data, multiples)
79
+
80
+
81
+ def _all_dimensions(x):
82
+ """Returns a 1D-tensor listing all dimensions in x."""
83
+ # Fast path: avoid creating Rank and Range ops if ndims is known.
84
+ if isinstance(x, ops.Tensor) and x.get_shape().ndims is not None:
85
+ return constant_op.constant(np.arange(x.get_shape().ndims), dtype=tf.int32)
86
+ if (isinstance(x, sparse_tensor.SparseTensor) and
87
+ x.dense_shape.get_shape().is_fully_defined()):
88
+ r = x.dense_shape.get_shape().dims[0].value # sparse.dense_shape is 1-D.
89
+ return constant_op.constant(np.arange(r), dtype=tf.int32)
90
+
91
+ # Otherwise, we rely on `range` and `rank` to do the right thing at runtime.
92
+ return gen_math_ops._range(0, tf.rank(x), 1)
93
+
94
+
95
+ # This op is intended to exactly match the semantics of numpy.repeat, with
96
+ # one exception: numpy.repeat has special (and somewhat non-intuitive) behavior
97
+ # when axis is not specified. Rather than implement that special behavior, we
98
+ # simply make `axis` be a required argument.
99
+ #
100
+ # External (OSS) `tf.repeat` feature request:
101
+ # https://github.com/tensorflow/tensorflow/issues/8246
102
+ def repeat_with_axis(data, repeats, axis, name=None):
103
+ """Repeats elements of `data`.
104
+
105
+ Args:
106
+ data: An `N`-dimensional tensor.
107
+ repeats: A 1-D integer tensor specifying how many times each element in
108
+ `axis` should be repeated. `len(repeats)` must equal `data.shape[axis]`.
109
+ Supports broadcasting from a scalar value.
110
+ axis: `int`. The axis along which to repeat values. Must be less than
111
+ `max(N, 1)`.
112
+ name: A name for the operation.
113
+
114
+ Returns:
115
+ A tensor with `max(N, 1)` dimensions. Has the same shape as `data`,
116
+ except that dimension `axis` has size `sum(repeats)`.
117
+ #### Examples:
118
+ ```python
119
+ >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
120
+ ['a', 'a', 'a', 'c', 'c']
121
+ >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
122
+ [[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]]
123
+ >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
124
+ [[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]]
125
+ ```
126
+ """
127
+ if not isinstance(axis, int):
128
+ raise TypeError('axis must be an int; got %s' % type(axis).__name__)
129
+
130
+ with ops.name_scope(name, 'Repeat', [data, repeats]):
131
+ data = ops.convert_to_tensor(data, name='data')
132
+ repeats = convert_to_int_tensor(repeats, name='repeats')
133
+ repeats.shape.with_rank_at_most(1)
134
+
135
+ # If `data` is a scalar, then upgrade it to a vector.
136
+ data = _with_nonzero_rank(data)
137
+ data_shape = tf.shape(data)
138
+
139
+ # If `axis` is negative, then convert it to a positive value.
140
+ axis = get_positive_axis(axis, data.shape.ndims)
141
+
142
+ # Check data Tensor shapes.
143
+ if repeats.shape.ndims == 1:
144
+ data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0])
145
+
146
+ # If we know that `repeats` is a scalar, then we can just tile & reshape.
147
+ if repeats.shape.ndims == 0:
148
+ expanded = tf.expand_dims(data, axis + 1)
149
+ tiled = tile_one_dimension(expanded, axis + 1, repeats)
150
+ result_shape = tf.concat([data_shape[:axis], [-1], data_shape[axis + 1:]],
151
+ axis=0)
152
+ return tf.reshape(tiled, result_shape)
153
+
154
+ # Broadcast the `repeats` tensor so rank(repeats) == axis + 1.
155
+ if repeats.shape.ndims != axis + 1:
156
+ repeats_shape = tf.shape(repeats)
157
+ repeats_ndims = tf.rank(repeats)
158
+ broadcast_shape = tf.concat(
159
+ [data_shape[:axis + 1 - repeats_ndims], repeats_shape], axis=0)
160
+ repeats = tf.broadcast_to(repeats, broadcast_shape)
161
+ repeats.set_shape([None] * (axis + 1))
162
+
163
+ # Create a "sequence mask" based on `repeats`, where slices across `axis`
164
+ # contain one `True` value for each repetition. E.g., if
165
+ # `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`.
166
+ max_repeat = gen_math_ops.maximum(
167
+ 0, gen_math_ops._max(repeats, _all_dimensions(repeats)))
168
+ mask = tf.sequence_mask(repeats, max_repeat)
169
+
170
+ # Add a new dimension around each value that needs to be repeated, and
171
+ # then tile that new dimension to match the maximum number of repetitions.
172
+ expanded = tf.expand_dims(data, axis + 1)
173
+ tiled = tile_one_dimension(expanded, axis + 1, max_repeat)
174
+
175
+ # Use `boolean_mask` to discard the extra repeated values. This also
176
+ # flattens all dimensions up through `axis`.
177
+ masked = tf.boolean_mask(tiled, mask)
178
+
179
+ # Reshape the output tensor to add the outer dimensions back.
180
+ if axis == 0:
181
+ result = masked
182
+ else:
183
+ result_shape = tf.concat([data_shape[:axis], [-1], data_shape[axis + 1:]],
184
+ axis=0)
185
+ result = tf.reshape(masked, result_shape)
186
+
187
+ # Preserve shape information.
188
+ if data.shape.ndims is not None:
189
+ new_axis_size = 0 if repeats.shape[0] == 0 else None
190
+ result.set_shape(data.shape[:axis].concatenate(
191
+ [new_axis_size]).concatenate(data.shape[axis + 1:]))
192
+
193
+ return result
194
+
195
+
196
+ def repeat(input, repeats, axis=None, name=None): # pylint: disable=redefined-builtin
197
+ """Repeat elements of `input`.
198
+
199
+ Args:
200
+ input: An `N`-dimensional Tensor.
201
+ repeats: An 1-D `int` Tensor. The number of repetitions for each element.
202
+ repeats is broadcasted to fit the shape of the given axis. `len(repeats)`
203
+ must equal `input.shape[axis]` if axis is not None.
204
+ axis: An int. The axis along which to repeat values. By default (axis=None),
205
+ use the flattened input array, and return a flat output array.
206
+ name: A name for the operation.
207
+
208
+ Returns:
209
+ A Tensor which has the same shape as `input`, except along the given axis.
210
+ If axis is None then the output array is flattened to match the flattened
211
+ input array.
212
+ #### Examples:
213
+ ```python
214
+ >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
215
+ ['a', 'a', 'a', 'c', 'c']
216
+ >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
217
+ [[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]]
218
+ >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
219
+ [[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]]
220
+ >>> repeat(3, repeats=4)
221
+ [3, 3, 3, 3]
222
+ >>> repeat([[1,2], [3,4]], repeats=2)
223
+ [1, 1, 2, 2, 3, 3, 4, 4]
224
+ ```
225
+ """
226
+ if axis is None:
227
+ input = tf.reshape(input, [-1])
228
+ axis = 0
229
+ return repeat_with_axis(input, repeats, axis, name)