easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,664 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import collections
4
+ import logging
5
+ import sys
6
+
7
+ import tensorflow as tf
8
+ from tensorflow.python.ops import partitioned_variables
9
+ from tensorflow.python.platform import gfile
10
+
11
+ from easy_rec.python.builders import hyperparams_builder
12
+ from easy_rec.python.compat.feature_column import sequence_feature_column
13
+ from easy_rec.python.protos.feature_config_pb2 import FeatureConfig
14
+ from easy_rec.python.protos.feature_config_pb2 import WideOrDeep
15
+ from easy_rec.python.utils.proto_util import copy_obj
16
+
17
+ from easy_rec.python.compat.feature_column import feature_column_v2 as feature_column # NOQA
18
+
19
+ MAX_HASH_BUCKET_SIZE = 9223372036854775807
20
+
21
+
22
+ class FeatureKeyError(KeyError):
23
+
24
+ def __init__(self, feature_name):
25
+ super(FeatureKeyError, self).__init__(feature_name)
26
+
27
+
28
+ class SharedEmbedding(object):
29
+
30
+ def __init__(self, embedding_name, index, sequence_combiner=None):
31
+ self.embedding_name = embedding_name
32
+ self.index = index
33
+ self.sequence_combiner = sequence_combiner
34
+
35
+
36
+ EVParams = collections.namedtuple('EVParams', [
37
+ 'filter_freq', 'steps_to_live', 'use_cache', 'init_capacity', 'max_capacity'
38
+ ])
39
+
40
+
41
+ class FeatureColumnParser(object):
42
+ """Parse and generate feature columns."""
43
+
44
+ def __init__(self,
45
+ feature_configs,
46
+ wide_deep_dict={},
47
+ wide_output_dim=-1,
48
+ ev_params=None):
49
+ """Initializes a `FeatureColumnParser`.
50
+
51
+ Args:
52
+ feature_configs: collections of
53
+ easy_rec.python.protos.feature_config_pb2.FeatureConfig
54
+ or easy_rec.python.protos.feature_config_pb2.FeatureConfigV2.features
55
+ wide_deep_dict: dict of {feature_name:WideOrDeep}, passed by
56
+ easy_rec.python.layers.input_layer.InputLayer, it is defined in
57
+ easy_rec.python.protos.easy_rec_model_pb2.EasyRecModel.feature_groups
58
+ wide_output_dim: output dimension for wide columns
59
+ ev_params: params used by EmbeddingVariable, which is provided by pai-tf
60
+ """
61
+ self._feature_configs = feature_configs
62
+ self._wide_output_dim = wide_output_dim
63
+ self._wide_deep_dict = wide_deep_dict
64
+ self._deep_columns = {}
65
+ self._wide_columns = {}
66
+ self._sequence_columns = {}
67
+
68
+ self._share_embed_names = {}
69
+ self._share_embed_infos = {}
70
+
71
+ self._vocab_size = {}
72
+
73
+ self._global_ev_params = None
74
+ if ev_params is not None:
75
+ self._global_ev_params = self._build_ev_params(ev_params)
76
+
77
+ def _cmp_embed_config(a, b):
78
+ return a.embedding_dim == b.embedding_dim and a.combiner == b.combiner and\
79
+ a.initializer == b.initializer and a.max_partitions == b.max_partitions and\
80
+ a.embedding_name == b.embedding_name
81
+
82
+ for config in self._feature_configs:
83
+ if not config.HasField('embedding_name'):
84
+ continue
85
+ embed_name = config.embedding_name
86
+
87
+ if embed_name in self._share_embed_names:
88
+ assert _cmp_embed_config(config, self._share_embed_infos[embed_name]),\
89
+ 'shared embed info of [%s] is not matched [%s] vs [%s]' % (
90
+ embed_name, config, self._share_embed_infos[embed_name])
91
+ self._share_embed_names[embed_name] += 1
92
+ if config.feature_type == FeatureConfig.FeatureType.SequenceFeature:
93
+ self._share_embed_infos[embed_name] = copy_obj(config)
94
+ else:
95
+ self._share_embed_names[embed_name] = 1
96
+ self._share_embed_infos[embed_name] = copy_obj(config)
97
+
98
+ # remove not shared embedding names
99
+ not_shared = [
100
+ x for x in self._share_embed_names if self._share_embed_names[x] == 1
101
+ ]
102
+ for embed_name in not_shared:
103
+ del self._share_embed_names[embed_name]
104
+ del self._share_embed_infos[embed_name]
105
+
106
+ logging.info('shared embeddings[num=%d]' % len(self._share_embed_names))
107
+ # for embed_name in self._share_embed_names:
108
+ # logging.info('\t%s: share_num[%d], share_info[%s]' %
109
+ # (embed_name, self._share_embed_names[embed_name],
110
+ # self._share_embed_infos[embed_name]))
111
+ self._deep_share_embed_columns = {
112
+ embed_name: [] for embed_name in self._share_embed_names
113
+ }
114
+ self._wide_share_embed_columns = {
115
+ embed_name: [] for embed_name in self._share_embed_names
116
+ }
117
+
118
+ self._feature_vocab_size = {}
119
+ for config in self._feature_configs:
120
+ assert isinstance(config, FeatureConfig)
121
+ try:
122
+ if config.feature_type == config.IdFeature:
123
+ self.parse_id_feature(config)
124
+ elif config.feature_type == config.TagFeature:
125
+ self.parse_tag_feature(config)
126
+ elif config.feature_type == config.RawFeature:
127
+ self.parse_raw_feature(config)
128
+ elif config.feature_type == config.ComboFeature:
129
+ self.parse_combo_feature(config)
130
+ elif config.feature_type == config.LookupFeature:
131
+ self.parse_lookup_feature(config)
132
+ elif config.feature_type == config.SequenceFeature:
133
+ self.parse_sequence_feature(config)
134
+ elif config.feature_type == config.ExprFeature:
135
+ self.parse_expr_feature(config)
136
+ elif config.feature_type != config.PassThroughFeature:
137
+ assert False, 'invalid feature type: %s' % config.feature_type
138
+ except FeatureKeyError:
139
+ pass
140
+
141
+ for embed_name in self._share_embed_names:
142
+ initializer = None
143
+ if self._share_embed_infos[embed_name].HasField('initializer'):
144
+ initializer = hyperparams_builder.build_initializer(
145
+ self._share_embed_infos[embed_name].initializer)
146
+
147
+ partitioner = self._build_partitioner(self._share_embed_infos[embed_name])
148
+
149
+ if self._share_embed_infos[embed_name].HasField('ev_params'):
150
+ ev_params = self._build_ev_params(
151
+ self._share_embed_infos[embed_name].ev_params)
152
+ else:
153
+ ev_params = self._global_ev_params
154
+
155
+ # for handling share embedding columns
156
+ if len(self._deep_share_embed_columns[embed_name]) > 0:
157
+ share_embed_fcs = feature_column.shared_embedding_columns(
158
+ self._deep_share_embed_columns[embed_name],
159
+ self._share_embed_infos[embed_name].embedding_dim,
160
+ initializer=initializer,
161
+ shared_embedding_collection_name=embed_name,
162
+ combiner=self._share_embed_infos[embed_name].combiner,
163
+ partitioner=partitioner,
164
+ ev_params=ev_params)
165
+ config = self._share_embed_infos[embed_name]
166
+ max_seq_len = config.max_seq_len if config.HasField(
167
+ 'max_seq_len') else -1
168
+ for fc in share_embed_fcs:
169
+ fc.max_seq_length = max_seq_len
170
+ self._deep_share_embed_columns[embed_name] = share_embed_fcs
171
+
172
+ # for handling wide share embedding columns
173
+ if len(self._wide_share_embed_columns[embed_name]) > 0:
174
+ share_embed_fcs = feature_column.shared_embedding_columns(
175
+ self._wide_share_embed_columns[embed_name],
176
+ self._wide_output_dim,
177
+ initializer=initializer,
178
+ shared_embedding_collection_name=embed_name + '_wide',
179
+ combiner='sum',
180
+ partitioner=partitioner,
181
+ ev_params=ev_params)
182
+ config = self._share_embed_infos[embed_name]
183
+ max_seq_len = config.max_seq_len if config.HasField(
184
+ 'max_seq_len') else -1
185
+ for fc in share_embed_fcs:
186
+ fc.max_seq_length = max_seq_len
187
+ self._wide_share_embed_columns[embed_name] = share_embed_fcs
188
+
189
+ for fc_name in self._deep_columns:
190
+ fc = self._deep_columns[fc_name]
191
+ if isinstance(fc, SharedEmbedding):
192
+ self._deep_columns[fc_name] = self._get_shared_embedding_column(fc)
193
+
194
+ for fc_name in self._wide_columns:
195
+ fc = self._wide_columns[fc_name]
196
+ if isinstance(fc, SharedEmbedding):
197
+ self._wide_columns[fc_name] = self._get_shared_embedding_column(
198
+ fc, deep=False)
199
+
200
+ for fc_name in self._sequence_columns:
201
+ fc = self._sequence_columns[fc_name]
202
+ if isinstance(fc, SharedEmbedding):
203
+ self._sequence_columns[fc_name] = self._get_shared_embedding_column(fc)
204
+
205
+ @property
206
+ def wide_columns(self):
207
+ return self._wide_columns
208
+
209
+ @property
210
+ def deep_columns(self):
211
+ return self._deep_columns
212
+
213
+ @property
214
+ def sequence_columns(self):
215
+ return self._sequence_columns
216
+
217
+ def is_wide(self, config):
218
+ if config.HasField('feature_name'):
219
+ feature_name = config.feature_name
220
+ else:
221
+ feature_name = config.input_names[0]
222
+ if feature_name not in self._wide_deep_dict:
223
+ raise FeatureKeyError(feature_name)
224
+ return self._wide_deep_dict[feature_name] in [
225
+ WideOrDeep.WIDE, WideOrDeep.WIDE_AND_DEEP
226
+ ]
227
+
228
+ def is_deep(self, config):
229
+ if config.HasField('feature_name'):
230
+ feature_name = config.feature_name
231
+ else:
232
+ feature_name = config.input_names[0]
233
+ # DEEP or WIDE_AND_DEEP
234
+ if feature_name not in self._wide_deep_dict:
235
+ raise FeatureKeyError(feature_name)
236
+ return self._wide_deep_dict[feature_name] in [
237
+ WideOrDeep.DEEP, WideOrDeep.WIDE_AND_DEEP
238
+ ]
239
+
240
+ def get_feature_vocab_size(self, feature):
241
+ return self._feature_vocab_size.get(feature, 1)
242
+
243
+ def _get_vocab_size(self, vocab_path):
244
+ if vocab_path in self._vocab_size:
245
+ return self._vocab_size[vocab_path]
246
+ with gfile.GFile(vocab_path, 'r') as fin:
247
+ vocabulary_size = sum(1 for _ in fin)
248
+ self._vocab_size[vocab_path] = vocabulary_size
249
+ return vocabulary_size
250
+
251
+ def _get_hash_bucket_size(self, config):
252
+ if not config.HasField('hash_bucket_size'):
253
+ return -1
254
+ if self._global_ev_params is not None or config.HasField('ev_params'):
255
+ return MAX_HASH_BUCKET_SIZE
256
+ else:
257
+ return config.hash_bucket_size
258
+
259
+ def parse_id_feature(self, config):
260
+ """Generate id feature columns.
261
+
262
+ if hash_bucket_size or vocab_list or vocab_file is set,
263
+ then will accept input tensor of string type, otherwise will accept input
264
+ tensor of integer type.
265
+
266
+ Args:
267
+ config: instance of easy_rec.python.protos.feature_config_pb2.FeatureConfig
268
+ """
269
+ feature_name = config.feature_name if config.HasField('feature_name') \
270
+ else config.input_names[0]
271
+ hash_bucket_size = self._get_hash_bucket_size(config)
272
+ if hash_bucket_size > 0:
273
+ fc = feature_column.categorical_column_with_hash_bucket(
274
+ feature_name,
275
+ hash_bucket_size=hash_bucket_size,
276
+ feature_name=feature_name)
277
+ elif config.vocab_list:
278
+ fc = feature_column.categorical_column_with_vocabulary_list(
279
+ feature_name,
280
+ default_value=0,
281
+ vocabulary_list=config.vocab_list,
282
+ feature_name=feature_name)
283
+ elif config.vocab_file:
284
+ fc = feature_column.categorical_column_with_vocabulary_file(
285
+ feature_name,
286
+ default_value=0,
287
+ vocabulary_file=config.vocab_file,
288
+ vocabulary_size=self._get_vocab_size(config.vocab_file),
289
+ feature_name=feature_name)
290
+ else:
291
+ use_ev = self._global_ev_params or config.HasField('ev_params')
292
+ num_buckets = sys.maxsize if use_ev else config.num_buckets
293
+ fc = feature_column.categorical_column_with_identity(
294
+ feature_name, num_buckets, default_value=0, feature_name=feature_name)
295
+
296
+ if self.is_wide(config):
297
+ self._add_wide_embedding_column(fc, config)
298
+ if self.is_deep(config):
299
+ self._add_deep_embedding_column(fc, config)
300
+
301
+ def parse_tag_feature(self, config):
302
+ """Generate tag feature columns.
303
+
304
+ if hash_bucket_size is set, will accept input of SparseTensor of string,
305
+ otherwise num_buckets must be set, will accept input of SparseTensor of integer.
306
+ tag feature preprocess is done in easy_rec/python/input/input.py: Input. _preprocess
307
+
308
+ Args:
309
+ config: instance of easy_rec.python.protos.feature_config_pb2.FeatureConfig
310
+ """
311
+ feature_name = config.feature_name if config.HasField('feature_name') \
312
+ else config.input_names[0]
313
+ hash_bucket_size = self._get_hash_bucket_size(config)
314
+ if hash_bucket_size > 0:
315
+ tag_fc = feature_column.categorical_column_with_hash_bucket(
316
+ feature_name,
317
+ hash_bucket_size,
318
+ dtype=tf.string,
319
+ feature_name=feature_name)
320
+ elif config.vocab_list:
321
+ tag_fc = feature_column.categorical_column_with_vocabulary_list(
322
+ feature_name,
323
+ default_value=0,
324
+ vocabulary_list=config.vocab_list,
325
+ feature_name=feature_name)
326
+ elif config.vocab_file:
327
+ tag_fc = feature_column.categorical_column_with_vocabulary_file(
328
+ feature_name,
329
+ default_value=0,
330
+ vocabulary_file=config.vocab_file,
331
+ vocabulary_size=self._get_vocab_size(config.vocab_file),
332
+ feature_name=feature_name)
333
+ else:
334
+ use_ev = self._global_ev_params or config.HasField('ev_params')
335
+ num_buckets = sys.maxsize if use_ev else config.num_buckets
336
+ tag_fc = feature_column.categorical_column_with_identity(
337
+ feature_name, num_buckets, default_value=0, feature_name=feature_name)
338
+
339
+ if len(config.input_names) > 1:
340
+ tag_fc = feature_column.weighted_categorical_column(
341
+ tag_fc, weight_feature_key=feature_name + '_w', dtype=tf.float32)
342
+ elif config.HasField('kv_separator'):
343
+ tag_fc = feature_column.weighted_categorical_column(
344
+ tag_fc, weight_feature_key=feature_name + '_w', dtype=tf.float32)
345
+
346
+ if self.is_wide(config):
347
+ self._add_wide_embedding_column(tag_fc, config)
348
+ if self.is_deep(config):
349
+ self._add_deep_embedding_column(tag_fc, config)
350
+
351
+ def parse_raw_feature(self, config):
352
+ """Generate raw features columns.
353
+
354
+ if boundaries is set, will be converted to category_column first.
355
+
356
+ Args:
357
+ config: instance of easy_rec.python.protos.feature_config_pb2.FeatureConfig
358
+ """
359
+ feature_name = config.feature_name if config.HasField('feature_name') \
360
+ else config.input_names[0]
361
+ fc = feature_column.numeric_column(
362
+ key=feature_name,
363
+ shape=(config.raw_input_dim,),
364
+ feature_name=feature_name)
365
+
366
+ bounds = None
367
+ if config.boundaries:
368
+ bounds = list(config.boundaries)
369
+ bounds.sort()
370
+ elif config.num_buckets > 1 and config.max_val > config.min_val:
371
+ # the feature values are already normalized into [0, 1]
372
+ bounds = [
373
+ x / float(config.num_buckets) for x in range(0, config.num_buckets)
374
+ ]
375
+ logging.info('discrete %s into %d buckets' %
376
+ (feature_name, config.num_buckets))
377
+
378
+ if bounds:
379
+ try:
380
+ fc = feature_column.bucketized_column(fc, bounds)
381
+ except Exception as e:
382
+ logging.error('bucketized_column [%s] with bounds %s error' %
383
+ (fc.name, str(bounds)))
384
+ raise e
385
+ if self.is_wide(config):
386
+ self._add_wide_embedding_column(fc, config)
387
+ if self.is_deep(config):
388
+ self._add_deep_embedding_column(fc, config)
389
+ else:
390
+ tmp_id_col = feature_column.categorical_column_with_identity(
391
+ feature_name + '_raw_proj_id',
392
+ config.raw_input_dim,
393
+ default_value=0,
394
+ feature_name=feature_name)
395
+ wgt_fc = feature_column.weighted_categorical_column(
396
+ tmp_id_col,
397
+ weight_feature_key=feature_name + '_raw_proj_val',
398
+ dtype=tf.float32)
399
+ if self.is_wide(config):
400
+ self._add_wide_embedding_column(wgt_fc, config)
401
+ if self.is_deep(config):
402
+ if config.embedding_dim > 0:
403
+ self._add_deep_embedding_column(wgt_fc, config)
404
+ else:
405
+ self._deep_columns[feature_name] = fc
406
+
407
+ def parse_expr_feature(self, config):
408
+ """Generate raw features columns.
409
+
410
+ if boundaries is set, will be converted to category_column first.
411
+
412
+ Args:
413
+ config: instance of easy_rec.python.protos.feature_config_pb2.FeatureConfig
414
+ """
415
+ feature_name = config.feature_name if config.HasField('feature_name') \
416
+ else config.input_names[0]
417
+ fc = feature_column.numeric_column(
418
+ feature_name, shape=(1,), feature_name=feature_name)
419
+ if self.is_wide(config):
420
+ self._add_wide_embedding_column(fc, config)
421
+ if self.is_deep(config):
422
+ self._deep_columns[feature_name] = fc
423
+
424
+ def parse_combo_feature(self, config):
425
+ """Generate combo feature columns.
426
+
427
+ Args:
428
+ config: instance of easy_rec.python.protos.feature_config_pb2.FeatureConfig
429
+ """
430
+ feature_name = config.feature_name if config.HasField('feature_name') \
431
+ else None
432
+ assert len(config.input_names) >= 2
433
+
434
+ if len(config.combo_join_sep) == 0:
435
+ input_names = []
436
+ for input_id in range(len(config.input_names)):
437
+ if input_id == 0:
438
+ input_names.append(feature_name)
439
+ else:
440
+ input_names.append(feature_name + '_' + str(input_id))
441
+ fc = feature_column.crossed_column(
442
+ input_names,
443
+ self._get_hash_bucket_size(config),
444
+ hash_key=None,
445
+ feature_name=feature_name)
446
+ else:
447
+ fc = feature_column.categorical_column_with_hash_bucket(
448
+ feature_name,
449
+ hash_bucket_size=self._get_hash_bucket_size(config),
450
+ feature_name=feature_name)
451
+
452
+ if self.is_wide(config):
453
+ self._add_wide_embedding_column(fc, config)
454
+ if self.is_deep(config):
455
+ self._add_deep_embedding_column(fc, config)
456
+
457
+ def parse_lookup_feature(self, config):
458
+ """Generate lookup feature columns.
459
+
460
+ Args:
461
+ config: instance of easy_rec.python.protos.feature_config_pb2.FeatureConfig
462
+ """
463
+ feature_name = config.feature_name if config.HasField('feature_name') \
464
+ else config.input_names[0]
465
+ assert config.HasField('hash_bucket_size')
466
+ hash_bucket_size = self._get_hash_bucket_size(config)
467
+ fc = feature_column.categorical_column_with_hash_bucket(
468
+ feature_name,
469
+ hash_bucket_size,
470
+ dtype=tf.string,
471
+ feature_name=feature_name)
472
+
473
+ if self.is_wide(config):
474
+ self._add_wide_embedding_column(fc, config)
475
+ if self.is_deep(config):
476
+ self._add_deep_embedding_column(fc, config)
477
+
478
+ def parse_sequence_feature(self, config):
479
+ """Generate sequence feature columns.
480
+
481
+ Args:
482
+ config: instance of easy_rec.python.protos.feature_config_pb2.FeatureConfig
483
+ """
484
+ feature_name = config.feature_name if config.HasField('feature_name') \
485
+ else config.input_names[0]
486
+ sub_feature_type = config.sub_feature_type
487
+ assert sub_feature_type in [config.IdFeature, config.RawFeature], \
488
+ 'Current sub_feature_type only support IdFeature and RawFeature.'
489
+ if sub_feature_type == config.IdFeature:
490
+ if config.HasField('hash_bucket_size'):
491
+ hash_bucket_size = self._get_hash_bucket_size(config)
492
+ fc = sequence_feature_column.sequence_categorical_column_with_hash_bucket(
493
+ feature_name,
494
+ hash_bucket_size,
495
+ dtype=tf.string,
496
+ feature_name=feature_name)
497
+ elif config.vocab_list:
498
+ fc = sequence_feature_column.sequence_categorical_column_with_vocabulary_list(
499
+ feature_name,
500
+ default_value=0,
501
+ vocabulary_list=config.vocab_list,
502
+ feature_name=feature_name)
503
+ elif config.vocab_file:
504
+ fc = sequence_feature_column.sequence_categorical_column_with_vocabulary_file(
505
+ feature_name,
506
+ default_value=0,
507
+ vocabulary_file=config.vocab_file,
508
+ vocabulary_size=self._get_vocab_size(config.vocab_file),
509
+ feature_name=feature_name)
510
+ else:
511
+ use_ev = self._global_ev_params or config.HasField('ev_params')
512
+ num_buckets = sys.maxsize if use_ev else config.num_buckets
513
+ fc = sequence_feature_column.sequence_categorical_column_with_identity(
514
+ feature_name,
515
+ num_buckets,
516
+ default_value=0,
517
+ feature_name=feature_name)
518
+ else: # raw feature
519
+ bounds = None
520
+ fc = sequence_feature_column.sequence_numeric_column(
521
+ feature_name, shape=(1,), feature_name=feature_name)
522
+ if config.hash_bucket_size > 0:
523
+ hash_bucket_size = self._get_hash_bucket_size(config)
524
+ assert sub_feature_type == config.IdFeature, \
525
+ 'You should set sub_feature_type to IdFeature to use hash_bucket_size.'
526
+ elif config.boundaries:
527
+ bounds = list(config.boundaries)
528
+ bounds.sort()
529
+ elif config.num_buckets > 1 and config.max_val > config.min_val:
530
+ # the feature values are already normalized into [0, 1]
531
+ bounds = [
532
+ x / float(config.num_buckets) for x in range(0, config.num_buckets)
533
+ ]
534
+ logging.info('sequence feature discrete %s into %d buckets' %
535
+ (feature_name, config.num_buckets))
536
+ if bounds:
537
+ try:
538
+ fc = sequence_feature_column.sequence_numeric_column_with_bucketized_column(
539
+ fc, bounds)
540
+ except Exception as e:
541
+ logging.error(
542
+ 'sequence features bucketized_column [%s] with bounds %s error' %
543
+ (feature_name, str(bounds)))
544
+ raise e
545
+ elif config.hash_bucket_size <= 0:
546
+ if config.embedding_dim > 0:
547
+ tmp_id_col = sequence_feature_column.sequence_categorical_column_with_identity(
548
+ feature_name + '_raw_proj_id',
549
+ config.raw_input_dim,
550
+ default_value=0,
551
+ feature_name=feature_name)
552
+ wgt_fc = sequence_feature_column.sequence_weighted_categorical_column(
553
+ tmp_id_col,
554
+ weight_feature_key=feature_name + '_raw_proj_val',
555
+ dtype=tf.float32)
556
+ fc = wgt_fc
557
+ else:
558
+ fc = sequence_feature_column.sequence_numeric_column_with_raw_column(
559
+ fc, config.sequence_length)
560
+
561
+ if config.embedding_dim > 0:
562
+ self._add_deep_embedding_column(fc, config)
563
+ else:
564
+ self._sequence_columns[feature_name] = fc
565
+
566
+ def _build_partitioner(self, config):
567
+ if config.max_partitions > 1:
568
+ if self._global_ev_params is not None or config.HasField('ev_params'):
569
+ # pai embedding_variable should use fixed_size_partitioner
570
+ return partitioned_variables.fixed_size_partitioner(
571
+ num_shards=config.max_partitions)
572
+ else:
573
+ return partitioned_variables.min_max_variable_partitioner(
574
+ max_partitions=config.max_partitions)
575
+ else:
576
+ return None
577
+
578
+ def _add_shared_embedding_column(self, embedding_name, fc, deep=True):
579
+ if deep:
580
+ curr_id = len(self._deep_share_embed_columns[embedding_name])
581
+ self._deep_share_embed_columns[embedding_name].append(fc)
582
+ else:
583
+ curr_id = len(self._wide_share_embed_columns[embedding_name])
584
+ self._wide_share_embed_columns[embedding_name].append(fc)
585
+ return SharedEmbedding(embedding_name, curr_id, None)
586
+
587
+ def _get_shared_embedding_column(self, fc_handle, deep=True):
588
+ embed_name, embed_id = fc_handle.embedding_name, fc_handle.index
589
+ if deep:
590
+ tmp = self._deep_share_embed_columns[embed_name][embed_id]
591
+ else:
592
+ tmp = self._wide_share_embed_columns[embed_name][embed_id]
593
+ tmp.sequence_combiner = fc_handle.sequence_combiner
594
+ return tmp
595
+
596
+ def _add_wide_embedding_column(self, fc, config):
597
+ """Generate wide feature columns.
598
+
599
+ We use embedding to simulate wide column, which is more efficient than indicator column for
600
+ sparse features
601
+ """
602
+ feature_name = config.feature_name if config.HasField('feature_name') \
603
+ else config.input_names[0]
604
+ assert self._wide_output_dim > 0, 'wide_output_dim is not set'
605
+ if config.embedding_name in self._wide_share_embed_columns:
606
+ wide_fc = self._add_shared_embedding_column(
607
+ config.embedding_name, fc, deep=False)
608
+ else:
609
+ initializer = None
610
+ if config.HasField('initializer'):
611
+ initializer = hyperparams_builder.build_initializer(config.initializer)
612
+ if config.HasField('ev_params'):
613
+ ev_params = self._build_ev_params(config.ev_params)
614
+ else:
615
+ ev_params = self._global_ev_params
616
+ wide_fc = feature_column.embedding_column(
617
+ fc,
618
+ self._wide_output_dim,
619
+ combiner='sum',
620
+ initializer=initializer,
621
+ partitioner=self._build_partitioner(config),
622
+ ev_params=ev_params)
623
+ self._wide_columns[feature_name] = wide_fc
624
+
625
+ def _add_deep_embedding_column(self, fc, config):
626
+ """Generate deep feature columns."""
627
+ feature_name = config.feature_name if config.HasField('feature_name') \
628
+ else config.input_names[0]
629
+ assert config.embedding_dim > 0, 'embedding_dim is not set for %s' % feature_name
630
+ self._feature_vocab_size[feature_name] = fc.num_buckets
631
+ if config.embedding_name in self._deep_share_embed_columns:
632
+ fc = self._add_shared_embedding_column(config.embedding_name, fc)
633
+ else:
634
+ initializer = None
635
+ if config.HasField('initializer'):
636
+ initializer = hyperparams_builder.build_initializer(config.initializer)
637
+ if config.HasField('ev_params'):
638
+ ev_params = self._build_ev_params(config.ev_params)
639
+ else:
640
+ ev_params = self._global_ev_params
641
+ fc = feature_column.embedding_column(
642
+ fc,
643
+ config.embedding_dim,
644
+ combiner=config.combiner,
645
+ initializer=initializer,
646
+ partitioner=self._build_partitioner(config),
647
+ ev_params=ev_params)
648
+ fc.max_seq_length = config.max_seq_len if config.HasField(
649
+ 'max_seq_len') else -1
650
+
651
+ if config.feature_type != config.SequenceFeature:
652
+ self._deep_columns[feature_name] = fc
653
+ else:
654
+ if config.HasField('sequence_combiner'):
655
+ fc.sequence_combiner = config.sequence_combiner
656
+ self._sequence_columns[feature_name] = fc
657
+
658
+ def _build_ev_params(self, ev_params):
659
+ """Build embedding_variables params."""
660
+ ev_params = EVParams(
661
+ ev_params.filter_freq,
662
+ ev_params.steps_to_live if ev_params.steps_to_live > 0 else None,
663
+ ev_params.use_cache, ev_params.init_capacity, ev_params.max_capacity)
664
+ return ev_params