easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,364 @@
1
+ """Layer Normalization layer."""
2
+ import tensorflow as tf
3
+ from tensorflow.python.keras import constraints
4
+ from tensorflow.python.keras import initializers
5
+ from tensorflow.python.keras import regularizers
6
+ from tensorflow.python.keras.layers import Layer
7
+
8
+
9
+ def validate_axis(axis, input_shape):
10
+ """Validate an axis value and returns its standardized form.
11
+
12
+ Args:
13
+ axis: Value to validate. Can be an integer or a list/tuple of integers.
14
+ Integers may be negative.
15
+ input_shape: Reference input shape that the axis/axes refer to.
16
+
17
+ Returns:
18
+ Normalized form of `axis`, i.e. a list with all-positive values.
19
+ """
20
+ input_shape = tf.TensorShape(input_shape)
21
+ rank = input_shape.ndims
22
+ if not rank:
23
+ raise ValueError(
24
+ 'Input has undefined rank. Received: input_shape={input_shape}'.format(
25
+ input_shape=input_shape))
26
+
27
+ # Convert axis to list and resolve negatives
28
+ if isinstance(axis, int):
29
+ axis = [axis]
30
+ else:
31
+ axis = list(axis)
32
+ for idx, x in enumerate(axis):
33
+ if x < 0:
34
+ axis[idx] = rank + x
35
+
36
+ # Validate axes
37
+ for x in axis:
38
+ if x < 0 or x >= rank:
39
+ raise ValueError('Invalid value for `axis` argument. '
40
+ 'Expected 0 <= axis < inputs.rank (with '
41
+ 'inputs.rank={rank}). Received: axis={axis}'.format(
42
+ rank=rank, axis=tuple(axis)))
43
+ if len(axis) != len(set(axis)):
44
+ raise ValueError('Duplicate axis: {axis}'.format(axis=tuple(axis)))
45
+ return axis
46
+
47
+
48
+ class LayerNormalization(Layer):
49
+ """Layer normalization layer (Ba et al., 2016).
50
+
51
+ Normalize the activations of the previous layer for each given example in a
52
+ batch independently, rather than across a batch like Batch Normalization.
53
+ i.e. applies a transformation that maintains the mean activation within each
54
+ example close to 0 and the activation standard deviation close to 1.
55
+
56
+ Given a tensor `inputs`, moments are calculated and normalization
57
+ is performed across the axes specified in `axis`.
58
+
59
+ Example:
60
+ >>> data = tf.constant(np.arange(10).reshape(5, 2) * 10, dtype=tf.float32)
61
+ >>> print(data)
62
+ tf.Tensor(
63
+ [[ 0. 10.]
64
+ [20. 30.]
65
+ [40. 50.]
66
+ [60. 70.]
67
+ [80. 90.]], shape=(5, 2), dtype=float32)
68
+
69
+ >>> layer = tf.keras.layers.LayerNormalization(axis=1)
70
+ >>> output = layer(data)
71
+ >>> print(output)
72
+ tf.Tensor(
73
+ [[-1. 1.]
74
+ [-1. 1.]
75
+ [-1. 1.]
76
+ [-1. 1.]
77
+ [-1. 1.]], shape=(5, 2), dtype=float32)
78
+
79
+ Notice that with Layer Normalization the normalization happens across the
80
+ axes *within* each example, rather than across different examples in the
81
+ batch.
82
+
83
+ If `scale` or `center` are enabled, the layer will scale the normalized
84
+ outputs by broadcasting them with a trainable variable `gamma`, and center
85
+ the outputs by broadcasting with a trainable variable `beta`. `gamma` will
86
+ default to a ones tensor and `beta` will default to a zeros tensor, so that
87
+ centering and scaling are no-ops before training has begun.
88
+
89
+ So, with scaling and centering enabled the normalization equations
90
+ are as follows:
91
+
92
+ Let the intermediate activations for a mini-batch to be the `inputs`.
93
+
94
+ For each sample `x_i` in `inputs` with `k` features, we compute the mean and
95
+ variance of the sample:
96
+
97
+ ```python
98
+ mean_i = sum(x_i[j] for j in range(k)) / k
99
+ var_i = sum((x_i[j] - mean_i) ** 2 for j in range(k)) / k
100
+ ```
101
+
102
+ and then compute a normalized `x_i_normalized`, including a small factor
103
+ `epsilon` for numerical stability.
104
+
105
+ ```python
106
+ x_i_normalized = (x_i - mean_i) / sqrt(var_i + epsilon)
107
+ ```
108
+
109
+ And finally `x_i_normalized ` is linearly transformed by `gamma` and `beta`,
110
+ which are learned parameters:
111
+
112
+ ```python
113
+ output_i = x_i_normalized * gamma + beta
114
+ ```
115
+
116
+ `gamma` and `beta` will span the axes of `inputs` specified in `axis`, and
117
+ this part of the inputs' shape must be fully defined.
118
+
119
+ For example:
120
+ >>> layer = tf.keras.layers.LayerNormalization(axis=[1, 2, 3])
121
+ >>> layer.build([5, 20, 30, 40])
122
+ >>> print(layer.beta.shape)
123
+ (20, 30, 40)
124
+ >>> print(layer.gamma.shape)
125
+ (20, 30, 40)
126
+
127
+ Note that other implementations of layer normalization may choose to define
128
+ `gamma` and `beta` over a separate set of axes from the axes being
129
+ normalized across. For example, Group Normalization
130
+ ([Wu et al. 2018](https://arxiv.org/abs/1803.08494)) with group size of 1
131
+ corresponds to a Layer Normalization that normalizes across height, width,
132
+ and channel and has `gamma` and `beta` span only the channel dimension.
133
+ So, this Layer Normalization implementation will not match a Group
134
+ Normalization layer with group size set to 1.
135
+
136
+ Args:
137
+ axis: Integer or List/Tuple. The axis or axes to normalize across.
138
+ Typically, this is the features axis/axes. The left-out axes are
139
+ typically the batch axis/axes. `-1` is the last dimension in the
140
+ input. Defaults to `-1`.
141
+ epsilon: Small float added to variance to avoid dividing by zero. Defaults
142
+ to 1e-3
143
+ center: If True, add offset of `beta` to normalized tensor. If False,
144
+ `beta` is ignored. Defaults to `True`.
145
+ scale: If True, multiply by `gamma`. If False, `gamma` is not used.
146
+ When the next layer is linear (also e.g. `nn.relu`), this can be
147
+ disabled since the scaling will be done by the next layer.
148
+ Defaults to `True`.
149
+ beta_initializer: Initializer for the beta weight. Defaults to zeros.
150
+ gamma_initializer: Initializer for the gamma weight. Defaults to ones.
151
+ beta_regularizer: Optional regularizer for the beta weight. None by
152
+ default.
153
+ gamma_regularizer: Optional regularizer for the gamma weight. None by
154
+ default.
155
+ beta_constraint: Optional constraint for the beta weight. None by default.
156
+ gamma_constraint: Optional constraint for the gamma weight. None by
157
+ default.
158
+
159
+ Input shape:
160
+ Arbitrary. Use the keyword argument `input_shape` (tuple of
161
+ integers, does not include the samples axis) when using this layer as the
162
+ first layer in a model.
163
+
164
+ Output shape:
165
+ Same shape as input.
166
+
167
+ Reference:
168
+ - [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450).
169
+ """
170
+
171
+ def __init__(self,
172
+ axis=-1,
173
+ epsilon=1e-3,
174
+ center=True,
175
+ scale=True,
176
+ beta_initializer='zeros',
177
+ gamma_initializer='ones',
178
+ beta_regularizer=None,
179
+ gamma_regularizer=None,
180
+ beta_constraint=None,
181
+ gamma_constraint=None,
182
+ **kwargs):
183
+ super(LayerNormalization, self).__init__(**kwargs)
184
+ if isinstance(axis, (list, tuple)):
185
+ self.axis = list(axis)
186
+ elif isinstance(axis, int):
187
+ self.axis = axis
188
+ else:
189
+ raise TypeError('Expected an int or a list/tuple of ints for the '
190
+ "argument 'axis', but received: %r" % axis)
191
+
192
+ self.epsilon = epsilon
193
+ self.center = center
194
+ self.scale = scale
195
+ self.beta_initializer = initializers.get(beta_initializer)
196
+ self.gamma_initializer = initializers.get(gamma_initializer)
197
+ self.beta_regularizer = regularizers.get(beta_regularizer)
198
+ self.gamma_regularizer = regularizers.get(gamma_regularizer)
199
+ self.beta_constraint = constraints.get(beta_constraint)
200
+ self.gamma_constraint = constraints.get(gamma_constraint)
201
+
202
+ self.supports_masking = True
203
+
204
+ # Indicates whether a faster fused implementation can be used. This will
205
+ # be set to True or False in build()"
206
+ self._fused = None
207
+
208
+ def _fused_can_be_used(self, ndims):
209
+ """Returns false if fused implementation cannot be used.
210
+
211
+ Check if the axis is contiguous and can be collapsed into the last axis.
212
+ The self.axis is assumed to have no duplicates.
213
+ """
214
+ if not tf.test.is_gpu_available():
215
+ return False
216
+ axis = sorted(self.axis)
217
+ can_use_fused = False
218
+
219
+ if axis[-1] == ndims - 1 and axis[-1] - axis[0] == len(axis) - 1:
220
+ can_use_fused = True
221
+
222
+ # fused_batch_norm will silently raise epsilon to be at least 1.001e-5,
223
+ # so we cannot used the fused version if epsilon is below that value.
224
+ # Also, the variable dtype must be float32, as fused_batch_norm only
225
+ # supports float32 variables.
226
+ if self.epsilon < 1.001e-5 or self.dtype != 'float32':
227
+ can_use_fused = False
228
+
229
+ return can_use_fused
230
+
231
+ def build(self, input_shape):
232
+ self.axis = validate_axis(self.axis, input_shape)
233
+ input_shape = tf.TensorShape(input_shape)
234
+ rank = input_shape.ndims
235
+
236
+ param_shape = [input_shape[dim] for dim in self.axis]
237
+ if self.scale:
238
+ self.gamma = self.add_weight(
239
+ name='gamma',
240
+ shape=param_shape,
241
+ initializer=self.gamma_initializer,
242
+ regularizer=self.gamma_regularizer,
243
+ constraint=self.gamma_constraint,
244
+ trainable=True,
245
+ )
246
+ else:
247
+ self.gamma = None
248
+
249
+ if self.center:
250
+ self.beta = self.add_weight(
251
+ name='beta',
252
+ shape=param_shape,
253
+ initializer=self.beta_initializer,
254
+ regularizer=self.beta_regularizer,
255
+ constraint=self.beta_constraint,
256
+ trainable=True,
257
+ )
258
+ else:
259
+ self.beta = None
260
+
261
+ self._fused = self._fused_can_be_used(rank)
262
+ super(LayerNormalization,
263
+ self).build(input_shape) # Be sure to call this somewhere!
264
+
265
+ def call(self, inputs):
266
+ # Compute the axes along which to reduce the mean / variance
267
+ input_shape = inputs.shape
268
+ ndims = len(input_shape)
269
+
270
+ # Broadcasting only necessary for norm when the axis is not just
271
+ # the last dimension
272
+ broadcast_shape = [1] * ndims
273
+ for dim in self.axis:
274
+ broadcast_shape[dim] = input_shape.dims[dim].value
275
+
276
+ def _broadcast(v):
277
+ if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]):
278
+ return tf.reshape(v, broadcast_shape)
279
+ return v
280
+
281
+ if not self._fused:
282
+ input_dtype = inputs.dtype
283
+ if (input_dtype in ('float16', 'bfloat16') and self.dtype == 'float32'):
284
+ # If mixed precision is used, cast inputs to float32 so that
285
+ # this is at least as numerically stable as the fused version.
286
+ inputs = tf.cast(inputs, 'float32')
287
+
288
+ # Calculate the moments on the last axis (layer activations).
289
+ mean, variance = tf.nn.moments(inputs, self.axis, keep_dims=True)
290
+
291
+ scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
292
+
293
+ # Compute layer normalization using the batch_normalization
294
+ # function.
295
+ outputs = tf.nn.batch_normalization(
296
+ inputs,
297
+ mean,
298
+ variance,
299
+ offset=offset,
300
+ scale=scale,
301
+ variance_epsilon=self.epsilon,
302
+ )
303
+ outputs = tf.cast(outputs, input_dtype)
304
+ else:
305
+ # Collapse dims before self.axis, and dims in self.axis
306
+
307
+ axis = sorted(self.axis)
308
+ tensor_shape = tf.shape(inputs)
309
+ pre_dim = tf.reduce_prod(tensor_shape[:axis[0]])
310
+ in_dim = tf.reduce_prod(tensor_shape[axis[0]:])
311
+ squeezed_shape = [1, pre_dim, in_dim, 1]
312
+ # This fused operation requires reshaped inputs to be NCHW.
313
+ data_format = 'NCHW'
314
+
315
+ inputs = tf.reshape(inputs, squeezed_shape)
316
+
317
+ # self.gamma and self.beta have the wrong shape for
318
+ # fused_batch_norm, so we cannot pass them as the scale and offset
319
+ # parameters. Therefore, we create two constant tensors in correct
320
+ # shapes for fused_batch_norm and later construct a separate
321
+ # calculation on the scale and offset.
322
+ scale = tf.ones(tf.convert_to_tensor([pre_dim]), dtype=self.dtype)
323
+ offset = tf.zeros(tf.convert_to_tensor([pre_dim]), dtype=self.dtype)
324
+
325
+ # Compute layer normalization using the fused_batch_norm function.
326
+ outputs, _, _ = tf.compat.v1.nn.fused_batch_norm(
327
+ inputs,
328
+ scale=scale,
329
+ offset=offset,
330
+ epsilon=self.epsilon,
331
+ data_format=data_format,
332
+ )
333
+
334
+ outputs = tf.reshape(outputs, tensor_shape)
335
+
336
+ scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
337
+
338
+ if scale is not None:
339
+ outputs = outputs * tf.cast(scale, outputs.dtype)
340
+ if offset is not None:
341
+ outputs = outputs + tf.cast(offset, outputs.dtype)
342
+
343
+ # If some components of the shape got lost due to adjustments, fix that.
344
+ outputs.set_shape(input_shape)
345
+ return outputs
346
+
347
+ def compute_output_shape(self, input_shape):
348
+ return input_shape
349
+
350
+ def get_config(self):
351
+ config = {
352
+ 'axis': self.axis,
353
+ 'epsilon': self.epsilon,
354
+ 'center': self.center,
355
+ 'scale': self.scale,
356
+ 'beta_initializer': initializers.serialize(self.beta_initializer),
357
+ 'gamma_initializer': initializers.serialize(self.gamma_initializer),
358
+ 'beta_regularizer': regularizers.serialize(self.beta_regularizer),
359
+ 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
360
+ 'beta_constraint': constraints.serialize(self.beta_constraint),
361
+ 'gamma_constraint': constraints.serialize(self.gamma_constraint),
362
+ }
363
+ base_config = super(LayerNormalization, self).get_config()
364
+ return dict(list(base_config.items()) + list(config.items()))
@@ -0,0 +1,166 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+ from tensorflow.python.keras.layers import Activation
7
+ from tensorflow.python.keras.layers import Dense
8
+ from tensorflow.python.keras.layers import Layer
9
+
10
+ from easy_rec.python.layers.keras.blocks import MLP
11
+ from easy_rec.python.layers.keras.layer_norm import LayerNormalization
12
+ from easy_rec.python.layers.utils import Parameter
13
+
14
+
15
+ class MaskBlock(Layer):
16
+ """MaskBlock use in MaskNet.
17
+
18
+ Args:
19
+ projection_dim: project dimension to reduce the computational cost.
20
+ Default is `None` such that a full (`input_dim` by `aggregation_size`) matrix
21
+ W is used. If enabled, a low-rank matrix W = U*V will be used, where U
22
+ is of size `input_dim` by `projection_dim` and V is of size
23
+ `projection_dim` by `aggregation_size`. `projection_dim` need to be smaller
24
+ than `aggregation_size`/2 to improve the model efficiency. In practice, we've
25
+ observed that `projection_dim` = d/4 consistently preserved the
26
+ accuracy of a full-rank version.
27
+ """
28
+
29
+ def __init__(self, params, name='mask_block', reuse=None, **kwargs):
30
+ super(MaskBlock, self).__init__(name=name, **kwargs)
31
+ self.config = params.get_pb_config()
32
+ self.l2_reg = params.l2_regularizer
33
+ self._projection_dim = params.get_or_default('projection_dim', None)
34
+ self.reuse = reuse
35
+ self.final_relu = Activation('relu', name='relu')
36
+
37
+ def build(self, input_shape):
38
+ if type(input_shape) in (tuple, list):
39
+ assert len(input_shape) >= 2, 'MaskBlock must has at least two inputs'
40
+ input_dim = int(input_shape[0][-1])
41
+ mask_input_dim = int(input_shape[1][-1])
42
+ else:
43
+ input_dim, mask_input_dim = input_shape[-1], input_shape[-1]
44
+ if self.config.HasField('reduction_factor'):
45
+ aggregation_size = int(mask_input_dim * self.config.reduction_factor)
46
+ elif self.config.HasField('aggregation_size') is not None:
47
+ aggregation_size = self.config.aggregation_size
48
+ else:
49
+ raise ValueError(
50
+ 'Need one of reduction factor or aggregation size for MaskBlock.')
51
+
52
+ self.aggr_layer = Dense(
53
+ aggregation_size,
54
+ activation='relu',
55
+ kernel_initializer='he_uniform',
56
+ kernel_regularizer=self.l2_reg,
57
+ name='aggregation')
58
+ self.weight_layer = Dense(input_dim, name='weights')
59
+ if self._projection_dim is not None:
60
+ logging.info('%s project dim is %d', self.name, self._projection_dim)
61
+ self.project_layer = Dense(
62
+ self._projection_dim,
63
+ kernel_regularizer=self.l2_reg,
64
+ use_bias=False,
65
+ name='project')
66
+ if self.config.input_layer_norm:
67
+ # 推荐在调用MaskBlock之前做好 layer norm,否则每一次调用都需要对input做ln
68
+ if tf.__version__ >= '2.0':
69
+ self.input_layer_norm = tf.keras.layers.LayerNormalization(
70
+ name='input_ln')
71
+ else:
72
+ self.input_layer_norm = LayerNormalization(name='input_ln')
73
+
74
+ if self.config.HasField('output_size'):
75
+ self.output_layer = Dense(
76
+ self.config.output_size, use_bias=False, name='output')
77
+ if tf.__version__ >= '2.0':
78
+ self.output_layer_norm = tf.keras.layers.LayerNormalization(
79
+ name='output_ln')
80
+ else:
81
+ self.output_layer_norm = LayerNormalization(name='output_ln')
82
+ super(MaskBlock, self).build(input_shape)
83
+
84
+ def call(self, inputs, training=None, **kwargs):
85
+ if type(inputs) in (tuple, list):
86
+ net, mask_input = inputs[:2]
87
+ else:
88
+ net, mask_input = inputs, inputs
89
+
90
+ if self.config.input_layer_norm:
91
+ net = self.input_layer_norm(net)
92
+
93
+ if self._projection_dim is None:
94
+ aggr = self.aggr_layer(mask_input)
95
+ else:
96
+ u = self.project_layer(mask_input)
97
+ aggr = self.aggr_layer(u)
98
+
99
+ weights = self.weight_layer(aggr)
100
+ masked_net = net * weights
101
+
102
+ if not self.config.HasField('output_size'):
103
+ return masked_net
104
+
105
+ hidden = self.output_layer(masked_net)
106
+ ln_hidden = self.output_layer_norm(hidden)
107
+ return self.final_relu(ln_hidden)
108
+
109
+
110
+ class MaskNet(Layer):
111
+ """MaskNet: Introducing Feature-Wise Multiplication to CTR Ranking Models by Instance-Guided Mask.
112
+
113
+ Refer: https://arxiv.org/pdf/2102.07619.pdf
114
+ """
115
+
116
+ def __init__(self, params, name='mask_net', reuse=None, **kwargs):
117
+ super(MaskNet, self).__init__(name=name, **kwargs)
118
+ self.reuse = reuse
119
+ self.params = params
120
+ self.config = params.get_pb_config()
121
+ if self.config.HasField('mlp'):
122
+ p = Parameter.make_from_pb(self.config.mlp)
123
+ p.l2_regularizer = params.l2_regularizer
124
+ self.mlp = MLP(p, name='mlp', reuse=reuse)
125
+ else:
126
+ self.mlp = None
127
+
128
+ self.mask_layers = []
129
+ for i, block_conf in enumerate(self.config.mask_blocks):
130
+ params = Parameter.make_from_pb(block_conf)
131
+ params.l2_regularizer = self.params.l2_regularizer
132
+ mask_layer = MaskBlock(params, name='block_%d' % i, reuse=self.reuse)
133
+ self.mask_layers.append(mask_layer)
134
+
135
+ if self.config.input_layer_norm:
136
+ if tf.__version__ >= '2.0':
137
+ self.input_layer_norm = tf.keras.layers.LayerNormalization(
138
+ name='input_ln')
139
+ else:
140
+ self.input_layer_norm = LayerNormalization(name='input_ln')
141
+
142
+ def call(self, inputs, training=None, **kwargs):
143
+ if self.config.input_layer_norm:
144
+ inputs = self.input_layer_norm(inputs)
145
+
146
+ if self.config.use_parallel:
147
+ mask_outputs = [
148
+ mask_layer((inputs, inputs)) for mask_layer in self.mask_layers
149
+ ]
150
+ all_mask_outputs = tf.concat(mask_outputs, axis=1)
151
+ if self.mlp is not None:
152
+ output = self.mlp(all_mask_outputs, training=training)
153
+ else:
154
+ output = all_mask_outputs
155
+ return output
156
+ else:
157
+ net = inputs
158
+ for i, _ in enumerate(self.config.mask_blocks):
159
+ mask_layer = self.mask_layers[i]
160
+ net = mask_layer((net, inputs))
161
+
162
+ if self.mlp is not None:
163
+ output = self.mlp(net, training=training)
164
+ else:
165
+ output = net
166
+ return output