easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
File without changes
@@ -0,0 +1,73 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.layers import multihead_attention
8
+ from easy_rec.python.model.rank_model import RankModel
9
+
10
+ from easy_rec.python.protos.autoint_pb2 import AutoInt as AutoIntConfig # NOQA
11
+
12
+ if tf.__version__ >= '2.0':
13
+ tf = tf.compat.v1
14
+
15
+
16
+ class AutoInt(RankModel):
17
+
18
+ def __init__(self,
19
+ model_config,
20
+ feature_configs,
21
+ features,
22
+ labels=None,
23
+ is_training=False):
24
+ super(AutoInt, self).__init__(model_config, feature_configs, features,
25
+ labels, is_training)
26
+ assert self._model_config.WhichOneof('model') == 'autoint', \
27
+ 'invalid model config: %s' % self._model_config.WhichOneof('model')
28
+ self._features, _ = self._input_layer(self._feature_dict, 'all')
29
+ self._feature_num = len(self._model_config.feature_groups[0].feature_names)
30
+ self._seq_key_num = 0
31
+ if len(self._model_config.feature_groups[0].sequence_features) > 0:
32
+ for seq_fea in self._model_config.feature_groups[0].sequence_features:
33
+ for seq_att in seq_fea.seq_att_map:
34
+ self._feature_num += len(seq_att.hist_seq)
35
+ self._seq_key_num += len(seq_att.key)
36
+ self._model_config = self._model_config.autoint
37
+ assert isinstance(self._model_config, AutoIntConfig)
38
+
39
+ fea_emb_dim_list = []
40
+ for feature_config in feature_configs:
41
+ fea_emb_dim_list.append(feature_config.embedding_dim)
42
+ assert len(set(fea_emb_dim_list)) == 1 and len(fea_emb_dim_list) == self._feature_num, \
43
+ 'AutoInt requires that all feature dimensions must be consistent.'
44
+
45
+ self._d_model = fea_emb_dim_list[0]
46
+ self._head_num = self._model_config.multi_head_num
47
+ self._head_size = self._model_config.multi_head_size
48
+
49
+ def build_predict_graph(self):
50
+ logging.info('feature_num: {0}'.format(self._feature_num))
51
+
52
+ attention_fea = tf.reshape(
53
+ self._features,
54
+ shape=[-1, self._feature_num + self._seq_key_num, self._d_model])
55
+
56
+ for i in range(self._model_config.interacting_layer_num):
57
+ attention_layer = multihead_attention.MultiHeadAttention(
58
+ head_num=self._head_num,
59
+ head_size=self._head_size,
60
+ l2_reg=self._l2_reg,
61
+ use_res=True,
62
+ name='multi_head_self_attention_layer_%d' % i)
63
+ attention_fea = attention_layer(attention_fea)
64
+
65
+ attention_fea = tf.reshape(
66
+ attention_fea,
67
+ shape=[-1, attention_fea.shape[1] * attention_fea.shape[2]])
68
+
69
+ final = tf.layers.dense(attention_fea, self._num_class, name='output')
70
+
71
+ self._add_to_prediction_dict(final)
72
+
73
+ return self._prediction_dict
@@ -0,0 +1,47 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ from easy_rec.python.layers import cmbf
6
+ from easy_rec.python.layers import dnn
7
+ from easy_rec.python.model.rank_model import RankModel
8
+
9
+ from easy_rec.python.protos.cmbf_pb2 import CMBF as CMBFConfig # NOQA
10
+
11
+ if tf.__version__ >= '2.0':
12
+ tf = tf.compat.v1
13
+
14
+
15
+ class CMBF(RankModel):
16
+ """CMBF: Cross-Modal-Based Fusion Recommendation Algorithm.
17
+
18
+ This is almost an exact implementation of the original CMBF model.
19
+ See the original paper:
20
+ https://www.mdpi.com/1424-8220/21/16/5275
21
+ """
22
+
23
+ def __init__(self,
24
+ model_config,
25
+ feature_configs,
26
+ features,
27
+ labels=None,
28
+ is_training=False):
29
+ super(CMBF, self).__init__(model_config, feature_configs, features, labels,
30
+ is_training)
31
+ assert self._model_config.WhichOneof('model') == 'cmbf', (
32
+ 'invalid model config: %s' % self._model_config.WhichOneof('model'))
33
+
34
+ self._cmbf_layer = cmbf.CMBF(model_config, feature_configs, features,
35
+ self._model_config.cmbf.config,
36
+ self._input_layer)
37
+ self._model_config = self._model_config.cmbf
38
+
39
+ def build_predict_graph(self):
40
+ hidden = self._cmbf_layer(self._is_training, l2_reg=self._l2_reg)
41
+ final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg,
42
+ 'final_dnn', self._is_training)
43
+ all_fea = final_dnn_layer(hidden)
44
+
45
+ final = tf.layers.dense(all_fea, self._num_class, name='output')
46
+ self._add_to_prediction_dict(final)
47
+ return self._prediction_dict
@@ -0,0 +1,182 @@
1
+ import tensorflow as tf
2
+
3
+ from easy_rec.python.core.metrics import metric_learning_average_precision_at_k
4
+ from easy_rec.python.core.metrics import metric_learning_recall_at_k
5
+ from easy_rec.python.layers import dnn
6
+ from easy_rec.python.layers.common_layers import highway
7
+ from easy_rec.python.loss.circle_loss import circle_loss
8
+ from easy_rec.python.loss.multi_similarity import ms_loss
9
+ from easy_rec.python.model.easy_rec_model import EasyRecModel
10
+ from easy_rec.python.protos.loss_pb2 import LossType
11
+ from easy_rec.python.utils.activation import gelu
12
+ from easy_rec.python.utils.proto_util import copy_obj
13
+
14
+ from easy_rec.python.protos.collaborative_metric_learning_pb2 import CoMetricLearningI2I as MetricLearningI2IConfig # NOQA
15
+
16
+ if tf.__version__ >= '2.0':
17
+ tf = tf.compat.v1
18
+
19
+
20
+ class CoMetricLearningI2I(EasyRecModel):
21
+
22
+ def __init__(
23
+ self,
24
+ model_config, # pipeline.model_config
25
+ feature_configs, # pipeline.feature_configs
26
+ features, # same as model_fn input
27
+ labels=None,
28
+ is_training=False):
29
+ super(CoMetricLearningI2I, self).__init__(model_config, feature_configs,
30
+ features, labels, is_training)
31
+ model = self._model_config.WhichOneof('model')
32
+ assert model == 'metric_learning', 'invalid model config: %s' % model
33
+
34
+ self._loss_type = self._model_config.loss_type
35
+ loss_type_name = LossType.Name(self._loss_type).lower()
36
+
37
+ self._model_config = self._model_config.metric_learning
38
+ assert isinstance(self._model_config, MetricLearningI2IConfig)
39
+
40
+ model_loss = self._model_config.WhichOneof('loss').lower()
41
+ assert model_loss == loss_type_name, 'invalid loss type: %s' % model_loss
42
+
43
+ if self._loss_type == LossType.CIRCLE_LOSS:
44
+ self.loss = self._model_config.circle_loss
45
+ elif self._loss_type == LossType.MULTI_SIMILARITY_LOSS:
46
+ self.loss = self._model_config.multi_similarity_loss
47
+ else:
48
+ raise ValueError('unsupported loss type: %s' %
49
+ LossType.Name(self._loss_type))
50
+
51
+ if not self.has_backbone:
52
+ self._highway_features = {}
53
+ self._highway_num = len(self._model_config.highway)
54
+ for _id in range(self._highway_num):
55
+ highway_cfg = self._model_config.highway[_id]
56
+ highway_feature, _ = self._input_layer(self._feature_dict,
57
+ highway_cfg.input)
58
+ self._highway_features[highway_cfg.input] = highway_feature
59
+
60
+ self.input_features = []
61
+ if self._model_config.HasField('input'):
62
+ input_feature, _ = self._input_layer(self._feature_dict,
63
+ self._model_config.input)
64
+ self.input_features.append(input_feature)
65
+
66
+ self.dnn = copy_obj(self._model_config.dnn)
67
+
68
+ if self._labels is not None:
69
+ if self._model_config.HasField('session_id'):
70
+ self.session_ids = self._labels.pop(self._model_config.session_id)
71
+ else:
72
+ self.session_ids = None
73
+
74
+ assert len(self._labels) > 0
75
+ self.labels = list(self._labels.values())[0]
76
+
77
+ if self._model_config.HasField('sample_id'):
78
+ self.sample_id = self._model_config.sample_id
79
+ else:
80
+ self.sample_id = None
81
+
82
+ def build_predict_graph(self):
83
+ if self.has_backbone:
84
+ tower_emb = self.backbone
85
+ else:
86
+ for _id in range(self._highway_num):
87
+ highway_cfg = self._model_config.highway[_id]
88
+ highway_fea = tf.layers.batch_normalization(
89
+ self._highway_features[highway_cfg.input],
90
+ training=self._is_training,
91
+ trainable=True,
92
+ name='highway_%s_bn' % highway_cfg.input)
93
+ highway_fea = highway(
94
+ highway_fea,
95
+ highway_cfg.emb_size,
96
+ activation=gelu,
97
+ scope='highway_%s' % _id)
98
+ print('highway_fea: ', highway_fea)
99
+ self.input_features.append(highway_fea)
100
+
101
+ feature = tf.concat(self.input_features, axis=1)
102
+
103
+ num_dnn_layer = len(self.dnn.hidden_units)
104
+ last_hidden = self.dnn.hidden_units.pop()
105
+ dnn_net = dnn.DNN(self.dnn, self._l2_reg, 'dnn', self._is_training)
106
+ net_output = dnn_net(feature)
107
+ tower_emb = tf.layers.dense(
108
+ inputs=net_output,
109
+ units=last_hidden,
110
+ kernel_regularizer=self._l2_reg,
111
+ name='dnn/dnn_%d' % (num_dnn_layer - 1))
112
+
113
+ if self._model_config.output_l2_normalized_emb:
114
+ norm_emb = tf.nn.l2_normalize(tower_emb, axis=-1)
115
+ self._prediction_dict['norm_emb'] = norm_emb
116
+ self._prediction_dict['norm_embedding'] = tf.reduce_join(
117
+ tf.as_string(norm_emb), axis=-1, separator=',')
118
+
119
+ self._prediction_dict['float_emb'] = tower_emb
120
+ self._prediction_dict['embedding'] = tf.reduce_join(
121
+ tf.as_string(tower_emb), axis=-1, separator=',')
122
+ if self.sample_id is not None and self.sample_id in self._feature_dict:
123
+ self._prediction_dict['sample_id'] = tf.identity(
124
+ self._feature_dict[self.sample_id])
125
+ return self._prediction_dict
126
+
127
+ def build_loss_graph(self):
128
+ emb = self._prediction_dict['float_emb']
129
+ emb_normed = self._model_config.output_l2_normalized_emb
130
+ norm_emb = self._prediction_dict['norm_emb'] if emb_normed else emb
131
+ if self._loss_type == LossType.CIRCLE_LOSS:
132
+ self._loss_dict['circle_loss'] = circle_loss(
133
+ norm_emb,
134
+ self.labels,
135
+ self.session_ids,
136
+ self.loss.margin,
137
+ self.loss.gamma,
138
+ embed_normed=emb_normed)
139
+ elif self._loss_type == LossType.MULTI_SIMILARITY_LOSS:
140
+ self._loss_dict['ms_loss'] = ms_loss(
141
+ norm_emb,
142
+ self.labels,
143
+ self.session_ids,
144
+ self.loss.alpha,
145
+ self.loss.beta,
146
+ self.loss.lamb,
147
+ self.loss.eps,
148
+ embed_normed=emb_normed)
149
+ else:
150
+ raise ValueError('invalid loss type: %s' % LossType.Name(self._loss_type))
151
+
152
+ return self._loss_dict
153
+
154
+ def get_outputs(self):
155
+ outputs = ['embedding', 'float_emb']
156
+ if self.sample_id is not None and 'sample_id' in self._prediction_dict:
157
+ outputs.append('sample_id')
158
+ if self._model_config.output_l2_normalized_emb:
159
+ outputs.append('norm_embedding')
160
+ outputs.append('norm_emb')
161
+ return outputs
162
+
163
+ def build_metric_graph(self, eval_config):
164
+ metric_dict = {}
165
+ recall_at_k = []
166
+ precision_at_k = []
167
+ for metric in eval_config.metrics_set:
168
+ if metric.WhichOneof('metric') == 'recall_at_topk':
169
+ recall_at_k.append(metric.recall_at_topk.topk)
170
+ elif metric.WhichOneof('metric') == 'precision_at_topk':
171
+ precision_at_k.append(metric.precision_at_topk.topk)
172
+
173
+ emb = self._prediction_dict['float_emb']
174
+ if len(recall_at_k) > 0:
175
+ metric_dict.update(
176
+ metric_learning_recall_at_k(recall_at_k, emb, self.labels,
177
+ self.session_ids))
178
+ if len(precision_at_k) > 0:
179
+ metric_dict.update(
180
+ metric_learning_average_precision_at_k(precision_at_k, emb,
181
+ self.labels, self.session_ids))
182
+ return metric_dict
@@ -0,0 +1,323 @@
1
+ # easy_rec/python/model/custom_model.py
2
+ import os
3
+ import sys
4
+
5
+ import six
6
+ import tensorflow as tf
7
+
8
+ from easy_rec.python.builders import loss_builder
9
+ from easy_rec.python.compat import regularizers
10
+ from easy_rec.python.feature_column.feature_column import FeatureColumnParser
11
+ from easy_rec.python.model.easy_rec_model import EasyRecModel
12
+ from easy_rec.python.protos.deepfm_pb2 import DeepFM as DeepFMConfig
13
+ # from easy_rec.python.protos.easy_rec_model_pb2 import LossType
14
+ from easy_rec.python.protos.loss_pb2 import LossType
15
+
16
+ if tf.__version__ >= '2.0':
17
+ tf = tf.compat.v1
18
+
19
+
20
+ class MultiHeadAttention(tf.compat.v1.keras.layers.Layer):
21
+
22
+ def __init__(self, num_heads, d_model):
23
+ super(MultiHeadAttention, self).__init__()
24
+ self.num_heads = num_heads
25
+ self.d_model = d_model
26
+
27
+ assert d_model % num_heads == 0
28
+
29
+ self.depth = d_model // num_heads
30
+
31
+ self.wq = tf.compat.v1.keras.layers.Dense(d_model)
32
+ self.wk = tf.compat.v1.keras.layers.Dense(d_model)
33
+ self.wv = tf.compat.v1.keras.layers.Dense(d_model)
34
+
35
+ self.dense = tf.compat.v1.keras.layers.Dense(d_model)
36
+
37
+ def split_heads(self, x, batch_size):
38
+ x = tf.reshape(x, (batch_size, 15, self.num_heads, self.depth))
39
+ return tf.transpose(x, perm=[0, 2, 1, 3])
40
+
41
+ def __call__(self, q, k, v, mask):
42
+ batch_size = tf.shape(q)[0]
43
+
44
+ q = self.wq(q)
45
+ k = self.wk(k)
46
+ v = self.wv(v)
47
+
48
+ q = self.split_heads(q, batch_size)
49
+ k = self.split_heads(k, batch_size)
50
+ v = self.split_heads(v, batch_size)
51
+
52
+ scaled_attention, attention_weights = self.scaled_dot_product_attention(
53
+ q, k, v, mask)
54
+
55
+ scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
56
+ concat_attention = tf.reshape(scaled_attention,
57
+ (batch_size, 15, self.d_model))
58
+
59
+ output = self.dense(concat_attention)
60
+
61
+ return output, attention_weights
62
+
63
+ def scaled_dot_product_attention(self, q, k, v, mask):
64
+ matmul_qk = tf.matmul(q, k, transpose_b=True)
65
+
66
+ dk = tf.cast(tf.shape(k)[-1], tf.float32)
67
+ scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
68
+
69
+ if mask is not None:
70
+ scaled_attention_logits += (mask * -1e9)
71
+
72
+ attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
73
+
74
+ output = tf.matmul(attention_weights, v)
75
+
76
+ return output, attention_weights
77
+
78
+
79
+ class CustomModel(EasyRecModel):
80
+
81
+ def __init__(self,
82
+ model_config,
83
+ feature_configs,
84
+ features,
85
+ labels=None,
86
+ is_training=False):
87
+ """
88
+ Args:
89
+ model_config: easy_rec.python.protos.easy_rec_model_pb2.EasyRecModel
90
+ model_config.custom_model is instance of:
91
+ easy_rec.python.protos.easy_rec_model_pb2.CustomModel
92
+ feature_configs: a collection of easy_rec.python.protos.feature_config.FeatureConfig
93
+ features: dict of feature tensors, which are described by easy_rec.python.protos.DatasetConfig.input_fields
94
+ labels: dict of labels tensors, which are described by easy_rec.python.protos.DatasetConfig.label_fields
95
+ """
96
+ super(CustomModel, self).__init__(model_config, feature_configs, features,
97
+ labels, is_training)
98
+ self.drop_out_rate = 0.05
99
+ self._raw_features, self._raw_feature_lst = self._input_layer(
100
+ self._feature_dict, 'raw_feature')
101
+ self._seq_features, _, _ = self._input_layer(
102
+ self._feature_dict, 'seq_feature', is_combine=False)
103
+ self._multi_head_1_features, _, _ = self._input_layer(
104
+ self._feature_dict, 'multi_head_feature_1', is_combine=False)
105
+ self._multi_head_2_features, _, _ = self._input_layer(
106
+ self._feature_dict, 'multi_head_feature_2', is_combine=False)
107
+
108
+ self._seq_features_concat = self._get_features_concat(
109
+ self._get_seq_features_reduce(
110
+ self._seq_features, reduce_type='mean', axis=1, keepdims=False),
111
+ axis=-1)
112
+
113
+ self._multi_head_1_features_concat = self._get_features_concat(
114
+ self._get_seq_features_reduce(
115
+ self._multi_head_1_features,
116
+ reduce_type='mean',
117
+ axis=1,
118
+ keepdims=True),
119
+ axis=1)
120
+
121
+ self._multi_head_2_features_concat = self._get_features_concat(
122
+ self._get_seq_features_reduce(
123
+ self._multi_head_2_features,
124
+ reduce_type='mean',
125
+ axis=1,
126
+ keepdims=True),
127
+ axis=1)
128
+
129
+ self._multi_head_1_layer = MultiHeadAttention(4, 12)
130
+ self._multi_head_2_layer = MultiHeadAttention(4, 12)
131
+
132
+ self._multi_head_1_output, _ = self._multi_head_1_layer(
133
+ self._multi_head_1_features_concat, self._multi_head_1_features_concat,
134
+ self._multi_head_1_features_concat, None)
135
+
136
+ self._multi_head_2_output, _ = self._multi_head_2_layer(
137
+ self._multi_head_2_features_concat, self._multi_head_2_features_concat,
138
+ self._multi_head_2_features_concat, None)
139
+
140
+ self._multi_head_1_output_end = self._get_seq_feature_reduce(
141
+ self._multi_head_1_output, reduce_type='mean', axis=1, keepdims=False)
142
+ self._multi_head_2_output_end = self._get_seq_feature_reduce(
143
+ self._multi_head_2_output, reduce_type='mean', axis=1, keepdims=False)
144
+
145
+ self.deep_input = self._get_features_concat([
146
+ self._raw_features, self._seq_features_concat,
147
+ self._multi_head_1_output_end, self._multi_head_2_output_end
148
+ ],
149
+ axis=-1)
150
+
151
+ def _get_seq_features_reduce(self, seq_features, reduce_type, axis: int,
152
+ keepdims: bool):
153
+ assert reduce_type in ['mean', 'sum',
154
+ 'max'], 'reduce_type must in mean | sum | max'
155
+ assert axis in [-1, 1, 2], 'axis must in -1 | 1 | 2'
156
+ seq_features_reduce = []
157
+ for feature in seq_features:
158
+ if reduce_type == 'mean':
159
+ seq_features_reduce.append(
160
+ tf.reduce_mean(feature[0], axis=axis, keepdims=keepdims))
161
+ elif reduce_type == 'sum':
162
+ seq_features_reduce.append(
163
+ tf.reduce_sum(feature[0], axis=axis, keepdims=keepdims))
164
+ elif reduce_type == 'max':
165
+ seq_features_reduce.append(
166
+ tf.reduce_max(feature[0], axis=axis, keepdims=keepdims))
167
+ else:
168
+ pass
169
+ return seq_features_reduce
170
+
171
+ def _get_seq_feature_reduce(self, seq_feature, reduce_type, axis: int,
172
+ keepdims: bool):
173
+ assert reduce_type in ['mean', 'sum',
174
+ 'max'], 'reduce_type must in mean | sum | max'
175
+ assert axis in [-1, 1, 2], 'axis must in -1 | 1 | 2'
176
+ if reduce_type == 'mean':
177
+ return tf.reduce_mean(seq_feature, axis=axis, keepdims=keepdims)
178
+ elif reduce_type == 'sum':
179
+ return tf.reduce_sum(seq_feature, axis=axis, keepdims=keepdims)
180
+ elif reduce_type == 'max':
181
+ return tf.reduce_max(seq_feature, axis=axis, keepdims=keepdims)
182
+ else:
183
+ pass
184
+
185
+ def _get_features_concat(self, features, axis):
186
+ assert axis in [-1, 1, 2], 'axis must in -1 | 1 | 2'
187
+ return tf.concat(features, axis=axis)
188
+
189
+ def build_predict_graph(self):
190
+ # build forward graph
191
+ dnn_1_list = self.get_layer_1(
192
+ self.deep_input, '1:64', prefix='dnn_1_1', n=1)
193
+
194
+ dnn_1_2_list = self.get_layer_n(
195
+ dnn_1_list, '1:32', prefix='dnn_1_2', branch_num=2)
196
+
197
+ dnn_1_3_list = self.get_layer_n(
198
+ dnn_1_2_list, '1:16', prefix='dnn_1_3', branch_num=2)
199
+
200
+ dnn_1_4_list = self.get_layer_n(
201
+ dnn_1_3_list, '1:8', prefix='dnn_1_4', branch_num=2)
202
+ dnn_1_5_list = self.get_layer_n(
203
+ dnn_1_4_list, '1:4', prefix='dnn_1_5', branch_num=2)
204
+ dnn_1_concat = tf.concat(dnn_1_5_list, axis=-1, name='dnn_1_concat')
205
+ dnn_2_1 = tf.keras.layers.Dense(
206
+ units=32, activation='relu', name=f'dnn_layer_2_1')(
207
+ dnn_1_concat)
208
+ if self.drop_out_rate == 0:
209
+ dnn_2_1_dropout = dnn_2_1
210
+ else:
211
+ dnn_2_1_dropout = tf.keras.layers.Dropout(
212
+ self.drop_out_rate, noise_shape=None, seed=None)(
213
+ dnn_2_1)
214
+
215
+ dnn_2_2 = tf.keras.layers.Dense(
216
+ units=16, activation='relu', name=f'dnn_layer_2_2')(
217
+ dnn_2_1_dropout)
218
+ if self.drop_out_rate == 0:
219
+ dnn_2_2_dropout = dnn_2_2
220
+ else:
221
+ dnn_2_2_dropout = tf.keras.layers.Dropout(
222
+ self.drop_out_rate, noise_shape=None, seed=None)(
223
+ dnn_2_2)
224
+
225
+ dnn_2_3 = tf.keras.layers.Dense(
226
+ units=8, activation='relu', name=f'dnn_layer_2_3')(
227
+ dnn_2_2_dropout)
228
+ if self.drop_out_rate == 0:
229
+ dnn_2_3_dropout = dnn_2_3
230
+ else:
231
+ dnn_2_3_dropout = tf.keras.layers.Dropout(
232
+ self.drop_out_rate, noise_shape=None, seed=None)(
233
+ dnn_2_3)
234
+
235
+ dnn_1_sig = tf.keras.layers.Dense(
236
+ units=1, activation='sigmoid', name='dnn_1_sig')(
237
+ dnn_2_3_dropout)
238
+
239
+ self._prediction_dict['label'] = dnn_1_sig
240
+ return self._prediction_dict
241
+
242
+ def build_loss_graph(self):
243
+ # assert self._model_config.loss_type == LossType.CLASSIFICATION
244
+ loss = tf.keras.losses.BinaryFocalCrossentropy(gamma=2, from_logits=False)
245
+ label = list(self._labels.values())[0]
246
+
247
+ self._loss_dict['custom_loss'] = loss(label, self._prediction_dict['label'])
248
+
249
+ return self._loss_dict
250
+
251
+ def build_metric_graph(self, eval_config):
252
+ metric_dict = {}
253
+ num_thresholds = eval_config.metrics_set[0].auc.num_thresholds
254
+ metric_dict['auc'] = tf.metrics.auc(
255
+ list(self._labels.values())[0],
256
+ self._prediction_dict['label'],
257
+ num_thresholds=num_thresholds)
258
+ return metric_dict
259
+
260
+ def get_outputs(self):
261
+
262
+ return ['label']
263
+
264
+ def get_layer_1(self, input, dnn_layers, prefix, n=2):
265
+ output_list = []
266
+ dnn_layers_list = dnn_layers.split(',')
267
+ for i in range(n):
268
+ for j in range(len(dnn_layers_list)):
269
+ dnn_info_list = dnn_layers_list[j].split(':')
270
+ if j == 0:
271
+ deep_layer = tf.keras.layers.Dense(
272
+ units=int(dnn_info_list[1]),
273
+ activation='relu'
274
+ # , kernel_regularizer=tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
275
+ ,
276
+ name=f'dnn_layer_{prefix}_{i}_{j}')(
277
+ input)
278
+
279
+ else:
280
+ deep_layer = tf.keras.layers.Dense(
281
+ units=int(dnn_info_list[1]),
282
+ activation='relu'
283
+ # , kernel_regularizer=tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
284
+ ,
285
+ name=f'dnn_layer_{prefix}_{i}_{j}')(
286
+ deep_layer)
287
+ output_list.append(deep_layer)
288
+ return output_list
289
+
290
+ def get_layer_n(self, layer_output_list, dnn_layers, prefix, branch_num=2):
291
+ output_list = []
292
+ dnn_layers_list = dnn_layers.split(',')
293
+ for branch in range(branch_num):
294
+ for i in range(len(layer_output_list)):
295
+ for j in range(len(dnn_layers_list)):
296
+ dnn_info_list = dnn_layers_list[j].split(':')
297
+ if j == 0:
298
+ deep_layer = tf.keras.layers.Dense(
299
+ units=int(dnn_info_list[1]),
300
+ activation='relu'
301
+ # , kernel_regularizer=tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
302
+ ,
303
+ name=f'dnn_layer_{prefix}_{branch}_{i}_{j}')(
304
+ layer_output_list[i])
305
+
306
+ else:
307
+ deep_layer = tf.keras.layers.Dense(
308
+ units=int(dnn_info_list[1]),
309
+ activation='relu'
310
+ # , kernel_regularizer=tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
311
+ ,
312
+ name=f'dnn_layer_{prefix}_{branch}_{i}_{j}')(
313
+ deep_layer)
314
+ # deep_layer_end = tf.concat([deep_layer,bundle_info_sum],axis=-1)
315
+ if self.drop_out_rate == 0:
316
+ deep_layer_dropout = deep_layer
317
+ else:
318
+ deep_layer_dropout = tf.keras.layers.Dropout(
319
+ self.drop_out_rate, noise_shape=None, seed=None)(
320
+ deep_layer)
321
+ output_list.append(deep_layer_dropout)
322
+
323
+ return output_list