easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,303 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+ from collections import OrderedDict
5
+
6
+ import tensorflow as tf
7
+
8
+ from easy_rec.python.builders import loss_builder
9
+ from easy_rec.python.layers.dnn import DNN
10
+ from easy_rec.python.model.rank_model import RankModel
11
+ from easy_rec.python.protos import tower_pb2
12
+ from easy_rec.python.protos.easy_rec_model_pb2 import EasyRecModel
13
+ from easy_rec.python.protos.loss_pb2 import LossType
14
+
15
+ if tf.__version__ >= '2.0':
16
+ tf = tf.compat.v1
17
+
18
+
19
+ class MultiTaskModel(RankModel):
20
+
21
+ def __init__(self,
22
+ model_config,
23
+ feature_configs,
24
+ features,
25
+ labels=None,
26
+ is_training=False):
27
+ super(MultiTaskModel, self).__init__(model_config, feature_configs,
28
+ features, labels, is_training)
29
+ self._task_towers = []
30
+ self._task_num = None
31
+ self._label_name_dict = {}
32
+
33
+ def build_predict_graph(self):
34
+ if not self.has_backbone:
35
+ raise NotImplementedError(
36
+ 'method `build_predict_graph` must be implemented when backbone network do not exists'
37
+ )
38
+ model = self._model_config.WhichOneof('model')
39
+ assert model == 'model_params', '`model_params` must be configured'
40
+ config = self._model_config.model_params
41
+ for out in config.outputs:
42
+ self._outputs.append(out)
43
+
44
+ self._init_towers(config.task_towers)
45
+
46
+ backbone = self.backbone
47
+ if type(backbone) in (list, tuple):
48
+ if len(backbone) != len(config.task_towers):
49
+ raise ValueError(
50
+ 'The number of backbone outputs and task towers must be equal')
51
+ task_input_list = backbone
52
+ else:
53
+ task_input_list = [backbone] * len(config.task_towers)
54
+
55
+ tower_features = {}
56
+ for i, task_tower_cfg in enumerate(config.task_towers):
57
+ tower_name = task_tower_cfg.tower_name
58
+ with tf.name_scope(tower_name):
59
+ if task_tower_cfg.HasField('dnn'):
60
+ tower_dnn = DNN(
61
+ task_tower_cfg.dnn,
62
+ self._l2_reg,
63
+ name=tower_name,
64
+ is_training=self._is_training)
65
+ tower_output = tower_dnn(task_input_list[i])
66
+ else:
67
+ tower_output = task_input_list[i]
68
+ tower_features[tower_name] = tower_output
69
+
70
+ tower_outputs = {}
71
+ relation_features = {}
72
+ # bayes network
73
+ for task_tower_cfg in config.task_towers:
74
+ tower_name = task_tower_cfg.tower_name
75
+ with tf.name_scope(tower_name):
76
+ if task_tower_cfg.HasField('relation_dnn'):
77
+ relation_dnn = DNN(
78
+ task_tower_cfg.relation_dnn,
79
+ self._l2_reg,
80
+ name=tower_name + '/relation_dnn',
81
+ is_training=self._is_training)
82
+ tower_inputs = [tower_features[tower_name]]
83
+ for relation_tower_name in task_tower_cfg.relation_tower_names:
84
+ tower_inputs.append(relation_features[relation_tower_name])
85
+ relation_input = tf.concat(
86
+ tower_inputs, axis=-1, name=tower_name + '/relation_input')
87
+ relation_fea = relation_dnn(relation_input)
88
+ relation_features[tower_name] = relation_fea
89
+ else:
90
+ relation_fea = tower_features[tower_name]
91
+
92
+ output_logits = tf.layers.dense(
93
+ relation_fea,
94
+ task_tower_cfg.num_class,
95
+ kernel_regularizer=self._l2_reg,
96
+ name=tower_name + '/output')
97
+ tower_outputs[tower_name] = output_logits
98
+
99
+ self._add_to_prediction_dict(tower_outputs)
100
+ return self._prediction_dict
101
+
102
+ def _init_towers(self, task_tower_configs):
103
+ """Init task towers."""
104
+ self._task_towers = task_tower_configs
105
+ self._task_num = len(task_tower_configs)
106
+ for i, task_tower_config in enumerate(task_tower_configs):
107
+ assert isinstance(task_tower_config, tower_pb2.TaskTower) or \
108
+ isinstance(task_tower_config, tower_pb2.BayesTaskTower), \
109
+ 'task_tower_config must be a instance of tower_pb2.TaskTower or tower_pb2.BayesTaskTower'
110
+ tower_name = task_tower_config.tower_name
111
+
112
+ # For label backward compatibility with list
113
+ if self._labels is not None:
114
+ if task_tower_config.HasField('label_name'):
115
+ label_name = task_tower_config.label_name
116
+ else:
117
+ # If label name is not specified, task_tower and label will be matched by order
118
+ label_name = list(self._labels.keys())[i]
119
+ logging.info('Task Tower [%s] use label [%s]' %
120
+ (tower_name, label_name))
121
+ assert label_name in self._labels, 'label [%s] must exists in labels' % label_name
122
+ self._label_name_dict[tower_name] = label_name
123
+
124
+ def _add_to_prediction_dict(self, output):
125
+ for task_tower_cfg in self._task_towers:
126
+ tower_name = task_tower_cfg.tower_name
127
+ if len(task_tower_cfg.losses) == 0:
128
+ self._prediction_dict.update(
129
+ self._output_to_prediction_impl(
130
+ output[tower_name],
131
+ loss_type=task_tower_cfg.loss_type,
132
+ num_class=task_tower_cfg.num_class,
133
+ suffix='_%s' % tower_name))
134
+ else:
135
+ for loss in task_tower_cfg.losses:
136
+ self._prediction_dict.update(
137
+ self._output_to_prediction_impl(
138
+ output[tower_name],
139
+ loss_type=loss.loss_type,
140
+ num_class=task_tower_cfg.num_class,
141
+ suffix='_%s' % tower_name))
142
+
143
+ def build_metric_graph(self, eval_config):
144
+ """Build metric graph for multi task model."""
145
+ for task_tower_cfg in self._task_towers:
146
+ tower_name = task_tower_cfg.tower_name
147
+ for metric in task_tower_cfg.metrics_set:
148
+ loss_types = {task_tower_cfg.loss_type}
149
+ if len(task_tower_cfg.losses) > 0:
150
+ loss_types = {loss.loss_type for loss in task_tower_cfg.losses}
151
+ self._metric_dict.update(
152
+ self._build_metric_impl(
153
+ metric,
154
+ loss_type=loss_types,
155
+ label_name=self._label_name_dict[tower_name],
156
+ num_class=task_tower_cfg.num_class,
157
+ suffix='_%s' % tower_name))
158
+ return self._metric_dict
159
+
160
+ def build_loss_weight(self):
161
+ loss_weights = OrderedDict()
162
+ num_loss = 0
163
+ for task_tower_cfg in self._task_towers:
164
+ tower_name = task_tower_cfg.tower_name
165
+ losses = task_tower_cfg.losses
166
+ n = len(losses)
167
+ if n > 0:
168
+ loss_weights[tower_name] = [
169
+ loss.weight * task_tower_cfg.weight for loss in losses
170
+ ]
171
+ num_loss += n
172
+ else:
173
+ loss_weights[tower_name] = [task_tower_cfg.weight]
174
+ num_loss += 1
175
+
176
+ strategy = self._base_model_config.loss_weight_strategy
177
+ if strategy == self._base_model_config.Random:
178
+ weights = tf.random_normal([num_loss])
179
+ weights = tf.nn.softmax(weights)
180
+ i = 0
181
+ for k, v in loss_weights.items():
182
+ n = len(v)
183
+ loss_weights[k] = weights[i:i + n]
184
+ i += n
185
+ return loss_weights
186
+
187
+ def get_learnt_loss(self, loss_type, name, value):
188
+ strategy = self._base_model_config.loss_weight_strategy
189
+ if strategy == self._base_model_config.Uncertainty:
190
+ uncertainty = tf.Variable(
191
+ 0, name='%s_loss_weight' % name, dtype=tf.float32)
192
+ tf.summary.scalar('loss/%s_uncertainty' % name, uncertainty)
193
+ if loss_type in {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}:
194
+ return 0.5 * tf.exp(-uncertainty) * value + 0.5 * uncertainty
195
+ else:
196
+ return tf.exp(-uncertainty) * value + 0.5 * uncertainty
197
+ else:
198
+ strategy_name = EasyRecModel.LossWeightStrategy.Name(strategy)
199
+ raise ValueError('Unsupported loss weight strategy: ' + strategy_name)
200
+
201
+ def build_loss_graph(self):
202
+ """Build loss graph for multi task model."""
203
+ task_loss_weights = self.build_loss_weight()
204
+ for task_tower_cfg in self._task_towers:
205
+ tower_name = task_tower_cfg.tower_name
206
+ loss_weight = 1.0
207
+ if task_tower_cfg.use_sample_weight:
208
+ loss_weight *= self._sample_weight
209
+
210
+ if hasattr(task_tower_cfg, 'task_space_indicator_label') and \
211
+ task_tower_cfg.HasField('task_space_indicator_label'):
212
+ in_task_space = tf.to_float(
213
+ self._labels[task_tower_cfg.task_space_indicator_label] > 0)
214
+ loss_weight = loss_weight * (
215
+ task_tower_cfg.in_task_space_weight * in_task_space +
216
+ task_tower_cfg.out_task_space_weight * (1 - in_task_space))
217
+
218
+ if task_tower_cfg.HasField('task_space_indicator_name') and \
219
+ task_tower_cfg.HasField('task_space_indicator_value'):
220
+ in_task_space = tf.to_float(
221
+ tf.equal(
222
+ self._feature_dict[task_tower_cfg.task_space_indicator_name],
223
+ task_tower_cfg.task_space_indicator_value))
224
+ loss_weight = loss_weight * (
225
+ task_tower_cfg.in_task_space_weight * in_task_space +
226
+ task_tower_cfg.out_task_space_weight * (1 - in_task_space))
227
+
228
+ task_loss_weight = task_loss_weights[tower_name]
229
+ loss_dict = {}
230
+ losses = task_tower_cfg.losses
231
+ if len(losses) == 0:
232
+ loss_dict = self._build_loss_impl(
233
+ task_tower_cfg.loss_type,
234
+ label_name=self._label_name_dict[tower_name],
235
+ loss_weight=loss_weight,
236
+ num_class=task_tower_cfg.num_class,
237
+ suffix='_%s' % tower_name)
238
+ for loss_name in loss_dict.keys():
239
+ loss_dict[loss_name] = loss_dict[loss_name] * task_loss_weight[0]
240
+ else:
241
+ calibrate_loss = []
242
+ for loss in losses:
243
+ if loss.loss_type == LossType.ORDER_CALIBRATE_LOSS:
244
+ y_t = self._prediction_dict['probs_%s' % tower_name]
245
+ for relation_tower_name in task_tower_cfg.relation_tower_names:
246
+ y_rt = self._prediction_dict['probs_%s' % relation_tower_name]
247
+ cali_loss = tf.reduce_mean(tf.nn.relu(y_t - y_rt))
248
+ calibrate_loss.append(cali_loss * loss.weight)
249
+ logging.info('calibrate loss: %s -> %s' %
250
+ (relation_tower_name, tower_name))
251
+ continue
252
+ loss_param = loss.WhichOneof('loss_param')
253
+ if loss_param is not None:
254
+ loss_param = getattr(loss, loss_param)
255
+ loss_ops = self._build_loss_impl(
256
+ loss.loss_type,
257
+ label_name=self._label_name_dict[tower_name],
258
+ loss_weight=loss_weight,
259
+ num_class=task_tower_cfg.num_class,
260
+ suffix='_%s' % tower_name,
261
+ loss_name=loss.loss_name,
262
+ loss_param=loss_param)
263
+ for i, loss_name in enumerate(loss_ops):
264
+ loss_value = loss_ops[loss_name]
265
+ if loss.learn_loss_weight:
266
+ loss_dict[loss_name] = self.get_learnt_loss(
267
+ loss.loss_type, loss_name, loss_value)
268
+ else:
269
+ loss_dict[loss_name] = loss_value * task_loss_weight[i]
270
+ if calibrate_loss:
271
+ cali_loss = tf.add_n(calibrate_loss)
272
+ loss_dict['order_calibrate_loss'] = cali_loss
273
+ tf.summary.scalar('loss/order_calibrate_loss', cali_loss)
274
+ self._loss_dict.update(loss_dict)
275
+
276
+ kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict,
277
+ self._labels, self._feature_dict)
278
+ self._loss_dict.update(kd_loss_dict)
279
+
280
+ return self._loss_dict
281
+
282
+ def get_outputs(self):
283
+ outputs = []
284
+ if self._outputs:
285
+ outputs.extend(self._outputs)
286
+ for task_tower_cfg in self._task_towers:
287
+ tower_name = task_tower_cfg.tower_name
288
+ if len(task_tower_cfg.losses) == 0:
289
+ outputs.extend(
290
+ self._get_outputs_impl(
291
+ task_tower_cfg.loss_type,
292
+ task_tower_cfg.num_class,
293
+ suffix='_%s' % tower_name))
294
+ else:
295
+ for loss in task_tower_cfg.losses:
296
+ if loss.loss_type == LossType.ORDER_CALIBRATE_LOSS:
297
+ continue
298
+ outputs.extend(
299
+ self._get_outputs_impl(
300
+ loss.loss_type,
301
+ task_tower_cfg.num_class,
302
+ suffix='_%s' % tower_name))
303
+ return list(set(outputs))
@@ -0,0 +1,62 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import tensorflow as tf
5
+
6
+ from easy_rec.python.layers import dnn
7
+ from easy_rec.python.model.rank_model import RankModel
8
+
9
+ from easy_rec.python.protos.multi_tower_pb2 import MultiTower as MultiTowerConfig # NOQA
10
+
11
+ if tf.__version__ >= '2.0':
12
+ tf = tf.compat.v1
13
+
14
+
15
+ class MultiTower(RankModel):
16
+
17
+ def __init__(self,
18
+ model_config,
19
+ feature_configs,
20
+ features,
21
+ labels=None,
22
+ is_training=False):
23
+ super(MultiTower, self).__init__(model_config, feature_configs, features,
24
+ labels, is_training)
25
+ assert self._model_config.WhichOneof('model') == 'multi_tower', (
26
+ 'invalid model config: %s' % self._model_config.WhichOneof('model'))
27
+ self._model_config = self._model_config.multi_tower
28
+ assert isinstance(self._model_config, MultiTowerConfig)
29
+
30
+ self._tower_features = []
31
+ self._tower_num = len(self._model_config.towers)
32
+ for tower_id in range(self._tower_num):
33
+ tower = self._model_config.towers[tower_id]
34
+ tower_feature, _ = self._input_layer(self._feature_dict, tower.input)
35
+ self._tower_features.append(tower_feature)
36
+
37
+ def build_predict_graph(self):
38
+ tower_fea_arr = []
39
+ for tower_id in range(self._tower_num):
40
+ tower_fea = self._tower_features[tower_id]
41
+ tower = self._model_config.towers[tower_id]
42
+ tower_name = tower.input
43
+ tower_fea = tf.layers.batch_normalization(
44
+ tower_fea,
45
+ training=self._is_training,
46
+ trainable=True,
47
+ name='%s_fea_bn' % tower_name)
48
+
49
+ tower_dnn_layer = dnn.DNN(tower.dnn, self._l2_reg, '%s_dnn' % tower_name,
50
+ self._is_training)
51
+ tower_fea = tower_dnn_layer(tower_fea)
52
+ tower_fea_arr.append(tower_fea)
53
+
54
+ all_fea = tf.concat(tower_fea_arr, axis=1)
55
+ final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg,
56
+ 'final_dnn', self._is_training)
57
+ all_fea = final_dnn_layer(all_fea)
58
+ output = tf.layers.dense(all_fea, self._num_class, name='output')
59
+
60
+ self._add_to_prediction_dict(output)
61
+
62
+ return self._prediction_dict
@@ -0,0 +1,190 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+ import math
5
+
6
+ import tensorflow as tf
7
+
8
+ from easy_rec.python.compat import regularizers
9
+ from easy_rec.python.layers import dnn
10
+ from easy_rec.python.layers import layer_norm
11
+ from easy_rec.python.layers import seq_input_layer
12
+ from easy_rec.python.model.rank_model import RankModel
13
+
14
+ from easy_rec.python.protos.multi_tower_pb2 import MultiTower as MultiTowerConfig # NOQA
15
+
16
+ if tf.__version__ >= '2.0':
17
+ tf = tf.compat.v1
18
+
19
+
20
+ class MultiTowerBST(RankModel):
21
+
22
+ def __init__(self,
23
+ model_config,
24
+ feature_configs,
25
+ features,
26
+ labels=None,
27
+ is_training=False):
28
+ super(MultiTowerBST, self).__init__(model_config, feature_configs, features,
29
+ labels, is_training)
30
+ self._seq_input_layer = seq_input_layer.SeqInputLayer(
31
+ feature_configs,
32
+ model_config.seq_att_groups,
33
+ embedding_regularizer=self._emb_reg,
34
+ ev_params=self._global_ev_params)
35
+ assert self._model_config.WhichOneof('model') == 'multi_tower', \
36
+ 'invalid model config: %s' % self._model_config.WhichOneof('model')
37
+ self._model_config = self._model_config.multi_tower
38
+ assert isinstance(self._model_config, MultiTowerConfig)
39
+
40
+ self._tower_features = []
41
+ self._tower_num = len(self._model_config.towers)
42
+ for tower_id in range(self._tower_num):
43
+ tower = self._model_config.towers[tower_id]
44
+ tower_feature, _ = self._input_layer(self._feature_dict, tower.input)
45
+ self._tower_features.append(tower_feature)
46
+
47
+ self._bst_tower_features = []
48
+ self._bst_tower_num = len(self._model_config.bst_towers)
49
+
50
+ logging.info('all tower num: {0}'.format(self._tower_num +
51
+ self._bst_tower_num))
52
+ logging.info('bst tower num: {0}'.format(self._bst_tower_num))
53
+
54
+ for tower_id in range(self._bst_tower_num):
55
+ tower = self._model_config.bst_towers[tower_id]
56
+ tower_feature = self._seq_input_layer(self._feature_dict, tower.input)
57
+ regularizers.apply_regularization(
58
+ self._emb_reg, weights_list=[tower_feature['key']])
59
+ regularizers.apply_regularization(
60
+ self._emb_reg, weights_list=[tower_feature['hist_seq_emb']])
61
+ self._bst_tower_features.append(tower_feature)
62
+
63
+ def dnn_net(self, net, dnn_units, name):
64
+ dnn_units_len = len(dnn_units)
65
+ with tf.variable_scope(name_or_scope=name, reuse=tf.AUTO_REUSE):
66
+ for idx, units in enumerate(dnn_units):
67
+ if idx + 1 < dnn_units_len:
68
+ net = tf.layers.dense(
69
+ net,
70
+ units=units,
71
+ activation=tf.nn.relu,
72
+ name='%s_%d' % (name, idx))
73
+ else:
74
+ net = tf.layers.dense(
75
+ net, units=units, activation=None, name='%s_%d' % (name, idx))
76
+ return net
77
+
78
+ def attention_net(self, net, dim, cur_seq_len, seq_size, name):
79
+ query_net = self.dnn_net(net, [dim], name + '_query') # B, seq_len,dim
80
+ key_net = self.dnn_net(net, [dim], name + '_key')
81
+ value_net = self.dnn_net(net, [dim], name + '_value')
82
+ scores = tf.matmul(
83
+ query_net, key_net, transpose_b=True) # [B, seq_size, seq_size]
84
+
85
+ hist_mask = tf.sequence_mask(
86
+ cur_seq_len, maxlen=seq_size - 1) # [B, seq_size-1]
87
+ cur_id_mask = tf.ones(
88
+ tf.stack([tf.shape(hist_mask)[0], 1]), dtype=tf.bool) # [B, 1]
89
+ mask = tf.concat([hist_mask, cur_id_mask], axis=1) # [B, seq_size]
90
+ masks = tf.reshape(tf.tile(mask, [1, seq_size]),
91
+ (-1, seq_size, seq_size)) # [B, seq_size, seq_size]
92
+ padding = tf.ones_like(scores) * (-2**32 + 1)
93
+ scores = tf.where(masks, scores, padding) # [B, seq_size, seq_size]
94
+
95
+ # Scale
96
+ scores = tf.nn.softmax(scores) # (B, seq_size, seq_size)
97
+ att_res_net = tf.matmul(scores, value_net) # [B, seq_size, emb_dim]
98
+ return att_res_net
99
+
100
+ def multi_head_att_net(self, id_cols, head_count, emb_dim, seq_len, seq_size):
101
+ multi_head_attention_res = []
102
+ part_cols_emd_dim = int(math.ceil(emb_dim / head_count))
103
+ for start_idx in range(0, emb_dim, part_cols_emd_dim):
104
+ if start_idx + part_cols_emd_dim > emb_dim:
105
+ part_cols_emd_dim = emb_dim - start_idx
106
+ part_id_col = tf.slice(id_cols, [0, 0, start_idx],
107
+ [-1, -1, part_cols_emd_dim])
108
+ part_attention_net = self.attention_net(
109
+ part_id_col,
110
+ part_cols_emd_dim,
111
+ seq_len,
112
+ seq_size,
113
+ name='multi_head_%d' % start_idx)
114
+ multi_head_attention_res.append(part_attention_net)
115
+ multi_head_attention_res_net = tf.concat(multi_head_attention_res, axis=2)
116
+ multi_head_attention_res_net = self.dnn_net(
117
+ multi_head_attention_res_net, [emb_dim], name='multi_head_attention')
118
+ return multi_head_attention_res_net
119
+
120
+ def add_and_norm(self, net_1, net_2, emb_dim, name):
121
+ net = tf.add(net_1, net_2)
122
+ # layer = tf.keras.layers.LayerNormalization(axis=2)
123
+ layer = layer_norm.LayerNormalization(emb_dim)
124
+ net = layer(net)
125
+ return net
126
+
127
+ def bst(self, bst_fea, seq_size, head_count, name):
128
+ cur_id, hist_id_col, seq_len = bst_fea['key'], bst_fea[
129
+ 'hist_seq_emb'], bst_fea['hist_seq_len']
130
+
131
+ cur_batch_max_seq_len = tf.shape(hist_id_col)[1]
132
+
133
+ hist_id_col = tf.cond(
134
+ tf.constant(seq_size) > cur_batch_max_seq_len, lambda: tf.pad(
135
+ hist_id_col, [[0, 0], [0, seq_size - cur_batch_max_seq_len - 1],
136
+ [0, 0]], 'CONSTANT'),
137
+ lambda: tf.slice(hist_id_col, [0, 0, 0], [-1, seq_size - 1, -1]))
138
+ all_ids = tf.concat([hist_id_col, tf.expand_dims(cur_id, 1)],
139
+ axis=1) # b, seq_size, emb_dim
140
+
141
+ emb_dim = int(all_ids.shape[2])
142
+ attention_net = self.multi_head_att_net(all_ids, head_count, emb_dim,
143
+ seq_len, seq_size)
144
+
145
+ tmp_net = self.add_and_norm(
146
+ all_ids, attention_net, emb_dim, name='add_and_norm_1')
147
+ feed_forward_net = self.dnn_net(tmp_net, [emb_dim], 'feed_forward_net')
148
+ net = self.add_and_norm(
149
+ tmp_net, feed_forward_net, emb_dim, name='add_and_norm_2')
150
+ bst_output = tf.reshape(net, [-1, seq_size * emb_dim])
151
+ return bst_output
152
+
153
+ def build_predict_graph(self):
154
+ tower_fea_arr = []
155
+ for tower_id in range(self._tower_num):
156
+ tower_fea = self._tower_features[tower_id]
157
+ tower = self._model_config.towers[tower_id]
158
+ tower_name = tower.input
159
+ tower_fea = tf.layers.batch_normalization(
160
+ tower_fea,
161
+ training=self._is_training,
162
+ trainable=True,
163
+ name='%s_fea_bn' % tower_name)
164
+ tower_dnn = dnn.DNN(tower.dnn, self._l2_reg, '%s_dnn' % tower_name,
165
+ self._is_training)
166
+ tower_fea = tower_dnn(tower_fea)
167
+ tower_fea_arr.append(tower_fea)
168
+
169
+ for tower_id in range(self._bst_tower_num):
170
+ tower_fea = self._bst_tower_features[tower_id]
171
+ tower = self._model_config.bst_towers[tower_id]
172
+ tower_name = tower.input
173
+ tower_seq_len = tower.seq_len
174
+ tower_multi_head_size = tower.multi_head_size
175
+ tower_fea = self.bst(
176
+ tower_fea,
177
+ seq_size=tower_seq_len,
178
+ head_count=tower_multi_head_size,
179
+ name=tower_name)
180
+ tower_fea_arr.append(tower_fea)
181
+
182
+ all_fea = tf.concat(tower_fea_arr, axis=1)
183
+ final_dnn = dnn.DNN(self._model_config.final_dnn, self._l2_reg, 'final_dnn',
184
+ self._is_training)
185
+ all_fea = final_dnn(all_fea)
186
+ output = tf.layers.dense(all_fea, self._num_class, name='output')
187
+
188
+ self._add_to_prediction_dict(output)
189
+
190
+ return self._prediction_dict
@@ -0,0 +1,130 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.compat import regularizers
8
+ from easy_rec.python.layers import dnn
9
+ from easy_rec.python.layers import seq_input_layer
10
+ from easy_rec.python.model.rank_model import RankModel
11
+
12
+ from easy_rec.python.protos.multi_tower_pb2 import MultiTower as MultiTowerConfig # NOQA
13
+
14
+ if tf.__version__ >= '2.0':
15
+ tf = tf.compat.v1
16
+
17
+
18
+ class MultiTowerDIN(RankModel):
19
+
20
+ def __init__(self,
21
+ model_config,
22
+ feature_configs,
23
+ features,
24
+ labels=None,
25
+ is_training=False):
26
+ super(MultiTowerDIN, self).__init__(model_config, feature_configs, features,
27
+ labels, is_training)
28
+ self._seq_input_layer = seq_input_layer.SeqInputLayer(
29
+ feature_configs,
30
+ model_config.seq_att_groups,
31
+ embedding_regularizer=self._emb_reg,
32
+ ev_params=self._global_ev_params)
33
+ assert self._model_config.WhichOneof('model') == 'multi_tower', \
34
+ 'invalid model config: %s' % self._model_config.WhichOneof('model')
35
+ self._model_config = self._model_config.multi_tower
36
+ assert isinstance(self._model_config, MultiTowerConfig)
37
+
38
+ self._tower_features = []
39
+ self._tower_num = len(self._model_config.towers)
40
+ for tower_id in range(self._tower_num):
41
+ tower = self._model_config.towers[tower_id]
42
+ tower_feature, _ = self._input_layer(self._feature_dict, tower.input)
43
+ self._tower_features.append(tower_feature)
44
+
45
+ self._din_tower_features = []
46
+ self._din_tower_num = len(self._model_config.din_towers)
47
+
48
+ logging.info('all tower num: {0}'.format(self._tower_num +
49
+ self._din_tower_num))
50
+ logging.info('din tower num: {0}'.format(self._din_tower_num))
51
+
52
+ for tower_id in range(self._din_tower_num):
53
+ tower = self._model_config.din_towers[tower_id]
54
+ tower_feature = self._seq_input_layer(self._feature_dict, tower.input)
55
+
56
+ # apply regularization for sequence feature key in seq_input_layer.
57
+
58
+ regularizers.apply_regularization(
59
+ self._emb_reg, weights_list=[tower_feature['hist_seq_emb']])
60
+ self._din_tower_features.append(tower_feature)
61
+
62
+ def din(self, dnn_config, deep_fea, name):
63
+ cur_id, hist_id_col, seq_len = deep_fea['key'], deep_fea[
64
+ 'hist_seq_emb'], deep_fea['hist_seq_len']
65
+
66
+ seq_max_len = tf.shape(hist_id_col)[1]
67
+ emb_dim = hist_id_col.shape[2]
68
+
69
+ cur_ids = tf.tile(cur_id, [1, seq_max_len])
70
+ cur_ids = tf.reshape(cur_ids,
71
+ tf.shape(hist_id_col)) # (B, seq_max_len, emb_dim)
72
+
73
+ din_net = tf.concat(
74
+ [cur_ids, hist_id_col, cur_ids - hist_id_col, cur_ids * hist_id_col],
75
+ axis=-1) # (B, seq_max_len, emb_dim*4)
76
+
77
+ din_layer = dnn.DNN(
78
+ dnn_config,
79
+ self._l2_reg,
80
+ name,
81
+ self._is_training,
82
+ last_layer_no_activation=True,
83
+ last_layer_no_batch_norm=True)
84
+ din_net = din_layer(din_net)
85
+ scores = tf.reshape(din_net, [-1, 1, seq_max_len]) # (B, 1, ?)
86
+
87
+ seq_len = tf.expand_dims(seq_len, 1)
88
+ mask = tf.sequence_mask(seq_len)
89
+ padding = tf.ones_like(scores) * (-2**32 + 1)
90
+ scores = tf.where(mask, scores, padding) # [B, 1, seq_max_len]
91
+
92
+ # Scale
93
+ scores = tf.nn.softmax(scores) # (B, 1, seq_max_len)
94
+ hist_din_emb = tf.matmul(scores, hist_id_col) # [B, 1, emb_dim]
95
+ hist_din_emb = tf.reshape(hist_din_emb, [-1, emb_dim]) # [B, emb_dim]
96
+ din_output = tf.concat([hist_din_emb, cur_id], axis=1)
97
+ return din_output
98
+
99
+ def build_predict_graph(self):
100
+ tower_fea_arr = []
101
+ for tower_id in range(self._tower_num):
102
+ tower_fea = self._tower_features[tower_id]
103
+ tower = self._model_config.towers[tower_id]
104
+ tower_name = tower.input
105
+ tower_fea = tf.layers.batch_normalization(
106
+ tower_fea,
107
+ training=self._is_training,
108
+ trainable=True,
109
+ name='%s_fea_bn' % tower_name)
110
+ dnn_layer = dnn.DNN(tower.dnn, self._l2_reg, '%s_dnn' % tower_name,
111
+ self._is_training)
112
+ tower_fea = dnn_layer(tower_fea)
113
+ tower_fea_arr.append(tower_fea)
114
+
115
+ for tower_id in range(self._din_tower_num):
116
+ tower_fea = self._din_tower_features[tower_id]
117
+ tower = self._model_config.din_towers[tower_id]
118
+ tower_name = tower.input
119
+ tower_fea = self.din(tower.dnn, tower_fea, name='%s_dnn' % tower_name)
120
+ tower_fea_arr.append(tower_fea)
121
+
122
+ all_fea = tf.concat(tower_fea_arr, axis=1)
123
+ final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg,
124
+ 'final_dnn', self._is_training)
125
+ all_fea = final_dnn_layer(all_fea)
126
+ output = tf.layers.dense(all_fea, self._num_class, name='output')
127
+
128
+ self._add_to_prediction_dict(output)
129
+
130
+ return self._prediction_dict