easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,467 @@
1
+ # -*- encoding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import logging
5
+ import os
6
+ import re
7
+ from abc import abstractmethod
8
+
9
+ import six
10
+ import tensorflow as tf
11
+ from tensorflow.python.framework import ops
12
+ from tensorflow.python.framework import tensor_shape
13
+ from tensorflow.python.ops import variables
14
+ from tensorflow.python.platform import gfile
15
+
16
+ from easy_rec.python.compat import regularizers
17
+ from easy_rec.python.layers import input_layer
18
+ from easy_rec.python.layers.backbone import Backbone
19
+ from easy_rec.python.utils import constant
20
+ from easy_rec.python.utils import estimator_utils
21
+ from easy_rec.python.utils import restore_filter
22
+ from easy_rec.python.utils.load_class import get_register_class_meta
23
+
24
+ try:
25
+ import horovod.tensorflow as hvd
26
+ from sparse_operation_kit.experiment import raw_ops as dynamic_variable_ops
27
+ from sparse_operation_kit import experiment as sok
28
+ except Exception:
29
+ dynamic_variable_ops = None
30
+ sok = None
31
+
32
+ try:
33
+ from tensorflow.python.framework.load_library import load_op_library
34
+ import easy_rec
35
+ load_embed_lib_path = os.path.join(easy_rec.ops_dir, 'libload_embed.so')
36
+ load_embed_lib = load_op_library(load_embed_lib_path)
37
+ except Exception as ex:
38
+ logging.warning('load libload_embed.so failed: %s' % str(ex))
39
+ load_embed_lib = None
40
+
41
+ if tf.__version__ >= '2.0':
42
+ tf = tf.compat.v1
43
+
44
+ _EASY_REC_MODEL_CLASS_MAP = {}
45
+ _meta_type = get_register_class_meta(
46
+ _EASY_REC_MODEL_CLASS_MAP, have_abstract_class=True)
47
+
48
+
49
+ class EasyRecModel(six.with_metaclass(_meta_type, object)):
50
+
51
+ def __init__(self,
52
+ model_config,
53
+ feature_configs,
54
+ features,
55
+ labels=None,
56
+ is_training=False):
57
+ self._base_model_config = model_config
58
+ self._model_config = model_config
59
+ self._is_training = is_training
60
+ self._is_predicting = labels is None
61
+ self._feature_dict = features
62
+
63
+ # embedding variable parameters
64
+ self._global_ev_params = None
65
+ if model_config.HasField('ev_params'):
66
+ self._global_ev_params = model_config.ev_params
67
+
68
+ if self.embedding_regularization > 0:
69
+ self._emb_reg = regularizers.l2_regularizer(self.embedding_regularization)
70
+ else:
71
+ self._emb_reg = None
72
+
73
+ if self.l2_regularization > 0:
74
+ self._l2_reg = regularizers.l2_regularizer(self.l2_regularization)
75
+ else:
76
+ self._l2_reg = None
77
+
78
+ # only used by model with wide feature groups, e.g. WideAndDeep
79
+ self._wide_output_dim = -1
80
+ if self.has_backbone:
81
+ wide_dim = Backbone.wide_embed_dim(model_config.backbone)
82
+ if wide_dim:
83
+ self._wide_output_dim = wide_dim
84
+ logging.info('set `wide_output_dim` to %d' % wide_dim)
85
+
86
+ self._feature_configs = feature_configs
87
+ self.build_input_layer(model_config, feature_configs)
88
+
89
+ self._labels = labels
90
+ self._prediction_dict = {}
91
+ self._loss_dict = {}
92
+ self._metric_dict = {}
93
+
94
+ # add sample weight from inputs
95
+ self._sample_weight = 1.0
96
+ if constant.SAMPLE_WEIGHT in features:
97
+ self._sample_weight = features[constant.SAMPLE_WEIGHT]
98
+
99
+ self._backbone_output = None
100
+ self._backbone_net = self.build_backbone_network()
101
+
102
+ def build_backbone_network(self):
103
+ if self.has_backbone:
104
+ return Backbone(
105
+ self._base_model_config.backbone,
106
+ self._feature_dict,
107
+ input_layer=self._input_layer,
108
+ l2_reg=self._l2_reg)
109
+ return None
110
+
111
+ @property
112
+ def has_backbone(self):
113
+ return self._base_model_config.HasField('backbone')
114
+
115
+ @property
116
+ def backbone(self):
117
+ if self._backbone_output:
118
+ return self._backbone_output
119
+ if self._backbone_net:
120
+ kwargs = {
121
+ 'loss_dict': self._loss_dict,
122
+ 'metric_dict': self._metric_dict,
123
+ 'prediction_dict': self._prediction_dict,
124
+ 'labels': self._labels,
125
+ constant.SAMPLE_WEIGHT: self._sample_weight
126
+ }
127
+ return self._backbone_net(self._is_training, **kwargs)
128
+ return None
129
+
130
+ @property
131
+ def embedding_regularization(self):
132
+ return self._base_model_config.embedding_regularization
133
+
134
+ @property
135
+ def kd(self):
136
+ return self._base_model_config.kd
137
+
138
+ @property
139
+ def feature_groups(self):
140
+ return self._base_model_config.feature_groups
141
+
142
+ @property
143
+ def l2_regularization(self):
144
+ model_config = getattr(self._base_model_config,
145
+ self._base_model_config.WhichOneof('model'))
146
+ l2_regularization = 0.0
147
+ if hasattr(model_config, 'dense_regularization') and \
148
+ model_config.HasField('dense_regularization'):
149
+ # backward compatibility
150
+ logging.warn(
151
+ 'dense_regularization is deprecated, please use l2_regularization')
152
+ l2_regularization = model_config.dense_regularization
153
+ elif hasattr(model_config, 'l2_regularization'):
154
+ l2_regularization = model_config.l2_regularization
155
+ return l2_regularization
156
+
157
+ def build_input_layer(self, model_config, feature_configs):
158
+ self._input_layer = input_layer.InputLayer(
159
+ feature_configs,
160
+ model_config.feature_groups,
161
+ wide_output_dim=self._wide_output_dim,
162
+ ev_params=self._global_ev_params,
163
+ embedding_regularizer=self._emb_reg,
164
+ kernel_regularizer=self._l2_reg,
165
+ variational_dropout_config=model_config.variational_dropout
166
+ if model_config.HasField('variational_dropout') else None,
167
+ is_training=self._is_training,
168
+ is_predicting=self._is_predicting)
169
+
170
+ @abstractmethod
171
+ def build_predict_graph(self):
172
+ pass
173
+
174
+ @abstractmethod
175
+ def build_loss_graph(self):
176
+ pass
177
+
178
+ def build_metric_graph(self, eval_config):
179
+ return self._metric_dict
180
+
181
+ @abstractmethod
182
+ def get_outputs(self):
183
+ pass
184
+
185
+ def build_output_dict(self):
186
+ """For exporting: get standard output nodes."""
187
+ outputs = {}
188
+ for name in self.get_outputs():
189
+ if name not in self._prediction_dict:
190
+ raise KeyError(
191
+ 'output node {} not in prediction_dict, can not be exported'.format(
192
+ name))
193
+ outputs[name] = self._prediction_dict[name]
194
+ return outputs
195
+
196
+ def build_feature_output_dict(self):
197
+ """For exporting: get output feature nodes."""
198
+ outputs = {}
199
+ for feature_name in self._feature_dict:
200
+ out_name = 'feature_' + feature_name
201
+ feature_value = self._feature_dict[feature_name]
202
+ if isinstance(feature_value, tf.SparseTensor):
203
+ sparse_values = feature_value.values
204
+ if sparse_values.dtype != tf.string:
205
+ sparse_values = tf.as_string(sparse_values)
206
+ feature_value = tf.sparse_to_dense(feature_value.indices,
207
+ feature_value.dense_shape,
208
+ sparse_values, '')
209
+ elif feature_value.dtype != tf.string:
210
+ feature_value = tf.as_string(feature_value)
211
+ feature_value = tf.reduce_join(feature_value, axis=-1, separator=',')
212
+ outputs[out_name] = feature_value
213
+ return outputs
214
+
215
+ def build_rtp_output_dict(self):
216
+ """For exporting: get output nodes for RTP infering."""
217
+ return {}
218
+
219
+ def restore(self,
220
+ ckpt_path,
221
+ include_global_step=False,
222
+ ckpt_var_map_path='',
223
+ force_restore_shape_compatible=False):
224
+ """Restore variables from ckpt_path.
225
+
226
+ steps:
227
+ 1. list the variables in graph that need to be restored
228
+ 2. inspect checkpoint and find the variables that could restore from checkpoint
229
+ substitute scope names in case necessary
230
+ 3. call tf.train.init_from_checkpoint to restore the variables
231
+
232
+ Args:
233
+ ckpt_path: checkpoint path to restore from
234
+ include_global_step: whether to restore global_step variable
235
+ ckpt_var_map_path: variable map from graph variables to variables in a checkpoint
236
+ each line consists of: variable name in graph variable name in ckpt
237
+ force_restore_shape_compatible: if variable shape is incompatible, clip or pad
238
+ variables in checkpoint, and then restore
239
+
240
+ Returns:
241
+ IncompatibleShapeRestoreHook if force_shape_compatible else None
242
+ """
243
+ name2var_map = self._get_restore_vars(ckpt_var_map_path)
244
+ logging.info('start to restore from %s' % ckpt_path)
245
+
246
+ ckpt_reader = tf.train.NewCheckpointReader(ckpt_path)
247
+ ckpt_var2shape_map = ckpt_reader.get_variable_to_shape_map()
248
+ if not include_global_step:
249
+ ckpt_var2shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
250
+
251
+ vars_in_ckpt = {}
252
+ incompatible_shape_var_map = {}
253
+ fail_restore_vars = []
254
+ for variable_name, variable in sorted(name2var_map.items()):
255
+ if variable_name in ckpt_var2shape_map:
256
+ print('restore %s' % variable_name)
257
+ ckpt_var_shape = ckpt_var2shape_map[variable_name]
258
+ if type(variable) == list:
259
+ shape_arr = [x.get_shape() for x in variable]
260
+ var_shape = list(shape_arr[0])
261
+ for x in shape_arr[1:]:
262
+ var_shape[0] += x[0]
263
+ var_shape = tensor_shape.TensorShape(var_shape)
264
+ variable = variables.PartitionedVariable(
265
+ variable_name,
266
+ var_shape,
267
+ variable[0].dtype,
268
+ variable,
269
+ partitions=[len(variable)] + [1] * (len(var_shape) - 1))
270
+ else:
271
+ var_shape = variable.shape.as_list()
272
+ if ckpt_var_shape == var_shape:
273
+ vars_in_ckpt[variable_name] = list(variable) if isinstance(
274
+ variable, variables.PartitionedVariable) else variable
275
+ elif len(ckpt_var_shape) == len(var_shape):
276
+ if force_restore_shape_compatible:
277
+ # create a variable compatible with checkpoint to restore
278
+ dtype = variable[0].dtype if isinstance(variable,
279
+ list) else variable.dtype
280
+ with tf.variable_scope('incompatible_shape_restore'):
281
+ tmp_var = tf.get_variable(
282
+ name=variable_name + '_T_E_M_P',
283
+ shape=ckpt_var_shape,
284
+ trainable=False,
285
+ # add to a special collection for easy reference
286
+ # by tf.get_collection('T_E_M_P_RESTROE')
287
+ collections=['T_E_M_P_RESTROE'],
288
+ dtype=dtype)
289
+ vars_in_ckpt[variable_name] = tmp_var
290
+ incompatible_shape_var_map[variable] = tmp_var
291
+ print('incompatible restore %s[%s, %s]' %
292
+ (variable_name, str(var_shape), str(ckpt_var_shape)))
293
+ else:
294
+ logging.warning(
295
+ 'Variable [%s] is available in checkpoint, but '
296
+ 'incompatible shape with model variable.', variable_name)
297
+ else:
298
+ logging.warning(
299
+ 'Variable [%s] is available in checkpoint, but '
300
+ 'incompatible shape dims with model variable.', variable_name)
301
+ elif 'EmbeddingVariable' in str(type(variable)):
302
+ if '%s-keys' % variable_name not in ckpt_var2shape_map:
303
+ continue
304
+ print('restore embedding_variable %s' % variable_name)
305
+ from tensorflow.python.training import saver
306
+ names_to_saveables = saver.BaseSaverBuilder.OpListToDict([variable])
307
+ saveable_objects = []
308
+ for name, op in names_to_saveables.items():
309
+ for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name):
310
+ saveable_objects.append(s)
311
+ init_op = saveable_objects[0].restore([ckpt_path], None)
312
+ variable._initializer_op = init_op
313
+ elif type(variable) == list and 'EmbeddingVariable' in str(
314
+ type(variable[0])):
315
+ if '%s/part_0-keys' % variable_name not in ckpt_var2shape_map:
316
+ continue
317
+ print('restore partitioned embedding_variable %s' % variable_name)
318
+ from tensorflow.python.training import saver
319
+ for part_var in variable:
320
+ names_to_saveables = saver.BaseSaverBuilder.OpListToDict([part_var])
321
+ saveable_objects = []
322
+ for name, op in names_to_saveables.items():
323
+ for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name):
324
+ saveable_objects.append(s)
325
+ init_op = saveable_objects[0].restore([ckpt_path], None)
326
+ part_var._initializer_op = init_op
327
+ elif sok is not None and isinstance(variable, sok.DynamicVariable):
328
+ print('restore dynamic_variable %s' % variable_name)
329
+ keys, vals = load_embed_lib.load_kv_embed(
330
+ task_index=hvd.rank(),
331
+ task_num=hvd.size(),
332
+ embed_dim=variable._dimension,
333
+ var_name='embed-' + variable.name.replace('/', '__'),
334
+ ckpt_path=ckpt_path)
335
+ with ops.control_dependencies([variable._initializer_op]):
336
+ variable._initializer_op = dynamic_variable_ops.dummy_var_assign(
337
+ variable.handle, keys, vals)
338
+ else:
339
+ fail_restore_vars.append(variable_name)
340
+ for variable_name in fail_restore_vars:
341
+ if 'Momentum' not in variable_name:
342
+ logging.warning('Variable [%s] is not available in checkpoint',
343
+ variable_name)
344
+
345
+ tf.train.init_from_checkpoint(ckpt_path, vars_in_ckpt)
346
+
347
+ if force_restore_shape_compatible:
348
+ return estimator_utils.IncompatibleShapeRestoreHook(
349
+ incompatible_shape_var_map)
350
+ else:
351
+ return None
352
+
353
+ def _get_restore_vars(self, ckpt_var_map_path):
354
+ """Restore by specify variable map between graph variables and ckpt variables.
355
+
356
+ Args:
357
+ ckpt_var_map_path: variable map from graph variables to variables in a checkpoint
358
+ each line consists of: variable name in graph variable name in ckpt
359
+
360
+ Returns:
361
+ the list of variables which need to restore from checkpoint
362
+ """
363
+ # here must use global_variables, because variables such as moving_mean
364
+ # and moving_variance is usually not trainable in detection models
365
+ all_vars = tf.global_variables()
366
+ PARTITION_PATTERN = '/part_[0-9]+'
367
+ VAR_SUFIX_PATTERN = ':[0-9]$'
368
+
369
+ name2var = {}
370
+ for one_var in all_vars:
371
+ var_name = re.sub(VAR_SUFIX_PATTERN, '', one_var.name)
372
+ if re.search(PARTITION_PATTERN,
373
+ var_name) and one_var._save_slice_info is not None:
374
+ var_name = re.sub(PARTITION_PATTERN, '', var_name)
375
+ is_part = True
376
+ else:
377
+ is_part = False
378
+ if var_name in name2var:
379
+ assert is_part, 'multiple vars: %s' % var_name
380
+ name2var[var_name].append(one_var)
381
+ else:
382
+ name2var[var_name] = [one_var] if is_part else one_var
383
+
384
+ if ckpt_var_map_path != '':
385
+ if not gfile.Exists(ckpt_var_map_path):
386
+ logging.warning('%s not exist' % ckpt_var_map_path)
387
+ return name2var
388
+
389
+ # load var map
390
+ name_map = {}
391
+ with gfile.GFile(ckpt_var_map_path, 'r') as fin:
392
+ for one_line in fin:
393
+ one_line = one_line.strip()
394
+ line_tok = [x for x in one_line.split() if x != '']
395
+ if len(line_tok) != 2:
396
+ logging.warning('Failed to process: %s' % one_line)
397
+ continue
398
+ name_map[line_tok[0]] = line_tok[1]
399
+ update_map = {}
400
+ old_keys = []
401
+ for var_name in name2var:
402
+ if var_name in name_map:
403
+ in_ckpt_name = name_map[var_name]
404
+ update_map[in_ckpt_name] = name2var[var_name]
405
+ old_keys.append(var_name)
406
+ for tmp_key in old_keys:
407
+ del name2var[tmp_key]
408
+ name2var.update(update_map)
409
+ return name2var
410
+ else:
411
+ var_filter, scope_update = self.get_restore_filter()
412
+ if var_filter is not None:
413
+ name2var = {
414
+ var_name: name2var[var_name]
415
+ for var in name2var
416
+ if var_filter.keep(var.name)
417
+ }
418
+ # drop scope prefix if necessary
419
+ if scope_update is not None:
420
+ name2var = {
421
+ scope_update(var_name): name2var[var_name] for var_name in name2var
422
+ }
423
+ return name2var
424
+
425
+ def get_restore_filter(self):
426
+ """Get restore variable filter.
427
+
428
+ Return:
429
+ filter: type of Filter in restore_filter.py
430
+ scope_drop: type of ScopeDrop in restore_filter.py
431
+ """
432
+ if len(self._base_model_config.restore_filters) == 0:
433
+ return None, None
434
+
435
+ for x in self._base_model_config.restore_filters:
436
+ logging.info('restore will filter out pattern %s' % x)
437
+
438
+ all_filters = [
439
+ restore_filter.KeywordFilter(x, True)
440
+ for x in self._base_model_config.restore_filters
441
+ ]
442
+
443
+ return restore_filter.CombineFilter(all_filters,
444
+ restore_filter.Logical.AND), None
445
+
446
+ def get_grouped_vars(self, opt_num):
447
+ """Group the vars into different optimization groups.
448
+
449
+ Each group will be optimized by a separate optimizer.
450
+
451
+ Args:
452
+ opt_num: number of optimizers from easyrec config.
453
+
454
+ Return:
455
+ list of list of variables.
456
+ """
457
+ assert opt_num == 2, 'could only support 2 optimizers, one for embedding, one for the other layers'
458
+
459
+ embedding_vars = []
460
+ deep_vars = []
461
+ for tmp_var in variables.trainable_variables():
462
+ if tmp_var.name.startswith(
463
+ 'input_layer') or '/embedding_weights' in tmp_var.name:
464
+ embedding_vars.append(tmp_var)
465
+ else:
466
+ deep_vars.append(tmp_var)
467
+ return [embedding_vars, deep_vars]
@@ -0,0 +1,242 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.layers import dnn
8
+ from easy_rec.python.model.multi_task_model import MultiTaskModel
9
+ from easy_rec.python.protos.esmm_pb2 import ESMM as ESMMConfig
10
+ from easy_rec.python.protos.loss_pb2 import LossType
11
+
12
+ if tf.__version__ >= '2.0':
13
+ tf = tf.compat.v1
14
+ losses = tf.losses
15
+
16
+
17
+ class ESMM(MultiTaskModel):
18
+
19
+ def __init__(self,
20
+ model_config,
21
+ feature_configs,
22
+ features,
23
+ labels=None,
24
+ is_training=False):
25
+ super(ESMM, self).__init__(model_config, feature_configs, features, labels,
26
+ is_training)
27
+ assert self._model_config.WhichOneof('model') == 'esmm', \
28
+ 'invalid model config: %s' % self._model_config.WhichOneof('model')
29
+ self._model_config = self._model_config.esmm
30
+ assert isinstance(self._model_config, ESMMConfig)
31
+
32
+ self._group_num = len(self._model_config.groups)
33
+ self._group_features = []
34
+ if self.has_backbone:
35
+ logging.info('use bottom backbone network')
36
+ elif self._group_num > 0:
37
+ logging.info('group_num: {0}'.format(self._group_num))
38
+ for group_id in range(self._group_num):
39
+ group = self._model_config.groups[group_id]
40
+ group_feature, _ = self._input_layer(self._feature_dict, group.input)
41
+ self._group_features.append(group_feature)
42
+ else:
43
+ group_feature, _ = self._input_layer(self._feature_dict, 'all')
44
+ self._group_features.append(group_feature)
45
+
46
+ # This model only supports two tasks (cvr+ctr or playtime+ctr).
47
+ # In order to be consistent with the paper,
48
+ # we call these two towers cvr_tower (main tower) and ctr_tower (aux tower).
49
+ self._cvr_tower_cfg = self._model_config.cvr_tower
50
+ self._ctr_tower_cfg = self._model_config.ctr_tower
51
+ self._init_towers([self._cvr_tower_cfg, self._ctr_tower_cfg])
52
+
53
+ assert self._model_config.ctr_tower.loss_type == LossType.CLASSIFICATION, \
54
+ 'ctr tower must be binary classification.'
55
+ for task_tower_cfg in self._task_towers:
56
+ assert task_tower_cfg.num_class == 1, 'Does not support multiclass classification problem'
57
+
58
+ def build_loss_graph(self):
59
+ """Build loss graph.
60
+
61
+ Returns:
62
+ self._loss_dict: Weighted loss of ctr and cvr.
63
+ """
64
+ cvr_tower_name = self._cvr_tower_cfg.tower_name
65
+ ctr_tower_name = self._ctr_tower_cfg.tower_name
66
+ cvr_label_name = self._label_name_dict[cvr_tower_name]
67
+ ctr_label_name = self._label_name_dict[ctr_tower_name]
68
+ if self._cvr_tower_cfg.loss_type == LossType.CLASSIFICATION:
69
+ ctcvr_label = tf.cast(
70
+ self._labels[cvr_label_name] * self._labels[ctr_label_name],
71
+ tf.float32)
72
+ cvr_losses = tf.keras.backend.binary_crossentropy(
73
+ ctcvr_label, self._prediction_dict['probs_ctcvr'])
74
+ cvr_loss = tf.reduce_sum(cvr_losses, name='ctcvr_loss')
75
+ # The weight defaults to 1.
76
+ self._loss_dict['weighted_cross_entropy_loss_%s' %
77
+ cvr_tower_name] = self._cvr_tower_cfg.weight * cvr_loss
78
+
79
+ elif self._cvr_tower_cfg.loss_type == LossType.L2_LOSS:
80
+ logging.info('l2 loss is used')
81
+ cvr_dtype = self._labels[cvr_label_name].dtype
82
+ ctcvr_label = self._labels[cvr_label_name] * tf.cast(
83
+ self._labels[ctr_label_name], cvr_dtype)
84
+ cvr_loss = tf.losses.mean_squared_error(
85
+ labels=ctcvr_label,
86
+ predictions=self._prediction_dict['y_ctcvr'],
87
+ weights=self._sample_weight)
88
+ self._loss_dict['weighted_l2_loss_%s' %
89
+ cvr_tower_name] = self._cvr_tower_cfg.weight * cvr_loss
90
+ _labels = tf.cast(self._labels[ctr_label_name], tf.float32)
91
+ _logits = self._prediction_dict['logits_%s' % ctr_tower_name]
92
+ cross = tf.nn.sigmoid_cross_entropy_with_logits(
93
+ labels=_labels, logits=_logits, name='ctr_loss')
94
+ ctr_loss = tf.reduce_sum(cross)
95
+ self._loss_dict['weighted_cross_entropy_loss_%s' %
96
+ ctr_tower_name] = self._ctr_tower_cfg.weight * ctr_loss
97
+ return self._loss_dict
98
+
99
+ def build_metric_graph(self, eval_config):
100
+ """Build metric graph.
101
+
102
+ Args:
103
+ eval_config: Evaluation configuration.
104
+
105
+ Returns:
106
+ metric_dict: Calculate AUC of ctr, cvr and ctrvr.
107
+ """
108
+ metric_dict = {}
109
+
110
+ cvr_tower_name = self._cvr_tower_cfg.tower_name
111
+ ctr_tower_name = self._ctr_tower_cfg.tower_name
112
+ cvr_label_name = self._label_name_dict[cvr_tower_name]
113
+ ctr_label_name = self._label_name_dict[ctr_tower_name]
114
+ for metric in self._cvr_tower_cfg.metrics_set:
115
+ # CTCVR metric
116
+ ctcvr_label_name = cvr_label_name + '_ctcvr'
117
+ cvr_dtype = self._labels[cvr_label_name].dtype
118
+ self._labels[ctcvr_label_name] = self._labels[cvr_label_name] * tf.cast(
119
+ self._labels[ctr_label_name], cvr_dtype)
120
+ metric_dict.update(
121
+ self._build_metric_impl(
122
+ metric,
123
+ loss_type=self._cvr_tower_cfg.loss_type,
124
+ label_name=ctcvr_label_name,
125
+ num_class=self._cvr_tower_cfg.num_class,
126
+ suffix='_ctcvr'))
127
+
128
+ # CVR metric
129
+ cvr_label_masked_name = cvr_label_name + '_masked'
130
+ ctr_mask = self._labels[ctr_label_name] > 0
131
+ self._labels[cvr_label_masked_name] = tf.boolean_mask(
132
+ self._labels[cvr_label_name], ctr_mask)
133
+ pred_prefix = 'probs' if self._cvr_tower_cfg.loss_type == LossType.CLASSIFICATION else 'y'
134
+ pred_name = '%s_%s' % (pred_prefix, cvr_tower_name)
135
+ self._prediction_dict[pred_name + '_masked'] = tf.boolean_mask(
136
+ self._prediction_dict[pred_name], ctr_mask)
137
+ metric_dict.update(
138
+ self._build_metric_impl(
139
+ metric,
140
+ loss_type=self._cvr_tower_cfg.loss_type,
141
+ label_name=cvr_label_masked_name,
142
+ num_class=self._cvr_tower_cfg.num_class,
143
+ suffix='_%s_masked' % cvr_tower_name))
144
+
145
+ for metric in self._ctr_tower_cfg.metrics_set:
146
+ # CTR metric
147
+ metric_dict.update(
148
+ self._build_metric_impl(
149
+ metric,
150
+ loss_type=self._ctr_tower_cfg.loss_type,
151
+ label_name=ctr_label_name,
152
+ num_class=self._ctr_tower_cfg.num_class,
153
+ suffix='_%s' % ctr_tower_name))
154
+ return metric_dict
155
+
156
+ def _add_to_prediction_dict(self, output):
157
+ super(ESMM, self)._add_to_prediction_dict(output)
158
+ if self._cvr_tower_cfg.loss_type == LossType.CLASSIFICATION:
159
+ prob = tf.multiply(
160
+ self._prediction_dict['probs_%s' % self._cvr_tower_cfg.tower_name],
161
+ self._prediction_dict['probs_%s' % self._ctr_tower_cfg.tower_name])
162
+ # pctcvr = pctr * pcvr
163
+ self._prediction_dict['probs_ctcvr'] = prob
164
+
165
+ else:
166
+ prob = tf.multiply(
167
+ self._prediction_dict['y_%s' % self._cvr_tower_cfg.tower_name],
168
+ self._prediction_dict['probs_%s' % self._ctr_tower_cfg.tower_name])
169
+ # pctcvr = pctr * pcvr
170
+ self._prediction_dict['y_ctcvr'] = prob
171
+
172
+ def build_predict_graph(self):
173
+ """Forward function.
174
+
175
+ Returns:
176
+ self._prediction_dict: Prediction result of two tasks.
177
+ """
178
+ if self.has_backbone:
179
+ all_fea = self.backbone
180
+ elif self._group_num > 0:
181
+ group_fea_arr = []
182
+ # Both towers share the underlying network.
183
+ for group_id in range(self._group_num):
184
+ group_fea = self._group_features[group_id]
185
+ group = self._model_config.groups[group_id]
186
+ group_name = group.input
187
+ dnn_model = dnn.DNN(group.dnn, self._l2_reg, group_name,
188
+ self._is_training)
189
+ group_fea = dnn_model(group_fea)
190
+ group_fea_arr.append(group_fea)
191
+ all_fea = tf.concat(group_fea_arr, axis=1)
192
+ else:
193
+ all_fea = self._group_features[0]
194
+
195
+ cvr_tower_name = self._cvr_tower_cfg.tower_name
196
+ dnn_model = dnn.DNN(
197
+ self._cvr_tower_cfg.dnn,
198
+ self._l2_reg,
199
+ name=cvr_tower_name,
200
+ is_training=self._is_training)
201
+ cvr_tower_output = dnn_model(all_fea)
202
+ cvr_tower_output = tf.layers.dense(
203
+ inputs=cvr_tower_output,
204
+ units=1,
205
+ kernel_regularizer=self._l2_reg,
206
+ name='%s/dnn_output' % cvr_tower_name)
207
+
208
+ ctr_tower_name = self._ctr_tower_cfg.tower_name
209
+ dnn_model = dnn.DNN(
210
+ self._ctr_tower_cfg.dnn,
211
+ self._l2_reg,
212
+ name=ctr_tower_name,
213
+ is_training=self._is_training)
214
+ ctr_tower_output = dnn_model(all_fea)
215
+ ctr_tower_output = tf.layers.dense(
216
+ inputs=ctr_tower_output,
217
+ units=1,
218
+ kernel_regularizer=self._l2_reg,
219
+ name='%s/dnn_output' % ctr_tower_name)
220
+
221
+ tower_outputs = {
222
+ cvr_tower_name: cvr_tower_output,
223
+ ctr_tower_name: ctr_tower_output
224
+ }
225
+ self._add_to_prediction_dict(tower_outputs)
226
+ return self._prediction_dict
227
+
228
+ def get_outputs(self):
229
+ """Get model outputs.
230
+
231
+ Returns:
232
+ outputs: The list of tensor names output by the model.
233
+ """
234
+ outputs = super(ESMM, self).get_outputs()
235
+ if self._cvr_tower_cfg.loss_type == LossType.CLASSIFICATION:
236
+ outputs.append('probs_ctcvr')
237
+ elif self._cvr_tower_cfg.loss_type == LossType.L2_LOSS:
238
+ outputs.append('y_ctcvr')
239
+ else:
240
+ raise ValueError('invalid cvr_tower loss type: %s' %
241
+ str(self._cvr_tower_cfg.loss_type))
242
+ return outputs