easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,119 @@
1
+ # -*- encoding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+ from tensorflow.python.keras.layers import Layer
5
+
6
+ from easy_rec.python.layers import multihead_cross_attention
7
+ from easy_rec.python.utils.activation import get_activation
8
+ from easy_rec.python.utils.shape_utils import get_shape_list
9
+
10
+ if tf.__version__ >= '2.0':
11
+ tf = tf.compat.v1
12
+
13
+
14
+ class BST(Layer):
15
+
16
+ def __init__(self, params, name='bst', reuse=None, **kwargs):
17
+ super(BST, self).__init__(name=name, **kwargs)
18
+ self.reuse = reuse
19
+ self.l2_reg = params.l2_regularizer
20
+ self.config = params.get_pb_config()
21
+
22
+ def encode(self, seq_input, max_position):
23
+ seq_fea = multihead_cross_attention.embedding_postprocessor(
24
+ seq_input,
25
+ position_embedding_name=self.name,
26
+ max_position_embeddings=max_position,
27
+ reuse_position_embedding=self.reuse)
28
+
29
+ n = tf.count_nonzero(seq_input, axis=-1)
30
+ seq_mask = tf.cast(n > 0, tf.int32)
31
+
32
+ attention_mask = multihead_cross_attention.create_attention_mask_from_input_mask(
33
+ from_tensor=seq_fea, to_mask=seq_mask)
34
+
35
+ hidden_act = get_activation(self.config.hidden_act)
36
+ attention_fea = multihead_cross_attention.transformer_encoder(
37
+ seq_fea,
38
+ hidden_size=self.config.hidden_size,
39
+ num_hidden_layers=self.config.num_hidden_layers,
40
+ num_attention_heads=self.config.num_attention_heads,
41
+ attention_mask=attention_mask,
42
+ intermediate_size=self.config.intermediate_size,
43
+ intermediate_act_fn=hidden_act,
44
+ hidden_dropout_prob=self.config.hidden_dropout_prob,
45
+ attention_probs_dropout_prob=self.config.attention_probs_dropout_prob,
46
+ initializer_range=self.config.initializer_range,
47
+ name=self.name + '/transformer',
48
+ reuse=self.reuse)
49
+ # attention_fea shape: [batch_size, seq_length, hidden_size]
50
+ if self.config.output_all_token_embeddings:
51
+ out_fea = tf.reshape(attention_fea,
52
+ [-1, max_position * self.config.hidden_size])
53
+ else:
54
+ out_fea = attention_fea[:, 0, :] # target feature
55
+ print('bst output shape:', out_fea.shape)
56
+ return out_fea
57
+
58
+ def call(self, inputs, training=None, **kwargs):
59
+ if not training:
60
+ self.config.hidden_dropout_prob = 0.0
61
+ self.config.attention_probs_dropout_prob = 0.0
62
+ assert isinstance(inputs, (list, tuple))
63
+ assert len(inputs) >= 2
64
+ # seq_input: [batch_size, seq_len, embed_size]
65
+ seq_input, seq_len = inputs[:2]
66
+ target = inputs[2] if len(inputs) > 2 else None
67
+ max_position = self.config.max_position_embeddings
68
+ # max_seq_len: the max sequence length in current mini-batch, all sequences are padded to this length
69
+ batch_size, cur_batch_max_seq_len, seq_embed_size = get_shape_list(
70
+ seq_input, 3)
71
+ valid_len = tf.assert_less_equal(
72
+ cur_batch_max_seq_len,
73
+ max_position,
74
+ message='sequence length is greater than `max_position_embeddings`:' +
75
+ str(max_position) + ' in feature group:' + self.name +
76
+ ', you should set `max_seq_len` in sequence feature configs')
77
+
78
+ if self.config.output_all_token_embeddings:
79
+ seq_input = tf.cond(
80
+ tf.constant(max_position) > cur_batch_max_seq_len, lambda: tf.pad(
81
+ seq_input, [[0, 0], [0, max_position - cur_batch_max_seq_len],
82
+ [0, 0]], 'CONSTANT'),
83
+ lambda: tf.slice(seq_input, [0, 0, 0], [-1, max_position, -1]))
84
+
85
+ if seq_embed_size != self.config.hidden_size:
86
+ seq_input = tf.layers.dense(
87
+ seq_input,
88
+ self.config.hidden_size,
89
+ activation=tf.nn.relu,
90
+ kernel_regularizer=self.l2_reg,
91
+ name=self.name + '/seq_project',
92
+ reuse=self.reuse)
93
+
94
+ keep_target = self.config.target_item_position in ('head', 'tail')
95
+ if target is not None and keep_target:
96
+ target_size = target.shape.as_list()[-1]
97
+ assert seq_embed_size == target_size, 'the embedding size of sequence and target item is not equal' \
98
+ ' in feature group:' + self.name
99
+ if target_size != self.config.hidden_size:
100
+ target = tf.layers.dense(
101
+ target,
102
+ self.config.hidden_size,
103
+ activation=tf.nn.relu,
104
+ kernel_regularizer=self.l2_reg,
105
+ name=self.name + '/target_project',
106
+ reuse=self.reuse)
107
+ # target_feature: [batch_size, 1, embed_size]
108
+ target = tf.expand_dims(target, 1)
109
+ # seq_input: [batch_size, seq_len+1, embed_size]
110
+ if self.config.target_item_position == 'head':
111
+ seq_input = tf.concat([target, seq_input], axis=1)
112
+ else:
113
+ seq_input = tf.concat([seq_input, target], axis=1)
114
+ max_position += 1
115
+ elif self.config.reserve_target_position:
116
+ max_position += 1
117
+
118
+ with tf.control_dependencies([valid_len]):
119
+ return self.encode(seq_input, max_position)
@@ -0,0 +1,250 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """Convenience blocks for using custom ops."""
4
+ import logging
5
+ import os
6
+
7
+ import tensorflow as tf
8
+ from tensorflow.python.framework import ops
9
+ from tensorflow.python.keras.layers import Layer
10
+
11
+ curr_dir, _ = os.path.split(__file__)
12
+ parent_dir = os.path.dirname(curr_dir)
13
+ ops_idr = os.path.dirname(parent_dir)
14
+ ops_dir = os.path.join(ops_idr, 'ops')
15
+ if 'PAI' in tf.__version__:
16
+ ops_dir = os.path.join(ops_dir, '1.12_pai')
17
+ elif tf.__version__.startswith('1.12'):
18
+ ops_dir = os.path.join(ops_dir, '1.12')
19
+ elif tf.__version__.startswith('1.15'):
20
+ if 'IS_ON_PAI' in os.environ:
21
+ ops_dir = os.path.join(ops_dir, 'DeepRec')
22
+ else:
23
+ ops_dir = os.path.join(ops_dir, '1.15')
24
+ elif tf.__version__.startswith('2.12'):
25
+ ops_dir = os.path.join(ops_dir, '2.12')
26
+
27
+ logging.info('ops_dir is %s' % ops_dir)
28
+ custom_op_path = os.path.join(ops_dir, 'libcustom_ops.so')
29
+ try:
30
+ custom_ops = tf.load_op_library(custom_op_path)
31
+ logging.info('load custom op from %s succeed' % custom_op_path)
32
+ except Exception as ex:
33
+ logging.warning('load custom op from %s failed: %s' %
34
+ (custom_op_path, str(ex)))
35
+ custom_ops = None
36
+
37
+ # if tf.__version__ >= '2.0':
38
+ # tf = tf.compat.v1
39
+
40
+
41
+ class SeqAugmentOps(Layer):
42
+ """Do data augmentation for input sequence embedding."""
43
+
44
+ def __init__(self, params, name='sequence_aug', reuse=None, **kwargs):
45
+ super(SeqAugmentOps, self).__init__(name=name, **kwargs)
46
+ self.reuse = reuse
47
+ self.seq_aug_params = params.get_pb_config()
48
+ self.seq_augment = custom_ops.my_seq_augment
49
+
50
+ def call(self, inputs, training=None, **kwargs):
51
+ assert isinstance(
52
+ inputs,
53
+ (list, tuple)), 'the inputs of SeqAugmentOps must be type of list/tuple'
54
+ assert len(inputs) >= 2, 'SeqAugmentOps must have at least 2 inputs'
55
+ seq_input, seq_len = inputs[:2]
56
+ embedding_dim = int(seq_input.shape[-1])
57
+ with tf.variable_scope(self.name, reuse=self.reuse):
58
+ mask_emb = tf.get_variable(
59
+ 'mask', (embedding_dim,), dtype=tf.float32, trainable=True)
60
+ seq_len = tf.to_int32(seq_len)
61
+ with ops.device('/CPU:0'):
62
+ aug_seq, aug_len = self.seq_augment(seq_input, seq_len, mask_emb,
63
+ self.seq_aug_params.crop_rate,
64
+ self.seq_aug_params.reorder_rate,
65
+ self.seq_aug_params.mask_rate)
66
+ return aug_seq, aug_len
67
+
68
+
69
+ class TextNormalize(Layer):
70
+
71
+ def __init__(self, params, name='text_normalize', reuse=None, **kwargs):
72
+ super(TextNormalize, self).__init__(name=name, **kwargs)
73
+ self.txt_normalizer = custom_ops.text_normalize_op
74
+ self.norm_parameter = params.get_or_default('norm_parameter', 0)
75
+ self.remove_space = params.get_or_default('remove_space', False)
76
+
77
+ def call(self, inputs, training=None, **kwargs):
78
+ inputs = inputs if type(inputs) in (tuple, list) else [inputs]
79
+ with ops.device('/CPU:0'):
80
+ result = [
81
+ self.txt_normalizer(
82
+ txt,
83
+ parameter=self.norm_parameter,
84
+ remove_space=self.remove_space) for txt in inputs
85
+ ]
86
+ if len(result) == 1:
87
+ return result[0]
88
+ return result
89
+
90
+
91
+ class MappedDotProduct(Layer):
92
+
93
+ def __init__(self, params, name='mapped_dot_product', reuse=None, **kwargs):
94
+ super(MappedDotProduct, self).__init__(name=name, **kwargs)
95
+ self.mapped_dot_product = custom_ops.mapped_dot_product
96
+ self.bucketize = custom_ops.my_bucketize
97
+ self.default_value = params.get_or_default('default_value', 0)
98
+ self.separator = params.get_or_default('separator', '\035')
99
+ self.norm_fn = params.get_or_default('normalize_fn', None)
100
+ self.boundaries = list(params.get_or_default('boundaries', []))
101
+ self.emb_dim = params.get_or_default('embedding_dim', 0)
102
+ self.print_first_n = params.get_or_default('print_first_n', 0)
103
+ self.summarize = params.get_or_default('summarize', None)
104
+ if self.emb_dim > 0:
105
+ vocab_size = len(self.boundaries) + 1
106
+ with tf.variable_scope(self.name, reuse=reuse):
107
+ self.embedding_table = tf.get_variable(
108
+ name='dot_product_emb_table',
109
+ shape=[vocab_size, self.emb_dim],
110
+ dtype=tf.float32)
111
+
112
+ def call(self, inputs, training=None, **kwargs):
113
+ query, doc = inputs[:2]
114
+ with ops.device('/CPU:0'):
115
+ feature = self.mapped_dot_product(
116
+ query=query,
117
+ document=doc,
118
+ feature_name=self.name,
119
+ separator=self.separator,
120
+ default_value=self.default_value)
121
+ tf.summary.scalar(self.name, tf.reduce_mean(feature))
122
+ if self.print_first_n:
123
+ encode_q = tf.regex_replace(query, self.separator, ' ')
124
+ encode_t = tf.regex_replace(query, self.separator, ' ')
125
+ feature = tf.Print(
126
+ feature, [encode_q, encode_t, feature],
127
+ message=self.name,
128
+ first_n=self.print_first_n,
129
+ summarize=self.summarize)
130
+ if self.norm_fn is not None:
131
+ fn = eval(self.norm_fn)
132
+ feature = fn(feature)
133
+ tf.summary.scalar('normalized_%s' % self.name, tf.reduce_mean(feature))
134
+ if self.print_first_n:
135
+ feature = tf.Print(
136
+ feature, [feature],
137
+ message='normalized %s' % self.name,
138
+ first_n=self.print_first_n,
139
+ summarize=self.summarize)
140
+ if self.boundaries:
141
+ feature = self.bucketize(feature, boundaries=self.boundaries)
142
+ tf.summary.histogram('bucketized_%s' % self.name, feature)
143
+ if self.emb_dim > 0 and self.boundaries:
144
+ vocab_size = len(self.boundaries) + 1
145
+ one_hot_input_ids = tf.one_hot(feature, depth=vocab_size)
146
+ return tf.matmul(one_hot_input_ids, self.embedding_table)
147
+ return tf.expand_dims(feature, axis=-1)
148
+
149
+
150
+ class OverlapFeature(Layer):
151
+
152
+ def __init__(self, params, name='overlap_feature', reuse=None, **kwargs):
153
+ super(OverlapFeature, self).__init__(name=name, **kwargs)
154
+ self.overlap_feature = custom_ops.overlap_fg_op
155
+ methods = params.get_or_default('methods', [])
156
+ assert methods, 'overlap feature methods must be set'
157
+ self.methods = [str(method) for method in methods]
158
+ self.norm_fn = params.get_or_default('normalize_fn', None)
159
+ self.boundaries = list(params.get_or_default('boundaries', []))
160
+ self.separator = params.get_or_default('separator', '\035')
161
+ self.default_value = params.get_or_default('default_value', '-1')
162
+ self.emb_dim = params.get_or_default('embedding_dim', 0)
163
+ self.print_first_n = params.get_or_default('print_first_n', 0)
164
+ self.summarize = params.get_or_default('summarize', None)
165
+ if self.emb_dim > 0:
166
+ vocab_size = len(self.boundaries) + 1
167
+ vocab_size *= len(self.methods)
168
+ with tf.variable_scope(self.name, reuse=reuse):
169
+ self.embedding_table = tf.get_variable(
170
+ name='overlap_emb_table',
171
+ shape=[vocab_size, self.emb_dim],
172
+ dtype=tf.float32)
173
+
174
+ def call(self, inputs, training=None, **kwargs):
175
+ query, title = inputs[:2]
176
+ with ops.device('/CPU:0'):
177
+ feature = self.overlap_feature(
178
+ query=query,
179
+ title=title,
180
+ feature_name=self.name,
181
+ separator=self.separator,
182
+ default_value=self.default_value,
183
+ boundaries=self.boundaries,
184
+ methods=self.methods,
185
+ dtype=tf.int32 if self.boundaries else tf.float32)
186
+
187
+ for i, method in enumerate(self.methods):
188
+ # warning: feature[:, i] may be not the result of method
189
+ if self.boundaries:
190
+ tf.summary.histogram('bucketized_%s' % method, feature[:, i])
191
+ else:
192
+ tf.summary.scalar(method, tf.reduce_mean(feature[:, i]))
193
+ if self.print_first_n:
194
+ encode_q = tf.regex_replace(query, self.separator, ' ')
195
+ encode_t = tf.regex_replace(query, self.separator, ' ')
196
+ feature = tf.Print(
197
+ feature, [encode_q, encode_t, feature],
198
+ message=self.name,
199
+ first_n=self.print_first_n,
200
+ summarize=self.summarize)
201
+ if self.norm_fn is not None:
202
+ fn = eval(self.norm_fn)
203
+ feature = fn(feature)
204
+
205
+ if self.emb_dim > 0 and self.boundaries:
206
+ # This vocab will be small so we always do one-hot here, since it is always
207
+ # faster for a small vocabulary.
208
+ batch_size = tf.shape(feature)[0]
209
+ vocab_size = len(self.boundaries) + 1
210
+ num_indices = len(self.methods)
211
+ # Compute offsets, add to every column indices
212
+ offsets = tf.range(num_indices) * vocab_size # Shape: [3]
213
+ offsets = tf.reshape(offsets, [1, num_indices]) # Shape: [1, 3]
214
+ offsets = tf.tile(offsets,
215
+ [batch_size, 1]) # Shape: [batch_size, num_indices]
216
+ shifted_indices = feature + offsets # Shape: [batch_size, num_indices]
217
+ flat_feature_ids = tf.reshape(shifted_indices, [-1])
218
+ one_hot_ids = tf.one_hot(flat_feature_ids, depth=vocab_size * num_indices)
219
+ feature_embeddings = tf.matmul(one_hot_ids, self.embedding_table)
220
+ feature_embeddings = tf.reshape(feature_embeddings,
221
+ [batch_size, num_indices * self.emb_dim])
222
+ return feature_embeddings
223
+ return feature
224
+
225
+
226
+ class EditDistance(Layer):
227
+
228
+ def __init__(self, params, name='edit_distance', reuse=None, **kwargs):
229
+ super(EditDistance, self).__init__(name=name, **kwargs)
230
+ self.edit_distance = custom_ops.my_edit_distance
231
+ self.txt_encoding = params.get_or_default('text_encoding', 'utf-8')
232
+ self.emb_size = params.get_or_default('embedding_size', 512)
233
+ emb_dim = params.get_or_default('embedding_dim', 4)
234
+ with tf.variable_scope(self.name, reuse=reuse):
235
+ self.embedding_table = tf.get_variable('embedding_table',
236
+ [self.emb_size, emb_dim],
237
+ tf.float32)
238
+
239
+ def call(self, inputs, training=None, **kwargs):
240
+ input1, input2 = inputs[:2]
241
+ with ops.device('/CPU:0'):
242
+ dist = self.edit_distance(
243
+ input1,
244
+ input2,
245
+ normalize=False,
246
+ dtype=tf.int32,
247
+ encoding=self.txt_encoding)
248
+ ids = tf.clip_by_value(dist, 0, self.emb_size - 1)
249
+ embed = tf.nn.embedding_lookup(self.embedding_table, ids)
250
+ return embed
@@ -0,0 +1,133 @@
1
+ # -*- encoding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+ from tensorflow.python.keras.layers import Layer
5
+
6
+ from easy_rec.python.utils.shape_utils import get_shape_list
7
+
8
+ if tf.__version__ >= '2.0':
9
+ tf = tf.compat.v1
10
+
11
+
12
+ def item_mask(aug_data, length, mask_emb, mask_rate):
13
+ length1 = tf.cast(length, dtype=tf.float32)
14
+ num_mask = tf.cast(tf.math.floor(length1 * mask_rate), dtype=tf.int32)
15
+ max_len = tf.shape(aug_data)[0]
16
+ seq_mask = tf.sequence_mask(num_mask, length)
17
+ seq_mask = tf.random.shuffle(seq_mask)
18
+ padding = tf.sequence_mask(0, max_len - length)
19
+ seq_mask = tf.concat([seq_mask, padding], axis=0)
20
+
21
+ mask_emb = tf.tile(mask_emb, [max_len, 1])
22
+
23
+ masked_item_seq = tf.where(seq_mask, mask_emb, aug_data)
24
+ return masked_item_seq, length
25
+
26
+
27
+ def item_crop(aug_data, length, crop_rate):
28
+ length1 = tf.cast(length, dtype=tf.float32)
29
+ max_len, _ = get_shape_list(aug_data)
30
+ max_length = tf.cast(max_len, dtype=tf.int32)
31
+
32
+ num_left = tf.cast(tf.math.floor(length1 * crop_rate), dtype=tf.int32)
33
+ crop_begin = tf.random.uniform([],
34
+ minval=0,
35
+ maxval=length - num_left,
36
+ dtype=tf.int32)
37
+ zeros = tf.zeros_like(aug_data)
38
+ x = aug_data[crop_begin:crop_begin + num_left]
39
+ y = zeros[:max_length - num_left]
40
+ cropped = tf.concat([x, y], axis=0)
41
+ cropped_item_seq = tf.where(
42
+ crop_begin + num_left < max_length, cropped,
43
+ tf.concat([aug_data[crop_begin:], zeros[:crop_begin]], axis=0))
44
+ return cropped_item_seq, num_left
45
+
46
+
47
+ def item_reorder(aug_data, length, reorder_rate):
48
+ length1 = tf.cast(length, dtype=tf.float32)
49
+ num_reorder = tf.cast(tf.math.floor(length1 * reorder_rate), dtype=tf.int32)
50
+ reorder_begin = tf.random.uniform([],
51
+ minval=0,
52
+ maxval=length - num_reorder,
53
+ dtype=tf.int32)
54
+ shuffle_index = tf.range(reorder_begin, reorder_begin + num_reorder)
55
+ shuffle_index = tf.random.shuffle(shuffle_index)
56
+ x = tf.range(get_shape_list(aug_data)[0])
57
+ left = tf.slice(x, [0], [reorder_begin])
58
+ right = tf.slice(x, [reorder_begin + num_reorder], [-1])
59
+ reordered_item_index = tf.concat([left, shuffle_index, right], axis=0)
60
+ reordered_item_seq = tf.scatter_nd(
61
+ tf.expand_dims(reordered_item_index, axis=1), aug_data,
62
+ tf.shape(aug_data))
63
+ return reordered_item_seq, length
64
+
65
+
66
+ def augment_fn(x, aug_param, mask):
67
+ seq, length = x
68
+
69
+ def crop_fn():
70
+ return item_crop(seq, length, aug_param.crop_rate)
71
+
72
+ def mask_fn():
73
+ return item_mask(seq, length, mask, aug_param.mask_rate)
74
+
75
+ def reorder_fn():
76
+ return item_reorder(seq, length, aug_param.reorder_rate)
77
+
78
+ trans_fn = []
79
+ if aug_param.crop_rate < 1.0:
80
+ trans_fn.append(crop_fn)
81
+ if aug_param.mask_rate > 0:
82
+ trans_fn.append(mask_fn)
83
+ if aug_param.reorder_rate > 0:
84
+ trans_fn.append(reorder_fn)
85
+
86
+ num_trans = len(trans_fn)
87
+ if num_trans == 0:
88
+ return seq, length
89
+
90
+ if num_trans == 1:
91
+ return trans_fn[0]()
92
+
93
+ method = tf.random.uniform([], minval=0, maxval=num_trans, dtype=tf.int32)
94
+ if num_trans == 2:
95
+ return tf.cond(tf.equal(method, 0), trans_fn[0], trans_fn[1])
96
+
97
+ aug_seq, aug_len = tf.cond(
98
+ tf.equal(method, 0), crop_fn,
99
+ lambda: tf.cond(tf.equal(method, 1), mask_fn, reorder_fn))
100
+ return aug_seq, aug_len
101
+
102
+
103
+ def sequence_augment(seq_input, seq_len, mask, aug_param):
104
+ lengths = tf.cast(seq_len, dtype=tf.int32)
105
+ aug_seq, aug_len = tf.map_fn(
106
+ lambda elems: augment_fn(elems, aug_param, mask),
107
+ elems=(seq_input, lengths),
108
+ dtype=(tf.float32, tf.int32))
109
+
110
+ aug_seq = tf.reshape(aug_seq, tf.shape(seq_input))
111
+ return aug_seq, aug_len
112
+
113
+
114
+ class SeqAugment(Layer):
115
+ """Do data augmentation for input sequence embedding."""
116
+
117
+ def __init__(self, params, name='seq_aug', reuse=None, **kwargs):
118
+ super(SeqAugment, self).__init__(name=name, **kwargs)
119
+ self.reuse = reuse
120
+ self.seq_aug_params = params.get_pb_config()
121
+
122
+ def call(self, inputs, training=None, **kwargs):
123
+ assert isinstance(inputs, (list, tuple))
124
+ seq_input, seq_len = inputs[:2]
125
+
126
+ embedding_size = int(seq_input.shape[-1])
127
+ with tf.variable_scope(self.name, reuse=self.reuse):
128
+ mask_emb = tf.get_variable(
129
+ 'mask', [1, embedding_size], dtype=tf.float32, trainable=True)
130
+
131
+ aug_seq, aug_len = sequence_augment(seq_input, seq_len, mask_emb,
132
+ self.seq_aug_params)
133
+ return aug_seq, aug_len
@@ -0,0 +1,67 @@
1
+ # -*- encoding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+ from tensorflow.python.keras.layers import Layer
7
+
8
+ from easy_rec.python.layers.keras import MLP
9
+ from easy_rec.python.layers.utils import Parameter
10
+ from easy_rec.python.utils.shape_utils import get_shape_list
11
+
12
+
13
+ class DIN(Layer):
14
+
15
+ def __init__(self, params, name='din', reuse=None, **kwargs):
16
+ super(DIN, self).__init__(name=name, **kwargs)
17
+ self.reuse = reuse
18
+ self.l2_reg = params.l2_regularizer
19
+ self.config = params.get_pb_config()
20
+ self.config.attention_dnn.use_final_bn = False
21
+ self.config.attention_dnn.use_final_bias = True
22
+ self.config.attention_dnn.final_activation = 'linear'
23
+ mlp_params = Parameter.make_from_pb(self.config.attention_dnn)
24
+ mlp_params.l2_regularizer = self.l2_reg
25
+ self.din_layer = MLP(mlp_params, 'din_attention', reuse=self.reuse)
26
+
27
+ def call(self, inputs, training=None, **kwargs):
28
+ keys, seq_len, query = inputs
29
+ assert query is not None, '[%s] target feature is empty' % self.name
30
+ query_emb_size = int(query.shape[-1])
31
+ seq_emb_size = keys.shape.as_list()[-1]
32
+ if query_emb_size != seq_emb_size:
33
+ logging.info(
34
+ '<din> the embedding size of sequence [%d] and target item [%d] is not equal'
35
+ ' in feature group: %s', seq_emb_size, query_emb_size, self.name)
36
+ if query_emb_size < seq_emb_size:
37
+ query = tf.pad(query, [[0, 0], [0, seq_emb_size - query_emb_size]])
38
+ else:
39
+ assert False, 'the embedding size of target item is larger than the one of sequence'
40
+
41
+ batch_size, max_seq_len, _ = get_shape_list(keys, 3)
42
+ queries = tf.tile(tf.expand_dims(query, 1), [1, max_seq_len, 1])
43
+ din_all = tf.concat([queries, keys, queries - keys, queries * keys],
44
+ axis=-1)
45
+ output = self.din_layer(din_all, training) # [B, L, 1]
46
+ scores = tf.transpose(output, [0, 2, 1]) # [B, 1, L]
47
+
48
+ seq_mask = tf.sequence_mask(seq_len, max_seq_len, dtype=tf.bool)
49
+ seq_mask = tf.expand_dims(seq_mask, 1)
50
+ paddings = tf.ones_like(scores) * (-2**32 + 1)
51
+ scores = tf.where(seq_mask, scores, paddings) # [B, 1, L]
52
+ if self.config.attention_normalizer == 'softmax':
53
+ scores = tf.nn.softmax(scores) # (B, 1, L)
54
+ elif self.config.attention_normalizer == 'sigmoid':
55
+ scores = scores / (seq_emb_size**0.5)
56
+ scores = tf.nn.sigmoid(scores)
57
+ else:
58
+ raise ValueError('unsupported attention normalizer: ' +
59
+ self.config.attention_normalizer)
60
+
61
+ if query_emb_size < seq_emb_size:
62
+ keys = keys[:, :, :query_emb_size] # [B, L, E]
63
+ output = tf.squeeze(tf.matmul(scores, keys), axis=[1])
64
+ if self.config.need_target_feature:
65
+ output = tf.concat([output, query], axis=-1)
66
+ print('din output shape:', output.shape)
67
+ return output