easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,192 @@
1
+ # -*- encoding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import six
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.compat.layers import layer_norm as tf_layer_norm
8
+ from easy_rec.python.utils.activation import get_activation
9
+
10
+ if tf.__version__ >= '2.0':
11
+ tf = tf.compat.v1
12
+
13
+
14
+ def highway(x,
15
+ size=None,
16
+ activation=None,
17
+ num_layers=1,
18
+ scope='highway',
19
+ dropout=0.0,
20
+ init_gate_bias=-1.0,
21
+ reuse=None):
22
+ if isinstance(activation, six.string_types):
23
+ activation = get_activation(activation)
24
+ with tf.variable_scope(scope, reuse):
25
+ if size is None:
26
+ size = x.shape.as_list()[-1]
27
+ else:
28
+ x = tf.layers.dense(x, size, name='input_projection', reuse=reuse)
29
+
30
+ initializer = tf.constant_initializer(init_gate_bias)
31
+ for i in range(num_layers):
32
+ T = tf.layers.dense(
33
+ x,
34
+ size,
35
+ activation=tf.sigmoid,
36
+ bias_initializer=initializer,
37
+ name='gate_%d' % i,
38
+ reuse=reuse)
39
+ H = tf.layers.dense(
40
+ x, size, activation=activation, name='activation_%d' % i, reuse=reuse)
41
+ if dropout > 0.0:
42
+ H = tf.nn.dropout(H, 1.0 - dropout)
43
+ x = H * T + x * (1.0 - T)
44
+ return x
45
+
46
+
47
+ def text_cnn(x,
48
+ filter_sizes=(3, 4, 5),
49
+ num_filters=(128, 64, 64),
50
+ scope_name='textcnn',
51
+ reuse=False):
52
+ # x: None * step_dim * embed_dim
53
+ assert len(filter_sizes) == len(num_filters)
54
+ initializer = tf.variance_scaling_initializer()
55
+ pooled_outputs = []
56
+ for i in range(len(filter_sizes)):
57
+ filter_size = filter_sizes[i]
58
+ num_filter = num_filters[i]
59
+ scope_name_i = scope_name + '_' + str(filter_size)
60
+ with tf.variable_scope(scope_name_i, reuse=reuse):
61
+ # conv shape: (batch_size, seq_len - filter_size + 1, num_filters)
62
+ conv = tf.layers.conv1d(
63
+ x,
64
+ filters=int(num_filter),
65
+ kernel_size=int(filter_size),
66
+ activation=tf.nn.relu,
67
+ name='conv_layer',
68
+ reuse=reuse,
69
+ kernel_initializer=initializer,
70
+ padding='same')
71
+ pool = tf.reduce_max(
72
+ conv, axis=1) # max pooling, shape: (batch_size, num_filters)
73
+ pooled_outputs.append(pool)
74
+ pool_flat = tf.concat(
75
+ pooled_outputs, 1) # shape: (batch_size, num_filters * len(filter_sizes))
76
+ return pool_flat
77
+
78
+
79
+ def layer_norm(input_tensor, name=None, reuse=None):
80
+ """Run layer normalization on the last dimension of the tensor."""
81
+ return tf_layer_norm(
82
+ inputs=input_tensor,
83
+ begin_norm_axis=-1,
84
+ begin_params_axis=-1,
85
+ reuse=reuse,
86
+ scope=name)
87
+
88
+
89
+ class EnhancedInputLayer(object):
90
+ """Enhance the raw input layer."""
91
+
92
+ def __init__(self, input_layer, feature_dict, group_name, reuse=None):
93
+ self._group_name = group_name
94
+ self.name = 'input_' + self._group_name
95
+ self._input_layer = input_layer
96
+ self._feature_dict = feature_dict
97
+ self._reuse = reuse
98
+ self.built = False
99
+
100
+ def __call__(self, config, is_training, **kwargs):
101
+ if not self.built:
102
+ self.build(config, is_training)
103
+
104
+ if config.output_seq_and_normal_feature:
105
+ return self.inputs
106
+
107
+ if config.do_batch_norm and config.do_layer_norm:
108
+ raise ValueError(
109
+ 'can not do batch norm and layer norm for input layer at the same time'
110
+ )
111
+ with tf.name_scope(self.name):
112
+ return self.call(config, is_training)
113
+
114
+ def build(self, config, training):
115
+ self.built = True
116
+ combine = not config.output_seq_and_normal_feature
117
+ self.inputs = self._input_layer(
118
+ self._feature_dict, self._group_name, is_combine=combine)
119
+ if config.output_seq_and_normal_feature:
120
+ seq_feature_and_len, _, target_features = self.inputs
121
+ seq_len = seq_feature_and_len[0][1]
122
+ seq_features = [seq_fea for seq_fea, _ in seq_feature_and_len]
123
+ if config.concat_seq_feature:
124
+ if target_features:
125
+ target_features = tf.concat(target_features, axis=-1)
126
+ else:
127
+ target_features = None
128
+ assert len(
129
+ seq_features) > 0, '[%s] sequence feature is empty' % self.name
130
+ seq_features = tf.concat(seq_features, axis=-1)
131
+ self.inputs = seq_features, seq_len, target_features
132
+ self.reset(config, training)
133
+
134
+ def reset(self, config, training):
135
+ if 0.0 < config.dropout_rate < 1.0:
136
+ self.dropout = tf.keras.layers.Dropout(rate=config.dropout_rate)
137
+
138
+ if training and 0.0 < config.feature_dropout_rate < 1.0:
139
+ keep_prob = 1.0 - config.feature_dropout_rate
140
+ self.bern = tf.distributions.Bernoulli(probs=keep_prob, dtype=tf.float32)
141
+
142
+ def call(self, config, training):
143
+ features, feature_list = self.inputs
144
+ num_features = len(feature_list)
145
+
146
+ do_ln = config.do_layer_norm
147
+ do_bn = config.do_batch_norm
148
+ do_feature_dropout = training and 0.0 < config.feature_dropout_rate < 1.0
149
+ if do_feature_dropout:
150
+ keep_prob = 1.0 - config.feature_dropout_rate
151
+ mask = self.bern.sample(num_features)
152
+ elif do_bn:
153
+ features = tf.layers.batch_normalization(
154
+ features, training=training, reuse=self._reuse)
155
+ elif do_ln:
156
+ features = layer_norm(
157
+ features, name=self._group_name + '_features', reuse=self._reuse)
158
+
159
+ output_feature_list = config.output_2d_tensor_and_feature_list
160
+ output_feature_list = output_feature_list or config.only_output_feature_list
161
+ output_feature_list = output_feature_list or config.only_output_3d_tensor
162
+ rate = config.dropout_rate
163
+ do_dropout = 0.0 < rate < 1.0
164
+ if do_feature_dropout or do_ln or do_bn or do_dropout:
165
+ for i in range(num_features):
166
+ fea = feature_list[i]
167
+ if do_bn:
168
+ fea = tf.layers.batch_normalization(
169
+ fea, training=training, reuse=self._reuse)
170
+ elif do_ln:
171
+ ln_name = self._group_name + 'f_%d' % i
172
+ fea = layer_norm(fea, name=ln_name, reuse=self._reuse)
173
+ if do_dropout and output_feature_list:
174
+ fea = self.dropout.call(fea, training=training)
175
+ if do_feature_dropout:
176
+ fea = tf.div(fea, keep_prob) * mask[i]
177
+ feature_list[i] = fea
178
+ if do_feature_dropout:
179
+ features = tf.concat(feature_list, axis=-1)
180
+
181
+ if do_dropout and not do_feature_dropout:
182
+ features = self.dropout.call(features, training=training)
183
+ if features.shape.ndims == 3 and int(features.shape[0]) == 1:
184
+ features = tf.squeeze(features, axis=0)
185
+
186
+ if config.only_output_feature_list:
187
+ return feature_list
188
+ if config.only_output_3d_tensor:
189
+ return tf.stack(feature_list, axis=1)
190
+ if config.output_2d_tensor_and_feature_list:
191
+ return features, feature_list
192
+ return features
@@ -0,0 +1,87 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.utils.activation import get_activation
8
+
9
+ if tf.__version__ >= '2.0':
10
+ tf = tf.compat.v1
11
+
12
+
13
+ class DNN:
14
+
15
+ def __init__(self,
16
+ dnn_config,
17
+ l2_reg,
18
+ name='dnn',
19
+ is_training=False,
20
+ last_layer_no_activation=False,
21
+ last_layer_no_batch_norm=False):
22
+ """Initializes a `DNN` Layer.
23
+
24
+ Args:
25
+ dnn_config: instance of easy_rec.python.protos.dnn_pb2.DNN
26
+ l2_reg: l2 regularizer
27
+ name: scope of the DNN, so that the parameters could be separated from other dnns
28
+ is_training: train phase or not, impact batch_norm and dropout
29
+ last_layer_no_activation: in last layer, use or not use activation
30
+ last_layer_no_batch_norm: in last layer, use or not use batch norm
31
+ """
32
+ self._config = dnn_config
33
+ self._l2_reg = l2_reg
34
+ self._name = name
35
+ self._is_training = is_training
36
+ logging.info('dnn activation function = %s' % self._config.activation)
37
+ self.activation = get_activation(
38
+ self._config.activation, training=is_training)
39
+ self._last_layer_no_activation = last_layer_no_activation
40
+ self._last_layer_no_batch_norm = last_layer_no_batch_norm
41
+
42
+ @property
43
+ def hidden_units(self):
44
+ return self._config.hidden_units
45
+
46
+ @property
47
+ def dropout_ratio(self):
48
+ return self._config.dropout_ratio
49
+
50
+ def __call__(self, deep_fea, hidden_layer_feature_output=False):
51
+ hidden_units_len = len(self.hidden_units)
52
+ if hidden_units_len == 1 and self.hidden_units[0] == 0:
53
+ return deep_fea
54
+
55
+ hidden_feature_dict = {}
56
+ for i, unit in enumerate(self.hidden_units):
57
+ deep_fea = tf.layers.dense(
58
+ inputs=deep_fea,
59
+ units=unit,
60
+ kernel_regularizer=self._l2_reg,
61
+ activation=None,
62
+ name='%s/dnn_%d' % (self._name, i))
63
+ if self._config.use_bn and ((i + 1 < hidden_units_len) or
64
+ not self._last_layer_no_batch_norm):
65
+ deep_fea = tf.layers.batch_normalization(
66
+ deep_fea,
67
+ training=self._is_training,
68
+ trainable=True,
69
+ name='%s/dnn_%d/bn' % (self._name, i))
70
+ if (i + 1 < hidden_units_len) or not self._last_layer_no_activation:
71
+ deep_fea = self.activation(
72
+ deep_fea, name='%s/dnn_%d/act' % (self._name, i))
73
+ if len(self.dropout_ratio) > 0 and self._is_training:
74
+ assert self.dropout_ratio[
75
+ i] < 1, 'invalid dropout_ratio: %.3f' % self.dropout_ratio[i]
76
+ deep_fea = tf.nn.dropout(
77
+ deep_fea,
78
+ keep_prob=1 - self.dropout_ratio[i],
79
+ name='%s/%d/dropout' % (self._name, i))
80
+
81
+ if hidden_layer_feature_output:
82
+ hidden_feature_dict['hidden_layer' + str(i)] = deep_fea
83
+ if (i + 1 == hidden_units_len):
84
+ hidden_feature_dict['hidden_layer_end'] = deep_fea
85
+ return hidden_feature_dict
86
+ else:
87
+ return deep_fea
@@ -0,0 +1,25 @@
1
+ # -*- encoding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import tensorflow as tf
5
+
6
+ from easy_rec.python.feature_column.feature_group import FeatureGroup
7
+
8
+
9
+ class EmbedInputLayer(object):
10
+
11
+ def __init__(self, feature_groups_config, dump_dir=None):
12
+ self._feature_groups = {
13
+ x.group_name: FeatureGroup(x) for x in feature_groups_config
14
+ }
15
+ self._dump_dir = dump_dir
16
+
17
+ def __call__(self, features, group_name):
18
+ assert group_name in self._feature_groups, 'invalid group_name[%s], list: %s' % \
19
+ ','.join([x for x in self._feature_groups])
20
+ feature_group = self._feature_groups[group_name]
21
+ group_features = []
22
+ for feature_name in feature_group.feature_names:
23
+ tmp_fea = features[feature_name]
24
+ group_features.append(tmp_fea)
25
+ return tf.concat(group_features, axis=1), group_features
@@ -0,0 +1,26 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import tensorflow as tf
5
+
6
+ if tf.__version__ >= '2.0':
7
+ tf = tf.compat.v1
8
+
9
+
10
+ class FM:
11
+
12
+ def __init__(self, name='fm'):
13
+ """Initializes a `FM` Layer.
14
+
15
+ Args:
16
+ name: scope of the FM
17
+ """
18
+ self._name = name
19
+
20
+ def __call__(self, fm_fea):
21
+ with tf.name_scope(self._name):
22
+ fm_feas = tf.stack(fm_fea, axis=1)
23
+ sum_square = tf.square(tf.reduce_sum(fm_feas, 1))
24
+ square_sum = tf.reduce_sum(tf.square(fm_feas), 1)
25
+ y_v = 0.5 * tf.subtract(sum_square, square_sum)
26
+ return y_v