easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,73 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ if tf.__version__ >= '2.0':
6
+ tf = tf.compat.v1
7
+
8
+
9
+ class SENet:
10
+ """Squeeze and Excite Network.
11
+
12
+ Input shape
13
+ - A list of 2D tensor with shape: ``(batch_size,embedding_size)``.
14
+ The ``embedding_size`` of each field can have different value.
15
+
16
+ Args:
17
+ num_fields: int, number of fields.
18
+ num_squeeze_group: int, number of groups for squeeze.
19
+ reduction_ratio: int, reduction ratio for squeeze.
20
+ l2_reg: float, l2 regularizer for embedding.
21
+ name: str, name of the layer.
22
+ """
23
+
24
+ def __init__(self,
25
+ num_fields,
26
+ num_squeeze_group,
27
+ reduction_ratio,
28
+ l2_reg,
29
+ name='SENet'):
30
+ self.num_fields = num_fields
31
+ self.num_squeeze_group = num_squeeze_group
32
+ self.reduction_ratio = reduction_ratio
33
+ self._l2_reg = l2_reg
34
+ self._name = name
35
+
36
+ def __call__(self, inputs):
37
+ g = self.num_squeeze_group
38
+ f = self.num_fields
39
+ r = self.reduction_ratio
40
+ reduction_size = max(1, f * g * 2 // r)
41
+
42
+ emb_size = 0
43
+ for input in inputs:
44
+ emb_size += int(input.shape[-1])
45
+
46
+ group_embs = [
47
+ tf.reshape(emb, [-1, g, int(emb.shape[-1]) // g]) for emb in inputs
48
+ ]
49
+
50
+ squeezed = []
51
+ for emb in group_embs:
52
+ squeezed.append(tf.reduce_max(emb, axis=-1)) # [B, g]
53
+ squeezed.append(tf.reduce_mean(emb, axis=-1)) # [B, g]
54
+ z = tf.concat(squeezed, axis=1) # [bs, field_size * num_groups * 2]
55
+
56
+ reduced = tf.layers.dense(
57
+ inputs=z,
58
+ units=reduction_size,
59
+ kernel_regularizer=self._l2_reg,
60
+ activation='relu',
61
+ name='%s/reduce' % self._name)
62
+
63
+ excited_weights = tf.layers.dense(
64
+ inputs=reduced,
65
+ units=emb_size,
66
+ kernel_initializer='glorot_normal',
67
+ name='%s/excite' % self._name)
68
+
69
+ # Re-weight
70
+ inputs = tf.concat(inputs, axis=-1)
71
+ output = inputs * excited_weights
72
+
73
+ return output
@@ -0,0 +1,134 @@
1
+ # -*- encoding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import logging
5
+
6
+ import tensorflow as tf
7
+ from tensorflow.python.framework import ops
8
+ from tensorflow.python.ops import variable_scope
9
+
10
+ from easy_rec.python.compat import regularizers
11
+ from easy_rec.python.compat.feature_column import feature_column
12
+ from easy_rec.python.feature_column.feature_column import FeatureColumnParser
13
+ from easy_rec.python.protos.feature_config_pb2 import WideOrDeep
14
+
15
+ if tf.__version__ >= '2.0':
16
+ tf = tf.compat.v1
17
+
18
+
19
+ class SeqInputLayer(object):
20
+
21
+ def __init__(self,
22
+ feature_configs,
23
+ feature_groups_config,
24
+ embedding_regularizer=None,
25
+ ev_params=None):
26
+ self._feature_groups_config = {
27
+ x.group_name: x for x in feature_groups_config
28
+ }
29
+ wide_and_deep_dict = self.get_wide_deep_dict()
30
+ self._fc_parser = FeatureColumnParser(
31
+ feature_configs, wide_and_deep_dict, ev_params=ev_params)
32
+ self._embedding_regularizer = embedding_regularizer
33
+
34
+ def __call__(self,
35
+ features,
36
+ group_name,
37
+ feature_name_to_output_tensors={},
38
+ allow_key_search=True,
39
+ scope_name=None):
40
+ feature_column_dict = self._fc_parser.deep_columns
41
+ feature_column_dict.update(self._fc_parser.sequence_columns)
42
+
43
+ builder = feature_column._LazyBuilder(features)
44
+
45
+ feature_dict = self._feature_groups_config[group_name]
46
+ tf_summary = feature_dict.tf_summary
47
+ if tf_summary:
48
+ logging.info('Write sequence feature to tensorflow summary.')
49
+
50
+ def _seq_embed_summary_name(input_name):
51
+ input_name = input_name.split(':')[0]
52
+ input_name = input_name.split('/')[:2]
53
+ return 'sequence_feature/' + '/'.join(input_name)
54
+
55
+ if scope_name is None:
56
+ scope_name = group_name
57
+ # name_scope is specified to avoid adding _1 _2 after scope_name
58
+ with variable_scope.variable_scope(
59
+ scope_name,
60
+ reuse=variable_scope.AUTO_REUSE), ops.name_scope(scope_name + '/'):
61
+ key_tensors = []
62
+ hist_tensors = []
63
+ check_op_list = []
64
+ for x in feature_dict.seq_att_map:
65
+ for key in x.key:
66
+ if key not in feature_name_to_output_tensors or (
67
+ feature_name_to_output_tensors[key] is None and allow_key_search):
68
+ qfc = feature_column_dict[key]
69
+ with variable_scope.variable_scope(qfc._var_scope_name):
70
+ tmp_key_tensor = feature_column_dict[key]._get_dense_tensor(
71
+ builder)
72
+ regularizers.apply_regularization(
73
+ self._embedding_regularizer, weights_list=[tmp_key_tensor])
74
+ key_tensors.append(tmp_key_tensor)
75
+ elif feature_name_to_output_tensors[key] is None:
76
+ assert feature_name_to_output_tensors[
77
+ key] is not None, 'When allow_key_search is False, key: %s should defined in same feature group.' % key
78
+ else:
79
+ key_tensors.append(feature_name_to_output_tensors[key])
80
+
81
+ if tf_summary:
82
+ for key_tensor in key_tensors:
83
+ tf.summary.histogram(
84
+ _seq_embed_summary_name(key_tensor.name), key_tensor)
85
+ cur_hist_seqs = []
86
+ for hist_seq in x.hist_seq:
87
+ seq_fc = feature_column_dict[hist_seq]
88
+ with variable_scope.variable_scope(seq_fc._var_scope_name):
89
+ cur_hist_seqs.append(
90
+ feature_column_dict[hist_seq]._get_sequence_dense_tensor(
91
+ builder))
92
+ hist_tensors.extend(cur_hist_seqs)
93
+
94
+ aux_hist_emb_list = []
95
+ for aux_hist_seq in x.aux_hist_seq:
96
+ seq_fc = feature_column_dict[aux_hist_seq]
97
+ with variable_scope.variable_scope(seq_fc._var_scope_name):
98
+ aux_hist_embedding, _ = feature_column_dict[
99
+ aux_hist_seq]._get_sequence_dense_tensor(builder)
100
+ aux_hist_emb_list.append(aux_hist_embedding)
101
+
102
+ if tf_summary:
103
+ for hist_embed, hist_seq_len in hist_tensors:
104
+ tf.summary.histogram(
105
+ _seq_embed_summary_name(hist_embed.name), hist_embed)
106
+ tf.summary.histogram(
107
+ _seq_embed_summary_name(hist_seq_len.name), hist_seq_len)
108
+
109
+ for idx in range(1, len(cur_hist_seqs)):
110
+ check_op = tf.assert_equal(
111
+ cur_hist_seqs[0][1],
112
+ cur_hist_seqs[idx][1],
113
+ message='SequenceFeature Error: The size of %s not equal to the size of %s.'
114
+ % (x.hist_seq[idx], x.hist_seq[0]))
115
+ check_op_list.append(check_op)
116
+
117
+ with tf.control_dependencies(check_op_list):
118
+ features = {
119
+ 'key': tf.concat(key_tensors, axis=-1),
120
+ 'hist_seq_emb': tf.concat([x[0] for x in hist_tensors], axis=-1),
121
+ 'hist_seq_len': hist_tensors[0][1],
122
+ 'aux_hist_seq_emb_list': aux_hist_emb_list
123
+ }
124
+ return features
125
+
126
+ def get_wide_deep_dict(self):
127
+ wide_and_deep_dict = {}
128
+ for group_name_config in self._feature_groups_config.values():
129
+ for x in group_name_config.seq_att_map:
130
+ for key in x.key:
131
+ wide_and_deep_dict[key] = WideOrDeep.DEEP
132
+ for hist_seq in x.hist_seq:
133
+ wide_and_deep_dict[hist_seq] = WideOrDeep.DEEP
134
+ return wide_and_deep_dict
@@ -0,0 +1,249 @@
1
+ import logging
2
+ import os
3
+
4
+ import tensorflow as tf
5
+ from tensorflow.python.framework import ops
6
+
7
+ from easy_rec.python.compat import regularizers
8
+ from easy_rec.python.layers import dnn
9
+ from easy_rec.python.layers import seq_input_layer
10
+ from easy_rec.python.utils import conditional
11
+
12
+ if tf.__version__ >= '2.0':
13
+ tf = tf.compat.v1
14
+
15
+
16
+ class SequenceFeatureLayer(object):
17
+
18
+ def __init__(self,
19
+ feature_configs,
20
+ feature_groups_config,
21
+ ev_params=None,
22
+ embedding_regularizer=None,
23
+ kernel_regularizer=None,
24
+ is_training=False,
25
+ is_predicting=False):
26
+ self._seq_feature_groups_config = []
27
+ for x in feature_groups_config:
28
+ for y in x.sequence_features:
29
+ self._seq_feature_groups_config.append(y)
30
+ self._seq_input_layer = None
31
+ if len(self._seq_feature_groups_config) > 0:
32
+ self._seq_input_layer = seq_input_layer.SeqInputLayer(
33
+ feature_configs,
34
+ self._seq_feature_groups_config,
35
+ embedding_regularizer=embedding_regularizer,
36
+ ev_params=ev_params)
37
+ self._embedding_regularizer = embedding_regularizer
38
+ self._kernel_regularizer = kernel_regularizer
39
+ self._is_training = is_training
40
+ self._is_predicting = is_predicting
41
+
42
+ def negative_sampler_target_attention(self,
43
+ dnn_config,
44
+ deep_fea,
45
+ concat_features,
46
+ name,
47
+ need_key_feature=True,
48
+ allow_key_transform=False):
49
+ cur_id, hist_id_col, seq_len, aux_hist_emb_list = deep_fea['key'], deep_fea[
50
+ 'hist_seq_emb'], deep_fea['hist_seq_len'], deep_fea[
51
+ 'aux_hist_seq_emb_list']
52
+
53
+ seq_max_len = tf.shape(hist_id_col)[1]
54
+ seq_emb_dim = hist_id_col.shape[2]
55
+ cur_id_dim = tf.shape(cur_id)[-1]
56
+ batch_size = tf.shape(hist_id_col)[0]
57
+
58
+ pos_feature = cur_id[:batch_size]
59
+ neg_feature = cur_id[batch_size:]
60
+ cur_id = tf.concat([
61
+ pos_feature[:, tf.newaxis, :],
62
+ tf.tile(neg_feature[tf.newaxis, :, :], multiples=[batch_size, 1, 1])
63
+ ],
64
+ axis=1) # noqa: E126
65
+ neg_num_add_1 = tf.shape(cur_id)[1]
66
+ hist_id_col_tmp = tf.tile(
67
+ hist_id_col[:, :, :], multiples=[1, neg_num_add_1, 1])
68
+ hist_id_col = tf.reshape(
69
+ hist_id_col_tmp, [batch_size * neg_num_add_1, seq_max_len, seq_emb_dim])
70
+
71
+ concat_features = tf.tile(
72
+ concat_features[:, tf.newaxis, :], multiples=[1, neg_num_add_1, 1])
73
+ seq_len = tf.tile(seq_len, multiples=[neg_num_add_1])
74
+
75
+ if allow_key_transform and (cur_id_dim != seq_emb_dim):
76
+ cur_id = tf.layers.dense(
77
+ cur_id, seq_emb_dim, name='sequence_key_transform_layer')
78
+
79
+ cur_ids = tf.tile(cur_id, [1, 1, seq_max_len])
80
+ cur_ids = tf.reshape(
81
+ cur_ids,
82
+ tf.shape(hist_id_col)) # (B * neg_num_add_1, seq_max_len, seq_emb_dim)
83
+
84
+ din_net = tf.concat(
85
+ [cur_ids, hist_id_col, cur_ids - hist_id_col, cur_ids * hist_id_col],
86
+ axis=-1) # (B * neg_num_add_1, seq_max_len, seq_emb_dim*4)
87
+
88
+ din_layer = dnn.DNN(
89
+ dnn_config,
90
+ self._kernel_regularizer,
91
+ name,
92
+ self._is_training,
93
+ last_layer_no_activation=True,
94
+ last_layer_no_batch_norm=True)
95
+ din_net = din_layer(din_net)
96
+ scores = tf.reshape(din_net, [-1, 1, seq_max_len]) # (B, 1, ?)
97
+
98
+ seq_len = tf.expand_dims(seq_len, 1)
99
+ mask = tf.sequence_mask(seq_len)
100
+ padding = tf.ones_like(scores) * (-2**32 + 1)
101
+ scores = tf.where(mask, scores,
102
+ padding) # [B*neg_num_add_1, 1, seq_max_len]
103
+
104
+ # Scale
105
+ scores = tf.nn.softmax(scores) # (B * neg_num_add_1, 1, seq_max_len)
106
+ hist_din_emb = tf.matmul(scores,
107
+ hist_id_col) # [B * neg_num_add_1, 1, seq_emb_dim]
108
+ hist_din_emb = tf.reshape(hist_din_emb,
109
+ [batch_size, neg_num_add_1, seq_emb_dim
110
+ ]) # [B * neg_num_add_1, seq_emb_dim]
111
+ if len(aux_hist_emb_list) > 0:
112
+ all_hist_dim_emb = [hist_din_emb]
113
+ for hist_col in aux_hist_emb_list:
114
+ cur_aux_hist = tf.matmul(scores, hist_col)
115
+ outputs = tf.reshape(cur_aux_hist, [-1, seq_emb_dim])
116
+ all_hist_dim_emb.append(outputs)
117
+ hist_din_emb = tf.concat(all_hist_dim_emb, axis=1)
118
+ if not need_key_feature:
119
+ return hist_din_emb, concat_features
120
+ din_output = tf.concat([hist_din_emb, cur_id], axis=2)
121
+ return din_output, concat_features
122
+
123
+ def target_attention(self,
124
+ dnn_config,
125
+ deep_fea,
126
+ name,
127
+ need_key_feature=True,
128
+ allow_key_transform=False,
129
+ transform_dnn=False):
130
+ cur_id, hist_id_col, seq_len, aux_hist_emb_list = deep_fea['key'], deep_fea[
131
+ 'hist_seq_emb'], deep_fea['hist_seq_len'], deep_fea[
132
+ 'aux_hist_seq_emb_list']
133
+
134
+ seq_max_len = tf.shape(hist_id_col)[1]
135
+ seq_emb_dim = hist_id_col.shape[2]
136
+ cur_id_dim = cur_id.shape[-1]
137
+
138
+ if allow_key_transform and (cur_id_dim != seq_emb_dim):
139
+ if seq_emb_dim > cur_id_dim and not transform_dnn:
140
+ cur_id = tf.pad(cur_id, [[0, 0], [0, seq_emb_dim - cur_id_dim]])
141
+ else:
142
+ cur_key_layer_name = 'sequence_key_transform_layer_' + name
143
+ cur_id = tf.layers.dense(cur_id, seq_emb_dim, name=cur_key_layer_name)
144
+ cur_fea_layer_name = 'sequence_fea_transform_layer_' + name
145
+ hist_id_col = tf.layers.dense(
146
+ hist_id_col, seq_emb_dim, name=cur_fea_layer_name)
147
+ else:
148
+ cur_id = cur_id[:tf.shape(hist_id_col)[0], ...] # for negative sampler
149
+
150
+ cur_ids = tf.tile(cur_id, [1, seq_max_len])
151
+ cur_ids = tf.reshape(cur_ids,
152
+ tf.shape(hist_id_col)) # (B, seq_max_len, seq_emb_dim)
153
+
154
+ din_net = tf.concat(
155
+ [cur_ids, hist_id_col, cur_ids - hist_id_col, cur_ids * hist_id_col],
156
+ axis=-1) # (B, seq_max_len, seq_emb_dim*4)
157
+
158
+ din_layer = dnn.DNN(
159
+ dnn_config,
160
+ self._kernel_regularizer,
161
+ name,
162
+ self._is_training,
163
+ last_layer_no_activation=True,
164
+ last_layer_no_batch_norm=True)
165
+ din_net = din_layer(din_net)
166
+ scores = tf.reshape(din_net, [-1, 1, seq_max_len]) # (B, 1, ?)
167
+
168
+ seq_len = tf.expand_dims(seq_len, 1)
169
+ mask = tf.sequence_mask(seq_len)
170
+ padding = tf.ones_like(scores) * (-2**32 + 1)
171
+ scores = tf.where(mask, scores, padding) # [B, 1, seq_max_len]
172
+
173
+ # Scale
174
+ scores = tf.nn.softmax(scores) # (B, 1, seq_max_len)
175
+ hist_din_emb = tf.matmul(scores, hist_id_col) # [B, 1, seq_emb_dim]
176
+ hist_din_emb = tf.reshape(hist_din_emb,
177
+ [-1, seq_emb_dim]) # [B, seq_emb_dim]
178
+ if len(aux_hist_emb_list) > 0:
179
+ all_hist_dim_emb = [hist_din_emb]
180
+ for hist_col in aux_hist_emb_list:
181
+ aux_hist_dim = hist_col.shape[-1]
182
+ cur_aux_hist = tf.matmul(scores, hist_col)
183
+ outputs = tf.reshape(cur_aux_hist, [-1, aux_hist_dim])
184
+ all_hist_dim_emb.append(outputs)
185
+ hist_din_emb = tf.concat(all_hist_dim_emb, axis=1)
186
+ if not need_key_feature:
187
+ return hist_din_emb
188
+ din_output = tf.concat([hist_din_emb, cur_id], axis=1)
189
+ return din_output
190
+
191
+ def __call__(self,
192
+ features,
193
+ concat_features,
194
+ all_seq_att_map_config,
195
+ feature_name_to_output_tensors=None,
196
+ negative_sampler=False,
197
+ scope_name=None):
198
+ logging.info('use sequence feature layer.')
199
+ all_seq_fea = []
200
+ # process all sequence features
201
+ for seq_att_map_config in all_seq_att_map_config:
202
+ group_name = seq_att_map_config.group_name
203
+ allow_key_search = seq_att_map_config.allow_key_search
204
+ need_key_feature = seq_att_map_config.need_key_feature
205
+ allow_key_transform = seq_att_map_config.allow_key_transform
206
+ transform_dnn = seq_att_map_config.transform_dnn
207
+
208
+ place_on_cpu = os.getenv('place_embedding_on_cpu')
209
+ place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
210
+ with conditional(self._is_predicting and place_on_cpu,
211
+ ops.device('/CPU:0')):
212
+ seq_features = self._seq_input_layer(features, group_name,
213
+ feature_name_to_output_tensors,
214
+ allow_key_search, scope_name)
215
+
216
+ # apply regularization for sequence feature key in seq_input_layer.
217
+
218
+ regularizers.apply_regularization(
219
+ self._embedding_regularizer,
220
+ weights_list=[seq_features['hist_seq_emb']])
221
+ seq_dnn_config = None
222
+ if seq_att_map_config.HasField('seq_dnn'):
223
+ seq_dnn_config = seq_att_map_config.seq_dnn
224
+ else:
225
+ logging.info(
226
+ 'seq_dnn not set in seq_att_groups, will use default settings')
227
+ # If not set seq_dnn, will use default settings
228
+ from easy_rec.python.protos.dnn_pb2 import DNN
229
+ seq_dnn_config = DNN()
230
+ seq_dnn_config.hidden_units.extend([128, 64, 32, 1])
231
+ cur_target_attention_name = 'seq_dnn' + group_name
232
+ if negative_sampler:
233
+ seq_fea, concat_features = self.negative_sampler_target_attention(
234
+ seq_dnn_config,
235
+ seq_features,
236
+ concat_features,
237
+ name=cur_target_attention_name,
238
+ need_key_feature=need_key_feature,
239
+ allow_key_transform=allow_key_transform)
240
+ else:
241
+ seq_fea = self.target_attention(
242
+ seq_dnn_config,
243
+ seq_features,
244
+ name=cur_target_attention_name,
245
+ need_key_feature=need_key_feature,
246
+ allow_key_transform=allow_key_transform,
247
+ transform_dnn=transform_dnn)
248
+ all_seq_fea.append(seq_fea)
249
+ return concat_features, all_seq_fea