easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,251 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import itertools
4
+ import logging
5
+
6
+ import tensorflow as tf
7
+ from tensorflow.python.keras.layers import Dense
8
+ from tensorflow.python.keras.layers import Layer
9
+
10
+ from easy_rec.python.layers.keras.blocks import MLP
11
+ from easy_rec.python.layers.keras.layer_norm import LayerNormalization
12
+ from easy_rec.python.layers.utils import Parameter
13
+
14
+
15
+ class SENet(Layer):
16
+ """SENET Layer used in FiBiNET.
17
+
18
+ Input shape
19
+ - A list of 2D tensor with shape: ``(batch_size,embedding_size)``.
20
+ The ``embedding_size`` of each field can have different value.
21
+
22
+ Output shape
23
+ - A 2D tensor with shape: ``(batch_size,sum_of_embedding_size)``.
24
+
25
+ References:
26
+ 1. [FiBiNET](https://arxiv.org/pdf/1905.09433.pdf)
27
+ Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction
28
+ 2. [FiBiNet++](https://arxiv.org/pdf/2209.05016.pdf)
29
+ Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction
30
+ """
31
+
32
+ def __init__(self, params, name='SENet', reuse=None, **kwargs):
33
+ super(SENet, self).__init__(name=name, **kwargs)
34
+ self.config = params.get_pb_config()
35
+ self.reuse = reuse
36
+ if tf.__version__ >= '2.0':
37
+ self.layer_norm = tf.keras.layers.LayerNormalization(name='output_ln')
38
+ else:
39
+ self.layer_norm = LayerNormalization(name='output_ln')
40
+
41
+ def build(self, input_shape):
42
+ g = self.config.num_squeeze_group
43
+ emb_size = 0
44
+ for shape in input_shape:
45
+ assert shape.ndims == 2, 'field embeddings must be rank 2 tensors'
46
+ dim = int(shape[-1])
47
+ assert dim >= g and dim % g == 0, 'field embedding dimension %d must be divisible by %d' % (
48
+ dim, g)
49
+ emb_size += dim
50
+
51
+ r = self.config.reduction_ratio
52
+ field_size = len(input_shape)
53
+ reduction_size = max(1, field_size * g * 2 // r)
54
+ self.reduce_layer = Dense(
55
+ units=reduction_size,
56
+ activation='relu',
57
+ kernel_initializer='he_normal',
58
+ name='W1')
59
+ self.excite_layer = Dense(
60
+ units=emb_size, kernel_initializer='glorot_normal', name='W2')
61
+ super(SENet, self).build(input_shape) # Be sure to call this somewhere!
62
+
63
+ def call(self, inputs, **kwargs):
64
+ g = self.config.num_squeeze_group
65
+
66
+ # Squeeze
67
+ # embedding dimension 必须能被 g 整除
68
+ group_embs = [
69
+ tf.reshape(emb, [-1, g, int(emb.shape[-1]) // g]) for emb in inputs
70
+ ]
71
+
72
+ squeezed = []
73
+ for emb in group_embs:
74
+ squeezed.append(tf.reduce_max(emb, axis=-1)) # [B, g]
75
+ squeezed.append(tf.reduce_mean(emb, axis=-1)) # [B, g]
76
+ z = tf.concat(squeezed, axis=1) # [bs, field_size * num_groups * 2]
77
+
78
+ # Excitation
79
+ a1 = self.reduce_layer(z)
80
+ weights = self.excite_layer(a1)
81
+
82
+ # Re-weight
83
+ inputs = tf.concat(inputs, axis=-1)
84
+ output = inputs * weights
85
+
86
+ # Fuse, add skip-connection
87
+ if self.config.use_skip_connection:
88
+ output += inputs
89
+
90
+ # Layer Normalization
91
+ if self.config.use_output_layer_norm:
92
+ output = self.layer_norm(output)
93
+ return output
94
+
95
+
96
+ def _full_interaction(v_i, v_j):
97
+ # [bs, 1, dim] x [bs, dim, 1] = [bs, 1]
98
+ interaction = tf.matmul(
99
+ tf.expand_dims(v_i, axis=1), tf.expand_dims(v_j, axis=-1))
100
+ return tf.squeeze(interaction, axis=1)
101
+
102
+
103
+ class BiLinear(Layer):
104
+ """BilinearInteraction Layer used in FiBiNET.
105
+
106
+ Input shape
107
+ - A list of 2D tensor with shape: ``(batch_size,embedding_size)``.
108
+ Its length is ``filed_size``.
109
+ The ``embedding_size`` of each field can have different value.
110
+
111
+ Output shape
112
+ - 2D tensor with shape: ``(batch_size,output_size)``.
113
+
114
+ Attributes:
115
+ num_output_units: the number of output units
116
+ type: ['all', 'each', 'interaction'], types of bilinear functions used in this layer
117
+ use_plus: whether to use bi-linear+
118
+
119
+ References:
120
+ 1. [FiBiNET](https://arxiv.org/pdf/1905.09433.pdf)
121
+ Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction
122
+ 2. [FiBiNet++](https://arxiv.org/pdf/2209.05016.pdf)
123
+ Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction
124
+ """
125
+
126
+ def __init__(self, params, name='bilinear', reuse=None, **kwargs):
127
+ super(BiLinear, self).__init__(name=name, **kwargs)
128
+ self.reuse = reuse
129
+ params.check_required(['num_output_units'])
130
+ bilinear_plus = params.get_or_default('use_plus', True)
131
+ self.output_size = params.num_output_units
132
+ self.bilinear_type = params.get_or_default('type', 'interaction').lower()
133
+ if self.bilinear_type not in ['all', 'each', 'interaction']:
134
+ raise NotImplementedError(
135
+ "bilinear_type only support: ['all', 'each', 'interaction']")
136
+ if bilinear_plus:
137
+ self.func = _full_interaction
138
+ else:
139
+ self.func = tf.multiply
140
+ self.output_layer = Dense(self.output_size, name='output')
141
+
142
+ def build(self, input_shape):
143
+ if type(input_shape) not in (tuple, list):
144
+ raise TypeError('input of BiLinear layer must be a list')
145
+ field_num = len(input_shape)
146
+ logging.info('Bilinear Layer with %d inputs' % field_num)
147
+ if field_num > 200:
148
+ logging.warning('Too many inputs for bilinear layer: %d' % field_num)
149
+ equal_dim = True
150
+ _dim = input_shape[0][-1]
151
+ for shape in input_shape:
152
+ assert shape.ndims == 2, 'field embeddings must be rank 2 tensors'
153
+ if shape[-1] != _dim:
154
+ equal_dim = False
155
+ if not equal_dim and self.bilinear_type != 'interaction':
156
+ raise ValueError(
157
+ 'all embedding dimensions must be same when not use bilinear type: interaction'
158
+ )
159
+ dim = int(_dim)
160
+
161
+ if self.bilinear_type == 'all':
162
+ self.dot_layer = Dense(dim, name='all')
163
+ elif self.bilinear_type == 'each':
164
+ self.dot_layers = [
165
+ Dense(dim, name='each_%d' % i) for i in range(field_num - 1)
166
+ ]
167
+ else: # interaction
168
+ self.dot_layers = [
169
+ Dense(
170
+ units=int(input_shape[j][-1]), name='interaction_%d_%d' % (i, j))
171
+ for i, j in itertools.combinations(range(field_num), 2)
172
+ ]
173
+ super(BiLinear, self).build(input_shape) # Be sure to call this somewhere!
174
+
175
+ def call(self, inputs, **kwargs):
176
+ embeddings = inputs
177
+ field_num = len(embeddings)
178
+
179
+ # bi-linear+: dimension of `p` is [bs, f*(f-1)/2]
180
+ # bi-linear:
181
+ # - when equal_dim=True, dimension of `p` is [bs, f*(f-1)/2*k], k is embedding size
182
+ # - when equal_dim=False, dimension of `p` is [bs, (k_2+k_3+...+k_f)+...+(k_i+k_{i+1}+...+k_f)+...+k_f],
183
+ # - where k_i is the embedding size of the ith field
184
+ if self.bilinear_type == 'all':
185
+ v_dot = [self.dot_layer(v_i) for v_i in embeddings[:-1]]
186
+ p = [
187
+ self.func(v_dot[i], embeddings[j])
188
+ for i, j in itertools.combinations(range(field_num), 2)
189
+ ]
190
+ elif self.bilinear_type == 'each':
191
+ v_dot = [self.dot_layers[i](v_i) for i, v_i in enumerate(embeddings[:-1])]
192
+ p = [
193
+ self.func(v_dot[i], embeddings[j])
194
+ for i, j in itertools.combinations(range(field_num), 2)
195
+ ]
196
+ else: # interaction
197
+ p = [
198
+ self.func(self.dot_layers[i * field_num + j](embeddings[i]),
199
+ embeddings[j])
200
+ for i, j in itertools.combinations(range(field_num), 2)
201
+ ]
202
+
203
+ return self.output_layer(tf.concat(p, axis=-1))
204
+
205
+
206
+ class FiBiNet(Layer):
207
+ """FiBiNet++:Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction.
208
+
209
+ References:
210
+ - [FiBiNet++](https://arxiv.org/pdf/2209.05016.pdf)
211
+ Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction
212
+ """
213
+
214
+ def __init__(self, params, name='fibinet', reuse=None, **kwargs):
215
+ super(FiBiNet, self).__init__(name=name, **kwargs)
216
+ self.reuse = reuse
217
+ self._config = params.get_pb_config()
218
+
219
+ se_params = Parameter.make_from_pb(self._config.senet)
220
+ self.senet_layer = SENet(
221
+ se_params, name=self.name + '/senet', reuse=self.reuse)
222
+
223
+ if self._config.HasField('bilinear'):
224
+ bi_params = Parameter.make_from_pb(self._config.bilinear)
225
+ self.bilinear_layer = BiLinear(
226
+ bi_params, name=self.name + '/bilinear', reuse=self.reuse)
227
+
228
+ if self._config.HasField('mlp'):
229
+ p = Parameter.make_from_pb(self._config.mlp)
230
+ p.l2_regularizer = params.l2_regularizer
231
+ self.final_mlp = MLP(p, name=self.name + '/mlp', reuse=reuse)
232
+ else:
233
+ self.final_mlp = None
234
+
235
+ def call(self, inputs, training=None, **kwargs):
236
+ feature_list = []
237
+
238
+ senet_output = self.senet_layer(inputs)
239
+ feature_list.append(senet_output)
240
+
241
+ if self._config.HasField('bilinear'):
242
+ feature_list.append(self.bilinear_layer(inputs))
243
+
244
+ if len(feature_list) > 1:
245
+ feature = tf.concat(feature_list, axis=-1)
246
+ else:
247
+ feature = feature_list[0]
248
+
249
+ if self.final_mlp is not None:
250
+ feature = self.final_mlp(feature, training=training)
251
+ return feature
@@ -0,0 +1,416 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ from easy_rec.python.utils.activation import get_activation
6
+
7
+
8
+ class FM(tf.keras.layers.Layer):
9
+ """Factorization Machine models pairwise (order-2) feature interactions without linear term and bias.
10
+
11
+ References
12
+ - [Factorization Machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf)
13
+ Input shape.
14
+ - List of 2D tensor with shape: ``(batch_size,embedding_size)``.
15
+ - Or a 3D tensor with shape: ``(batch_size,field_size,embedding_size)``
16
+ Output shape
17
+ - 2D tensor with shape: ``(batch_size, 1)``.
18
+ """
19
+
20
+ def __init__(self, params, name='fm', reuse=None, **kwargs):
21
+ super(FM, self).__init__(name=name, **kwargs)
22
+ self.use_variant = params.get_or_default('use_variant', False)
23
+
24
+ def call(self, inputs, **kwargs):
25
+ if type(inputs) == list:
26
+ emb_dims = set(map(lambda x: int(x.shape[-1]), inputs))
27
+ if len(emb_dims) != 1:
28
+ dims = ','.join([str(d) for d in emb_dims])
29
+ raise ValueError('all embedding dim must be equal in FM layer:' + dims)
30
+ with tf.name_scope(self.name):
31
+ fea = tf.stack(inputs, axis=1)
32
+ else:
33
+ assert inputs.shape.ndims == 3, 'input of FM layer must be a 3D tensor or a list of 2D tensors'
34
+ fea = inputs
35
+
36
+ with tf.name_scope(self.name):
37
+ square_of_sum = tf.square(tf.reduce_sum(fea, axis=1))
38
+ sum_of_square = tf.reduce_sum(tf.square(fea), axis=1)
39
+ cross_term = tf.subtract(square_of_sum, sum_of_square)
40
+ if self.use_variant:
41
+ cross_term = 0.5 * cross_term
42
+ else:
43
+ cross_term = 0.5 * tf.reduce_sum(cross_term, axis=-1, keepdims=True)
44
+ return cross_term
45
+
46
+
47
+ class DotInteraction(tf.keras.layers.Layer):
48
+ """Dot interaction layer of DLRM model..
49
+
50
+ See theory in the DLRM paper: https://arxiv.org/pdf/1906.00091.pdf,
51
+ section 2.1.3. Sparse activations and dense activations are combined.
52
+ Dot interaction is applied to a batch of input Tensors [e1,...,e_k] of the
53
+ same dimension and the output is a batch of Tensors with all distinct pairwise
54
+ dot products of the form dot(e_i, e_j) for i <= j if self self_interaction is
55
+ True, otherwise dot(e_i, e_j) i < j.
56
+
57
+ Attributes:
58
+ self_interaction: Boolean indicating if features should self-interact.
59
+ If it is True, then the diagonal entries of the interaction metric are
60
+ also taken.
61
+ skip_gather: An optimization flag. If it's set then the upper triangle part
62
+ of the dot interaction matrix dot(e_i, e_j) is set to 0. The resulting
63
+ activations will be of dimension [num_features * num_features] from which
64
+ half will be zeros. Otherwise activations will be only lower triangle part
65
+ of the interaction matrix. The later saves space but is much slower.
66
+ name: String name of the layer.
67
+ """
68
+
69
+ def __init__(self, params, name=None, reuse=None, **kwargs):
70
+ super(DotInteraction, self).__init__(name=name, **kwargs)
71
+ self._self_interaction = params.get_or_default('self_interaction', False)
72
+ self._skip_gather = params.get_or_default('skip_gather', False)
73
+
74
+ def call(self, inputs, **kwargs):
75
+ """Performs the interaction operation on the tensors in the list.
76
+
77
+ The tensors represent as transformed dense features and embedded categorical
78
+ features.
79
+ Pre-condition: The tensors should all have the same shape.
80
+
81
+ Args:
82
+ inputs: List of features with shapes [batch_size, feature_dim].
83
+
84
+ Returns:
85
+ activations: Tensor representing interacted features. It has a dimension
86
+ `num_features * num_features` if skip_gather is True, otherside
87
+ `num_features * (num_features + 1) / 2` if self_interaction is True and
88
+ `num_features * (num_features - 1) / 2` if self_interaction is False.
89
+ """
90
+ if isinstance(inputs, (list, tuple)):
91
+ # concat_features shape: batch_size, num_features, feature_dim
92
+ try:
93
+ concat_features = tf.stack(inputs, axis=1)
94
+ except (ValueError, tf.errors.InvalidArgumentError) as e:
95
+ raise ValueError('Input tensors` dimensions must be equal, original'
96
+ 'error message: {}'.format(e))
97
+ else:
98
+ assert inputs.shape.ndims == 3, 'input of dot func must be a 3D tensor or a list of 2D tensors'
99
+ concat_features = inputs
100
+
101
+ batch_size = tf.shape(concat_features)[0]
102
+
103
+ # Interact features, select lower-triangular portion, and re-shape.
104
+ xactions = tf.matmul(concat_features, concat_features, transpose_b=True)
105
+ num_features = xactions.shape[-1]
106
+ ones = tf.ones_like(xactions)
107
+ if self._self_interaction:
108
+ # Selecting lower-triangular portion including the diagonal.
109
+ lower_tri_mask = tf.linalg.band_part(ones, -1, 0)
110
+ upper_tri_mask = ones - lower_tri_mask
111
+ out_dim = num_features * (num_features + 1) // 2
112
+ else:
113
+ # Selecting lower-triangular portion not included the diagonal.
114
+ upper_tri_mask = tf.linalg.band_part(ones, 0, -1)
115
+ lower_tri_mask = ones - upper_tri_mask
116
+ out_dim = num_features * (num_features - 1) // 2
117
+
118
+ if self._skip_gather:
119
+ # Setting upper triangle part of the interaction matrix to zeros.
120
+ activations = tf.where(
121
+ condition=tf.cast(upper_tri_mask, tf.bool),
122
+ x=tf.zeros_like(xactions),
123
+ y=xactions)
124
+ out_dim = num_features * num_features
125
+ else:
126
+ activations = tf.boolean_mask(xactions, lower_tri_mask)
127
+ activations = tf.reshape(activations, (batch_size, out_dim))
128
+ return activations
129
+
130
+
131
+ class Cross(tf.keras.layers.Layer):
132
+ """Cross Layer in Deep & Cross Network to learn explicit feature interactions.
133
+
134
+ A layer that creates explicit and bounded-degree feature interactions
135
+ efficiently. The `call` method accepts `inputs` as a tuple of size 2
136
+ tensors. The first input `x0` is the base layer that contains the original
137
+ features (usually the embedding layer); the second input `xi` is the output
138
+ of the previous `Cross` layer in the stack, i.e., the i-th `Cross`
139
+ layer. For the first `Cross` layer in the stack, x0 = xi.
140
+
141
+ The output is x_{i+1} = x0 .* (W * xi + bias + diag_scale * xi) + xi,
142
+ where .* designates elementwise multiplication, W could be a full-rank
143
+ matrix, or a low-rank matrix U*V to reduce the computational cost, and
144
+ diag_scale increases the diagonal of W to improve training stability (
145
+ especially for the low-rank case).
146
+
147
+ References:
148
+ 1. [R. Wang et al.](https://arxiv.org/pdf/2008.13535.pdf)
149
+ See Eq. (1) for full-rank and Eq. (2) for low-rank version.
150
+ 2. [R. Wang et al.](https://arxiv.org/pdf/1708.05123.pdf)
151
+
152
+ Example:
153
+
154
+ ```python
155
+ # after embedding layer in a functional model:
156
+ input = tf.keras.Input(shape=(None,), name='index', dtype=tf.int64)
157
+ x0 = tf.keras.layers.Embedding(input_dim=32, output_dim=6)
158
+ x1 = Cross()(x0, x0)
159
+ x2 = Cross()(x0, x1)
160
+ logits = tf.keras.layers.Dense(units=10)(x2)
161
+ model = tf.keras.Model(input, logits)
162
+ ```
163
+
164
+ Args:
165
+ projection_dim: project dimension to reduce the computational cost.
166
+ Default is `None` such that a full (`input_dim` by `input_dim`) matrix
167
+ W is used. If enabled, a low-rank matrix W = U*V will be used, where U
168
+ is of size `input_dim` by `projection_dim` and V is of size
169
+ `projection_dim` by `input_dim`. `projection_dim` need to be smaller
170
+ than `input_dim`/2 to improve the model efficiency. In practice, we've
171
+ observed that `projection_dim` = d/4 consistently preserved the
172
+ accuracy of a full-rank version.
173
+ diag_scale: a non-negative float used to increase the diagonal of the
174
+ kernel W by `diag_scale`, that is, W + diag_scale * I, where I is an
175
+ identity matrix.
176
+ use_bias: whether to add a bias term for this layer. If set to False,
177
+ no bias term will be used.
178
+ preactivation: Activation applied to output matrix of the layer, before
179
+ multiplication with the input. Can be used to control the scale of the
180
+ layer's outputs and improve stability.
181
+ kernel_initializer: Initializer to use on the kernel matrix.
182
+ bias_initializer: Initializer to use on the bias vector.
183
+ kernel_regularizer: Regularizer to use on the kernel matrix.
184
+ bias_regularizer: Regularizer to use on bias vector.
185
+
186
+ Input shape: A tuple of 2 (batch_size, `input_dim`) dimensional inputs.
187
+ Output shape: A single (batch_size, `input_dim`) dimensional output.
188
+ """
189
+
190
+ def __init__(self, params, name='cross', reuse=None, **kwargs):
191
+ super(Cross, self).__init__(name=name, **kwargs)
192
+ self._projection_dim = params.get_or_default('projection_dim', None)
193
+ self._diag_scale = params.get_or_default('diag_scale', 0.0)
194
+ self._use_bias = params.get_or_default('use_bias', True)
195
+ preactivation = params.get_or_default('preactivation', None)
196
+ preact = get_activation(preactivation)
197
+ self._preactivation = tf.keras.activations.get(preact)
198
+ kernel_initializer = params.get_or_default('kernel_initializer',
199
+ 'truncated_normal')
200
+ self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
201
+ bias_initializer = params.get_or_default('bias_initializer', 'zeros')
202
+ self._bias_initializer = tf.keras.initializers.get(bias_initializer)
203
+ kernel_regularizer = params.get_or_default('kernel_regularizer', None)
204
+ self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
205
+ bias_regularizer = params.get_or_default('bias_regularizer', None)
206
+ self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
207
+ self._input_dim = None
208
+ self._supports_masking = True
209
+
210
+ if self._diag_scale < 0: # pytype: disable=unsupported-operands
211
+ raise ValueError(
212
+ '`diag_scale` should be non-negative. Got `diag_scale` = {}'.format(
213
+ self._diag_scale))
214
+
215
+ def build(self, input_shape):
216
+ last_dim = input_shape[0][-1]
217
+
218
+ if self._projection_dim is None:
219
+ self._dense = tf.keras.layers.Dense(
220
+ last_dim,
221
+ kernel_initializer=_clone_initializer(self._kernel_initializer),
222
+ bias_initializer=self._bias_initializer,
223
+ kernel_regularizer=self._kernel_regularizer,
224
+ bias_regularizer=self._bias_regularizer,
225
+ use_bias=self._use_bias,
226
+ dtype=self.dtype,
227
+ activation=self._preactivation,
228
+ )
229
+ else:
230
+ self._dense_u = tf.keras.layers.Dense(
231
+ self._projection_dim,
232
+ kernel_initializer=_clone_initializer(self._kernel_initializer),
233
+ kernel_regularizer=self._kernel_regularizer,
234
+ use_bias=False,
235
+ dtype=self.dtype,
236
+ )
237
+ self._dense_v = tf.keras.layers.Dense(
238
+ last_dim,
239
+ kernel_initializer=_clone_initializer(self._kernel_initializer),
240
+ bias_initializer=self._bias_initializer,
241
+ kernel_regularizer=self._kernel_regularizer,
242
+ bias_regularizer=self._bias_regularizer,
243
+ use_bias=self._use_bias,
244
+ dtype=self.dtype,
245
+ activation=self._preactivation,
246
+ )
247
+ super(Cross, self).build(input_shape) # Be sure to call this somewhere!
248
+
249
+ def call(self, inputs, **kwargs):
250
+ """Computes the feature cross.
251
+
252
+ Args:
253
+ inputs: The input tensor(x0, x)
254
+ - x0: The input tensor
255
+ - x: Optional second input tensor. If provided, the layer will compute
256
+ crosses between x0 and x; if not provided, the layer will compute
257
+ crosses between x0 and itself.
258
+
259
+ Returns:
260
+ Tensor of crosses.
261
+ """
262
+ if isinstance(inputs, (list, tuple)):
263
+ x0, x = inputs
264
+ else:
265
+ x0, x = inputs, inputs
266
+
267
+ if not self.built:
268
+ self.build(x0.shape)
269
+
270
+ if x0.shape[-1] != x.shape[-1]:
271
+ raise ValueError(
272
+ '`x0` and `x` dimension mismatch! Got `x0` dimension {}, and x '
273
+ 'dimension {}. This case is not supported yet.'.format(
274
+ x0.shape[-1], x.shape[-1]))
275
+
276
+ if self._projection_dim is None:
277
+ prod_output = self._dense(x)
278
+ else:
279
+ prod_output = self._dense_v(self._dense_u(x))
280
+
281
+ # prod_output = tf.cast(prod_output, self.compute_dtype)
282
+
283
+ if self._diag_scale:
284
+ prod_output = prod_output + self._diag_scale * x
285
+
286
+ return x0 * prod_output + x
287
+
288
+ def get_config(self):
289
+ config = {
290
+ 'projection_dim':
291
+ self._projection_dim,
292
+ 'diag_scale':
293
+ self._diag_scale,
294
+ 'use_bias':
295
+ self._use_bias,
296
+ 'preactivation':
297
+ tf.keras.activations.serialize(self._preactivation),
298
+ 'kernel_initializer':
299
+ tf.keras.initializers.serialize(self._kernel_initializer),
300
+ 'bias_initializer':
301
+ tf.keras.initializers.serialize(self._bias_initializer),
302
+ 'kernel_regularizer':
303
+ tf.keras.regularizers.serialize(self._kernel_regularizer),
304
+ 'bias_regularizer':
305
+ tf.keras.regularizers.serialize(self._bias_regularizer),
306
+ }
307
+ base_config = super(Cross, self).get_config()
308
+ return dict(list(base_config.items()) + list(config.items()))
309
+
310
+
311
+ class CIN(tf.keras.layers.Layer):
312
+ """Compressed Interaction Network(CIN) module in xDeepFM model.
313
+
314
+ CIN layer is aimed at achieving high-order feature interactions at
315
+ vector-wise level rather than bit-wise level.
316
+
317
+
318
+ Reference:
319
+ [xDeepFM](https://arxiv.org/pdf/1803.05170)
320
+ xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems
321
+ """
322
+
323
+ def __init__(self, params, name='cin', reuse=None, **kwargs):
324
+ super(CIN, self).__init__(name=name, **kwargs)
325
+ self._name = name
326
+ self._hidden_feature_sizes = list(
327
+ params.get_or_default('hidden_feature_sizes', []))
328
+
329
+ assert isinstance(self._hidden_feature_sizes, list) and len(
330
+ self._hidden_feature_sizes
331
+ ) > 0, 'parameter hidden_feature_sizes must be a list of int with length greater than 0'
332
+
333
+ kernel_regularizer = params.get_or_default('kernel_regularizer', None)
334
+ self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
335
+ bias_regularizer = params.get_or_default('bias_regularizer', None)
336
+ self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
337
+
338
+ def build(self, input_shape):
339
+ if len(input_shape) != 3:
340
+ raise ValueError(
341
+ 'Unexpected inputs dimensions %d, expect to be 3 dimensions' %
342
+ (len(input_shape)))
343
+
344
+ hidden_feature_sizes = [input_shape[1]
345
+ ] + [h for h in self._hidden_feature_sizes]
346
+ tfv1 = tf.compat.v1 if tf.__version__ >= '2.0' else tf
347
+ with tfv1.variable_scope(self._name):
348
+ self.kernel_list = [
349
+ tfv1.get_variable(
350
+ name='cin_kernel_%d' % i,
351
+ shape=[
352
+ hidden_feature_sizes[i + 1], hidden_feature_sizes[i],
353
+ hidden_feature_sizes[0]
354
+ ],
355
+ initializer=tf.initializers.he_normal(),
356
+ regularizer=self._kernel_regularizer,
357
+ trainable=True) for i in range(len(self._hidden_feature_sizes))
358
+ ]
359
+ self.bias_list = [
360
+ tfv1.get_variable(
361
+ name='cin_bias_%d' % i,
362
+ shape=[hidden_feature_sizes[i + 1]],
363
+ initializer=tf.keras.initializers.Zeros,
364
+ regularizer=self._bias_regularizer,
365
+ trainable=True) for i in range(len(self._hidden_feature_sizes))
366
+ ]
367
+
368
+ super(CIN, self).build(input_shape)
369
+
370
+ def call(self, input, **kwargs):
371
+ """Computes the compressed feature maps.
372
+
373
+ Args:
374
+ input: The 3D input tensor with shape (b, h0, d), where b is batch_size,
375
+ h0 is the number of features, d is the feature embedding dimension.
376
+
377
+ Returns:
378
+ 2D tensor of compressed feature map with shape (b, featuremap_num),
379
+ where b is the batch_size, featuremap_num is sum of the hidden layer sizes
380
+ """
381
+ x_0 = input
382
+ x_i = input
383
+ x_0_expanded = tf.expand_dims(x_0, 1)
384
+ pooled_feature_map_list = []
385
+ for i in range(len(self._hidden_feature_sizes)):
386
+ hk = self._hidden_feature_sizes[i]
387
+
388
+ x_i_expanded = tf.expand_dims(x_i, 2)
389
+ intermediate_tensor = tf.multiply(x_0_expanded, x_i_expanded)
390
+
391
+ intermediate_tensor_expanded = tf.expand_dims(intermediate_tensor, 1)
392
+ intermediate_tensor_expanded = tf.tile(intermediate_tensor_expanded,
393
+ [1, hk, 1, 1, 1])
394
+
395
+ feature_map_elementwise = tf.multiply(
396
+ intermediate_tensor_expanded,
397
+ tf.expand_dims(tf.expand_dims(self.kernel_list[i], -1), 0))
398
+ feature_map = tf.reduce_sum(
399
+ tf.reduce_sum(feature_map_elementwise, axis=3), axis=2)
400
+
401
+ feature_map = tf.add(
402
+ feature_map,
403
+ tf.expand_dims(tf.expand_dims(self.bias_list[i], axis=-1), axis=0))
404
+ feature_map = tf.nn.relu(feature_map)
405
+
406
+ x_i = feature_map
407
+ pooled_feature_map_list.append(tf.reduce_sum(feature_map, axis=-1))
408
+ return tf.concat(
409
+ pooled_feature_map_list, axis=-1) # shape = (b, h1 + ... + hk)
410
+
411
+ def get_config(self):
412
+ pass
413
+
414
+
415
+ def _clone_initializer(initializer):
416
+ return initializer.__class__.from_config(initializer.get_config())