easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,648 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+ """This API defines FeatureColumn for sequential input.
17
+
18
+ NOTE: This API is a work in progress and will likely be changing frequently.
19
+ """
20
+
21
+ from __future__ import absolute_import
22
+ from __future__ import division
23
+ from __future__ import print_function
24
+
25
+ import collections
26
+
27
+ from tensorflow.python.framework import dtypes
28
+ from tensorflow.python.framework import ops
29
+ from tensorflow.python.framework import tensor_shape
30
+ from tensorflow.python.ops import array_ops
31
+ from tensorflow.python.ops import check_ops
32
+ from tensorflow.python.ops import math_ops
33
+ from tensorflow.python.ops import parsing_ops
34
+ from tensorflow.python.ops import sparse_ops
35
+
36
+ from easy_rec.python.compat.feature_column import feature_column as fc_v1
37
+ from easy_rec.python.compat.feature_column import feature_column_v2 as fc
38
+ from easy_rec.python.compat.feature_column import utils as fc_utils
39
+
40
+ # pylint: disable=protected-access
41
+
42
+
43
+ class SequenceFeatures(fc._BaseFeaturesLayer):
44
+ """A layer for sequence input.
45
+
46
+ All `feature_columns` must be sequence dense columns with the same
47
+ `sequence_length`. The output of this method can be fed into sequence
48
+ networks, such as RNN.
49
+
50
+ The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`.
51
+ `T` is the maximum sequence length for this batch, which could differ from
52
+ batch to batch.
53
+
54
+ If multiple `feature_columns` are given with `Di` `num_elements` each, their
55
+ outputs are concatenated. So, the final `Tensor` has shape
56
+ `[batch_size, T, D0 + D1 + ... + Dn]`.
57
+
58
+ Example:
59
+ ```python
60
+ rating = sequence_numeric_column('rating')
61
+ watches = sequence_categorical_column_with_identity(
62
+ 'watches', num_buckets=1000)
63
+ watches_embedding = embedding_column(watches, dimension=10)
64
+ columns = [rating, watches_embedding]
65
+
66
+ sequence_input_layer = SequenceFeatures(columns)
67
+ features = tf.io.parse_example(...,
68
+ features=make_parse_example_spec(columns))
69
+ sequence_input, sequence_length = sequence_input_layer(features)
70
+ sequence_length_mask = tf.sequence_mask(sequence_length)
71
+
72
+ rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
73
+ rnn_layer = tf.keras.layers.RNN(rnn_cell)
74
+ outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
75
+ ```
76
+ """
77
+
78
+ def __init__(self, feature_columns, trainable=True, name=None, **kwargs):
79
+ """Constructs a SequenceFeatures layer.
80
+
81
+ Args:
82
+ feature_columns: An iterable of dense sequence columns. Valid columns are
83
+ - `embedding_column` that wraps a `sequence_categorical_column_with_*`
84
+ - `sequence_numeric_column`.
85
+ trainable: Boolean, whether the layer's variables will be updated via
86
+ gradient descent during training.
87
+ name: Name to give to the SequenceFeatures.
88
+ **kwargs: Keyword arguments to construct a layer.
89
+
90
+ Raises:
91
+ ValueError: If any of the `feature_columns` is not a
92
+ `SequenceDenseColumn`.
93
+ """
94
+ super(SequenceFeatures, self).__init__(
95
+ feature_columns=feature_columns,
96
+ trainable=trainable,
97
+ name=name,
98
+ expected_column_type=fc.SequenceDenseColumn,
99
+ **kwargs)
100
+
101
+ def _target_shape(self, input_shape, total_elements):
102
+ return (input_shape[0], input_shape[1], total_elements)
103
+
104
+ def call(self, features):
105
+ """Returns sequence input corresponding to the `feature_columns`.
106
+
107
+ Args:
108
+ features: A dict mapping keys to tensors.
109
+
110
+ Returns:
111
+ An `(input_layer, sequence_length)` tuple where:
112
+ - input_layer: A float `Tensor` of shape `[batch_size, T, D]`.
113
+ `T` is the maximum sequence length for this batch, which could differ
114
+ from batch to batch. `D` is the sum of `num_elements` for all
115
+ `feature_columns`.
116
+ - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence
117
+ length for each example.
118
+
119
+ Raises:
120
+ ValueError: If features are not a dictionary.
121
+ """
122
+ if not isinstance(features, dict):
123
+ raise ValueError('We expected a dictionary here. Instead we got: ',
124
+ features)
125
+ transformation_cache = fc.FeatureTransformationCache(features)
126
+ output_tensors = []
127
+ sequence_lengths = []
128
+
129
+ for column in self._feature_columns:
130
+ with ops.name_scope(column.name):
131
+ dense_tensor, sequence_length = column.get_sequence_dense_tensor(
132
+ transformation_cache, self._state_manager)
133
+ # Flattens the final dimension to produce a 3D Tensor.
134
+ output_tensors.append(self._process_dense_tensor(column, dense_tensor))
135
+ sequence_lengths.append(sequence_length)
136
+
137
+ # Check and process sequence lengths.
138
+ fc._verify_static_batch_size_equality(sequence_lengths,
139
+ self._feature_columns)
140
+ sequence_length = _assert_all_equal_and_return(sequence_lengths)
141
+
142
+ return self._verify_and_concat_tensors(output_tensors), sequence_length
143
+
144
+
145
+ def concatenate_context_input(context_input, sequence_input):
146
+ """Replicates `context_input` across all timesteps of `sequence_input`.
147
+
148
+ Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
149
+ This value is appended to `sequence_input` on dimension 2 and the result is
150
+ returned.
151
+
152
+ Args:
153
+ context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
154
+ sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
155
+ padded_length, d0]`.
156
+
157
+ Returns:
158
+ A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
159
+ d0 + d1]`.
160
+
161
+ Raises:
162
+ ValueError: If `sequence_input` does not have rank 3 or `context_input` does
163
+ not have rank 2.
164
+ """
165
+ seq_rank_check = check_ops.assert_rank(
166
+ sequence_input,
167
+ 3,
168
+ message='sequence_input must have rank 3',
169
+ data=[array_ops.shape(sequence_input)])
170
+ seq_type_check = check_ops.assert_type(
171
+ sequence_input,
172
+ dtypes.float32,
173
+ message='sequence_input must have dtype float32; got {}.'.format(
174
+ sequence_input.dtype))
175
+ ctx_rank_check = check_ops.assert_rank(
176
+ context_input,
177
+ 2,
178
+ message='context_input must have rank 2',
179
+ data=[array_ops.shape(context_input)])
180
+ ctx_type_check = check_ops.assert_type(
181
+ context_input,
182
+ dtypes.float32,
183
+ message='context_input must have dtype float32; got {}.'.format(
184
+ context_input.dtype))
185
+ with ops.control_dependencies(
186
+ [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
187
+ padded_length = array_ops.shape(sequence_input)[1]
188
+ tiled_context_input = array_ops.tile(
189
+ array_ops.expand_dims(context_input, 1),
190
+ array_ops.concat([[1], [padded_length], [1]], 0))
191
+ return array_ops.concat([sequence_input, tiled_context_input], 2)
192
+
193
+
194
+ def sequence_categorical_column_with_identity(key,
195
+ num_buckets,
196
+ default_value=None,
197
+ feature_name=None):
198
+ """Returns a feature column that represents sequences of integers.
199
+
200
+ Pass this to `embedding_column` or `indicator_column` to convert sequence
201
+ categorical data into dense representation for input to sequence NN, such as
202
+ RNN.
203
+
204
+ Example:
205
+
206
+ ```python
207
+ watches = sequence_categorical_column_with_identity(
208
+ 'watches', num_buckets=1000)
209
+ watches_embedding = embedding_column(watches, dimension=10)
210
+ columns = [watches_embedding]
211
+
212
+ features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
213
+ sequence_feature_layer = SequenceFeatures(columns)
214
+ sequence_input, sequence_length = sequence_feature_layer(features)
215
+ sequence_length_mask = tf.sequence_mask(sequence_length)
216
+
217
+ rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
218
+ rnn_layer = tf.keras.layers.RNN(rnn_cell)
219
+ outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
220
+ ```
221
+
222
+ Args:
223
+ key: A unique string identifying the input feature.
224
+ num_buckets: Range of inputs. Namely, inputs are expected to be in the
225
+ range `[0, num_buckets)`.
226
+ default_value: If `None`, this column's graph operations will fail for
227
+ out-of-range inputs. Otherwise, this value must be in the range
228
+ `[0, num_buckets)`, and will replace out-of-range inputs.
229
+
230
+ Returns:
231
+ A `SequenceCategoricalColumn`.
232
+
233
+ Raises:
234
+ ValueError: if `num_buckets` is less than one.
235
+ ValueError: if `default_value` is not in range `[0, num_buckets)`.
236
+ """
237
+ return fc.SequenceCategoricalColumn(
238
+ fc.categorical_column_with_identity(
239
+ feature_name=feature_name,
240
+ key=key,
241
+ num_buckets=num_buckets,
242
+ default_value=default_value))
243
+
244
+
245
+ def sequence_numeric_column_with_bucketized_column(source_column, boundaries):
246
+ if not isinstance(source_column, (SequenceNumericColumn,)): # pylint: disable=protected-access
247
+ raise ValueError(
248
+ 'source_column must be a column generated with sequence_numeric_column(). '
249
+ 'Given: {}'.format(source_column))
250
+ if len(source_column.shape) > 1:
251
+ raise ValueError('source_column must be one-dimensional column. '
252
+ 'Given: {}'.format(source_column))
253
+ if not boundaries:
254
+ raise ValueError('boundaries must not be empty.')
255
+ if not (isinstance(boundaries, list) or isinstance(boundaries, tuple)):
256
+ raise ValueError('boundaries must be a sorted list.')
257
+ for i in range(len(boundaries) - 1):
258
+ if boundaries[i] >= boundaries[i + 1]:
259
+ raise ValueError('boundaries must be a sorted list.')
260
+ return fc.SequenceBucketizedColumn(source_column, tuple(boundaries))
261
+
262
+
263
+ def sequence_numeric_column_with_raw_column(source_column, sequence_length):
264
+ if not isinstance(source_column, (SequenceNumericColumn,)): # pylint: disable=protected-access
265
+ raise ValueError(
266
+ 'source_column must be a column generated with sequence_numeric_column(). '
267
+ 'Given: {}'.format(source_column))
268
+ if len(source_column.shape) > 1:
269
+ raise ValueError('source_column must be one-dimensional column. '
270
+ 'Given: {}'.format(source_column))
271
+
272
+ return fc.SequenceNumericColumn(source_column, sequence_length)
273
+
274
+
275
+ def sequence_weighted_categorical_column(categorical_column,
276
+ weight_feature_key,
277
+ dtype=dtypes.float32):
278
+ if (dtype is None) or not (dtype.is_integer or dtype.is_floating):
279
+ raise ValueError('dtype {} is not convertible to float.'.format(dtype))
280
+ return fc.SequenceWeightedCategoricalColumn(
281
+ categorical_column=categorical_column,
282
+ weight_feature_key=weight_feature_key,
283
+ dtype=dtype)
284
+
285
+
286
+ def sequence_categorical_column_with_hash_bucket(key,
287
+ hash_bucket_size,
288
+ dtype=dtypes.string,
289
+ feature_name=None):
290
+ """A sequence of categorical terms where ids are set by hashing.
291
+
292
+ Pass this to `embedding_column` or `indicator_column` to convert sequence
293
+ categorical data into dense representation for input to sequence NN, such as
294
+ RNN.
295
+
296
+ Example:
297
+
298
+ ```python
299
+ tokens = sequence_categorical_column_with_hash_bucket(
300
+ 'tokens', hash_bucket_size=1000)
301
+ tokens_embedding = embedding_column(tokens, dimension=10)
302
+ columns = [tokens_embedding]
303
+
304
+ features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
305
+ sequence_feature_layer = SequenceFeatures(columns)
306
+ sequence_input, sequence_length = sequence_feature_layer(features)
307
+ sequence_length_mask = tf.sequence_mask(sequence_length)
308
+
309
+ rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
310
+ rnn_layer = tf.keras.layers.RNN(rnn_cell)
311
+ outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
312
+ ```
313
+
314
+ Args:
315
+ key: A unique string identifying the input feature.
316
+ hash_bucket_size: An int > 1. The number of buckets.
317
+ dtype: The type of features. Only string and integer types are supported.
318
+
319
+ Returns:
320
+ A `SequenceCategoricalColumn`.
321
+
322
+ Raises:
323
+ ValueError: `hash_bucket_size` is not greater than 1.
324
+ ValueError: `dtype` is neither string nor integer.
325
+ """
326
+ return fc.SequenceCategoricalColumn(
327
+ fc.categorical_column_with_hash_bucket(
328
+ feature_name=feature_name,
329
+ key=key,
330
+ hash_bucket_size=hash_bucket_size,
331
+ dtype=dtype))
332
+
333
+
334
+ def sequence_categorical_column_with_vocabulary_file(key,
335
+ vocabulary_file,
336
+ vocabulary_size=None,
337
+ num_oov_buckets=0,
338
+ default_value=None,
339
+ dtype=dtypes.string,
340
+ feature_name=None):
341
+ """A sequence of categorical terms where ids use a vocabulary file.
342
+
343
+ Pass this to `embedding_column` or `indicator_column` to convert sequence
344
+ categorical data into dense representation for input to sequence NN, such as
345
+ RNN.
346
+
347
+ Example:
348
+
349
+ ```python
350
+ states = sequence_categorical_column_with_vocabulary_file(
351
+ key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
352
+ num_oov_buckets=5)
353
+ states_embedding = embedding_column(states, dimension=10)
354
+ columns = [states_embedding]
355
+
356
+ features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
357
+ sequence_feature_layer = SequenceFeatures(columns)
358
+ sequence_input, sequence_length = sequence_feature_layer(features)
359
+ sequence_length_mask = tf.sequence_mask(sequence_length)
360
+
361
+ rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
362
+ rnn_layer = tf.keras.layers.RNN(rnn_cell)
363
+ outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
364
+ ```
365
+
366
+ Args:
367
+ key: A unique string identifying the input feature.
368
+ vocabulary_file: The vocabulary file name.
369
+ vocabulary_size: Number of the elements in the vocabulary. This must be no
370
+ greater than length of `vocabulary_file`, if less than length, later
371
+ values are ignored. If None, it is set to the length of `vocabulary_file`.
372
+ num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
373
+ buckets. All out-of-vocabulary inputs will be assigned IDs in the range
374
+ `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
375
+ the input value. A positive `num_oov_buckets` can not be specified with
376
+ `default_value`.
377
+ default_value: The integer ID value to return for out-of-vocabulary feature
378
+ values, defaults to `-1`. This can not be specified with a positive
379
+ `num_oov_buckets`.
380
+ dtype: The type of features. Only string and integer types are supported.
381
+
382
+ Returns:
383
+ A `SequenceCategoricalColumn`.
384
+
385
+ Raises:
386
+ ValueError: `vocabulary_file` is missing or cannot be opened.
387
+ ValueError: `vocabulary_size` is missing or < 1.
388
+ ValueError: `num_oov_buckets` is a negative integer.
389
+ ValueError: `num_oov_buckets` and `default_value` are both specified.
390
+ ValueError: `dtype` is neither string nor integer.
391
+ """
392
+ return fc.SequenceCategoricalColumn(
393
+ fc.categorical_column_with_vocabulary_file(
394
+ feature_name=feature_name,
395
+ key=key,
396
+ vocabulary_file=vocabulary_file,
397
+ vocabulary_size=vocabulary_size,
398
+ num_oov_buckets=num_oov_buckets,
399
+ default_value=default_value,
400
+ dtype=dtype))
401
+
402
+
403
+ def sequence_categorical_column_with_vocabulary_list(key,
404
+ vocabulary_list,
405
+ dtype=None,
406
+ default_value=-1,
407
+ num_oov_buckets=0,
408
+ feature_name=None):
409
+ """A sequence of categorical terms where ids use an in-memory list.
410
+
411
+ Pass this to `embedding_column` or `indicator_column` to convert sequence
412
+ categorical data into dense representation for input to sequence NN, such as
413
+ RNN.
414
+
415
+ Example:
416
+
417
+ ```python
418
+ colors = sequence_categorical_column_with_vocabulary_list(
419
+ key='colors', vocabulary_list=('R', 'G', 'B', 'Y'),
420
+ num_oov_buckets=2)
421
+ colors_embedding = embedding_column(colors, dimension=3)
422
+ columns = [colors_embedding]
423
+
424
+ features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
425
+ sequence_feature_layer = SequenceFeatures(columns)
426
+ sequence_input, sequence_length = sequence_feature_layer(features)
427
+ sequence_length_mask = tf.sequence_mask(sequence_length)
428
+
429
+ rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
430
+ rnn_layer = tf.keras.layers.RNN(rnn_cell)
431
+ outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
432
+ ```
433
+
434
+ Args:
435
+ key: A unique string identifying the input feature.
436
+ vocabulary_list: An ordered iterable defining the vocabulary. Each feature
437
+ is mapped to the index of its value (if present) in `vocabulary_list`.
438
+ Must be castable to `dtype`.
439
+ dtype: The type of features. Only string and integer types are supported.
440
+ If `None`, it will be inferred from `vocabulary_list`.
441
+ default_value: The integer ID value to return for out-of-vocabulary feature
442
+ values, defaults to `-1`. This can not be specified with a positive
443
+ `num_oov_buckets`.
444
+ num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
445
+ buckets. All out-of-vocabulary inputs will be assigned IDs in the range
446
+ `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a
447
+ hash of the input value. A positive `num_oov_buckets` can not be specified
448
+ with `default_value`.
449
+
450
+ Returns:
451
+ A `SequenceCategoricalColumn`.
452
+
453
+ Raises:
454
+ ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
455
+ ValueError: `num_oov_buckets` is a negative integer.
456
+ ValueError: `num_oov_buckets` and `default_value` are both specified.
457
+ ValueError: if `dtype` is not integer or string.
458
+ """
459
+ return fc.SequenceCategoricalColumn(
460
+ fc.categorical_column_with_vocabulary_list(
461
+ feature_name=feature_name,
462
+ key=key,
463
+ vocabulary_list=vocabulary_list,
464
+ dtype=dtype,
465
+ default_value=default_value,
466
+ num_oov_buckets=num_oov_buckets))
467
+
468
+
469
+ def sequence_numeric_column(key,
470
+ shape=(1,),
471
+ default_value=0.,
472
+ dtype=dtypes.float32,
473
+ normalizer_fn=None,
474
+ feature_name=None):
475
+ """Returns a feature column that represents sequences of numeric data.
476
+
477
+ Example:
478
+
479
+ ```python
480
+ temperature = sequence_numeric_column('temperature')
481
+ columns = [temperature]
482
+
483
+ features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
484
+ sequence_feature_layer = SequenceFeatures(columns)
485
+ sequence_input, sequence_length = sequence_feature_layer(features)
486
+ sequence_length_mask = tf.sequence_mask(sequence_length)
487
+
488
+ rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
489
+ rnn_layer = tf.keras.layers.RNN(rnn_cell)
490
+ outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
491
+ ```
492
+
493
+ Args:
494
+ key: A unique string identifying the input features.
495
+ shape: The shape of the input data per sequence id. E.g. if `shape=(2,)`,
496
+ each example must contain `2 * sequence_length` values.
497
+ default_value: A single value compatible with `dtype` that is used for
498
+ padding the sparse data into a dense `Tensor`.
499
+ dtype: The type of values.
500
+ normalizer_fn: If not `None`, a function that can be used to normalize the
501
+ value of the tensor after `default_value` is applied for parsing.
502
+ Normalizer function takes the input `Tensor` as its argument, and returns
503
+ the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that
504
+ even though the most common use case of this function is normalization, it
505
+ can be used for any kind of Tensorflow transformations.
506
+
507
+ Returns:
508
+ A `SequenceNumericColumn`.
509
+
510
+ Raises:
511
+ TypeError: if any dimension in shape is not an int.
512
+ ValueError: if any dimension in shape is not a positive integer.
513
+ ValueError: if `dtype` is not convertible to `tf.float32`.
514
+ """
515
+ shape = fc._check_shape(shape=shape, key=key)
516
+ if not (dtype.is_integer or dtype.is_floating):
517
+ raise ValueError('dtype must be convertible to float. '
518
+ 'dtype: {}, key: {}'.format(dtype, key))
519
+ if normalizer_fn is not None and not callable(normalizer_fn):
520
+ raise TypeError(
521
+ 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
522
+
523
+ return SequenceNumericColumn(
524
+ feature_name=feature_name,
525
+ key=key,
526
+ shape=shape,
527
+ default_value=default_value,
528
+ dtype=dtype,
529
+ normalizer_fn=normalizer_fn)
530
+
531
+
532
+ def _assert_all_equal_and_return(tensors, name=None):
533
+ """Asserts that all tensors are equal and returns the first one."""
534
+ with ops.name_scope(name, 'assert_all_equal', values=tensors):
535
+ if len(tensors) == 1:
536
+ return tensors[0]
537
+ assert_equal_ops = []
538
+ for t in tensors[1:]:
539
+ assert_equal_ops.append(check_ops.assert_equal(tensors[0], t))
540
+ with ops.control_dependencies(assert_equal_ops):
541
+ return array_ops.identity(tensors[0])
542
+
543
+
544
+ class SequenceNumericColumn(
545
+ fc.SequenceDenseColumn, fc_v1._FeatureColumn,
546
+ collections.namedtuple('SequenceNumericColumn',
547
+ ('feature_name', 'key', 'shape', 'default_value',
548
+ 'dtype', 'normalizer_fn'))):
549
+ """Represents sequences of numeric data."""
550
+
551
+ @property
552
+ def _is_v2_column(self):
553
+ return True
554
+
555
+ @property
556
+ def name(self):
557
+ """See `FeatureColumn` base class."""
558
+ return self.feature_name if self.feature_name else self.key
559
+
560
+ @property
561
+ def raw_name(self):
562
+ """See `FeatureColumn` base class."""
563
+ return self.key
564
+
565
+ @property
566
+ def parse_example_spec(self):
567
+ """See `FeatureColumn` base class."""
568
+ return {self.key: parsing_ops.VarLenFeature(self.dtype)}
569
+
570
+ def _transform_feature(self, inputs):
571
+ input_tensor = inputs.get(self.key)
572
+ return self._transform_input_tensor(input_tensor)
573
+
574
+ def _transform_input_tensor(self, input_tensor):
575
+ return math_ops.cast(input_tensor, dtypes.float32)
576
+
577
+ def transform_feature(self, transformation_cache, state_manager):
578
+ """See `FeatureColumn` base class.
579
+
580
+ In this case, we apply the `normalizer_fn` to the input tensor.
581
+
582
+ Args:
583
+ transformation_cache: A `FeatureTransformationCache` object to access
584
+ features.
585
+ state_manager: A `StateManager` to create / access resources such as
586
+ lookup tables.
587
+
588
+ Returns:
589
+ Normalized input tensor.
590
+ """
591
+ input_tensor = transformation_cache.get(self.key, state_manager)
592
+ if self.normalizer_fn is not None:
593
+ input_tensor = self.normalizer_fn(input_tensor)
594
+ return self._transform_input_tensor(input_tensor)
595
+
596
+ @property
597
+ def variable_shape(self):
598
+ """Returns a `TensorShape` representing the shape of sequence input."""
599
+ return tensor_shape.TensorShape(self.shape)
600
+
601
+ def get_sequence_dense_tensor(self, transformation_cache, state_manager):
602
+ """Returns a `TensorSequenceLengthPair`.
603
+
604
+ Args:
605
+ transformation_cache: A `FeatureTransformationCache` object to access
606
+ features.
607
+ state_manager: A `StateManager` to create / access resources such as
608
+ lookup tables.
609
+ """
610
+ sp_tensor = transformation_cache.get(self, state_manager)
611
+ dense_tensor = sparse_ops.sparse_tensor_to_dense(
612
+ sp_tensor, default_value=self.default_value)
613
+ # Reshape into [batch_size, T, variable_shape].
614
+ dense_shape = array_ops.concat(
615
+ [array_ops.shape(dense_tensor)[:1], [-1], self.variable_shape], axis=0)
616
+ dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape)
617
+
618
+ # Get the number of timesteps per example
619
+ # For the 2D case, the raw values are grouped according to num_elements;
620
+ # for the 3D case, the grouping happens in the third dimension, and
621
+ # sequence length is not affected.
622
+ if sp_tensor.shape.ndims == 2:
623
+ num_elements = self.variable_shape.num_elements()
624
+ else:
625
+ num_elements = 1
626
+ seq_length = fc_utils.sequence_length_from_sparse_tensor(
627
+ sp_tensor, num_elements=num_elements)
628
+
629
+ return fc.SequenceDenseColumn.TensorSequenceLengthPair(
630
+ dense_tensor=dense_tensor, sequence_length=seq_length)
631
+
632
+ # TODO(b/119409767): Implement parents, _{get,from}_config.
633
+ @property
634
+ def parents(self):
635
+ """See 'FeatureColumn` base class."""
636
+ raise NotImplementedError()
637
+
638
+ def _get_config(self):
639
+ """See 'FeatureColumn` base class."""
640
+ raise NotImplementedError()
641
+
642
+ @classmethod
643
+ def _from_config(cls, config, custom_objects=None, columns_by_name=None):
644
+ """See 'FeatureColumn` base class."""
645
+ raise NotImplementedError()
646
+
647
+
648
+ # pylint: enable=protected-access