easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,192 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import numpy as np
6
+ import tensorflow as tf
7
+ from tensorflow.python.keras.layers import Dense
8
+ from tensorflow.python.keras.layers import Dropout
9
+ from tensorflow.python.keras.layers import Embedding
10
+ from tensorflow.python.keras.layers import Layer
11
+
12
+ from easy_rec.python.layers.keras import MultiHeadAttention
13
+ from easy_rec.python.layers.keras.layer_norm import LayerNormalization
14
+ from easy_rec.python.layers.utils import Parameter
15
+ from easy_rec.python.protos import seq_encoder_pb2
16
+
17
+
18
+ class TransformerBlock(Layer):
19
+ """A transformer block combines multi-head attention and feed-forward networks with layer normalization and dropout.
20
+
21
+ Purpose: Combines attention and feed-forward layers with residual connections and normalization.
22
+ Components: Multi-head attention, feed-forward network, dropout, and layer normalization.
23
+ Output: Enhanced representation after applying attention and feed-forward layers.
24
+ """
25
+
26
+ def __init__(self, params, name='transformer_block', reuse=None, **kwargs):
27
+ super(TransformerBlock, self).__init__(name=name, **kwargs)
28
+ d_model = params.hidden_size
29
+ num_heads = params.num_attention_heads
30
+ mha_cfg = seq_encoder_pb2.MultiHeadAttention()
31
+ mha_cfg.num_heads = num_heads
32
+ mha_cfg.key_dim = d_model // num_heads
33
+ mha_cfg.dropout = params.get_or_default('attention_probs_dropout_prob', 0.0)
34
+ mha_cfg.return_attention_scores = False
35
+ args = Parameter.make_from_pb(mha_cfg)
36
+ self.mha = MultiHeadAttention(args, 'multi_head_attn')
37
+ dropout_rate = params.get_or_default('hidden_dropout_prob', 0.1)
38
+ ffn_units = params.get_or_default('intermediate_size', d_model)
39
+ ffn_act = params.get_or_default('hidden_act', 'relu')
40
+ self.ffn_dense1 = Dense(ffn_units, activation=ffn_act)
41
+ self.ffn_dense2 = Dense(d_model)
42
+ if tf.__version__ >= '2.0':
43
+ self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
44
+ self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
45
+ else:
46
+ self.layer_norm1 = LayerNormalization(epsilon=1e-6)
47
+ self.layer_norm2 = LayerNormalization(epsilon=1e-6)
48
+ self.dropout1 = Dropout(dropout_rate)
49
+ self.dropout2 = Dropout(dropout_rate)
50
+
51
+ def call(self, inputs, training=None, **kwargs):
52
+ x, mask = inputs
53
+ attn_output = self.mha([x, x, x], mask=mask, training=training)
54
+ attn_output = self.dropout1(attn_output, training=training)
55
+ out1 = self.layer_norm1(x + attn_output)
56
+ ffn_mid = self.ffn_dense1(out1)
57
+ ffn_output = self.ffn_dense2(ffn_mid)
58
+ ffn_output = self.dropout2(ffn_output, training=training)
59
+ out2 = self.layer_norm2(out1 + ffn_output)
60
+ return out2
61
+
62
+
63
+ # Positional Encoding, https://www.tensorflow.org/text/tutorials/transformer
64
+ def positional_encoding(length, depth):
65
+ depth = depth / 2
66
+ positions = np.arange(length)[:, np.newaxis] # (seq, 1)
67
+ depths = np.arange(depth)[np.newaxis, :] / depth # (1, depth)
68
+ angle_rates = 1 / (10000**depths) # (1, depth)
69
+ angle_rads = positions * angle_rates # (pos, depth)
70
+ pos_encoding = np.concatenate(
71
+ [np.sin(angle_rads), np.cos(angle_rads)], axis=-1)
72
+ return tf.cast(pos_encoding, dtype=tf.float32)
73
+
74
+
75
+ class PositionalEmbedding(Layer):
76
+
77
+ def __init__(self, vocab_size, d_model, max_position, name='pos_embedding'):
78
+ super(PositionalEmbedding, self).__init__(name=name)
79
+ self.d_model = d_model
80
+ self.embedding = Embedding(vocab_size, d_model)
81
+ self.pos_encoding = positional_encoding(length=max_position, depth=d_model)
82
+
83
+ def call(self, x, training=None):
84
+ length = tf.shape(x)[1]
85
+ x = self.embedding(x)
86
+ # This factor sets the relative scale of the embedding and positional_encoding.
87
+ x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
88
+ x = x + self.pos_encoding[tf.newaxis, :length, :]
89
+ return x
90
+
91
+
92
+ class TransformerEncoder(Layer):
93
+ """The encoder consists of a stack of encoder layers.
94
+
95
+ It converts the input sequence into a set of embeddings enriched with positional information.
96
+ Purpose: Encodes the input sequence into a set of embeddings.
97
+ Components: Embedding layer, positional encoding, and a stack of transformer blocks.
98
+ Output: Encoded representation of the input sequence.
99
+ """
100
+
101
+ def __init__(self, params, name='transformer_encoder', reuse=None, **kwargs):
102
+ super(TransformerEncoder, self).__init__(name=name, **kwargs)
103
+ d_model = params.hidden_size
104
+ dropout_rate = params.get_or_default('hidden_dropout_prob', 0.1)
105
+ max_position = params.get_or_default('max_position_embeddings', 512)
106
+ num_layers = params.get_or_default('num_hidden_layers', 1)
107
+ vocab_size = params.vocab_size
108
+ logging.info('vocab size of TransformerEncoder(%s) is %d', name, vocab_size)
109
+ self.output_all = params.get_or_default('output_all_token_embeddings', True)
110
+ self.pos_encoding = PositionalEmbedding(vocab_size, d_model, max_position)
111
+ self.dropout = Dropout(dropout_rate)
112
+ self.enc_layers = [
113
+ TransformerBlock(params, 'layer_%d' % i) for i in range(num_layers)
114
+ ]
115
+ self._vocab_size = vocab_size
116
+ self._max_position = max_position
117
+
118
+ @property
119
+ def vocab_size(self):
120
+ return self._vocab_size
121
+
122
+ @property
123
+ def max_position(self):
124
+ return self._max_position
125
+
126
+ def call(self, inputs, training=None, **kwargs):
127
+ x, mask = inputs
128
+ # `x` is token-IDs shape: (batch, seq_len)
129
+ x = self.pos_encoding(x) # Shape `(batch_size, seq_len, d_model)`.
130
+ x = self.dropout(x, training=training)
131
+ for block in self.enc_layers:
132
+ x = block([x, mask], training)
133
+ # x Shape `(batch_size, seq_len, d_model)`.
134
+ return x if self.output_all else x[:, 0, :]
135
+
136
+
137
+ class TextEncoder(Layer):
138
+
139
+ def __init__(self, params, name='text_encoder', reuse=None, **kwargs):
140
+ super(TextEncoder, self).__init__(name=name, **kwargs)
141
+ self.separator = params.get_or_default('separator', ' ')
142
+ self.cls_token = '[CLS]' + self.separator
143
+ self.sep_token = self.separator + '[SEP]' + self.separator
144
+ params.transformer.output_all_token_embeddings = False
145
+ trans_params = Parameter.make_from_pb(params.transformer)
146
+ vocab_file = params.get_or_default('vocab_file', None)
147
+ self.vocab = None
148
+ self.default_token_id = params.get_or_default('default_token_id', 0)
149
+ if vocab_file is not None:
150
+ self.vocab = tf.feature_column.categorical_column_with_vocabulary_file(
151
+ 'tokens',
152
+ vocabulary_file=vocab_file,
153
+ default_value=self.default_token_id)
154
+ logging.info('vocab file of TextEncoder(%s) is %s', name, vocab_file)
155
+ trans_params.vocab_size = self.vocab.vocabulary_size
156
+ self.encoder = TransformerEncoder(trans_params, name='transformer')
157
+
158
+ def call(self, inputs, training=None, **kwargs):
159
+ if type(inputs) not in (tuple, list):
160
+ inputs = [inputs]
161
+ inputs = [tf.squeeze(text) for text in inputs]
162
+ batch_size = tf.shape(inputs[0])
163
+ cls = tf.fill(batch_size, self.cls_token)
164
+ sep = tf.fill(batch_size, self.sep_token)
165
+ sentences = [cls]
166
+ for sentence in inputs:
167
+ sentences.append(sentence)
168
+ sentences.append(sep)
169
+ text = tf.strings.join(sentences)
170
+ tokens = tf.strings.split(text, self.separator)
171
+ if self.vocab is not None:
172
+ features = {'tokens': tokens}
173
+ token_ids = self.vocab._transform_feature(features)
174
+ token_ids = tf.sparse.to_dense(
175
+ token_ids, default_value=self.default_token_id, name='token_ids')
176
+ length = tf.shape(token_ids)[-1]
177
+ token_ids = tf.cond(
178
+ tf.less_equal(length, self.encoder.max_position), lambda: token_ids,
179
+ lambda: tf.slice(token_ids, [0, 0], [-1, self.encoder.max_position]))
180
+ mask = tf.not_equal(token_ids, self.default_token_id, name='mask')
181
+ else:
182
+ tokens = tf.sparse.to_dense(tokens, default_value='')
183
+ length = tf.shape(tokens)[-1]
184
+ tokens = tf.cond(
185
+ tf.less_equal(length, self.encoder.max_position), lambda: tokens,
186
+ lambda: tf.slice(tokens, [0, 0], [-1, self.encoder.max_position]))
187
+ token_ids = tf.string_to_hash_bucket_fast(
188
+ tokens, self.encoder.vocab_size, name='token_ids')
189
+ mask = tf.not_equal(tokens, '', name='mask')
190
+
191
+ encoding = self.encoder([token_ids, mask], training=training)
192
+ return encoding
@@ -0,0 +1,51 @@
1
+ # -*- encoding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ if tf.__version__ >= '2.0':
6
+ tf = tf.compat.v1
7
+
8
+
9
+ class LayerNormalization(tf.layers.Layer):
10
+ """Layer normalization for BTC format: supports L2(default) and L1 modes."""
11
+
12
+ def __init__(self, hidden_size, params={}):
13
+ super(LayerNormalization, self).__init__()
14
+ self.hidden_size = hidden_size
15
+ self.norm_type = params.get('type', 'layernorm_L2')
16
+ self.epsilon = params.get('epsilon', 1e-6)
17
+
18
+ def build(self, _):
19
+ self.scale = tf.get_variable(
20
+ 'layer_norm_scale', [self.hidden_size],
21
+ initializer=tf.keras.initializers.Ones(),
22
+ dtype=tf.float32)
23
+ self.bias = tf.get_variable(
24
+ 'layer_norm_bias', [self.hidden_size],
25
+ initializer=tf.keras.initializers.Zeros(),
26
+ dtype=tf.float32)
27
+ self.built = True
28
+
29
+ def call(self, x):
30
+ if self.norm_type == 'layernorm_L2':
31
+ epsilon = self.epsilon
32
+ dtype = x.dtype
33
+ x = tf.cast(x=x, dtype=tf.float32)
34
+ mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
35
+ variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True)
36
+ norm_x = (x - mean) * tf.rsqrt(variance + epsilon)
37
+ result = norm_x * self.scale + self.bias
38
+ return tf.cast(x=result, dtype=dtype)
39
+
40
+ else:
41
+ dtype = x.dtype
42
+ if dtype == tf.float16:
43
+ x = tf.cast(x, dtype=tf.float32)
44
+ mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
45
+ x = x - mean
46
+ variance = tf.reduce_mean(tf.abs(x), axis=[-1], keepdims=True)
47
+ norm_x = tf.div(x, variance + self.epsilon)
48
+ y = norm_x * self.scale + self.bias
49
+ if dtype == tf.float16:
50
+ y = tf.saturate_cast(y, dtype)
51
+ return y
@@ -0,0 +1,83 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.layers import dnn
8
+
9
+ if tf.__version__ >= '2.0':
10
+ tf = tf.compat.v1
11
+
12
+
13
+ class MMOE:
14
+
15
+ def __init__(self,
16
+ expert_dnn_config,
17
+ l2_reg,
18
+ num_task,
19
+ num_expert=None,
20
+ name='mmoe',
21
+ is_training=False):
22
+ """Initializes a `DNN` Layer.
23
+
24
+ Args:
25
+ expert_dnn_config: a instance or a list of easy_rec.python.protos.dnn_pb2.DNN,
26
+ if it is a list of configs, the param `num_expert` will be ignored,
27
+ if it is a single config, the number of experts will be specified by num_expert.
28
+ l2_reg: l2 regularizer.
29
+ num_task: number of tasks
30
+ num_expert: number of experts, default is the list length of expert_dnn_configs
31
+ name: scope of the DNN, so that the parameters could be separated from other dnns
32
+ is_training: train phase or not, impact batchnorm and dropout
33
+ """
34
+ if isinstance(expert_dnn_config, list):
35
+ self._expert_dnn_configs = expert_dnn_config
36
+ self._num_expert = len(expert_dnn_config)
37
+ else:
38
+ assert num_expert is not None and num_expert > 0, \
39
+ 'param `num_expert` must be large than zero, when expert_dnn_config is not a list'
40
+ self._expert_dnn_configs = [expert_dnn_config] * num_expert
41
+ self._num_expert = num_expert
42
+ logging.info('num_expert: {0}'.format(self._num_expert))
43
+
44
+ self._num_task = num_task
45
+ self._l2_reg = l2_reg
46
+ self._name = name
47
+ self._is_training = is_training
48
+
49
+ @property
50
+ def num_expert(self):
51
+ return self._num_expert
52
+
53
+ def gate(self, unit, deep_fea, name):
54
+ fea = tf.layers.dense(
55
+ inputs=deep_fea,
56
+ units=unit,
57
+ kernel_regularizer=self._l2_reg,
58
+ name='%s/dnn' % name)
59
+ fea = tf.nn.softmax(fea, axis=1)
60
+ return fea
61
+
62
+ def __call__(self, deep_fea):
63
+ expert_fea_list = []
64
+ for expert_id in range(self._num_expert):
65
+ expert_dnn_config = self._expert_dnn_configs[expert_id]
66
+ expert_dnn = dnn.DNN(
67
+ expert_dnn_config,
68
+ self._l2_reg,
69
+ name='%s/expert_%d' % (self._name, expert_id),
70
+ is_training=self._is_training)
71
+ expert_fea = expert_dnn(deep_fea)
72
+ expert_fea_list.append(expert_fea)
73
+ experts_fea = tf.stack(expert_fea_list, axis=1)
74
+
75
+ task_input_list = []
76
+ for task_id in range(self._num_task):
77
+ gate = self.gate(
78
+ self._num_expert, deep_fea, name='%s/gate_%d' % (self._name, task_id))
79
+ gate = tf.expand_dims(gate, -1)
80
+ task_input = tf.multiply(experts_fea, gate)
81
+ task_input = tf.reduce_sum(task_input, axis=1)
82
+ task_input_list.append(task_input)
83
+ return task_input_list
@@ -0,0 +1,162 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ if tf.__version__ >= '2.0':
6
+ tf = tf.compat.v1
7
+
8
+
9
+ class MultiHeadAttention:
10
+
11
+ def __init__(self, head_num, head_size, l2_reg, use_res=False, name=''):
12
+ """Initializes a `MultiHeadAttention` Layer.
13
+
14
+ Args:
15
+ head_num: The number of heads
16
+ head_size: The dimension of a head
17
+ l2_reg: l2 regularizer
18
+ use_res: Whether to use residual connections before output.
19
+ name: scope of the MultiHeadAttention, so that the parameters could be separated from other MultiHeadAttention
20
+ """
21
+ self._head_num = head_num
22
+ self._head_size = head_size
23
+ self._l2_reg = l2_reg
24
+ self._use_res = use_res
25
+ self._name = name
26
+
27
+ def _split_multihead_qkv(self, q, k, v):
28
+ """Split multiple heads.
29
+
30
+ Args:
31
+ q: Query matrix of shape [bs, feature_num, head_num * head_size].
32
+ k: Key matrix of shape [bs, feature_num, head_num * head_size].
33
+ v: Value matrix of shape [bs, feature_num, head_num * head_size].
34
+
35
+ Returns:
36
+ q: Query matrix of shape [bs, head_num, feature_num, head_size].
37
+ k: Key matrix of shape [bs, head_num, feature_num, head_size].
38
+ v: Value matrix of shape [bs, head_num, feature_num, head_size].
39
+ """
40
+ reshaped_q = tf.reshape(
41
+ q, shape=[-1, q.shape[1], self._head_num, self._head_size])
42
+ q = tf.transpose(reshaped_q, perm=[0, 2, 1, 3])
43
+ reshaped_k = tf.reshape(
44
+ k, shape=[-1, k.shape[1], self._head_num, self._head_size])
45
+ k = tf.transpose(reshaped_k, perm=[0, 2, 1, 3])
46
+ reshaped_v = tf.reshape(
47
+ v, shape=[-1, v.shape[1], self._head_num, self._head_size])
48
+ v = tf.transpose(reshaped_v, perm=[0, 2, 1, 3])
49
+ return q, k, v
50
+
51
+ def _scaled_dot_product_attention(self, q, k, v):
52
+ """Calculate scaled dot product attention by q, k and v.
53
+
54
+ Args:
55
+ q: Query matrix of shape [bs, head_num, feature_num, head_size].
56
+ k: Key matrix of shape [bs, head_num, feature_num, head_size].
57
+ v: Value matrix of shape [bs, head_num, feature_num, head_size].
58
+
59
+ Returns:
60
+ q: Query matrix of shape [bs, head_num, feature_num, head_size].
61
+ k: Key matrix of shape [bs, head_num, feature_num, head_size].
62
+ v: Value matrix of shape [bs, head_num, feature_num, head_size].
63
+ """
64
+ product = tf.linalg.matmul(
65
+ a=q, b=k, transpose_b=True) / (
66
+ self._head_size**-0.5)
67
+ weights = tf.nn.softmax(product)
68
+ out = tf.linalg.matmul(weights, v)
69
+ return out
70
+
71
+ def _compute_qkv(self, q, k, v):
72
+ """Calculate q, k and v matrices.
73
+
74
+ Args:
75
+ q: Query matrix of shape [bs, feature_num, d_model].
76
+ k: Key matrix of shape [bs, feature_num, d_model].
77
+ v: Value matrix of shape [bs, feature_num, d_model].
78
+
79
+ Returns:
80
+ q: Query matrix of shape [bs, feature_num, head_size * n_head].
81
+ k: Key matrix of shape [bs, feature_num, head_size * n_head].
82
+ v: Value matrix of shape [bs, feature_num, head_size * n_head].
83
+ """
84
+ q = tf.layers.dense(
85
+ q,
86
+ self._head_num * self._head_size,
87
+ use_bias=False,
88
+ kernel_regularizer=self._l2_reg,
89
+ name='%s/%s/dnn' % (self._name, 'query'))
90
+ k = tf.layers.dense(
91
+ k,
92
+ self._head_num * self._head_size,
93
+ use_bias=False,
94
+ kernel_regularizer=self._l2_reg,
95
+ name='%s/%s/dnn' % (self._name, 'key'))
96
+ v = tf.layers.dense(
97
+ v,
98
+ self._head_num * self._head_size,
99
+ use_bias=False,
100
+ kernel_regularizer=self._l2_reg,
101
+ name='%s/%s/dnn' % (self._name, 'value'))
102
+ return q, k, v
103
+
104
+ def _combine_heads(self, multi_head_tensor):
105
+ """Combine the results of multiple heads.
106
+
107
+ Args:
108
+ multi_head_tensor: Result matrix of shape [bs, head_num, feature_num, head_size].
109
+
110
+ Returns:
111
+ out: Result matrix of shape [bs, feature_num, head_num * head_size].
112
+ """
113
+ x = tf.transpose(multi_head_tensor, perm=[0, 2, 1, 3])
114
+ out = tf.reshape(x, shape=[-1, x.shape[1], x.shape[2] * x.shape[3]])
115
+ return out
116
+
117
+ def _multi_head_attention(self, attention_input):
118
+ """Build multiple heads attention layer.
119
+
120
+ Args:
121
+ attention_input: The input of interacting layer, has a shape of [bs, feature_num, d_model].
122
+
123
+ Returns:
124
+ out: The output of multi head attention layer, has a shape of [bs, feature_num, head_num * head_size].
125
+ """
126
+ if isinstance(attention_input, list):
127
+ assert len(attention_input) == 3 or len(attention_input) == 1, \
128
+ 'If the input of multi_head_attention is a list, the length must be 1 or 3.'
129
+
130
+ if len(attention_input) == 3:
131
+ ori_q = attention_input[0]
132
+ ori_k = attention_input[1]
133
+ ori_v = attention_input[2]
134
+ else:
135
+ ori_q = attention_input[0]
136
+ ori_k = attention_input[0]
137
+ ori_v = attention_input[0]
138
+ else:
139
+ ori_q = attention_input
140
+ ori_k = attention_input
141
+ ori_v = attention_input
142
+
143
+ q, k, v = self._compute_qkv(ori_q, ori_k, ori_v)
144
+ q, k, v = self._split_multihead_qkv(q, k, v)
145
+ multi_head_tensor = self._scaled_dot_product_attention(q, k, v)
146
+ out = self._combine_heads(multi_head_tensor)
147
+
148
+ if self._use_res:
149
+ W_0_x = tf.layers.dense(
150
+ ori_v,
151
+ out.shape[2],
152
+ use_bias=False,
153
+ kernel_regularizer=self._l2_reg,
154
+ name='%s/dnn' % (self._name))
155
+ res_out = tf.nn.relu(out + W_0_x)
156
+ return res_out
157
+ else:
158
+ return out
159
+
160
+ def __call__(self, deep_fea):
161
+ deep_fea = self._multi_head_attention(deep_fea)
162
+ return deep_fea