easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,138 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ from easy_rec.python.layers import dnn
6
+ from easy_rec.python.model.match_model import MatchModel
7
+ from easy_rec.python.protos.dat_pb2 import DAT as DATConfig
8
+ from easy_rec.python.protos.loss_pb2 import LossType
9
+ from easy_rec.python.utils.proto_util import copy_obj
10
+
11
+ if tf.__version__ >= '2.0':
12
+ tf = tf.compat.v1
13
+
14
+
15
+ class DAT(MatchModel):
16
+ """Dual Augmented Two-tower Model."""
17
+
18
+ def __init__(self,
19
+ model_config,
20
+ feature_configs,
21
+ features,
22
+ labels=None,
23
+ is_training=False):
24
+ super(DAT, self).__init__(model_config, feature_configs, features, labels,
25
+ is_training)
26
+ assert self._model_config.WhichOneof('model') == 'dat', \
27
+ 'invalid model config: %s' % self._model_config.WhichOneof('model')
28
+
29
+ feature_group_names = [
30
+ fg.group_name for fg in self._model_config.feature_groups
31
+ ]
32
+ assert 'user' in feature_group_names, 'user feature group not found'
33
+ assert 'item' in feature_group_names, 'item feature group not found'
34
+ assert 'user_id_augment' in feature_group_names, 'user_id_augment feature group not found'
35
+ assert 'item_id_augment' in feature_group_names, 'item_id_augment feature group not found'
36
+
37
+ self._model_config = self._model_config.dat
38
+ assert isinstance(self._model_config, DATConfig)
39
+
40
+ self.user_tower = copy_obj(self._model_config.user_tower)
41
+ self.user_deep_feature, _ = self._input_layer(self._feature_dict, 'user')
42
+ self.user_augmented_vec, _ = self._input_layer(self._feature_dict,
43
+ 'user_id_augment')
44
+
45
+ self.item_tower = copy_obj(self._model_config.item_tower)
46
+ self.item_deep_feature, _ = self._input_layer(self._feature_dict, 'item')
47
+ self.item_augmented_vec, _ = self._input_layer(self._feature_dict,
48
+ 'item_id_augment')
49
+
50
+ self._user_tower_emb = None
51
+ self._item_tower_emb = None
52
+
53
+ def build_predict_graph(self):
54
+ num_user_dnn_layer = len(self.user_tower.dnn.hidden_units)
55
+ last_user_hidden = self.user_tower.dnn.hidden_units.pop()
56
+ user_dnn = dnn.DNN(self.user_tower.dnn, self._l2_reg, 'user_dnn',
57
+ self._is_training)
58
+
59
+ user_tower_feature = tf.concat(
60
+ [self.user_deep_feature, self.user_augmented_vec], axis=-1)
61
+ user_tower_emb = user_dnn(user_tower_feature)
62
+ user_tower_emb = tf.layers.dense(
63
+ inputs=user_tower_emb,
64
+ units=last_user_hidden,
65
+ kernel_regularizer=self._l2_reg,
66
+ name='user_dnn/dnn_%d' % (num_user_dnn_layer - 1))
67
+
68
+ num_item_dnn_layer = len(self.item_tower.dnn.hidden_units)
69
+ last_item_hidden = self.item_tower.dnn.hidden_units.pop()
70
+ item_dnn = dnn.DNN(self.item_tower.dnn, self._l2_reg, 'item_dnn',
71
+ self._is_training)
72
+
73
+ item_tower_feature = tf.concat(
74
+ [self.item_deep_feature, self.item_augmented_vec], axis=-1)
75
+ item_tower_emb = item_dnn(item_tower_feature)
76
+ item_tower_emb = tf.layers.dense(
77
+ inputs=item_tower_emb,
78
+ units=last_item_hidden,
79
+ kernel_regularizer=self._l2_reg,
80
+ name='item_dnn/dnn_%d' % (num_item_dnn_layer - 1))
81
+
82
+ user_tower_emb = self.norm(user_tower_emb)
83
+ item_tower_emb = self.norm(item_tower_emb)
84
+ temperature = self._model_config.temperature
85
+
86
+ y_pred = self.sim(user_tower_emb, item_tower_emb) / temperature
87
+
88
+ if self._is_point_wise:
89
+ raise ValueError('Currently DAT model only supports list wise mode.')
90
+
91
+ if self._loss_type == LossType.CLASSIFICATION:
92
+ raise ValueError(
93
+ 'Currently DAT model only supports SOFTMAX_CROSS_ENTROPY loss.')
94
+ elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
95
+ y_pred = self._mask_in_batch(y_pred)
96
+ self._prediction_dict['logits'] = y_pred
97
+ self._prediction_dict['probs'] = tf.nn.softmax(y_pred)
98
+ else:
99
+ self._prediction_dict['y'] = y_pred
100
+
101
+ self._prediction_dict['user_tower_emb'] = user_tower_emb
102
+ self._prediction_dict['item_tower_emb'] = item_tower_emb
103
+ self._prediction_dict['user_emb'] = tf.reduce_join(
104
+ tf.as_string(user_tower_emb), axis=-1, separator=',')
105
+ self._prediction_dict['item_emb'] = tf.reduce_join(
106
+ tf.as_string(item_tower_emb), axis=-1, separator=',')
107
+
108
+ augmented_p_u = tf.stop_gradient(user_tower_emb)
109
+ augmented_p_i = tf.stop_gradient(item_tower_emb)
110
+
111
+ self._prediction_dict['augmented_p_u'] = augmented_p_u
112
+ self._prediction_dict['augmented_p_i'] = augmented_p_i
113
+
114
+ self._prediction_dict['augmented_a_u'] = self.user_augmented_vec
115
+ self._prediction_dict['augmented_a_i'] = self.item_augmented_vec
116
+
117
+ return self._prediction_dict
118
+
119
+ def get_outputs(self):
120
+ if self._loss_type == LossType.CLASSIFICATION:
121
+ raise ValueError(
122
+ 'Currently DAT model only supports SOFTMAX_CROSS_ENTROPY loss.')
123
+ elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
124
+ self._prediction_dict['logits'] = tf.squeeze(
125
+ self._prediction_dict['logits'], axis=-1)
126
+ self._prediction_dict['probs'] = tf.nn.sigmoid(
127
+ self._prediction_dict['logits'])
128
+ return [
129
+ 'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb',
130
+ 'item_tower_emb', 'augmented_p_u', 'augmented_p_i', 'augmented_a_u',
131
+ 'augmented_a_i'
132
+ ]
133
+ else:
134
+ raise ValueError('invalid loss type: %s' % str(self._loss_type))
135
+
136
+ def build_output_dict(self):
137
+ output_dict = super(DAT, self).build_output_dict()
138
+ return output_dict
@@ -0,0 +1,116 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ from easy_rec.python.layers import cmbf
6
+ from easy_rec.python.layers import dnn
7
+ from easy_rec.python.layers import mmoe
8
+ from easy_rec.python.layers import uniter
9
+ from easy_rec.python.model.multi_task_model import MultiTaskModel
10
+ from easy_rec.python.protos.dbmtl_pb2 import DBMTL as DBMTLConfig
11
+
12
+ if tf.__version__ >= '2.0':
13
+ tf = tf.compat.v1
14
+
15
+
16
+ class DBMTL(MultiTaskModel):
17
+
18
+ def __init__(self,
19
+ model_config,
20
+ feature_configs,
21
+ features,
22
+ labels=None,
23
+ is_training=False):
24
+ super(DBMTL, self).__init__(model_config, feature_configs, features, labels,
25
+ is_training)
26
+ assert self._model_config.WhichOneof('model') == 'dbmtl', \
27
+ 'invalid model config: %s' % self._model_config.WhichOneof('model')
28
+ self._model_config = self._model_config.dbmtl
29
+ assert isinstance(self._model_config, DBMTLConfig)
30
+
31
+ if self._model_config.HasField('bottom_cmbf'):
32
+ self._cmbf_layer = cmbf.CMBF(model_config, feature_configs, features,
33
+ self._model_config.bottom_cmbf,
34
+ self._input_layer)
35
+ elif self._model_config.HasField('bottom_uniter'):
36
+ self._uniter_layer = uniter.Uniter(model_config, feature_configs,
37
+ features,
38
+ self._model_config.bottom_uniter,
39
+ self._input_layer)
40
+ elif not self.has_backbone:
41
+ self._features, self._feature_list = self._input_layer(
42
+ self._feature_dict, 'all')
43
+ else:
44
+ assert False, 'invalid code branch'
45
+ self._init_towers(self._model_config.task_towers)
46
+
47
+ def build_predict_graph(self):
48
+ bottom_fea = self.backbone
49
+ if bottom_fea is None:
50
+ if self._model_config.HasField('bottom_cmbf'):
51
+ bottom_fea = self._cmbf_layer(self._is_training, l2_reg=self._l2_reg)
52
+ elif self._model_config.HasField('bottom_uniter'):
53
+ bottom_fea = self._uniter_layer(self._is_training, l2_reg=self._l2_reg)
54
+ elif self._model_config.HasField('bottom_dnn'):
55
+ bottom_dnn = dnn.DNN(
56
+ self._model_config.bottom_dnn,
57
+ self._l2_reg,
58
+ name='bottom_dnn',
59
+ is_training=self._is_training)
60
+ bottom_fea = bottom_dnn(self._features)
61
+ else:
62
+ bottom_fea = self._features
63
+
64
+ # MMOE block
65
+ if self._model_config.HasField('expert_dnn'):
66
+ mmoe_layer = mmoe.MMOE(
67
+ self._model_config.expert_dnn,
68
+ l2_reg=self._l2_reg,
69
+ num_task=self._task_num,
70
+ num_expert=self._model_config.num_expert)
71
+ task_input_list = mmoe_layer(bottom_fea)
72
+ else:
73
+ task_input_list = [bottom_fea] * self._task_num
74
+
75
+ tower_features = {}
76
+ # task specify network
77
+ for i, task_tower_cfg in enumerate(self._model_config.task_towers):
78
+ tower_name = task_tower_cfg.tower_name
79
+ if task_tower_cfg.HasField('dnn'):
80
+ tower_dnn = dnn.DNN(
81
+ task_tower_cfg.dnn,
82
+ self._l2_reg,
83
+ name=tower_name + '/dnn',
84
+ is_training=self._is_training)
85
+ tower_fea = tower_dnn(task_input_list[i])
86
+ tower_features[tower_name] = tower_fea
87
+ else:
88
+ tower_features[tower_name] = task_input_list[i]
89
+
90
+ tower_outputs = {}
91
+ relation_features = {}
92
+ # bayes network
93
+ for task_tower_cfg in self._model_config.task_towers:
94
+ tower_name = task_tower_cfg.tower_name
95
+ relation_dnn = dnn.DNN(
96
+ task_tower_cfg.relation_dnn,
97
+ self._l2_reg,
98
+ name=tower_name + '/relation_dnn',
99
+ is_training=self._is_training)
100
+ tower_inputs = [tower_features[tower_name]]
101
+ for relation_tower_name in task_tower_cfg.relation_tower_names:
102
+ tower_inputs.append(relation_features[relation_tower_name])
103
+ relation_input = tf.concat(
104
+ tower_inputs, axis=-1, name=tower_name + '/relation_input')
105
+ relation_fea = relation_dnn(relation_input)
106
+ relation_features[tower_name] = relation_fea
107
+
108
+ output_logits = tf.layers.dense(
109
+ relation_fea,
110
+ task_tower_cfg.num_class,
111
+ kernel_regularizer=self._l2_reg,
112
+ name=tower_name + '/output')
113
+ tower_outputs[tower_name] = output_logits
114
+
115
+ self._add_to_prediction_dict(tower_outputs)
116
+ return self._prediction_dict
@@ -0,0 +1,70 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import tensorflow as tf
5
+
6
+ from easy_rec.python.layers import dnn
7
+ from easy_rec.python.model.rank_model import RankModel
8
+
9
+ from easy_rec.python.protos.dcn_pb2 import DCN as DCNConfig # NOQA
10
+
11
+ if tf.__version__ >= '2.0':
12
+ tf = tf.compat.v1
13
+
14
+
15
+ class DCN(RankModel):
16
+
17
+ def __init__(self,
18
+ model_config,
19
+ feature_configs,
20
+ features,
21
+ labels=None,
22
+ is_training=False):
23
+ super(DCN, self).__init__(model_config, feature_configs, features, labels,
24
+ is_training)
25
+ assert self._model_config.WhichOneof('model') == 'dcn', \
26
+ 'invalid model config: %s' % self._model_config.WhichOneof('model')
27
+ self._model_config = self._model_config.dcn
28
+ assert isinstance(self._model_config, DCNConfig)
29
+
30
+ self._features, _ = self._input_layer(self._feature_dict, 'all')
31
+
32
+ def _cross_net(self, tensor, num_cross_layers):
33
+ x = x0 = tensor
34
+ input_dim = tensor.shape[-1]
35
+ for i in range(num_cross_layers):
36
+ name = 'cross_layer_%s' % i
37
+ w = tf.get_variable(
38
+ name=name + '_w',
39
+ dtype=tf.float32,
40
+ shape=(input_dim),
41
+ )
42
+ b = tf.get_variable(name=name + '_b', dtype=tf.float32, shape=(input_dim))
43
+ xw = tf.reduce_sum(x * w, axis=1, keepdims=True) # (B, 1)
44
+ x = tf.math.add(tf.math.add(x0 * xw, b), x)
45
+ return x
46
+
47
+ def build_predict_graph(self):
48
+ tower_fea_arr = []
49
+ # deep tower
50
+ deep_tower_config = self._model_config.deep_tower
51
+
52
+ dnn_layer = dnn.DNN(deep_tower_config.dnn, self._l2_reg, 'dnn',
53
+ self._is_training)
54
+ deep_tensor = dnn_layer(self._features)
55
+ tower_fea_arr.append(deep_tensor)
56
+ # cross tower
57
+ cross_tower_config = self._model_config.cross_tower
58
+ num_cross_layers = cross_tower_config.cross_num
59
+ cross_tensor = self._cross_net(self._features, num_cross_layers)
60
+ tower_fea_arr.append(cross_tensor)
61
+ # final tower
62
+ all_fea = tf.concat(tower_fea_arr, axis=1)
63
+ final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg,
64
+ 'final_dnn', self._is_training)
65
+ all_fea = final_dnn_layer(all_fea)
66
+ output = tf.layers.dense(all_fea, self._num_class, name='output')
67
+
68
+ self._add_to_prediction_dict(output)
69
+
70
+ return self._prediction_dict
@@ -0,0 +1,106 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ from easy_rec.python.layers import dnn
6
+ from easy_rec.python.layers import fm
7
+ from easy_rec.python.model.rank_model import RankModel
8
+ from easy_rec.python.protos.deepfm_pb2 import DeepFM as DeepFMConfig
9
+
10
+ if tf.__version__ >= '2.0':
11
+ tf = tf.compat.v1
12
+
13
+
14
+ class DeepFM(RankModel):
15
+
16
+ def __init__(self,
17
+ model_config,
18
+ feature_configs,
19
+ features,
20
+ labels=None,
21
+ is_training=False):
22
+ super(DeepFM, self).__init__(model_config, feature_configs, features,
23
+ labels, is_training)
24
+ assert self._model_config.WhichOneof('model') == 'deepfm', \
25
+ 'invalid model config: %s' % self._model_config.WhichOneof('model')
26
+ self._model_config = self._model_config.deepfm
27
+ assert isinstance(self._model_config, DeepFMConfig)
28
+
29
+ # backward compatibility
30
+ if self._model_config.HasField('wide_regularization'):
31
+ tf.logging.warn(
32
+ 'wide_regularization is deprecated, please use l2_regularization')
33
+
34
+ self._wide_features, _ = self._input_layer(self._feature_dict, 'wide')
35
+ self._deep_features, self._fm_features = self._input_layer(
36
+ self._feature_dict, 'deep')
37
+ if 'fm' in self._input_layer._feature_groups:
38
+ _, self._fm_features = self._input_layer(self._feature_dict, 'fm')
39
+
40
+ def build_input_layer(self, model_config, feature_configs):
41
+ # overwrite create input_layer to support wide_output_dim
42
+ has_final = len(model_config.deepfm.final_dnn.hidden_units) > 0
43
+ if not has_final:
44
+ assert model_config.deepfm.wide_output_dim == model_config.num_class
45
+ self._wide_output_dim = model_config.deepfm.wide_output_dim
46
+ super(DeepFM, self).build_input_layer(model_config, feature_configs)
47
+
48
+ def build_predict_graph(self):
49
+ # Wide
50
+ wide_fea = tf.reduce_sum(
51
+ self._wide_features, axis=1, keepdims=True, name='wide_feature')
52
+
53
+ # FM
54
+ fm_fea = fm.FM(name='fm_feature')(self._fm_features)
55
+ self._fm_outputs = fm_fea
56
+
57
+ # Deep
58
+ deep_layer = dnn.DNN(self._model_config.dnn, self._l2_reg, 'deep_feature',
59
+ self._is_training)
60
+ deep_fea = deep_layer(self._deep_features)
61
+
62
+ # Final
63
+ if len(self._model_config.final_dnn.hidden_units) > 0:
64
+ all_fea = tf.concat([wide_fea, fm_fea, deep_fea], axis=1)
65
+ final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg,
66
+ 'final_dnn', self._is_training)
67
+ all_fea = final_dnn_layer(all_fea)
68
+ output = tf.layers.dense(
69
+ all_fea,
70
+ self._num_class,
71
+ kernel_regularizer=self._l2_reg,
72
+ name='output')
73
+ else:
74
+ if self._num_class > 1:
75
+ fm_fea = tf.layers.dense(
76
+ fm_fea,
77
+ self._num_class,
78
+ kernel_regularizer=self._l2_reg,
79
+ name='fm_logits')
80
+ else:
81
+ fm_fea = tf.reduce_sum(fm_fea, 1, keepdims=True)
82
+ deep_fea = tf.layers.dense(
83
+ deep_fea,
84
+ self._num_class,
85
+ kernel_regularizer=self._l2_reg,
86
+ name='deep_logits')
87
+ output = wide_fea + fm_fea + deep_fea
88
+
89
+ self._add_to_prediction_dict(output)
90
+
91
+ return self._prediction_dict
92
+
93
+ def build_feature_output_dict(self):
94
+ outputs = super(DeepFM, self).build_feature_output_dict()
95
+ outputs.update({
96
+ 'wide_features':
97
+ tf.reduce_join(
98
+ tf.as_string(self._wide_features), axis=-1, separator=','),
99
+ 'deep_features':
100
+ tf.reduce_join(
101
+ tf.as_string(self._deep_features), axis=-1, separator=','),
102
+ 'fm_outputs':
103
+ tf.reduce_join(
104
+ tf.as_string(self._fm_outputs), axis=-1, separator=',')
105
+ })
106
+ return outputs
@@ -0,0 +1,73 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.layers import dnn
8
+ from easy_rec.python.model.rank_model import RankModel
9
+
10
+ from easy_rec.python.protos.dlrm_pb2 import DLRM as DLRMConfig # NOQA
11
+
12
+ if tf.__version__ >= '2.0':
13
+ tf = tf.compat.v1
14
+
15
+
16
+ class DLRM(RankModel):
17
+ """Implements Deep Learning Recommendation Model for Personalization and Recommendation Systems(FaceBook)."""
18
+
19
+ def __init__(self,
20
+ model_config,
21
+ feature_configs,
22
+ features,
23
+ labels=None,
24
+ is_training=False):
25
+ super(DLRM, self).__init__(model_config, feature_configs, features, labels,
26
+ is_training)
27
+ assert model_config.WhichOneof('model') == 'dlrm', \
28
+ 'invalid model config: %s' % model_config.WhichOneof('model')
29
+ self._model_config = model_config.dlrm
30
+ assert isinstance(self._model_config, DLRMConfig)
31
+ assert self._input_layer.has_group(
32
+ 'sparse'), 'sparse group is not specified'
33
+ _, self._sparse_features = self._input_layer(self._feature_dict, 'sparse')
34
+ assert self._input_layer.has_group('dense'), 'dense group is not specified'
35
+ self._dense_feature, _ = self._input_layer(self._feature_dict, 'dense')
36
+
37
+ def build_predict_graph(self):
38
+ bot_dnn = dnn.DNN(self._model_config.bot_dnn, self._l2_reg, 'bot_dnn',
39
+ self._is_training)
40
+ dense_fea = bot_dnn(self._dense_feature)
41
+ logging.info('arch_interaction_op = %s' %
42
+ self._model_config.arch_interaction_op)
43
+ if self._model_config.arch_interaction_op == 'cat':
44
+ all_fea = tf.concat([dense_fea] + self._sparse_features, axis=1)
45
+ elif self._model_config.arch_interaction_op == 'dot':
46
+ assert dense_fea.get_shape()[1] == self._sparse_features[0].get_shape()[1], \
47
+ 'bot_dnn last hidden[%d] != sparse feature embedding_dim[%d]' % (
48
+ dense_fea.get_shape()[1], self._sparse_features[0].get_shape()[1])
49
+
50
+ all_feas = [dense_fea] + self._sparse_features
51
+ all_feas = [x[:, None, :] for x in all_feas]
52
+ all_feas = tf.concat(all_feas, axis=1)
53
+ num_fea = all_feas.get_shape()[1]
54
+ interaction = tf.einsum('bne,bme->bnm', all_feas, all_feas)
55
+ offset = 0 if self._model_config.arch_interaction_itself else 1
56
+ upper_tri = []
57
+ for i in range(num_fea):
58
+ upper_tri.append(interaction[:, i, (i + offset):num_fea])
59
+ upper_tri = tf.concat(upper_tri, axis=1)
60
+ concat_feas = [upper_tri] + self._sparse_features
61
+ if self._model_config.arch_with_dense_feature:
62
+ concat_feas.append(dense_fea)
63
+ all_fea = tf.concat(concat_feas, axis=1)
64
+
65
+ top_dnn = dnn.DNN(self._model_config.top_dnn, self._l2_reg, 'top_dnn',
66
+ self._is_training)
67
+ all_fea = top_dnn(all_fea)
68
+ logits = tf.layers.dense(
69
+ all_fea, 1, kernel_regularizer=self._l2_reg, name='output')
70
+
71
+ self._add_to_prediction_dict(logits)
72
+
73
+ return self._prediction_dict