easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,1064 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+ import os
5
+ from abc import abstractmethod
6
+ from collections import OrderedDict
7
+
8
+ import six
9
+ import tensorflow as tf
10
+ from tensorflow.python.framework import ops
11
+ from tensorflow.python.ops import array_ops
12
+ from tensorflow.python.ops import sparse_ops
13
+ from tensorflow.python.ops import string_ops
14
+ from tensorflow.python.platform import gfile
15
+
16
+ from easy_rec.python.core import sampler as sampler_lib
17
+ from easy_rec.python.protos.dataset_pb2 import DatasetConfig
18
+ from easy_rec.python.utils import conditional
19
+ from easy_rec.python.utils import config_util
20
+ from easy_rec.python.utils import constant
21
+ from easy_rec.python.utils.check_utils import check_split
22
+ from easy_rec.python.utils.check_utils import check_string_to_number
23
+ from easy_rec.python.utils.expr_util import get_expression
24
+ from easy_rec.python.utils.input_utils import get_type_defaults
25
+ from easy_rec.python.utils.load_class import get_register_class_meta
26
+ from easy_rec.python.utils.load_class import load_by_path
27
+ from easy_rec.python.utils.tf_utils import get_tf_type
28
+
29
+ if tf.__version__ >= '2.0':
30
+ tf = tf.compat.v1
31
+
32
+ _INPUT_CLASS_MAP = {}
33
+ _meta_type = get_register_class_meta(_INPUT_CLASS_MAP, have_abstract_class=True)
34
+
35
+
36
+ class Input(six.with_metaclass(_meta_type, object)):
37
+
38
+ DATA_OFFSET = 'DATA_OFFSET'
39
+
40
+ def __init__(self,
41
+ data_config,
42
+ feature_configs,
43
+ input_path,
44
+ task_index=0,
45
+ task_num=1,
46
+ check_mode=False,
47
+ pipeline_config=None,
48
+ **kwargs):
49
+ self._pipeline_config = pipeline_config
50
+ self._data_config = data_config
51
+ self._check_mode = check_mode
52
+ logging.info('check_mode: %s ' % self._check_mode)
53
+ # tf.estimator.ModeKeys.*, only available before
54
+ # calling self._build
55
+ self._mode = None
56
+ if pipeline_config is not None and pipeline_config.model_config.HasField(
57
+ 'ev_params'):
58
+ self._has_ev = True
59
+ else:
60
+ self._has_ev = False
61
+
62
+ if self._data_config.auto_expand_input_fields:
63
+ input_fields = [x for x in self._data_config.input_fields]
64
+ while len(self._data_config.input_fields) > 0:
65
+ self._data_config.input_fields.pop()
66
+ for field in input_fields:
67
+ tmp_names = config_util.auto_expand_names(field.input_name)
68
+ for tmp_name in tmp_names:
69
+ one_field = DatasetConfig.Field()
70
+ one_field.CopyFrom(field)
71
+ one_field.input_name = tmp_name
72
+ self._data_config.input_fields.append(one_field)
73
+
74
+ self._input_fields = [x.input_name for x in data_config.input_fields]
75
+ self._input_dims = [x.input_dim for x in data_config.input_fields]
76
+ self._input_field_types = [x.input_type for x in data_config.input_fields]
77
+ self._input_field_defaults = [
78
+ x.default_val for x in data_config.input_fields
79
+ ]
80
+ self._label_fields = list(data_config.label_fields)
81
+ self._feature_fields = list(data_config.feature_fields)
82
+ self._label_sep = list(data_config.label_sep)
83
+ self._label_dim = list(data_config.label_dim)
84
+ if len(self._label_dim) < len(self._label_fields):
85
+ for x in range(len(self._label_fields) - len(self._label_dim)):
86
+ self._label_dim.append(1)
87
+
88
+ self._label_udf_map = {}
89
+ for config in self._data_config.input_fields:
90
+ if config.HasField('user_define_fn'):
91
+ self._label_udf_map[config.input_name] = self._load_label_fn(config)
92
+
93
+ self._batch_size = data_config.batch_size
94
+ self._prefetch_size = data_config.prefetch_size
95
+ self._feature_configs = list(feature_configs)
96
+ self._task_index = task_index
97
+ self._task_num = task_num
98
+
99
+ self._input_path = input_path
100
+
101
+ # findout effective fields
102
+ self._effective_fields = []
103
+
104
+ # for multi value inputs, the types maybe different
105
+ # from the types defined in input_fields
106
+ # it is used in create_multi_placeholders
107
+ self._multi_value_types = {}
108
+ self._multi_value_fields = set()
109
+
110
+ self._normalizer_fn = {}
111
+ for fc in self._feature_configs:
112
+ for input_name in fc.input_names:
113
+ assert input_name in self._input_fields, 'invalid input_name in %s' % str(
114
+ fc)
115
+ if input_name not in self._effective_fields:
116
+ self._effective_fields.append(input_name)
117
+
118
+ if fc.feature_type in [fc.TagFeature, fc.SequenceFeature]:
119
+ if fc.hash_bucket_size > 0 or len(
120
+ fc.vocab_list) > 0 or fc.HasField('vocab_file'):
121
+ self._multi_value_types[fc.input_names[0]] = tf.string
122
+ self._multi_value_fields.add(fc.input_names[0])
123
+ else:
124
+ self._multi_value_types[fc.input_names[0]] = tf.int64
125
+ self._multi_value_fields.add(fc.input_names[0])
126
+ if len(fc.input_names) > 1:
127
+ self._multi_value_types[fc.input_names[1]] = tf.float32
128
+ self._multi_value_fields.add(fc.input_names[1])
129
+
130
+ if fc.feature_type == fc.RawFeature and fc.raw_input_dim > 1:
131
+ self._multi_value_types[fc.input_names[0]] = tf.float32
132
+ self._multi_value_fields.add(fc.input_names[0])
133
+
134
+ if fc.HasField('normalizer_fn'):
135
+ feature_name = fc.feature_name if fc.HasField(
136
+ 'feature_name') else fc.input_names[0]
137
+ self._normalizer_fn[feature_name] = load_by_path(fc.normalizer_fn)
138
+
139
+ # add sample weight to effective fields
140
+ if self._data_config.HasField('sample_weight'):
141
+ self._effective_fields.append(self._data_config.sample_weight)
142
+
143
+ # add uid_field of GAUC and session_fields of SessionAUC
144
+ if self._pipeline_config is not None:
145
+ metrics = self._pipeline_config.eval_config.metrics_set
146
+ for metric in metrics:
147
+ metric_name = metric.WhichOneof('metric')
148
+ if metric_name == 'gauc':
149
+ uid = metric.gauc.uid_field
150
+ if uid not in self._effective_fields:
151
+ self._effective_fields.append(uid)
152
+ elif metric_name == 'session_auc':
153
+ sid = metric.session_auc.session_id_field
154
+ if sid not in self._effective_fields:
155
+ self._effective_fields.append(sid)
156
+
157
+ # check multi task model's metrics
158
+ model_config = self._pipeline_config.model_config
159
+ model_name = model_config.WhichOneof('model')
160
+ if model_name in {'mmoe', 'esmm', 'dbmtl', 'simple_multi_task', 'ple'}:
161
+ model = getattr(model_config, model_name)
162
+ towers = [model.ctr_tower, model.cvr_tower
163
+ ] if model_name == 'esmm' else model.task_towers
164
+ for tower in towers:
165
+ metrics = tower.metrics_set
166
+ for metric in metrics:
167
+ metric_name = metric.WhichOneof('metric')
168
+ if metric_name == 'gauc':
169
+ uid = metric.gauc.uid_field
170
+ if uid not in self._effective_fields:
171
+ self._effective_fields.append(uid)
172
+ elif metric_name == 'session_auc':
173
+ sid = metric.session_auc.session_id_field
174
+ if sid not in self._effective_fields:
175
+ self._effective_fields.append(sid)
176
+
177
+ self._effective_fids = [
178
+ self._input_fields.index(x) for x in self._effective_fields
179
+ ]
180
+ # sort fids from small to large
181
+ self._effective_fids = list(set(self._effective_fids))
182
+ self._effective_fields = [
183
+ self._input_fields[x] for x in self._effective_fids
184
+ ]
185
+
186
+ self._label_fids = [self._input_fields.index(x) for x in self._label_fields]
187
+
188
+ # virtual fields generated by self._preprocess
189
+ # which will be inputs to feature columns
190
+ self._appended_fields = []
191
+
192
+ # sampler
193
+ self._sampler = None
194
+ if input_path is not None:
195
+ # build sampler only when train and eval
196
+ self._sampler = sampler_lib.build(data_config)
197
+
198
+ self.get_type_defaults = get_type_defaults
199
+
200
+ def _load_label_fn(self, config):
201
+ udf_class = config.user_define_fn
202
+ udf_path = config.user_define_fn_path if config.HasField(
203
+ 'user_define_fn_path') else None
204
+ dtype = config.user_define_fn_res_type if config.HasField(
205
+ 'user_define_fn_res_type') else None
206
+
207
+ if udf_path:
208
+ if udf_path.startswith('oss://') or udf_path.startswith('hdfs://'):
209
+ with gfile.GFile(udf_path, 'r') as fin:
210
+ udf_content = fin.read()
211
+ final_udf_tmp_path = '/udf/'
212
+ final_udf_path = final_udf_tmp_path + udf_path.split('/')[-1]
213
+ logging.info('final udf path %s' % final_udf_path)
214
+ logging.info('udf content: %s' % udf_content)
215
+ if not gfile.Exists(final_udf_tmp_path):
216
+ gfile.MkDir(final_udf_tmp_path)
217
+ with gfile.GFile(final_udf_path, 'w') as fin:
218
+ fin.write(udf_content)
219
+ else:
220
+ final_udf_path = udf_path
221
+ final_udf_path = final_udf_path[:-3].replace('/', '.')
222
+ udf_class = final_udf_path + '.' + udf_class
223
+ logging.info('apply udf %s' % udf_class)
224
+ return load_by_path(udf_class), udf_class, dtype
225
+
226
+ @property
227
+ def num_epochs(self):
228
+ if self._data_config.num_epochs > 0:
229
+ return self._data_config.num_epochs
230
+ else:
231
+ return None
232
+
233
+ def get_feature_input_fields(self):
234
+ return [
235
+ x for x in self._input_fields
236
+ if x not in self._label_fields and x != self._data_config.sample_weight
237
+ ]
238
+
239
+ def should_stop(self, curr_epoch):
240
+ """Check whether have run enough num epochs."""
241
+ total_epoch = self.num_epochs
242
+ if self._mode != tf.estimator.ModeKeys.TRAIN:
243
+ total_epoch = 1
244
+ return total_epoch is not None and curr_epoch >= total_epoch
245
+
246
+ def create_multi_placeholders(self, export_config):
247
+ """Create multiply placeholders on export, one for each feature.
248
+
249
+ Args:
250
+ export_config: ExportConfig instance.
251
+ """
252
+ self._mode = tf.estimator.ModeKeys.PREDICT
253
+
254
+ if export_config.auto_multi_value:
255
+ export_fields_name = self._multi_value_fields
256
+ elif export_config.multi_value_fields:
257
+ export_fields_name = export_config.multi_value_fields.input_name
258
+ else:
259
+ export_fields_name = None
260
+ placeholder_named_by_input = export_config.placeholder_named_by_input
261
+
262
+ sample_weight_field = ''
263
+ if self._data_config.HasField('sample_weight'):
264
+ sample_weight_field = self._data_config.sample_weight
265
+
266
+ if export_config.filter_inputs:
267
+ effective_fids = list(self._effective_fids)
268
+ else:
269
+ effective_fids = [
270
+ fid for fid in range(len(self._input_fields))
271
+ if self._input_fields[fid] not in self._label_fields and
272
+ self._input_fields[fid] != sample_weight_field
273
+ ]
274
+
275
+ inputs = {}
276
+ for fid in effective_fids:
277
+ input_name = self._input_fields[fid]
278
+ if input_name == sample_weight_field:
279
+ continue
280
+ if placeholder_named_by_input:
281
+ placeholder_name = input_name
282
+ else:
283
+ placeholder_name = 'input_%d' % fid
284
+ if input_name in export_fields_name:
285
+ tf_type = self._multi_value_types[input_name] if input_name in self._multi_value_types \
286
+ else get_tf_type(self._input_field_types[fid])
287
+ logging.info('multi value input_name: %s, dtype: %s' %
288
+ (input_name, tf_type))
289
+ finput = array_ops.placeholder(
290
+ tf_type, [None, None], name=placeholder_name)
291
+ else:
292
+ ftype = self._input_field_types[fid]
293
+ tf_type = get_tf_type(ftype)
294
+ logging.info('input_name: %s, dtype: %s' % (input_name, tf_type))
295
+ finput = array_ops.placeholder(tf_type, [None], name=placeholder_name)
296
+ inputs[input_name] = finput
297
+ features = {x: inputs[x] for x in inputs}
298
+ features = self._preprocess(features)
299
+ return inputs, features['feature']
300
+
301
+ def create_placeholders(self, export_config):
302
+ self._mode = tf.estimator.ModeKeys.PREDICT
303
+ inputs_placeholder = array_ops.placeholder(
304
+ tf.string, [None], name='features')
305
+ input_vals = tf.string_split(
306
+ inputs_placeholder, self._data_config.separator,
307
+ skip_empty=False).values
308
+
309
+ sample_weight_field = ''
310
+ if self._data_config.HasField('sample_weight'):
311
+ sample_weight_field = self._data_config.sample_weight
312
+
313
+ if export_config.filter_inputs:
314
+ effective_fids = list(self._effective_fids)
315
+ logging.info('number of effective inputs:%d, total number inputs: %d' %
316
+ (len(effective_fids), len(self._input_fields)))
317
+ else:
318
+ effective_fids = [
319
+ fid for fid in range(len(self._input_fields))
320
+ if self._input_fields[fid] not in self._label_fields and
321
+ self._input_fields[fid] != sample_weight_field
322
+ ]
323
+ logging.info(
324
+ 'will not filter any input[except labels], total number inputs:%d' %
325
+ len(effective_fids))
326
+ input_vals = tf.reshape(
327
+ input_vals, [-1, len(effective_fids)], name='input_reshape')
328
+ features = {}
329
+ for tmp_id, fid in enumerate(effective_fids):
330
+ ftype = self._input_field_types[fid]
331
+ tf_type = get_tf_type(ftype)
332
+ input_name = self._input_fields[fid]
333
+ if tf_type in [tf.float32, tf.double, tf.int32, tf.int64]:
334
+ features[input_name] = tf.string_to_number(
335
+ input_vals[:, tmp_id],
336
+ tf_type,
337
+ name='input_str_to_%s' % tf_type.name)
338
+ else:
339
+ if ftype not in [DatasetConfig.STRING]:
340
+ logging.warning('unexpected field type: ftype=%s tf_type=%s' %
341
+ (ftype, tf_type))
342
+ features[input_name] = input_vals[:, tmp_id]
343
+ features = self._preprocess(features)
344
+ return {'features': inputs_placeholder}, features['feature']
345
+
346
+ def _get_features(self, fields):
347
+ return fields['feature']
348
+
349
+ def _get_labels(self, fields):
350
+ labels = fields['label']
351
+ return OrderedDict([
352
+ (x, tf.squeeze(labels[x], axis=1) if len(labels[x].get_shape()) == 2 and
353
+ labels[x].get_shape()[1] == 1 else labels[x]) for x in labels
354
+ ])
355
+
356
+ def _as_string(self, field, fc):
357
+ if field.dtype == tf.string:
358
+ return field
359
+ if field.dtype in [tf.float32, tf.double]:
360
+ feature_name = fc.feature_name if fc.HasField(
361
+ 'feature_name') else fc.input_names[0]
362
+ assert fc.precision > 0, 'fc.precision not set for feature[%s], it is dangerous to convert ' \
363
+ 'float or double to string due to precision problem, it is suggested ' \
364
+ ' to convert them into string format before using EasyRec; ' \
365
+ 'if you really need to do so, please set precision (the number of ' \
366
+ 'decimal digits) carefully.' % feature_name
367
+ precision = None
368
+ if field.dtype in [tf.float32, tf.double]:
369
+ if fc.precision > 0:
370
+ precision = fc.precision
371
+
372
+ # convert to string
373
+ if 'as_string' in dir(tf.strings):
374
+ return tf.strings.as_string(field, precision=precision)
375
+ else:
376
+ return tf.as_string(field, precision=precision)
377
+
378
+ def _parse_combo_feature(self, fc, parsed_dict, field_dict):
379
+ # for compatibility with existing implementations
380
+ feature_name = fc.feature_name if fc.HasField(
381
+ 'feature_name') else fc.input_names[0]
382
+
383
+ if len(fc.combo_input_seps) > 0:
384
+ assert len(fc.combo_input_seps) == len(fc.input_names), \
385
+ 'len(combo_separator)[%d] != len(fc.input_names)[%d]' % (
386
+ len(fc.combo_input_seps), len(fc.input_names))
387
+
388
+ def _get_input_sep(input_id):
389
+ if input_id < len(fc.combo_input_seps):
390
+ return fc.combo_input_seps[input_id]
391
+ else:
392
+ return ''
393
+
394
+ if len(fc.combo_join_sep) == 0:
395
+ for input_id, input_name in enumerate(fc.input_names):
396
+ if input_id > 0:
397
+ key = feature_name + '_' + str(input_id)
398
+ else:
399
+ key = feature_name
400
+ input_sep = _get_input_sep(input_id)
401
+ if input_sep != '':
402
+ assert field_dict[
403
+ input_name].dtype == tf.string, 'could not apply string_split to input-name[%s] dtype=%s' % (
404
+ input_name, field_dict[input_name].dtype)
405
+ parsed_dict[key] = tf.string_split(field_dict[input_name], input_sep)
406
+ else:
407
+ parsed_dict[key] = self._as_string(field_dict[input_name], fc)
408
+ else:
409
+ if len(fc.combo_input_seps) > 0:
410
+ split_inputs = []
411
+ for input_id, input_name in enumerate(fc.input_names):
412
+ input_sep = fc.combo_input_seps[input_id]
413
+ if len(input_sep) > 0:
414
+ assert field_dict[
415
+ input_name].dtype == tf.string, 'could not apply string_split to input-name[%s] dtype=%s' % (
416
+ input_name, field_dict[input_name].dtype)
417
+ split_inputs.append(
418
+ tf.string_split(field_dict[input_name],
419
+ fc.combo_input_seps[input_id]))
420
+ else:
421
+ split_inputs.append(tf.reshape(field_dict[input_name], [-1, 1]))
422
+ parsed_dict[feature_name] = sparse_ops.sparse_cross(
423
+ split_inputs, fc.combo_join_sep)
424
+ else:
425
+ inputs = [
426
+ self._as_string(field_dict[input_name], fc)
427
+ for input_name in fc.input_names
428
+ ]
429
+ parsed_dict[feature_name] = string_ops.string_join(
430
+ inputs, fc.combo_join_sep)
431
+
432
+ def _parse_tag_feature(self, fc, parsed_dict, field_dict):
433
+ input_0 = fc.input_names[0]
434
+ feature_name = fc.feature_name if fc.HasField('feature_name') else input_0
435
+ field = field_dict[input_0]
436
+ # Construct the output of TagFeature according to the dimension of field_dict.
437
+ # When the input field exceeds 2 dimensions, convert TagFeature to 2D output.
438
+ if len(field.get_shape()) < 2 or field.get_shape()[-1] == 1:
439
+ if len(field.get_shape()) == 0:
440
+ field = tf.expand_dims(field, axis=0)
441
+ elif len(field.get_shape()) == 2:
442
+ field = tf.squeeze(field, axis=-1)
443
+ if fc.HasField('kv_separator') and len(fc.input_names) > 1:
444
+ assert False, 'Tag Feature Error, ' \
445
+ 'Cannot set kv_separator and multi input_names in one feature config. Feature: %s.' % input_0
446
+ parsed_dict[feature_name] = tf.string_split(field, fc.separator)
447
+ if fc.HasField('kv_separator'):
448
+ indices = parsed_dict[feature_name].indices
449
+ tmp_kvs = parsed_dict[feature_name].values
450
+ tmp_kvs = tf.string_split(tmp_kvs, fc.kv_separator, skip_empty=False)
451
+ tmp_kvs = tf.reshape(tmp_kvs.values, [-1, 2])
452
+ tmp_ks, tmp_vs = tmp_kvs[:, 0], tmp_kvs[:, 1]
453
+
454
+ check_list = [
455
+ tf.py_func(check_string_to_number, [tmp_vs, input_0], Tout=tf.bool)
456
+ ] if self._check_mode else []
457
+ with tf.control_dependencies(check_list):
458
+ tmp_vs = tf.string_to_number(
459
+ tmp_vs, tf.float32, name='kv_tag_wgt_str_2_flt_%s' % input_0)
460
+ parsed_dict[feature_name] = tf.sparse.SparseTensor(
461
+ indices, tmp_ks, parsed_dict[feature_name].dense_shape)
462
+ parsed_dict[feature_name + '_w'] = tf.sparse.SparseTensor(
463
+ indices, tmp_vs, parsed_dict[feature_name].dense_shape)
464
+ if not fc.HasField('hash_bucket_size') and fc.num_buckets > 0:
465
+ check_list = [
466
+ tf.py_func(
467
+ check_string_to_number,
468
+ [parsed_dict[feature_name].values, input_0],
469
+ Tout=tf.bool)
470
+ ] if self._check_mode else []
471
+ with tf.control_dependencies(check_list):
472
+ vals = tf.string_to_number(
473
+ parsed_dict[feature_name].values,
474
+ tf.int32,
475
+ name='tag_fea_%s' % input_0)
476
+ parsed_dict[feature_name] = tf.sparse.SparseTensor(
477
+ parsed_dict[feature_name].indices, vals,
478
+ parsed_dict[feature_name].dense_shape)
479
+ if len(fc.input_names) > 1:
480
+ input_1 = fc.input_names[1]
481
+ field = field_dict[input_1]
482
+ if len(field.get_shape()) == 0:
483
+ field = tf.expand_dims(field, axis=0)
484
+ field = tf.string_split(field, fc.separator)
485
+ check_list = [
486
+ tf.py_func(
487
+ check_string_to_number, [field.values, input_1], Tout=tf.bool)
488
+ ] if self._check_mode else []
489
+ with tf.control_dependencies(check_list):
490
+ field_vals = tf.string_to_number(
491
+ field.values, tf.float32, name='tag_wgt_str_2_flt_%s' % input_1)
492
+ assert_op = tf.assert_equal(
493
+ tf.shape(field_vals)[0],
494
+ tf.shape(parsed_dict[feature_name].values)[0],
495
+ message='TagFeature Error: The size of %s not equal to the size of %s. Please check input: %s and %s.'
496
+ % (input_0, input_1, input_0, input_1))
497
+ with tf.control_dependencies([assert_op]):
498
+ field = tf.sparse.SparseTensor(field.indices, tf.identity(field_vals),
499
+ field.dense_shape)
500
+ parsed_dict[feature_name + '_w'] = field
501
+ else:
502
+ parsed_dict[feature_name] = field_dict[input_0]
503
+ if len(fc.input_names) > 1:
504
+ input_1 = fc.input_names[1]
505
+ parsed_dict[feature_name + '_w'] = field_dict[input_1]
506
+
507
+ def _parse_expr_feature(self, fc, parsed_dict, field_dict):
508
+ fea_name = fc.feature_name
509
+ prefix = 'expr_'
510
+ for input_name in fc.input_names:
511
+ new_input_name = prefix + input_name
512
+ if field_dict[input_name].dtype == tf.string:
513
+ check_list = [
514
+ tf.py_func(
515
+ check_string_to_number, [field_dict[input_name], input_name],
516
+ Tout=tf.bool)
517
+ ] if self._check_mode else []
518
+ with tf.control_dependencies(check_list):
519
+ parsed_dict[new_input_name] = tf.string_to_number(
520
+ field_dict[input_name],
521
+ tf.float64,
522
+ name='%s_str_2_int_for_expr' % new_input_name)
523
+ elif field_dict[input_name].dtype in [
524
+ tf.int32, tf.int64, tf.double, tf.float32
525
+ ]:
526
+ parsed_dict[new_input_name] = tf.cast(field_dict[input_name],
527
+ tf.float64)
528
+ else:
529
+ assert False, 'invalid input dtype[%s] for expr feature' % str(
530
+ field_dict[input_name].dtype)
531
+
532
+ expression = get_expression(fc.expression, fc.input_names, prefix=prefix)
533
+ logging.info('expression: %s' % expression)
534
+ parsed_dict[fea_name] = eval(expression)
535
+ self._appended_fields.append(fea_name)
536
+
537
+ def _parse_id_feature(self, fc, parsed_dict, field_dict):
538
+ input_0 = fc.input_names[0]
539
+ feature_name = fc.feature_name if fc.HasField('feature_name') else input_0
540
+ parsed_dict[feature_name] = field_dict[input_0]
541
+ if fc.HasField('hash_bucket_size'):
542
+ if field_dict[input_0].dtype != tf.string:
543
+ parsed_dict[feature_name] = self._as_string(field_dict[input_0], fc)
544
+ elif fc.num_buckets > 0:
545
+ if parsed_dict[feature_name].dtype == tf.string:
546
+ check_list = [
547
+ tf.py_func(
548
+ check_string_to_number, [parsed_dict[feature_name], input_0],
549
+ Tout=tf.bool)
550
+ ] if self._check_mode else []
551
+ with tf.control_dependencies(check_list):
552
+ parsed_dict[feature_name] = tf.string_to_number(
553
+ parsed_dict[feature_name],
554
+ tf.int32,
555
+ name='%s_str_2_int' % input_0)
556
+
557
+ def _parse_raw_feature(self, fc, parsed_dict, field_dict):
558
+ input_0 = fc.input_names[0]
559
+ feature_name = fc.feature_name if fc.HasField('feature_name') else input_0
560
+ if field_dict[input_0].dtype == tf.string:
561
+ if fc.HasField('seq_multi_sep') and fc.HasField('combiner'):
562
+ fea = tf.string_split(field_dict[input_0], fc.seq_multi_sep)
563
+ segment_ids = fea.indices[:, 0]
564
+ vals = fea.values
565
+ else:
566
+ vals = field_dict[input_0]
567
+ segment_ids = tf.range(0, tf.shape(vals)[0])
568
+ if fc.raw_input_dim > 1:
569
+ check_list = [
570
+ tf.py_func(
571
+ check_split, [vals, fc.separator, fc.raw_input_dim, input_0],
572
+ Tout=tf.bool)
573
+ ] if self._check_mode else []
574
+ with tf.control_dependencies(check_list):
575
+ tmp_fea = tf.string_split(vals, fc.separator)
576
+ check_list = [
577
+ tf.py_func(
578
+ check_string_to_number, [tmp_fea.values, input_0], Tout=tf.bool)
579
+ ] if self._check_mode else []
580
+ with tf.control_dependencies(check_list):
581
+ tmp_vals = tf.string_to_number(
582
+ tmp_fea.values,
583
+ tf.float32,
584
+ name='multi_raw_fea_to_flt_%s' % input_0)
585
+ if fc.HasField('seq_multi_sep') and fc.HasField('combiner'):
586
+ emb = tf.reshape(tmp_vals, [-1, fc.raw_input_dim])
587
+ if fc.combiner == 'max':
588
+ emb = tf.segment_max(emb, segment_ids)
589
+ elif fc.combiner == 'sum':
590
+ emb = tf.segment_sum(emb, segment_ids)
591
+ elif fc.combiner == 'min':
592
+ emb = tf.segment_min(emb, segment_ids)
593
+ elif fc.combiner == 'mean':
594
+ emb = tf.segment_mean(emb, segment_ids)
595
+ else:
596
+ assert False, 'unsupported combine operator: ' + fc.combiner
597
+ parsed_dict[feature_name] = emb
598
+ else:
599
+ parsed_dict[feature_name] = tf.sparse_to_dense(
600
+ tmp_fea.indices,
601
+ [tf.shape(field_dict[input_0])[0], fc.raw_input_dim],
602
+ tmp_vals,
603
+ default_value=0)
604
+ elif fc.HasField('seq_multi_sep') and fc.HasField('combiner'):
605
+ check_list = [
606
+ tf.py_func(check_string_to_number, [vals, input_0], Tout=tf.bool)
607
+ ] if self._check_mode else []
608
+ with tf.control_dependencies(check_list):
609
+ emb = tf.string_to_number(
610
+ vals, tf.float32, name='raw_fea_to_flt_%s' % input_0)
611
+ if fc.combiner == 'max':
612
+ emb = tf.segment_max(emb, segment_ids)
613
+ elif fc.combiner == 'sum':
614
+ emb = tf.segment_sum(emb, segment_ids)
615
+ elif fc.combiner == 'min':
616
+ emb = tf.segment_min(emb, segment_ids)
617
+ elif fc.combiner == 'mean':
618
+ emb = tf.segment_mean(emb, segment_ids)
619
+ else:
620
+ assert False, 'unsupported combine operator: ' + fc.combiner
621
+ parsed_dict[feature_name] = emb
622
+ else:
623
+ check_list = [
624
+ tf.py_func(
625
+ check_string_to_number, [field_dict[input_0], input_0],
626
+ Tout=tf.bool)
627
+ ] if self._check_mode else []
628
+ with tf.control_dependencies(check_list):
629
+ parsed_dict[feature_name] = tf.string_to_number(
630
+ field_dict[input_0], tf.float32)
631
+ elif field_dict[input_0].dtype in [
632
+ tf.int32, tf.int64, tf.double, tf.float32
633
+ ]:
634
+ parsed_dict[feature_name] = tf.to_float(field_dict[input_0])
635
+ else:
636
+ assert False, 'invalid dtype[%s] for raw feature' % str(
637
+ field_dict[input_0].dtype)
638
+ if fc.max_val > fc.min_val:
639
+ parsed_dict[feature_name] = (parsed_dict[feature_name] - fc.min_val) / (
640
+ fc.max_val - fc.min_val)
641
+
642
+ if fc.HasField('normalizer_fn'):
643
+ logging.info('apply normalizer_fn %s to `%s`' %
644
+ (fc.normalizer_fn, feature_name))
645
+ parsed_dict[feature_name] = self._normalizer_fn[feature_name](
646
+ parsed_dict[feature_name])
647
+
648
+ if not fc.boundaries and fc.num_buckets <= 1 and \
649
+ fc.embedding_dim > 0 and \
650
+ self._data_config.sample_weight != input_0:
651
+ # may need by wide model and deep model to project
652
+ # raw values to a vector, it maybe better implemented
653
+ # by a ProjectionColumn later
654
+ sample_num = tf.to_int64(tf.shape(parsed_dict[feature_name])[0])
655
+ indices_0 = tf.range(sample_num, dtype=tf.int64)
656
+ indices_1 = tf.range(fc.raw_input_dim, dtype=tf.int64)
657
+ indices_0 = indices_0[:, None]
658
+ indices_1 = indices_1[None, :]
659
+ indices_0 = tf.tile(indices_0, [1, fc.raw_input_dim])
660
+ indices_1 = tf.tile(indices_1, [sample_num, 1])
661
+ indices_0 = tf.reshape(indices_0, [-1, 1])
662
+ indices_1 = tf.reshape(indices_1, [-1, 1])
663
+ indices = tf.concat([indices_0, indices_1], axis=1)
664
+
665
+ tmp_parsed = parsed_dict[feature_name]
666
+ parsed_dict[feature_name + '_raw_proj_id'] = tf.SparseTensor(
667
+ indices=indices,
668
+ values=indices_1[:, 0],
669
+ dense_shape=[sample_num, fc.raw_input_dim])
670
+ parsed_dict[feature_name + '_raw_proj_val'] = tf.SparseTensor(
671
+ indices=indices,
672
+ values=tf.reshape(tmp_parsed, [-1]),
673
+ dense_shape=[sample_num, fc.raw_input_dim])
674
+ # self._appended_fields.append(input_0 + '_raw_proj_id')
675
+ # self._appended_fields.append(input_0 + '_raw_proj_val')
676
+
677
+ def _parse_seq_feature(self, fc, parsed_dict, field_dict):
678
+ input_0 = fc.input_names[0]
679
+ feature_name = fc.feature_name if fc.HasField('feature_name') else input_0
680
+ field = field_dict[input_0]
681
+ sub_feature_type = fc.sub_feature_type
682
+ # Construct the output of SeqFeature according to the dimension of field_dict.
683
+ # When the input field exceeds 2 dimensions, convert SeqFeature to 2D output.
684
+ if len(field.get_shape()) < 2:
685
+ parsed_dict[feature_name] = tf.strings.split(field, fc.separator)
686
+ if fc.HasField('seq_multi_sep'):
687
+ indices = parsed_dict[feature_name].indices
688
+ values = parsed_dict[feature_name].values
689
+ multi_vals = tf.string_split(values, fc.seq_multi_sep)
690
+ indices_1 = multi_vals.indices
691
+ indices = tf.gather(indices, indices_1[:, 0])
692
+ out_indices = tf.concat([indices, indices_1[:, 1:]], axis=1)
693
+ # 3 dimensional sparse tensor
694
+ out_shape = tf.concat(
695
+ [parsed_dict[feature_name].dense_shape, multi_vals.dense_shape[1:]],
696
+ axis=0)
697
+ parsed_dict[feature_name] = tf.sparse.SparseTensor(
698
+ out_indices, multi_vals.values, out_shape)
699
+ if (fc.num_buckets > 1 and fc.max_val == fc.min_val):
700
+ check_list = [
701
+ tf.py_func(
702
+ check_string_to_number,
703
+ [parsed_dict[feature_name].values, input_0],
704
+ Tout=tf.bool)
705
+ ] if self._check_mode else []
706
+ with tf.control_dependencies(check_list):
707
+ parsed_dict[feature_name] = tf.sparse.SparseTensor(
708
+ parsed_dict[feature_name].indices,
709
+ tf.string_to_number(
710
+ parsed_dict[feature_name].values,
711
+ tf.int64,
712
+ name='sequence_str_2_int_%s' % input_0),
713
+ parsed_dict[feature_name].dense_shape)
714
+ elif sub_feature_type == fc.RawFeature:
715
+ check_list = [
716
+ tf.py_func(
717
+ check_string_to_number,
718
+ [parsed_dict[feature_name].values, input_0],
719
+ Tout=tf.bool)
720
+ ] if self._check_mode else []
721
+ with tf.control_dependencies(check_list):
722
+ parsed_dict[feature_name] = tf.sparse.SparseTensor(
723
+ parsed_dict[feature_name].indices,
724
+ tf.string_to_number(
725
+ parsed_dict[feature_name].values,
726
+ tf.float32,
727
+ name='sequence_str_2_float_%s' % input_0),
728
+ parsed_dict[feature_name].dense_shape)
729
+ if fc.num_buckets > 1 and fc.max_val > fc.min_val:
730
+ normalized_values = (parsed_dict[feature_name].values - fc.min_val) / (
731
+ fc.max_val - fc.min_val)
732
+ parsed_dict[feature_name] = tf.sparse.SparseTensor(
733
+ parsed_dict[feature_name].indices, normalized_values,
734
+ parsed_dict[feature_name].dense_shape)
735
+ else:
736
+ parsed_dict[feature_name] = field
737
+ if not fc.boundaries and fc.num_buckets <= 1 and\
738
+ self._data_config.sample_weight != input_0 and\
739
+ sub_feature_type == fc.RawFeature and\
740
+ fc.raw_input_dim == 1:
741
+ logging.info(
742
+ 'Not set boundaries or num_buckets or hash_bucket_size, %s will process as two dimension sequence raw feature'
743
+ % feature_name)
744
+ parsed_dict[feature_name] = tf.sparse_to_dense(
745
+ parsed_dict[feature_name].indices,
746
+ [tf.shape(parsed_dict[feature_name])[0], fc.sequence_length],
747
+ parsed_dict[feature_name].values)
748
+ sample_num = tf.to_int64(tf.shape(parsed_dict[feature_name])[0])
749
+ indices_0 = tf.range(sample_num, dtype=tf.int64)
750
+ indices_1 = tf.range(fc.sequence_length, dtype=tf.int64)
751
+ indices_0 = indices_0[:, None]
752
+ indices_1 = indices_1[None, :]
753
+ indices_0 = tf.tile(indices_0, [1, fc.sequence_length])
754
+ indices_1 = tf.tile(indices_1, [sample_num, 1])
755
+ indices_0 = tf.reshape(indices_0, [-1, 1])
756
+ indices_1 = tf.reshape(indices_1, [-1, 1])
757
+ indices = tf.concat([indices_0, indices_1], axis=1)
758
+ tmp_parsed = parsed_dict[feature_name]
759
+ parsed_dict[feature_name + '_raw_proj_id'] = tf.SparseTensor(
760
+ indices=indices,
761
+ values=indices_1[:, 0],
762
+ dense_shape=[sample_num, fc.sequence_length])
763
+ parsed_dict[feature_name + '_raw_proj_val'] = tf.SparseTensor(
764
+ indices=indices,
765
+ values=tf.reshape(tmp_parsed, [-1]),
766
+ dense_shape=[sample_num, fc.sequence_length])
767
+ elif (not fc.boundaries and fc.num_buckets <= 1 and
768
+ self._data_config.sample_weight != input_0 and
769
+ sub_feature_type == fc.RawFeature and fc.raw_input_dim > 1):
770
+ # for 3 dimension sequence feature input.
771
+ logging.info('Not set boundaries or num_buckets or hash_bucket_size,'
772
+ ' %s will process as three dimension sequence raw feature' %
773
+ feature_name)
774
+ parsed_dict[feature_name] = tf.sparse_to_dense(
775
+ parsed_dict[feature_name].indices, [
776
+ tf.shape(parsed_dict[feature_name])[0], fc.sequence_length,
777
+ fc.raw_input_dim
778
+ ], parsed_dict[feature_name].values)
779
+ sample_num = tf.to_int64(tf.shape(parsed_dict[feature_name])[0])
780
+ indices_0 = tf.range(sample_num, dtype=tf.int64)
781
+ indices_1 = tf.range(fc.sequence_length, dtype=tf.int64)
782
+ indices_2 = tf.range(fc.raw_input_dim, dtype=tf.int64)
783
+ indices_0 = indices_0[:, None, None]
784
+ indices_1 = indices_1[None, :, None]
785
+ indices_2 = indices_2[None, None, :]
786
+ indices_0 = tf.tile(indices_0, [1, fc.sequence_length, fc.raw_input_dim])
787
+ indices_1 = tf.tile(indices_1, [sample_num, 1, fc.raw_input_dim])
788
+ indices_2 = tf.tile(indices_2, [sample_num, fc.sequence_length, 1])
789
+ indices_0 = tf.reshape(indices_0, [-1, 1])
790
+ indices_1 = tf.reshape(indices_1, [-1, 1])
791
+ indices_2 = tf.reshape(indices_2, [-1, 1])
792
+ indices = tf.concat([indices_0, indices_1, indices_2], axis=1)
793
+
794
+ tmp_parsed = parsed_dict[feature_name]
795
+ parsed_dict[feature_name + '_raw_proj_id'] = tf.SparseTensor(
796
+ indices=indices,
797
+ values=indices_1[:, 0],
798
+ dense_shape=[sample_num, fc.sequence_length, fc.raw_input_dim])
799
+ parsed_dict[feature_name + '_raw_proj_val'] = tf.SparseTensor(
800
+ indices=indices,
801
+ values=tf.reshape(parsed_dict[feature_name], [-1]),
802
+ dense_shape=[sample_num, fc.sequence_length, fc.raw_input_dim])
803
+ # self._appended_fields.append(input_0 + '_raw_proj_id')
804
+ # self._appended_fields.append(input_0 + '_raw_proj_val')
805
+
806
+ def _preprocess(self, field_dict):
807
+ """Preprocess the feature columns.
808
+
809
+ preprocess some feature columns, such as TagFeature or LookupFeature,
810
+ it is expected to handle batch inputs and single input,
811
+ it could be customized in subclasses
812
+
813
+ Args:
814
+ field_dict: string to tensor, tensors are dense,
815
+ could be of shape [batch_size], [batch_size, None], or of shape []
816
+
817
+ Returns:
818
+ output_dict: some of the tensors are transformed into sparse tensors,
819
+ such as input tensors of tag features and lookup features
820
+ """
821
+ parsed_dict = {}
822
+
823
+ if self._sampler is not None and self._mode != tf.estimator.ModeKeys.PREDICT:
824
+ if self._mode != tf.estimator.ModeKeys.TRAIN:
825
+ self._sampler.set_eval_num_sample()
826
+ sampler_type = self._data_config.WhichOneof('sampler')
827
+ sampler_config = getattr(self._data_config, sampler_type)
828
+ item_ids = field_dict[sampler_config.item_id_field]
829
+ if sampler_type in ['negative_sampler', 'negative_sampler_in_memory']:
830
+ sampled = self._sampler.get(item_ids)
831
+ elif sampler_type == 'negative_sampler_v2':
832
+ user_ids = field_dict[sampler_config.user_id_field]
833
+ sampled = self._sampler.get(user_ids, item_ids)
834
+ elif sampler_type.startswith('hard_negative_sampler'):
835
+ user_ids = field_dict[sampler_config.user_id_field]
836
+ sampled = self._sampler.get(user_ids, item_ids)
837
+ else:
838
+ raise ValueError('Unknown sampler %s' % sampler_type)
839
+ for k, v in sampled.items():
840
+ if k in field_dict:
841
+ field_dict[k] = tf.concat([field_dict[k], v], axis=0)
842
+ else:
843
+ print('appended fields: %s' % k)
844
+ parsed_dict[k] = v
845
+ self._appended_fields.append(k)
846
+
847
+ for fc in self._feature_configs:
848
+ feature_name = fc.feature_name
849
+ feature_type = fc.feature_type
850
+ if feature_type == fc.TagFeature:
851
+ self._parse_tag_feature(fc, parsed_dict, field_dict)
852
+ elif feature_type == fc.LookupFeature:
853
+ assert feature_name is not None and feature_name != ''
854
+ assert len(fc.input_names) == 2
855
+ parsed_dict[feature_name] = self._lookup_preprocess(fc, field_dict)
856
+ elif feature_type == fc.SequenceFeature:
857
+ self._parse_seq_feature(fc, parsed_dict, field_dict)
858
+ elif feature_type == fc.RawFeature:
859
+ self._parse_raw_feature(fc, parsed_dict, field_dict)
860
+ elif feature_type == fc.IdFeature:
861
+ self._parse_id_feature(fc, parsed_dict, field_dict)
862
+ elif feature_type == fc.ExprFeature:
863
+ self._parse_expr_feature(fc, parsed_dict, field_dict)
864
+ elif feature_type == fc.ComboFeature:
865
+ self._parse_combo_feature(fc, parsed_dict, field_dict)
866
+ else:
867
+ feature_name = fc.feature_name if fc.HasField(
868
+ 'feature_name') else fc.input_names[0]
869
+ for input_id, input_name in enumerate(fc.input_names):
870
+ if input_id > 0:
871
+ key = feature_name + '_' + str(input_id)
872
+ else:
873
+ key = feature_name
874
+ parsed_dict[key] = field_dict[input_name]
875
+
876
+ label_dict = {}
877
+ for input_id, input_name in enumerate(self._label_fields):
878
+ if input_name not in field_dict:
879
+ continue
880
+ if input_name in self._label_udf_map:
881
+ udf, udf_class, dtype = self._label_udf_map[input_name]
882
+ if dtype is None or dtype == '':
883
+ logging.info('apply tensorflow function transform: %s' % udf_class)
884
+ field_dict[input_name] = udf(field_dict[input_name])
885
+ else:
886
+ assert dtype is not None, 'must set user_define_fn_res_type'
887
+ logging.info('apply py_func transform: %s' % udf_class)
888
+ field_dict[input_name] = tf.py_func(
889
+ udf, [field_dict[input_name]], Tout=get_tf_type(dtype))
890
+ field_dict[input_name].set_shape(tf.TensorShape([None]))
891
+
892
+ if field_dict[input_name].dtype == tf.string:
893
+ if self._label_dim[input_id] > 1:
894
+ logging.info('will split labels[%d]=%s' % (input_id, input_name))
895
+ check_list = [
896
+ tf.py_func(
897
+ check_split, [
898
+ field_dict[input_name], self._label_sep[input_id],
899
+ self._label_dim[input_id], input_name
900
+ ],
901
+ Tout=tf.bool)
902
+ ] if self._check_mode else []
903
+ with tf.control_dependencies(check_list):
904
+ label_dict[input_name] = tf.string_split(
905
+ field_dict[input_name], self._label_sep[input_id]).values
906
+ label_dict[input_name] = tf.reshape(label_dict[input_name],
907
+ [-1, self._label_dim[input_id]])
908
+ else:
909
+ label_dict[input_name] = field_dict[input_name]
910
+ check_list = [
911
+ tf.py_func(
912
+ check_string_to_number, [label_dict[input_name], input_name],
913
+ Tout=tf.bool)
914
+ ] if self._check_mode else []
915
+ with tf.control_dependencies(check_list):
916
+ label_dict[input_name] = tf.string_to_number(
917
+ label_dict[input_name], tf.float32, name=input_name)
918
+ else:
919
+ assert field_dict[input_name].dtype in [
920
+ tf.float32, tf.double, tf.int32, tf.int64
921
+ ], 'invalid label dtype: %s' % str(field_dict[input_name].dtype)
922
+ label_dict[input_name] = field_dict[input_name]
923
+
924
+ if self._mode != tf.estimator.ModeKeys.PREDICT:
925
+ for func_config in self._data_config.extra_label_func:
926
+ lbl_name = func_config.label_name
927
+ func_name = func_config.label_func
928
+ logging.info('generating new label `%s` by transform: %s' %
929
+ (lbl_name, func_name))
930
+ lbl_fn = load_by_path(func_name)
931
+ label_dict[lbl_name] = lbl_fn(label_dict)
932
+
933
+ if self._data_config.HasField('sample_weight'):
934
+ parsed_dict[constant.SAMPLE_WEIGHT] = field_dict[
935
+ self._data_config.sample_weight]
936
+
937
+ if Input.DATA_OFFSET in field_dict:
938
+ parsed_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]
939
+ return {'feature': parsed_dict, 'label': label_dict}
940
+
941
+ def _lookup_preprocess(self, fc, field_dict):
942
+ """Preprocess function for lookup features.
943
+
944
+ Args:
945
+ fc: FeatureConfig
946
+ field_dict: input dict
947
+
948
+ Returns:
949
+ output_dict: add { feature_name:SparseTensor} with
950
+ other items similar as field_dict
951
+ """
952
+ max_sel_num = fc.lookup_max_sel_elem_num
953
+
954
+ def _lookup(args, pad=True):
955
+ one_key, one_map = args[0], args[1]
956
+ if len(one_map.get_shape()) == 0:
957
+ one_map = tf.expand_dims(one_map, axis=0)
958
+ kv_map = tf.string_split(one_map, fc.separator).values
959
+ kvs = tf.string_split(kv_map, fc.kv_separator)
960
+ kvs = tf.reshape(kvs.values, [-1, 2], name='kv_split_reshape')
961
+ keys, vals = kvs[:, 0], kvs[:, 1]
962
+ sel_ids = tf.where(tf.equal(keys, one_key))
963
+ sel_ids = tf.squeeze(sel_ids, axis=1)
964
+ sel_vals = tf.gather(vals, sel_ids)
965
+ if not pad:
966
+ return sel_vals
967
+ n = tf.shape(sel_vals)[0]
968
+ sel_vals = tf.pad(sel_vals, [[0, max_sel_num - n]])
969
+ len_msk = tf.sequence_mask(n, max_sel_num)
970
+ indices = tf.range(max_sel_num, dtype=tf.int64)
971
+ indices = indices * tf.to_int64(indices < tf.to_int64(n))
972
+ return sel_vals, len_msk, indices
973
+
974
+ key_field, map_field = fc.input_names[0], fc.input_names[1]
975
+ key_fields, map_fields = field_dict[key_field], field_dict[map_field]
976
+ if len(key_fields.get_shape()) == 0:
977
+ vals = _lookup((key_fields, map_fields), False)
978
+ n = tf.shape(vals)[0]
979
+ n = tf.to_int64(n)
980
+ indices_0 = tf.zeros([n], dtype=tf.int64)
981
+ indices_1 = tf.range(0, n, dtype=tf.int64)
982
+ indices = [
983
+ tf.expand_dims(indices_0, axis=1),
984
+ tf.expand_dims(indices_1, axis=1)
985
+ ]
986
+ indices = tf.concat(indices, axis=1)
987
+ return tf.sparse.SparseTensor(indices, vals, [1, n])
988
+
989
+ vals, masks, indices = tf.map_fn(
990
+ _lookup, [key_fields, map_fields], dtype=(tf.string, tf.bool, tf.int64))
991
+ batch_size = tf.to_int64(tf.shape(vals)[0])
992
+ vals = tf.boolean_mask(vals, masks)
993
+ indices_1 = tf.boolean_mask(indices, masks)
994
+ indices_0 = tf.range(0, batch_size, dtype=tf.int64)
995
+ indices_0 = tf.expand_dims(indices_0, axis=1)
996
+ indices_0 = indices_0 + tf.zeros([1, max_sel_num], dtype=tf.int64)
997
+ indices_0 = tf.boolean_mask(indices_0, masks)
998
+ indices = tf.concat(
999
+ [tf.expand_dims(indices_0, axis=1),
1000
+ tf.expand_dims(indices_1, axis=1)],
1001
+ axis=1)
1002
+ shapes = tf.stack([batch_size, tf.reduce_max(indices_1) + 1])
1003
+ return tf.sparse.SparseTensor(indices, vals, shapes)
1004
+
1005
+ @abstractmethod
1006
+ def _build(self, mode, params):
1007
+ raise NotImplementedError
1008
+
1009
+ def _pre_build(self, mode, params):
1010
+ pass
1011
+
1012
+ def restore(self, checkpoint_path):
1013
+ pass
1014
+
1015
+ def stop(self):
1016
+ pass
1017
+
1018
+ def _safe_shard(self, dataset):
1019
+ if self._data_config.chief_redundant:
1020
+ return dataset.shard(
1021
+ max(self._task_num - 1, 1), max(self._task_index - 1, 0))
1022
+ else:
1023
+ return dataset.shard(self._task_num, self._task_index)
1024
+
1025
+ def create_input(self, export_config=None):
1026
+
1027
+ def _input_fn(mode=None, params=None, config=None):
1028
+ """Build input_fn for estimator.
1029
+
1030
+ Args:
1031
+ mode: tf.estimator.ModeKeys.(TRAIN, EVAL, PREDICT)
1032
+ params: `dict` of hyper parameters, from Estimator
1033
+ config: tf.estimator.RunConfig instance
1034
+
1035
+ Return:
1036
+ if mode is not None, return:
1037
+ features: inputs to the model.
1038
+ labels: groundtruth
1039
+ else, return:
1040
+ tf.estimator.export.ServingInputReceiver instance
1041
+ """
1042
+ self._pre_build(mode, params)
1043
+ if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
1044
+ tf.estimator.ModeKeys.PREDICT):
1045
+ # build dataset from self._config.input_path
1046
+ self._mode = mode
1047
+ dataset = self._build(mode, params)
1048
+ return dataset
1049
+ elif mode is None: # serving_input_receiver_fn for export SavedModel
1050
+ place_on_cpu = os.getenv(constant.EmbeddingOnCPU)
1051
+ place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
1052
+ if export_config.multi_placeholder:
1053
+ with conditional(place_on_cpu, ops.device('/CPU:0')):
1054
+ inputs, features = self.create_multi_placeholders(export_config)
1055
+ return tf.estimator.export.ServingInputReceiver(features, inputs)
1056
+ else:
1057
+ with conditional(place_on_cpu, ops.device('/CPU:0')):
1058
+ inputs, features = self.create_placeholders(export_config)
1059
+ print('built feature placeholders. features: {}'.format(
1060
+ features.keys()))
1061
+ return tf.estimator.export.ServingInputReceiver(features, inputs)
1062
+
1063
+ _input_fn.input_creator = self
1064
+ return _input_fn