easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,125 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+ from tensorflow.python.keras.layers import Dense
7
+ from tensorflow.python.keras.layers import Layer
8
+
9
+ from easy_rec.python.layers.keras.attention import Attention
10
+ from easy_rec.python.layers.keras.blocks import MLP
11
+ from easy_rec.python.layers.utils import Parameter
12
+ from easy_rec.python.protos import seq_encoder_pb2
13
+
14
+ if tf.__version__ >= '2.0':
15
+ tf = tf.compat.v1
16
+
17
+
18
+ class MMoE(Layer):
19
+ """Multi-gate Mixture-of-Experts model."""
20
+
21
+ def __init__(self, params, name='MMoE', reuse=None, **kwargs):
22
+ super(MMoE, self).__init__(name=name, **kwargs)
23
+ params.check_required(['num_expert', 'num_task'])
24
+ self._reuse = reuse
25
+ self._num_expert = params.num_expert
26
+ self._num_task = params.num_task
27
+ if params.has_field('expert_mlp'):
28
+ expert_params = Parameter.make_from_pb(params.expert_mlp)
29
+ expert_params.l2_regularizer = params.l2_regularizer
30
+ self._has_experts = True
31
+ self._experts = [
32
+ MLP(expert_params, 'expert_%d' % i, reuse=reuse)
33
+ for i in range(self._num_expert)
34
+ ]
35
+ else:
36
+ self._has_experts = False
37
+
38
+ self._gates = []
39
+ for task_id in range(self._num_task):
40
+ dense = Dense(
41
+ self._num_expert,
42
+ activation='softmax',
43
+ name='gate_%d' % task_id,
44
+ kernel_regularizer=params.l2_regularizer)
45
+ self._gates.append(dense)
46
+
47
+ def call(self, inputs, training=None, **kwargs):
48
+ if self._num_expert == 0:
49
+ logging.warning('num_expert of MMoE layer `%s` is 0' % self.name)
50
+ return inputs
51
+ if self._has_experts:
52
+ expert_fea_list = [
53
+ expert(inputs, training=training) for expert in self._experts
54
+ ]
55
+ else:
56
+ expert_fea_list = inputs
57
+ experts_fea = tf.stack(expert_fea_list, axis=1)
58
+ # 不使用内置MLP作为expert时,gate的input使用最后一个额外的输入
59
+ gate_input = inputs if self._has_experts else inputs[self._num_expert]
60
+ task_input_list = []
61
+ for task_id in range(self._num_task):
62
+ gate = self._gates[task_id](gate_input)
63
+ gate = tf.expand_dims(gate, -1)
64
+ task_input = tf.multiply(experts_fea, gate)
65
+ task_input = tf.reduce_sum(task_input, axis=1)
66
+ task_input_list.append(task_input)
67
+ return task_input_list
68
+
69
+
70
+ class AITMTower(Layer):
71
+ """Adaptive Information Transfer Multi-task (AITM) Tower."""
72
+
73
+ def __init__(self, params, name='AITMTower', reuse=None, **kwargs):
74
+ super(AITMTower, self).__init__(name=name, **kwargs)
75
+ self.project_dim = params.get_or_default('project_dim', None)
76
+ self.stop_gradient = params.get_or_default('stop_gradient', True)
77
+ self.transfer = None
78
+ if params.has_field('transfer_mlp'):
79
+ mlp_cfg = Parameter.make_from_pb(params.transfer_mlp)
80
+ mlp_cfg.l2_regularizer = params.l2_regularizer
81
+ self.transfer = MLP(mlp_cfg, name='transfer')
82
+ self.queries = []
83
+ self.keys = []
84
+ self.values = []
85
+ self.attention = None
86
+
87
+ def build(self, input_shape):
88
+ if not isinstance(input_shape, (tuple, list)):
89
+ super(AITMTower, self).build(input_shape)
90
+ return
91
+ dim = self.project_dim if self.project_dim else int(input_shape[0][-1])
92
+ for i in range(len(input_shape)):
93
+ self.queries.append(Dense(dim, name='query_%d' % i))
94
+ self.keys.append(Dense(dim, name='key_%d' % i))
95
+ self.values.append(Dense(dim, name='value_%d' % i))
96
+ attn_cfg = seq_encoder_pb2.Attention()
97
+ attn_cfg.scale_by_dim = True
98
+ attn_params = Parameter.make_from_pb(attn_cfg)
99
+ self.attention = Attention(attn_params)
100
+ super(AITMTower, self).build(input_shape)
101
+
102
+ def call(self, inputs, training=None, **kwargs):
103
+ if not isinstance(inputs, (tuple, list)):
104
+ return inputs
105
+
106
+ queries = []
107
+ keys = []
108
+ values = []
109
+ for i, tower in enumerate(inputs):
110
+ if i == 0: # current tower
111
+ queries.append(self.queries[i](tower))
112
+ keys.append(self.keys[i](tower))
113
+ values.append(self.values[i](tower))
114
+ else:
115
+ dep = tf.stop_gradient(tower) if self.stop_gradient else tower
116
+ if self.transfer is not None:
117
+ dep = self.transfer(dep, training=training)
118
+ queries.append(self.queries[i](dep))
119
+ keys.append(self.keys[i](dep))
120
+ values.append(self.values[i](dep))
121
+ query = tf.stack(queries, axis=1)
122
+ key = tf.stack(keys, axis=1)
123
+ value = tf.stack(values, axis=1)
124
+ attn = self.attention([query, value, key])
125
+ return attn[:, 0, :]
@@ -0,0 +1,376 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+ import math
5
+ import os
6
+
7
+ import tensorflow as tf
8
+ from tensorflow.python.framework import ops
9
+ from tensorflow.python.keras.layers import Layer
10
+
11
+ from easy_rec.python.compat.array_ops import repeat
12
+ from easy_rec.python.utils.activation import get_activation
13
+ from easy_rec.python.utils.tf_utils import get_ps_num_from_tf_config
14
+
15
+ curr_dir, _ = os.path.split(__file__)
16
+ parent_dir = os.path.dirname(curr_dir)
17
+ ops_idr = os.path.dirname(parent_dir)
18
+ ops_dir = os.path.join(ops_idr, 'ops')
19
+ if 'PAI' in tf.__version__:
20
+ ops_dir = os.path.join(ops_dir, '1.12_pai')
21
+ elif tf.__version__.startswith('1.12'):
22
+ ops_dir = os.path.join(ops_dir, '1.12')
23
+ elif tf.__version__.startswith('1.15'):
24
+ if 'IS_ON_PAI' in os.environ:
25
+ ops_dir = os.path.join(ops_dir, 'DeepRec')
26
+ else:
27
+ ops_dir = os.path.join(ops_dir, '1.15')
28
+ elif tf.__version__.startswith('2.12'):
29
+ ops_dir = os.path.join(ops_dir, '2.12')
30
+
31
+ logging.info('ops_dir is %s' % ops_dir)
32
+ custom_op_path = os.path.join(ops_dir, 'libcustom_ops.so')
33
+ try:
34
+ custom_ops = tf.load_op_library(custom_op_path)
35
+ logging.info('load custom op from %s succeed' % custom_op_path)
36
+ except Exception as ex:
37
+ logging.warning('load custom op from %s failed: %s' %
38
+ (custom_op_path, str(ex)))
39
+ custom_ops = None
40
+
41
+
42
+ class NLinear(Layer):
43
+ """N linear layers for N token (feature) embeddings.
44
+
45
+ To understand this module, let's revise `tf.layers.dense`. When `tf.layers.dense` is
46
+ applied to three-dimensional inputs of the shape
47
+ ``(batch_size, n_tokens, d_embedding)``, then the same linear transformation is
48
+ applied to each of ``n_tokens`` token (feature) embeddings.
49
+
50
+ By contrast, `NLinear` allocates one linear layer per token (``n_tokens`` layers in total).
51
+ One such layer can be represented as ``tf.layers.dense(d_in, d_out)``.
52
+ So, the i-th linear transformation is applied to the i-th token embedding, as
53
+ illustrated in the following pseudocode::
54
+
55
+ layers = [tf.layers.dense(d_in, d_out) for _ in range(n_tokens)]
56
+ x = tf.random.normal(batch_size, n_tokens, d_in)
57
+ result = tf.stack([layers[i](x[:, i]) for i in range(n_tokens)], 1)
58
+
59
+ Examples:
60
+ .. testcode::
61
+
62
+ batch_size = 2
63
+ n_features = 3
64
+ d_embedding_in = 4
65
+ d_embedding_out = 5
66
+ x = tf.random.normal(batch_size, n_features, d_embedding_in)
67
+ m = NLinear(n_features, d_embedding_in, d_embedding_out)
68
+ assert m(x).shape == (batch_size, n_features, d_embedding_out)
69
+ """
70
+
71
+ def __init__(self,
72
+ n_tokens,
73
+ d_in,
74
+ d_out,
75
+ bias=True,
76
+ name='nd_linear',
77
+ **kwargs):
78
+ """Init with input shapes.
79
+
80
+ Args:
81
+ n_tokens: the number of tokens (features)
82
+ d_in: the input dimension
83
+ d_out: the output dimension
84
+ bias: indicates if the underlying linear layers have biases
85
+ name: layer name
86
+ """
87
+ super(NLinear, self).__init__(name=name, **kwargs)
88
+ self.weight = self.add_weight(
89
+ 'weights', [1, n_tokens, d_in, d_out], dtype=tf.float32)
90
+ if bias:
91
+ initializer = tf.constant_initializer(0.0)
92
+ self.bias = self.add_weight(
93
+ 'bias', [1, n_tokens, d_out],
94
+ dtype=tf.float32,
95
+ initializer=initializer)
96
+ else:
97
+ self.bias = None
98
+
99
+ def call(self, x, **kwargs):
100
+ if x.shape.ndims != 3:
101
+ raise ValueError(
102
+ 'The input must have three dimensions (batch_size, n_tokens, d_embedding)'
103
+ )
104
+ if x.shape[2] != self.weight.shape[2]:
105
+ raise ValueError('invalid input embedding dimension %d, expect %d' %
106
+ (int(x.shape[2]), int(self.weight.shape[2])))
107
+
108
+ x = x[..., None] * self.weight # [B, N, D, D_out]
109
+ x = tf.reduce_sum(x, axis=-2) # [B, N, D_out]
110
+ if self.bias is not None:
111
+ x = x + self.bias
112
+ return x
113
+
114
+
115
+ class PeriodicEmbedding(Layer):
116
+ """Periodic embeddings for numerical features described in [1].
117
+
118
+ References:
119
+ * [1] Yury Gorishniy, Ivan Rubachev, Artem Babenko,
120
+ "On Embeddings for Numerical Features in Tabular Deep Learning", 2022
121
+ https://arxiv.org/pdf/2203.05556.pdf
122
+
123
+ Attributes:
124
+ embedding_dim: the embedding size, must be an even positive integer.
125
+ sigma: the scale of the weight initialization.
126
+ **This is a super important parameter which significantly affects performance**.
127
+ Its optimal value can be dramatically different for different datasets, so
128
+ no "default value" can exist for this parameter, and it must be tuned for
129
+ each dataset. In the original paper, during hyperparameter tuning, this
130
+ parameter was sampled from the distribution ``LogUniform[1e-2, 1e2]``.
131
+ A similar grid would be ``[1e-2, 1e-1, 1e0, 1e1, 1e2]``.
132
+ If possible, add more intermediate values to this grid.
133
+ output_3d_tensor: whether to output a 3d tensor
134
+ output_tensor_list: whether to output the list of embedding
135
+ """
136
+
137
+ def __init__(self, params, name='periodic_embedding', reuse=None, **kwargs):
138
+ super(PeriodicEmbedding, self).__init__(name=name, **kwargs)
139
+ self.reuse = reuse
140
+ params.check_required(['embedding_dim', 'sigma'])
141
+ self.embedding_dim = int(params.embedding_dim)
142
+ if self.embedding_dim % 2:
143
+ raise ValueError('embedding_dim must be even')
144
+ sigma = params.sigma
145
+ self.initializer = tf.random_normal_initializer(stddev=sigma)
146
+ self.add_linear_layer = params.get_or_default('add_linear_layer', True)
147
+ self.linear_activation = params.get_or_default('linear_activation', 'relu')
148
+ self.output_tensor_list = params.get_or_default('output_tensor_list', False)
149
+ self.output_3d_tensor = params.get_or_default('output_3d_tensor', False)
150
+
151
+ def build(self, input_shape):
152
+ if input_shape.ndims != 2:
153
+ raise ValueError('inputs of AutoDisEmbedding must have 2 dimensions.')
154
+ self.num_features = int(input_shape[-1])
155
+ num_ps = get_ps_num_from_tf_config()
156
+ partitioner = None
157
+ if num_ps > 0:
158
+ partitioner = tf.fixed_size_partitioner(num_shards=num_ps)
159
+ emb_dim = self.embedding_dim // 2
160
+ self.coef = self.add_weight(
161
+ 'coefficients',
162
+ shape=[1, self.num_features, emb_dim],
163
+ partitioner=partitioner,
164
+ initializer=self.initializer)
165
+ if self.add_linear_layer:
166
+ self.linear = NLinear(
167
+ self.num_features,
168
+ self.embedding_dim,
169
+ self.embedding_dim,
170
+ name='nd_linear')
171
+ super(PeriodicEmbedding, self).build(input_shape)
172
+
173
+ def call(self, inputs, **kwargs):
174
+ features = inputs[..., None] # [B, N, 1]
175
+ v = 2 * math.pi * self.coef * features # [B, N, E]
176
+ emb = tf.concat([tf.sin(v), tf.cos(v)], axis=-1) # [B, N, 2E]
177
+
178
+ dim = self.embedding_dim
179
+ if self.add_linear_layer:
180
+ emb = self.linear(emb)
181
+ act = get_activation(self.linear_activation)
182
+ if callable(act):
183
+ emb = act(emb)
184
+ output = tf.reshape(emb, [-1, self.num_features * dim])
185
+
186
+ if self.output_tensor_list:
187
+ return output, tf.unstack(emb, axis=1)
188
+ if self.output_3d_tensor:
189
+ return output, emb
190
+ return output
191
+
192
+
193
+ class AutoDisEmbedding(Layer):
194
+ """An Embedding Learning Framework for Numerical Features in CTR Prediction.
195
+
196
+ Refer: https://arxiv.org/pdf/2012.08986v2.pdf
197
+ """
198
+
199
+ def __init__(self, params, name='auto_dis_embedding', reuse=None, **kwargs):
200
+ super(AutoDisEmbedding, self).__init__(name=name, **kwargs)
201
+ self.reuse = reuse
202
+ params.check_required(['embedding_dim', 'num_bins', 'temperature'])
203
+ self.emb_dim = int(params.embedding_dim)
204
+ self.num_bins = int(params.num_bins)
205
+ self.temperature = params.temperature
206
+ self.keep_prob = params.get_or_default('keep_prob', 0.8)
207
+ self.output_tensor_list = params.get_or_default('output_tensor_list', False)
208
+ self.output_3d_tensor = params.get_or_default('output_3d_tensor', False)
209
+
210
+ def build(self, input_shape):
211
+ if input_shape.ndims != 2:
212
+ raise ValueError('inputs of AutoDisEmbedding must have 2 dimensions.')
213
+ self.num_features = int(input_shape[-1])
214
+ num_ps = get_ps_num_from_tf_config()
215
+ partitioner = None
216
+ if num_ps > 0:
217
+ partitioner = tf.fixed_size_partitioner(num_shards=num_ps)
218
+ self.meta_emb = self.add_weight(
219
+ 'meta_embedding',
220
+ shape=[self.num_features, self.num_bins, self.emb_dim],
221
+ partitioner=partitioner)
222
+ self.proj_w = self.add_weight(
223
+ 'project_w',
224
+ shape=[1, self.num_features, self.num_bins],
225
+ partitioner=partitioner)
226
+ self.proj_mat = self.add_weight(
227
+ 'project_mat',
228
+ shape=[self.num_features, self.num_bins, self.num_bins],
229
+ partitioner=partitioner)
230
+ super(AutoDisEmbedding, self).build(input_shape)
231
+
232
+ def call(self, inputs, **kwargs):
233
+ x = tf.expand_dims(inputs, axis=-1) # [B, N, 1]
234
+ hidden = tf.nn.leaky_relu(self.proj_w * x) # [B, N, num_bin]
235
+ # 低版本的tf(1.12) matmul 不支持广播,所以改成 einsum
236
+ # y = tf.matmul(mat, hidden[..., None]) # [B, N, num_bin, 1]
237
+ # y = tf.squeeze(y, axis=3) # [B, N, num_bin]
238
+ y = tf.einsum('nik,bnk->bni', self.proj_mat, hidden) # [B, N, num_bin]
239
+
240
+ # keep_prob(float): if dropout_flag is True, keep_prob rate to keep connect
241
+ alpha = self.keep_prob
242
+ x_bar = y + alpha * hidden # [B, N, num_bin]
243
+ x_hat = tf.nn.softmax(x_bar / self.temperature) # [B, N, num_bin]
244
+
245
+ # emb = tf.matmul(x_hat[:, :, None, :], meta_emb) # [B, N, 1, D]
246
+ # emb = tf.squeeze(emb, axis=2) # [B, N, D]
247
+ emb = tf.einsum('bnk,nkd->bnd', x_hat, self.meta_emb)
248
+ output = tf.reshape(emb, [-1, self.emb_dim * self.num_features]) # [B, N*D]
249
+
250
+ if self.output_tensor_list:
251
+ return output, tf.unstack(emb, axis=1)
252
+ if self.output_3d_tensor:
253
+ return output, emb
254
+ return output
255
+
256
+
257
+ class NaryDisEmbedding(Layer):
258
+ """Numerical Feature Representation with Hybrid 𝑁 -ary Encoding, CIKM 2022..
259
+
260
+ Refer: https://dl.acm.org/doi/pdf/10.1145/3511808.3557090
261
+ """
262
+
263
+ def __init__(self, params, name='nary_dis_embedding', reuse=None, **kwargs):
264
+ super(NaryDisEmbedding, self).__init__(name=name, **kwargs)
265
+ self.reuse = reuse
266
+ self.nary_carry = custom_ops.nary_carry
267
+ params.check_required(['embedding_dim', 'carries'])
268
+ self.emb_dim = int(params.embedding_dim)
269
+ self.carries = params.get_or_default('carries', [2, 9])
270
+ self.num_replicas = params.get_or_default('num_replicas', 1)
271
+ assert self.num_replicas >= 1, 'num replicas must be >= 1'
272
+ self.lengths = list(map(self.max_length, self.carries))
273
+ self.vocab_size = int(sum(self.lengths))
274
+ self.multiplier = params.get_or_default('multiplier', 1.0)
275
+ self.intra_ary_pooling = params.get_or_default('intra_ary_pooling', 'sum')
276
+ self.output_3d_tensor = params.get_or_default('output_3d_tensor', False)
277
+ self.output_tensor_list = params.get_or_default('output_tensor_list', False)
278
+ logging.info(
279
+ '{} carries: {}, lengths: {}, vocab_size: {}, intra_ary: {}, replicas: {}, multiplier: {}'
280
+ .format(self.name, ','.join(map(str, self.carries)),
281
+ ','.join(map(str, self.lengths)), self.vocab_size,
282
+ self.intra_ary_pooling, self.num_replicas, self.multiplier))
283
+
284
+ @staticmethod
285
+ def max_length(carry):
286
+ bits = math.log(4294967295, carry)
287
+ return (math.floor(bits) + 1) * carry
288
+
289
+ def build(self, input_shape):
290
+ assert isinstance(input_shape,
291
+ tf.TensorShape), 'NaryDisEmbedding only takes 1 input'
292
+ self.num_features = int(input_shape[-1])
293
+ logging.info('%s has %d input features', self.name, self.num_features)
294
+ vocab_size = self.num_features * self.vocab_size
295
+ emb_dim = self.emb_dim * self.num_replicas
296
+ num_ps = get_ps_num_from_tf_config()
297
+ partitioner = None
298
+ if num_ps > 0:
299
+ partitioner = tf.fixed_size_partitioner(num_shards=num_ps)
300
+ self.embedding_table = self.add_weight(
301
+ 'embed_table', shape=[vocab_size, emb_dim], partitioner=partitioner)
302
+ super(NaryDisEmbedding, self).build(input_shape)
303
+
304
+ def call(self, inputs, **kwargs):
305
+ if inputs.shape.ndims != 2:
306
+ raise ValueError('inputs of NaryDisEmbedding must have 2 dimensions.')
307
+ if self.multiplier != 1.0:
308
+ inputs *= self.multiplier
309
+ inputs = tf.to_int32(inputs)
310
+ offset, emb_indices, emb_splits = 0, [], []
311
+ with ops.device('/CPU:0'):
312
+ for carry, length in zip(self.carries, self.lengths):
313
+ values, splits = self.nary_carry(inputs, carry=carry, offset=offset)
314
+ offset += length
315
+ emb_indices.append(values)
316
+ emb_splits.append(splits)
317
+ indices = tf.concat(emb_indices, axis=0)
318
+ splits = tf.concat(emb_splits, axis=0)
319
+ # embedding shape: [B*N*C, D]
320
+ embedding = tf.nn.embedding_lookup(self.embedding_table, indices)
321
+
322
+ total_length = tf.size(splits)
323
+ if self.intra_ary_pooling == 'sum':
324
+ if tf.__version__ >= '2.0':
325
+ segment_ids = tf.repeat(tf.range(total_length), repeats=splits)
326
+ else:
327
+ segment_ids = repeat(tf.range(total_length), repeats=splits)
328
+ embedding = tf.math.segment_sum(embedding, segment_ids)
329
+ elif self.intra_ary_pooling == 'mean':
330
+ if tf.__version__ >= '2.0':
331
+ segment_ids = tf.repeat(tf.range(total_length), repeats=splits)
332
+ else:
333
+ segment_ids = repeat(tf.range(total_length), repeats=splits)
334
+ embedding = tf.math.segment_mean(embedding, segment_ids)
335
+ else:
336
+ raise ValueError('Unsupported intra ary pooling method %s' %
337
+ self.intra_ary_pooling)
338
+ # B: batch size
339
+ # N: num features
340
+ # C: num carries
341
+ # D: embedding dimension
342
+ # R: num replicas
343
+ # shape of embedding: [B*N*C, R*D]
344
+ N = self.num_features
345
+ C = len(self.carries)
346
+ D = self.emb_dim
347
+ if self.num_replicas == 1:
348
+ embedding = tf.reshape(embedding, [C, -1, D]) # [C, B*N, D]
349
+ embedding = tf.transpose(embedding, perm=[1, 0, 2]) # [B*N, C, D]
350
+ embedding = tf.reshape(embedding, [-1, C * D]) # [B*N, C*D]
351
+ output = tf.reshape(embedding, [-1, N * C * D]) # [B, N*C*D]
352
+ if self.output_tensor_list:
353
+ return output, tf.split(embedding, N) # [B, C*D] * N
354
+ if self.output_3d_tensor:
355
+ embedding = tf.reshape(embedding, [-1, N, C * D]) # [B, N, C*D]
356
+ return output, embedding
357
+ return output
358
+
359
+ # self.num_replicas > 1:
360
+ replicas = tf.split(embedding, self.num_replicas, axis=1)
361
+ outputs = []
362
+ outputs2 = []
363
+ for replica in replicas:
364
+ # shape of replica: [B*N*C, D]
365
+ embedding = tf.reshape(replica, [C, -1, D]) # [C, B*N, D]
366
+ embedding = tf.transpose(embedding, perm=[1, 0, 2]) # [B*N, C, D]
367
+ embedding = tf.reshape(embedding, [-1, C * D]) # [B*N, C*D]
368
+ output = tf.reshape(embedding, [-1, N * C * D]) # [B, N*C*D]
369
+ outputs.append(output)
370
+ if self.output_tensor_list:
371
+ embedding = tf.split(embedding, N) # [B, C*D] * N
372
+ outputs2.append(embedding)
373
+ elif self.output_3d_tensor:
374
+ embedding = tf.reshape(embedding, [-1, N, C * D]) # [B, N, C*D]
375
+ outputs2.append(embedding)
376
+ return outputs + outputs2
@@ -0,0 +1,194 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """Convenience blocks for building models."""
4
+ import logging
5
+
6
+ import tensorflow as tf
7
+
8
+ from easy_rec.python.layers.keras.activation import activation_layer
9
+ from easy_rec.python.utils.tf_utils import add_elements_to_collection
10
+
11
+ if tf.__version__ >= '2.0':
12
+ tf = tf.compat.v1
13
+
14
+
15
+ class GateNN(tf.keras.layers.Layer):
16
+
17
+ def __init__(self,
18
+ params,
19
+ output_units=None,
20
+ name='gate_nn',
21
+ reuse=None,
22
+ **kwargs):
23
+ super(GateNN, self).__init__(name=name, **kwargs)
24
+ output_dim = output_units if output_units is not None else params.output_dim
25
+ hidden_dim = params.get_or_default('hidden_dim', output_dim)
26
+ initializer = params.get_or_default('initializer', 'he_uniform')
27
+ do_batch_norm = params.get_or_default('use_bn', False)
28
+ activation = params.get_or_default('activation', 'relu')
29
+ dropout_rate = params.get_or_default('dropout_rate', 0.0)
30
+
31
+ self._sub_layers = []
32
+ dense = tf.keras.layers.Dense(
33
+ units=hidden_dim,
34
+ use_bias=not do_batch_norm,
35
+ kernel_initializer=initializer)
36
+ self._sub_layers.append(dense)
37
+
38
+ if do_batch_norm:
39
+ bn = tf.keras.layers.BatchNormalization(trainable=True)
40
+ self._sub_layers.append(bn)
41
+
42
+ act_layer = activation_layer(activation)
43
+ self._sub_layers.append(act_layer)
44
+
45
+ if 0.0 < dropout_rate < 1.0:
46
+ dropout = tf.keras.layers.Dropout(dropout_rate)
47
+ self._sub_layers.append(dropout)
48
+ elif dropout_rate >= 1.0:
49
+ raise ValueError('invalid dropout_ratio: %.3f' % dropout_rate)
50
+
51
+ dense = tf.keras.layers.Dense(
52
+ units=output_dim,
53
+ activation='sigmoid',
54
+ use_bias=not do_batch_norm,
55
+ kernel_initializer=initializer,
56
+ name='weight')
57
+ self._sub_layers.append(dense)
58
+ self._sub_layers.append(lambda x: x * 2)
59
+
60
+ def call(self, x, training=None, **kwargs):
61
+ """Performs the forward computation of the block."""
62
+ for layer in self._sub_layers:
63
+ cls = layer.__class__.__name__
64
+ if cls in ('Dropout', 'BatchNormalization', 'Dice'):
65
+ x = layer(x, training=training)
66
+ if cls in ('BatchNormalization', 'Dice') and training:
67
+ add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS)
68
+ else:
69
+ x = layer(x)
70
+ return x
71
+
72
+
73
+ class PPNet(tf.keras.layers.Layer):
74
+ """PEPNet: Parameter and Embedding Personalized Network for Infusing with Personalized Prior Information.
75
+
76
+ Attributes:
77
+ units: Sequential list of layer sizes.
78
+ use_bias: Whether to include a bias term.
79
+ activation: Type of activation to use on all except the last layer.
80
+ final_activation: Type of activation to use on last layer.
81
+ **kwargs: Extra args passed to the Keras Layer base class.
82
+ """
83
+
84
+ def __init__(self, params, name='ppnet', reuse=None, **kwargs):
85
+ super(PPNet, self).__init__(name=name, **kwargs)
86
+ params.check_required('mlp')
87
+ self.full_gate_input = params.get_or_default('full_gate_input', True)
88
+ mode = params.get_or_default('mode', 'lazy')
89
+ gate_params = params.gate_params
90
+ params = params.mlp
91
+ params.check_required('hidden_units')
92
+ use_bn = params.get_or_default('use_bn', True)
93
+ use_final_bn = params.get_or_default('use_final_bn', True)
94
+ use_bias = params.get_or_default('use_bias', False)
95
+ use_final_bias = params.get_or_default('use_final_bias', False)
96
+ dropout_rate = list(params.get_or_default('dropout_ratio', []))
97
+ activation = params.get_or_default('activation', 'relu')
98
+ initializer = params.get_or_default('initializer', 'he_uniform')
99
+ final_activation = params.get_or_default('final_activation', None)
100
+ use_bn_after_act = params.get_or_default('use_bn_after_activation', False)
101
+ units = list(params.hidden_units)
102
+ logging.info(
103
+ 'MLP(%s) units: %s, dropout: %r, activate=%s, use_bn=%r, final_bn=%r,'
104
+ ' final_activate=%s, bias=%r, initializer=%s, bn_after_activation=%r' %
105
+ (name, units, dropout_rate, activation, use_bn, use_final_bn,
106
+ final_activation, use_bias, initializer, use_bn_after_act))
107
+ assert len(units) > 0, 'MLP(%s) takes at least one hidden units' % name
108
+ self.reuse = reuse
109
+
110
+ num_dropout = len(dropout_rate)
111
+ self._sub_layers = []
112
+
113
+ if mode != 'lazy':
114
+ self._sub_layers.append(GateNN(gate_params, None, 'gate_0'))
115
+ for i, num_units in enumerate(units[:-1]):
116
+ name = 'layer_%d' % i
117
+ drop_rate = dropout_rate[i] if i < num_dropout else 0.0
118
+ self.add_rich_layer(num_units, use_bn, drop_rate, activation, initializer,
119
+ use_bias, use_bn_after_act, name,
120
+ params.l2_regularizer)
121
+ self._sub_layers.append(
122
+ GateNN(gate_params, num_units, 'gate_%d' % (i + 1)))
123
+
124
+ n = len(units) - 1
125
+ drop_rate = dropout_rate[n] if num_dropout > n else 0.0
126
+ name = 'layer_%d' % n
127
+ self.add_rich_layer(units[-1], use_final_bn, drop_rate, final_activation,
128
+ initializer, use_final_bias, use_bn_after_act, name,
129
+ params.l2_regularizer)
130
+ if mode == 'lazy':
131
+ self._sub_layers.append(
132
+ GateNN(gate_params, units[-1], 'gate_%d' % (n + 1)))
133
+
134
+ def add_rich_layer(self,
135
+ num_units,
136
+ use_bn,
137
+ dropout_rate,
138
+ activation,
139
+ initializer,
140
+ use_bias,
141
+ use_bn_after_activation,
142
+ name,
143
+ l2_reg=None):
144
+ act_layer = activation_layer(activation, name='%s/act' % name)
145
+ if use_bn and not use_bn_after_activation:
146
+ dense = tf.keras.layers.Dense(
147
+ units=num_units,
148
+ use_bias=use_bias,
149
+ kernel_initializer=initializer,
150
+ kernel_regularizer=l2_reg,
151
+ name='%s/dense' % name)
152
+ self._sub_layers.append(dense)
153
+ bn = tf.keras.layers.BatchNormalization(
154
+ name='%s/bn' % name, trainable=True)
155
+ self._sub_layers.append(bn)
156
+ self._sub_layers.append(act_layer)
157
+ else:
158
+ dense = tf.keras.layers.Dense(
159
+ num_units,
160
+ use_bias=use_bias,
161
+ kernel_initializer=initializer,
162
+ kernel_regularizer=l2_reg,
163
+ name='%s/dense' % name)
164
+ self._sub_layers.append(dense)
165
+ self._sub_layers.append(act_layer)
166
+ if use_bn and use_bn_after_activation:
167
+ bn = tf.keras.layers.BatchNormalization(name='%s/bn' % name)
168
+ self._sub_layers.append(bn)
169
+
170
+ if 0.0 < dropout_rate < 1.0:
171
+ dropout = tf.keras.layers.Dropout(dropout_rate, name='%s/dropout' % name)
172
+ self._sub_layers.append(dropout)
173
+ elif dropout_rate >= 1.0:
174
+ raise ValueError('invalid dropout_ratio: %.3f' % dropout_rate)
175
+
176
+ def call(self, inputs, training=None, **kwargs):
177
+ """Performs the forward computation of the block."""
178
+ x, gate_input = inputs
179
+ if self.full_gate_input:
180
+ with tf.name_scope(self.name):
181
+ gate_input = tf.concat([tf.stop_gradient(x), gate_input], axis=-1)
182
+
183
+ for layer in self._sub_layers:
184
+ cls = layer.__class__.__name__
185
+ if cls == 'GateNN':
186
+ gate = layer(gate_input)
187
+ x *= gate
188
+ elif cls in ('Dropout', 'BatchNormalization', 'Dice'):
189
+ x = layer(x, training=training)
190
+ if cls in ('BatchNormalization', 'Dice') and training:
191
+ add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS)
192
+ else:
193
+ x = layer(x)
194
+ return x