easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,1299 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import glob
5
+ import logging
6
+ import os
7
+ import threading
8
+ import time
9
+ import unittest
10
+
11
+ import numpy as np
12
+ import six
13
+ import tensorflow as tf
14
+ from distutils.version import LooseVersion
15
+ from tensorflow.python.platform import gfile
16
+
17
+ from easy_rec.python.main import predict
18
+ from easy_rec.python.utils import config_util
19
+ from easy_rec.python.utils import constant
20
+ from easy_rec.python.utils import estimator_utils
21
+ from easy_rec.python.utils import test_utils
22
+
23
+ try:
24
+ import graphlearn as gl
25
+ except Exception:
26
+ gl = None
27
+
28
+ try:
29
+ import horovod as hvd
30
+ except Exception:
31
+ hvd = None
32
+
33
+ try:
34
+ from sparse_operation_kit import experiment as sok
35
+ except Exception:
36
+ sok = None
37
+
38
+ tf_version = tf.__version__
39
+ if tf.__version__ >= '2.0':
40
+ tf = tf.compat.v1
41
+
42
+
43
+ class TrainEvalTest(tf.test.TestCase):
44
+
45
+ def setUp(self):
46
+ logging.info('Testing %s.%s' % (type(self).__name__, self._testMethodName))
47
+ self._test_dir = test_utils.get_tmp_dir()
48
+ self._success = True
49
+ logging.info('test dir: %s' % self._test_dir)
50
+
51
+ def tearDown(self):
52
+ test_utils.set_gpu_id(None)
53
+ if self._success:
54
+ test_utils.clean_up(self._test_dir)
55
+
56
+ def test_deepfm_with_lookup_feature(self):
57
+ self._success = test_utils.test_single_train_eval(
58
+ 'samples/model_config/deepfm_lookup.config', self._test_dir)
59
+ self.assertTrue(self._success)
60
+
61
+ def test_deepfm_with_combo_feature(self):
62
+ self._success = test_utils.test_single_train_eval(
63
+ 'samples/model_config/deepfm_combo_on_avazu_ctr.config', self._test_dir)
64
+ self.assertTrue(self._success)
65
+
66
+ def test_deepfm_with_combo_v2_feature(self):
67
+ self._success = test_utils.test_single_train_eval(
68
+ 'samples/model_config/deepfm_combo_v2_on_avazu_ctr.config',
69
+ self._test_dir)
70
+ self.assertTrue(self._success)
71
+
72
+ def test_deepfm_with_combo_v3_feature(self):
73
+ self._success = test_utils.test_single_train_eval(
74
+ 'samples/model_config/deepfm_combo_v3_on_avazu_ctr.config',
75
+ self._test_dir)
76
+ self.assertTrue(self._success)
77
+
78
+ def test_deepfm_freeze_gradient(self):
79
+ self._success = test_utils.test_single_train_eval(
80
+ 'samples/model_config/deepfm_freeze_gradient.config', self._test_dir)
81
+ self.assertTrue(self._success)
82
+
83
+ def test_deepfm_with_vocab_list(self):
84
+ self._success = test_utils.test_single_train_eval(
85
+ 'samples/model_config/deepfm_vocab_list_on_avazu_ctr.config',
86
+ self._test_dir)
87
+ self.assertTrue(self._success)
88
+
89
+ def test_deepfm_with_multi_class(self):
90
+ self._success = test_utils.test_single_train_eval(
91
+ 'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config',
92
+ self._test_dir)
93
+ self.assertTrue(self._success)
94
+
95
+ def test_wide_and_deep_no_final(self):
96
+ self._success = test_utils.test_single_train_eval(
97
+ 'samples/model_config/wide_and_deep_no_final_on_avazau_ctr.config',
98
+ self._test_dir)
99
+ self.assertTrue(self._success)
100
+
101
+ def test_wide_and_deep(self):
102
+ self._success = test_utils.test_single_train_eval(
103
+ 'samples/model_config/wide_and_deep_on_avazau_ctr.config',
104
+ self._test_dir)
105
+ self.assertTrue(self._success)
106
+
107
+ def test_wide_and_deep_backbone(self):
108
+ self._success = test_utils.test_single_train_eval(
109
+ 'samples/model_config/wide_and_deep_backbone_on_avazau.config',
110
+ self._test_dir)
111
+ self.assertTrue(self._success)
112
+
113
+ def test_dlrm(self):
114
+ self._success = test_utils.test_single_train_eval(
115
+ 'samples/model_config/dlrm_on_taobao.config', self._test_dir)
116
+
117
+ def test_dlrm_backbone(self):
118
+ self._success = test_utils.test_single_train_eval(
119
+ 'samples/model_config/dlrm_backbone_on_taobao.config', self._test_dir)
120
+
121
+ def test_adamw_optimizer(self):
122
+ self._success = test_utils.test_single_train_eval(
123
+ 'samples/model_config/deepfm_combo_on_avazu_adamw_ctr.config',
124
+ self._test_dir)
125
+ self.assertTrue(self._success)
126
+
127
+ def test_momentumw_optimizer(self):
128
+ self._success = test_utils.test_single_train_eval(
129
+ 'samples/model_config/deepfm_combo_on_avazu_momentumw_ctr.config',
130
+ self._test_dir)
131
+ self.assertTrue(self._success)
132
+
133
+ def test_deepfm_with_param_edit(self):
134
+ model_dir = os.path.join(self._test_dir, 'train_new')
135
+ self._success = test_utils.test_single_train_eval(
136
+ 'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config',
137
+ self._test_dir,
138
+ hyperparam_str='{"model_dir":"%s", '
139
+ '"model_config.deepfm.wide_output_dim": 32}' % model_dir)
140
+ self.assertTrue(self._success)
141
+ config_path = os.path.join(model_dir, 'pipeline.config')
142
+ pipeline_config = config_util.get_configs_from_pipeline_file(config_path)
143
+ self.assertTrue(pipeline_config.model_dir == model_dir)
144
+ self.assertTrue(pipeline_config.model_config.deepfm.wide_output_dim == 32)
145
+
146
+ def test_multi_tower(self):
147
+ self._success = test_utils.test_single_train_eval(
148
+ 'samples/model_config/multi_tower_on_taobao.config', self._test_dir)
149
+ self.assertTrue(self._success)
150
+
151
+ def test_multi_tower_backbone(self):
152
+ self._success = test_utils.test_single_train_eval(
153
+ 'samples/model_config/multi_tower_backbone_on_taobao.config',
154
+ self._test_dir)
155
+ self.assertTrue(self._success)
156
+
157
+ def test_multi_tower_gauc(self):
158
+ self._success = test_utils.test_single_train_eval(
159
+ 'samples/model_config/multi_tower_on_taobao_gauc.config',
160
+ self._test_dir)
161
+ self.assertTrue(self._success)
162
+
163
+ def test_multi_tower_session_auc(self):
164
+ self._success = test_utils.test_single_train_eval(
165
+ 'samples/model_config/multi_tower_on_taobao_session_auc.config',
166
+ self._test_dir)
167
+ self.assertTrue(self._success)
168
+
169
+ def test_multi_tower_save_checkpoint_secs(self):
170
+ self._success = test_utils.test_single_train_eval(
171
+ 'samples/model_config/multi_tower_save_secs_on_taobao.config',
172
+ self._test_dir,
173
+ total_steps=100)
174
+ ckpts_times = []
175
+ ckpt_dir = os.path.join(self._test_dir, 'train')
176
+ for filepath in os.listdir(ckpt_dir):
177
+ if filepath.startswith('model.ckpt') and filepath.endswith('meta'):
178
+ ckpts_times.append(os.path.getmtime(os.path.join(ckpt_dir, filepath)))
179
+ # remove last ckpt time
180
+ ckpts_times = np.array(sorted(ckpts_times)[:-1])
181
+ # ensure interval is 20s
182
+ diffs = list(ckpts_times[1:] - ckpts_times[:-1])
183
+ logging.info('nearby ckpts_times diff = %s' % diffs)
184
+ self.assertAllClose(
185
+ ckpts_times[1:] - ckpts_times[:-1], [20] * (len(ckpts_times) - 1),
186
+ atol=20)
187
+ self.assertTrue(self._success)
188
+
189
+ def test_keep_ckpt_max(self):
190
+
191
+ def _post_check_func(pipeline_config):
192
+ ckpt_prefix = os.path.join(pipeline_config.model_dir, 'model.ckpt-*.meta')
193
+ ckpts = gfile.Glob(ckpt_prefix)
194
+ assert len(ckpts) == 3, 'invalid number of checkpoints: %d' % len(ckpts)
195
+
196
+ self._success = test_utils.test_single_train_eval(
197
+ 'samples/model_config/multi_tower_ckpt_keep_3_on_taobao.config',
198
+ self._test_dir,
199
+ total_steps=500,
200
+ post_check_func=_post_check_func)
201
+
202
+ def test_multi_tower_with_best_exporter(self):
203
+
204
+ def _post_check_func(pipeline_config):
205
+ model_dir = pipeline_config.model_dir
206
+ best_ckpts = os.path.join(model_dir, 'best_ckpt/model.ckpt-*.meta')
207
+ best_ckpts = gfile.Glob(best_ckpts)
208
+ assert len(best_ckpts) <= 2, 'too many best ckpts: %s' % str(best_ckpts)
209
+ best_exports = os.path.join(model_dir, 'export/best/*')
210
+ best_exports = gfile.Glob(best_exports)
211
+ assert len(
212
+ best_exports) <= 2, 'too many best exports: %s' % str(best_exports)
213
+ return True
214
+
215
+ self._success = test_utils.test_single_train_eval(
216
+ 'samples/model_config/multi_tower_best_export_on_taobao.config',
217
+ self._test_dir,
218
+ total_steps=800,
219
+ post_check_func=_post_check_func,
220
+ timeout=3000)
221
+ self.assertTrue(self._success)
222
+
223
+ def test_latest_ckpt(self):
224
+ tmp = estimator_utils.latest_checkpoint('data/test/latest_ckpt_test')
225
+ assert tmp.endswith('model.ckpt-500')
226
+ tmp = estimator_utils.latest_checkpoint('data/test/latest_ckpt_test/')
227
+ assert tmp.endswith('model.ckpt-500')
228
+
229
+ def test_latest_ckpt_v2(self):
230
+
231
+ def _post_check_func(pipeline_config):
232
+ logging.info('model_dir: %s' % pipeline_config.model_dir)
233
+ logging.info('latest_checkpoint: %s' %
234
+ estimator_utils.latest_checkpoint(pipeline_config.model_dir))
235
+ return tf.train.latest_checkpoint(pipeline_config.model_dir) == \
236
+ estimator_utils.latest_checkpoint(pipeline_config.model_dir)
237
+
238
+ self._success = test_utils.test_single_train_eval(
239
+ 'samples/model_config/taobao_fg.config',
240
+ self._test_dir,
241
+ post_check_func=_post_check_func)
242
+ self.assertTrue(self._success)
243
+
244
+ def test_oss_stop_signal(self):
245
+ train_dir = os.path.join(self._test_dir, 'train/')
246
+
247
+ def _watch_func():
248
+ while True:
249
+ tmp_ckpt = estimator_utils.latest_checkpoint(train_dir)
250
+ if tmp_ckpt is not None:
251
+ version = estimator_utils.get_ckpt_version(tmp_ckpt)
252
+ if version > 30:
253
+ break
254
+ time.sleep(1)
255
+ stop_file = os.path.join(train_dir, 'OSS_STOP_SIGNAL')
256
+ with open(stop_file, 'w') as fout:
257
+ fout.write('OSS_STOP_SIGNAL')
258
+
259
+ watch_th = threading.Thread(target=_watch_func)
260
+ watch_th.start()
261
+
262
+ self._success = test_utils.test_distributed_train_eval(
263
+ 'samples/model_config/taobao_fg_signal_stop.config',
264
+ self._test_dir,
265
+ total_steps=1000)
266
+ self.assertTrue(self._success)
267
+ watch_th.join()
268
+ final_ckpt = estimator_utils.latest_checkpoint(train_dir)
269
+ ckpt_version = estimator_utils.get_ckpt_version(final_ckpt)
270
+ logging.info('final ckpt version = %d' % ckpt_version)
271
+ self._success = ckpt_version < 1000
272
+ assert ckpt_version < 1000
273
+
274
+ def test_dead_line_stop_signal(self):
275
+ train_dir = os.path.join(self._test_dir, 'train/')
276
+ self._success = test_utils.test_distributed_train_eval(
277
+ 'samples/model_config/dead_line_stop.config',
278
+ self._test_dir,
279
+ total_steps=1000)
280
+ self.assertTrue(self._success)
281
+ final_ckpt = estimator_utils.latest_checkpoint(train_dir)
282
+ ckpt_version = estimator_utils.get_ckpt_version(final_ckpt)
283
+ logging.info('final ckpt version = %d' % ckpt_version)
284
+ self._success = ckpt_version < 1000
285
+ assert ckpt_version < 1000
286
+
287
+ def test_fine_tune_latest_ckpt_path(self):
288
+
289
+ def _post_check_func(pipeline_config):
290
+ logging.info('model_dir: %s' % pipeline_config.model_dir)
291
+ pipeline_config = config_util.get_configs_from_pipeline_file(
292
+ os.path.join(pipeline_config.model_dir, 'pipeline.config'), False)
293
+ logging.info('fine_tune_checkpoint: %s' %
294
+ pipeline_config.train_config.fine_tune_checkpoint)
295
+ return pipeline_config.train_config.fine_tune_checkpoint == \
296
+ 'data/test/mt_ckpt/model.ckpt-100'
297
+
298
+ self._success = test_utils.test_single_train_eval(
299
+ 'samples/model_config/multi_tower_on_taobao.config',
300
+ self._test_dir,
301
+ fine_tune_checkpoint='data/test/mt_ckpt',
302
+ post_check_func=_post_check_func)
303
+ self.assertTrue(self._success)
304
+
305
+ def test_fine_tune_ckpt(self):
306
+
307
+ def _post_check_func(pipeline_config):
308
+ pipeline_config.train_config.fine_tune_checkpoint = \
309
+ estimator_utils.latest_checkpoint(pipeline_config.model_dir)
310
+ test_dir = os.path.join(self._test_dir, 'fine_tune')
311
+ pipeline_config.model_dir = os.path.join(test_dir, 'ckpt')
312
+ return test_utils.test_single_train_eval(pipeline_config, test_dir)
313
+
314
+ self._success = test_utils.test_single_train_eval(
315
+ 'samples/model_config/taobao_fg.config',
316
+ self._test_dir,
317
+ post_check_func=_post_check_func)
318
+ self.assertTrue(self._success)
319
+
320
+ def test_multi_tower_multi_value_export(self):
321
+ self._success = test_utils.test_single_train_eval(
322
+ 'samples/model_config/multi_tower_multi_value_export_on_taobao.config',
323
+ self._test_dir)
324
+ self.assertTrue(self._success)
325
+
326
+ def test_multi_tower_fg_input(self):
327
+ self._success = test_utils.test_single_train_eval(
328
+ 'samples/model_config/taobao_fg.config', self._test_dir)
329
+ self.assertTrue(self._success)
330
+
331
+ def test_multi_tower_fg_json_config(self):
332
+ self._success = test_utils.test_single_train_eval(
333
+ 'samples/model_config/taobao_fg.json', self._test_dir)
334
+ self.assertTrue(self._success)
335
+
336
+ def test_fm(self):
337
+ self._success = test_utils.test_single_train_eval(
338
+ 'samples/model_config/fm_on_taobao.config', self._test_dir)
339
+ self.assertTrue(self._success)
340
+
341
+ def test_place_embed_on_cpu(self):
342
+ os.environ['place_embedding_on_cpu'] = 'True'
343
+ self._success = test_utils.test_single_train_eval(
344
+ 'samples/model_config/fm_on_taobao.config', self._test_dir)
345
+ self.assertTrue(self._success)
346
+
347
+ def test_din(self):
348
+ self._success = test_utils.test_single_train_eval(
349
+ 'samples/model_config/din_on_taobao.config', self._test_dir)
350
+ self.assertTrue(self._success)
351
+
352
+ def test_din_backbone(self):
353
+ self._success = test_utils.test_single_train_eval(
354
+ 'samples/model_config/din_backbone_on_taobao.config', self._test_dir)
355
+ self.assertTrue(self._success)
356
+
357
+ def test_bst(self):
358
+ self._success = test_utils.test_single_train_eval(
359
+ 'samples/model_config/bst_on_taobao.config', self._test_dir)
360
+ self.assertTrue(self._success)
361
+
362
+ def test_bst_backbone(self):
363
+ self._success = test_utils.test_single_train_eval(
364
+ 'samples/model_config/bst_backbone_on_taobao.config', self._test_dir)
365
+ self.assertTrue(self._success)
366
+
367
+ def test_cl4srec(self):
368
+ self._success = test_utils.test_single_train_eval(
369
+ 'samples/model_config/cl4srec_on_taobao.config', self._test_dir)
370
+ self.assertTrue(self._success)
371
+
372
+ def test_dcn(self):
373
+ self._success = test_utils.test_single_train_eval(
374
+ 'samples/model_config/dcn_on_taobao.config', self._test_dir)
375
+ self.assertTrue(self._success)
376
+
377
+ def test_ziln_loss(self):
378
+ self._success = test_utils.test_single_train_eval(
379
+ 'samples/model_config/mlp_on_taobao_with_ziln_loss.config',
380
+ self._test_dir)
381
+ self.assertTrue(self._success)
382
+
383
+ def test_fibinet(self):
384
+ self._success = test_utils.test_single_train_eval(
385
+ 'samples/model_config/fibinet_on_taobao.config', self._test_dir)
386
+ self.assertTrue(self._success)
387
+
388
+ def test_masknet(self):
389
+ self._success = test_utils.test_single_train_eval(
390
+ 'samples/model_config/masknet_on_taobao.config', self._test_dir)
391
+ self.assertTrue(self._success)
392
+
393
+ def test_dcn_backbone(self):
394
+ self._success = test_utils.test_single_train_eval(
395
+ 'samples/model_config/dcn_backbone_on_taobao.config', self._test_dir)
396
+ self.assertTrue(self._success)
397
+
398
+ def test_dcn_with_f1(self):
399
+ self._success = test_utils.test_single_train_eval(
400
+ 'samples/model_config/dcn_f1_on_taobao.config', self._test_dir)
401
+ self.assertTrue(self._success)
402
+
403
+ def test_autoint(self):
404
+ self._success = test_utils.test_single_train_eval(
405
+ 'samples/model_config/autoint_on_taobao.config', self._test_dir)
406
+ self.assertTrue(self._success)
407
+
408
+ def test_uniter(self):
409
+ self._success = test_utils.test_single_train_eval(
410
+ 'samples/model_config/uniter_on_movielens.config', self._test_dir)
411
+ self.assertTrue(self._success)
412
+
413
+ def test_highway(self):
414
+ self._success = test_utils.test_single_train_eval(
415
+ 'samples/model_config/highway_on_movielens.config', self._test_dir)
416
+ self.assertTrue(self._success)
417
+
418
+ # @unittest.skipIf(
419
+ # LooseVersion(tf.__version__) >= LooseVersion('2.0.0'),
420
+ # 'has no CustomOp when tf version == 2.4')
421
+ # def test_custom_op(self):
422
+ # self._success = test_utils.test_single_train_eval(
423
+ # 'samples/model_config/cl4srec_on_taobao_with_custom_op.config',
424
+ # self._test_dir)
425
+ # self.assertTrue(self._success)
426
+
427
+ def test_cdn(self):
428
+ self._success = test_utils.test_single_train_eval(
429
+ 'samples/model_config/cdn_on_taobao.config', self._test_dir)
430
+ self.assertTrue(self._success)
431
+
432
+ def test_ppnet(self):
433
+ self._success = test_utils.test_single_train_eval(
434
+ 'samples/model_config/ppnet_on_taobao.config', self._test_dir)
435
+ self.assertTrue(self._success)
436
+
437
+ def test_uniter_only_text_feature(self):
438
+ self._success = test_utils.test_single_train_eval(
439
+ 'samples/model_config/uniter_on_movielens_only_text_feature.config',
440
+ self._test_dir)
441
+ self.assertTrue(self._success)
442
+
443
+ def test_uniter_only_image_feature(self):
444
+ self._success = test_utils.test_single_train_eval(
445
+ 'samples/model_config/uniter_on_movielens_only_image_feature.config',
446
+ self._test_dir)
447
+ self.assertTrue(self._success)
448
+
449
+ def test_cmbf(self):
450
+ self._success = test_utils.test_single_train_eval(
451
+ 'samples/model_config/cmbf_on_movielens.config', self._test_dir)
452
+ self.assertTrue(self._success)
453
+
454
+ def test_cmbf_with_multi_loss(self):
455
+ self._success = test_utils.test_single_train_eval(
456
+ 'samples/model_config/cmbf_with_multi_loss.config', self._test_dir)
457
+ self.assertTrue(self._success)
458
+
459
+ def test_cmbf_has_other_feature(self):
460
+ self._success = test_utils.test_single_train_eval(
461
+ 'samples/model_config/cmbf_on_movielens_has_other_feature.config',
462
+ self._test_dir)
463
+ self.assertTrue(self._success)
464
+
465
+ def test_cmbf_only_text_feature(self):
466
+ self._success = test_utils.test_single_train_eval(
467
+ 'samples/model_config/cmbf_on_movielens_only_text_feature.config',
468
+ self._test_dir)
469
+ self.assertTrue(self._success)
470
+
471
+ def test_cmbf_only_image_feature(self):
472
+ self._success = test_utils.test_single_train_eval(
473
+ 'samples/model_config/cmbf_on_movielens_only_image_feature.config',
474
+ self._test_dir)
475
+ self.assertTrue(self._success)
476
+
477
+ def test_dssm(self):
478
+ self._success = test_utils.test_single_train_eval(
479
+ 'samples/model_config/dssm_on_taobao.config', self._test_dir)
480
+ self.assertTrue(self._success)
481
+
482
+ def test_dropoutnet(self):
483
+ self._success = test_utils.test_single_train_eval(
484
+ 'samples/model_config/dropoutnet_on_taobao.config', self._test_dir)
485
+ self.assertTrue(self._success)
486
+
487
+ def test_metric_learning(self):
488
+ self._success = test_utils.test_single_train_eval(
489
+ 'samples/model_config/metric_learning_on_taobao.config', self._test_dir)
490
+ self.assertTrue(self._success)
491
+
492
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
493
+ def test_dssm_neg_sampler(self):
494
+ self._success = test_utils.test_single_train_eval(
495
+ 'samples/model_config/dssm_neg_sampler_on_taobao.config',
496
+ self._test_dir)
497
+ self.assertTrue(self._success)
498
+
499
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
500
+ def test_dssm_neg_sampler_v2(self):
501
+ self._success = test_utils.test_single_train_eval(
502
+ 'samples/model_config/dssm_neg_sampler_v2_on_taobao.config',
503
+ self._test_dir)
504
+ self.assertTrue(self._success)
505
+
506
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
507
+ def test_dssm_hard_neg_sampler(self):
508
+ self._success = test_utils.test_single_train_eval(
509
+ 'samples/model_config/dssm_hard_neg_sampler_on_taobao.config',
510
+ self._test_dir)
511
+ self.assertTrue(self._success)
512
+
513
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
514
+ def test_dssm_hard_neg_regular_sampler(self):
515
+ self._success = test_utils.test_single_train_eval(
516
+ 'samples/model_config/dssm_hard_neg_sampler_regular_on_taobao.config',
517
+ self._test_dir)
518
+ self.assertTrue(self._success)
519
+
520
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
521
+ def test_dssm_hard_neg_sampler_v2(self):
522
+ self._success = test_utils.test_single_train_eval(
523
+ 'samples/model_config/dssm_hard_neg_sampler_v2_on_taobao.config',
524
+ self._test_dir)
525
+ self.assertTrue(self._success)
526
+
527
+ def test_dssm_no_norm(self):
528
+ self._success = test_utils.test_single_train_eval(
529
+ 'samples/model_config/dssm_inner_prod_on_taobao.config', self._test_dir)
530
+ self.assertTrue(self._success)
531
+
532
+ # def test_dssm_with_regression(self):
533
+ # self._success = test_utils.test_single_train_eval(
534
+ # 'samples/model_config/dssm_reg_on_taobao.config', self._test_dir)
535
+ # self.assertTrue(self._success)
536
+
537
+ def _test_kd(self, config0, config1):
538
+ self._success = test_utils.test_single_train_eval(config0, self._test_dir)
539
+ self.assertTrue(self._success)
540
+ config_path = os.path.join(self._test_dir, 'pipeline.config')
541
+ pipeline_config = config_util.get_configs_from_pipeline_file(config_path)
542
+
543
+ train_path = os.path.join(self._test_dir, 'train_kd')
544
+ eval_path = os.path.join(self._test_dir, 'eval_kd')
545
+
546
+ @test_utils.RunAsSubprocess
547
+ def _gen_kd_data(train_path, eval_path):
548
+ pred_result = predict(config_path, None, pipeline_config.train_input_path)
549
+ with gfile.GFile(pipeline_config.train_input_path, 'r') as fin:
550
+ with gfile.GFile(train_path, 'w') as fout:
551
+ for line, pred in zip(fin, pred_result):
552
+ if isinstance(pred['logits'], type(np.array([]))):
553
+ pred_logits = ''.join([str(x) for x in pred['logits']])
554
+ else:
555
+ pred_logits = str(pred['logits'])
556
+ fout.write(line.strip() + ',' + pred_logits + '\n')
557
+ pred_result = predict(config_path, None, pipeline_config.eval_input_path)
558
+ with gfile.GFile(pipeline_config.eval_input_path, 'r') as fin:
559
+ with gfile.GFile(eval_path, 'w') as fout:
560
+ for line, pred in zip(fin, pred_result):
561
+ if isinstance(pred['logits'], type(np.array([]))):
562
+ pred_logits = ''.join([str(x) for x in pred['logits']])
563
+ else:
564
+ pred_logits = str(pred['logits'])
565
+ fout.write(line.strip() + ',' + pred_logits + '\n')
566
+
567
+ _gen_kd_data(train_path, eval_path)
568
+ pipeline_config = config_util.get_configs_from_pipeline_file(config1)
569
+ pipeline_config.train_input_path = train_path
570
+ pipeline_config.eval_input_path = eval_path
571
+ config_util.save_pipeline_config(pipeline_config, self._test_dir,
572
+ 'kd_pipeline.config')
573
+ self._success = test_utils.test_single_train_eval(
574
+ os.path.join(self._test_dir, 'kd_pipeline.config'),
575
+ os.path.join(self._test_dir, 'kd'))
576
+ self.assertTrue(self._success)
577
+
578
+ def test_dssm_with_kd(self):
579
+ self._test_kd('samples/model_config/multi_tower_on_taobao.config',
580
+ 'samples/model_config/dssm_kd_on_taobao.config')
581
+
582
+ def test_deepfm_multi_class_with_kd(self):
583
+ self._test_kd('samples/model_config/deepfm_multi_cls_on_avazu_ctr.config',
584
+ 'samples/model_config/deepfm_multi_cls_small.config')
585
+
586
+ def test_mind(self):
587
+ self._success = test_utils.test_single_train_eval(
588
+ 'samples/model_config/mind_on_taobao.config', self._test_dir)
589
+ self.assertTrue(self._success)
590
+
591
+ def test_mind_with_time_id(self):
592
+ self._success = test_utils.test_single_train_eval(
593
+ 'samples/model_config/mind_on_taobao_with_time.config', self._test_dir)
594
+ self.assertTrue(self._success)
595
+
596
+ def test_deepfm_with_regression(self):
597
+ self._success = test_utils.test_single_train_eval(
598
+ 'samples/model_config/deepfm_combo_on_avazu_reg.config', self._test_dir)
599
+ self.assertTrue(self._success)
600
+
601
+ def test_deepfm_with_sigmoid_l2_loss(self):
602
+ self._success = test_utils.test_single_train_eval(
603
+ 'samples/model_config/deepfm_combo_on_avazu_sigmoid_l2.config',
604
+ self._test_dir)
605
+ self.assertTrue(self._success)
606
+
607
+ def test_deepfm_with_embedding_learning_rate(self):
608
+ self._success = test_utils.test_single_train_eval(
609
+ 'samples/model_config/deepfm_combo_on_avazu_emblr_ctr.config',
610
+ self._test_dir)
611
+ self.assertTrue(self._success)
612
+
613
+ def test_deepfm_with_eval_online(self):
614
+ self._success = test_utils.test_single_train_eval(
615
+ 'samples/model_config/deepfm_combo_on_avazu_eval_online_ctr.config',
616
+ self._test_dir)
617
+ self.assertTrue(self._success)
618
+
619
+ def test_deepfm_with_eval_online_gauc(self):
620
+ self._success = test_utils.test_single_train_eval(
621
+ 'samples/model_config/deepfm_combo_on_avazu_eval_online_gauc_ctr.config',
622
+ self._test_dir)
623
+ self.assertTrue(self._success)
624
+
625
+ def test_mmoe(self):
626
+ self._success = test_utils.test_single_train_eval(
627
+ 'samples/model_config/mmoe_on_taobao.config', self._test_dir)
628
+ self.assertTrue(self._success)
629
+
630
+ def test_mmoe_backbone(self):
631
+ self._success = test_utils.test_single_train_eval(
632
+ 'samples/model_config/mmoe_backbone_on_taobao.config', self._test_dir)
633
+ self.assertTrue(self._success)
634
+
635
+ def test_mmoe_with_multi_loss(self):
636
+ self._success = test_utils.test_single_train_eval(
637
+ 'samples/model_config/mmoe_on_taobao_with_multi_loss.config',
638
+ self._test_dir)
639
+ self.assertTrue(self._success)
640
+
641
+ def test_mmoe_deprecated(self):
642
+ self._success = test_utils.test_single_train_eval(
643
+ 'samples/model_config/mmoe_on_taobao_deprecated.config', self._test_dir)
644
+ self.assertTrue(self._success)
645
+
646
+ def test_simple_multi_task(self):
647
+ self._success = test_utils.test_single_train_eval(
648
+ 'samples/model_config/simple_multi_task_on_taobao.config',
649
+ self._test_dir)
650
+ self.assertTrue(self._success)
651
+
652
+ def test_simple_multi_task_backbone(self):
653
+ self._success = test_utils.test_single_train_eval(
654
+ 'samples/model_config/simple_multi_task_backbone_on_taobao.config',
655
+ self._test_dir)
656
+ self.assertTrue(self._success)
657
+
658
+ def test_esmm(self):
659
+ self._success = test_utils.test_single_train_eval(
660
+ 'samples/model_config/esmm_on_taobao.config', self._test_dir)
661
+ self.assertTrue(self._success)
662
+
663
+ def test_tag_kv_input(self):
664
+ self._success = test_utils.test_single_train_eval(
665
+ 'samples/model_config/kv_tag.config', self._test_dir)
666
+ self.assertTrue(self._success)
667
+
668
+ def test_aitm(self):
669
+ self._success = test_utils.test_single_train_eval(
670
+ 'samples/model_config/aitm_on_taobao.config', self._test_dir)
671
+ self.assertTrue(self._success)
672
+
673
+ def test_dbmtl(self):
674
+ self._success = test_utils.test_single_train_eval(
675
+ 'samples/model_config/dbmtl_on_taobao.config', self._test_dir)
676
+ self.assertTrue(self._success)
677
+
678
+ def test_dbmtl_backbone(self):
679
+ self._success = test_utils.test_single_train_eval(
680
+ 'samples/model_config/dbmtl_backbone_on_taobao.config', self._test_dir)
681
+ self.assertTrue(self._success)
682
+
683
+ def test_dbmtl_cmbf(self):
684
+ self._success = test_utils.test_single_train_eval(
685
+ 'samples/model_config/dbmtl_cmbf_on_movielens.config', self._test_dir)
686
+ self.assertTrue(self._success)
687
+
688
+ def test_dbmtl_uniter(self):
689
+ self._success = test_utils.test_single_train_eval(
690
+ 'samples/model_config/dbmtl_uniter_on_movielens.config', self._test_dir)
691
+ self.assertTrue(self._success)
692
+
693
+ def test_dbmtl_with_multi_loss(self):
694
+ self._success = test_utils.test_single_train_eval(
695
+ 'samples/model_config/dbmtl_on_taobao_with_multi_loss.config',
696
+ self._test_dir)
697
+ self.assertTrue(self._success)
698
+
699
+ def test_early_stop(self):
700
+ self._success = test_utils.test_single_train_eval(
701
+ 'samples/model_config/multi_tower_early_stop_on_taobao.config',
702
+ self._test_dir)
703
+ self.assertTrue(self._success)
704
+
705
+ def test_early_stop_custom(self):
706
+ self._success = test_utils.test_single_train_eval(
707
+ 'samples/model_config/custom_early_stop_on_taobao.config',
708
+ self._test_dir)
709
+ self.assertTrue(self._success)
710
+
711
+ def test_early_stop_dis(self):
712
+ self._success = test_utils.test_distributed_train_eval(
713
+ 'samples/model_config/multi_tower_early_stop_on_taobao.config',
714
+ self._test_dir)
715
+ self.assertTrue(self._success)
716
+
717
+ def test_latest_export_with_asset(self):
718
+ self._success = test_utils.test_distributed_train_eval(
719
+ 'samples/model_config/din_on_taobao_latest_export.config',
720
+ self._test_dir)
721
+ self.assertTrue(self._success)
722
+
723
+ def test_incompatible_restore(self):
724
+
725
+ def _post_check_func(config):
726
+ config.feature_config.features[0].hash_bucket_size += 20000
727
+ config.feature_config.features[1].hash_bucket_size += 100
728
+ config.train_config.fine_tune_checkpoint = config.model_dir
729
+ config.model_dir += '_finetune'
730
+ config.train_config.force_restore_shape_compatible = True
731
+ return test_utils.test_single_train_eval(
732
+ config, os.path.join(self._test_dir, 'finetune'))
733
+
734
+ self._success = test_utils.test_single_train_eval(
735
+ 'samples/model_config/taobao_fg.config',
736
+ self._test_dir,
737
+ post_check_func=_post_check_func)
738
+ self.assertTrue(self._success)
739
+
740
+ def test_dbmtl_variational_dropout(self):
741
+ self._success = test_utils.test_single_train_eval(
742
+ 'samples/model_config/dbmtl_variational_dropout.config',
743
+ self._test_dir,
744
+ post_check_func=test_utils.test_feature_selection)
745
+ self.assertTrue(self._success)
746
+
747
+ def test_dbmtl_variational_dropout_feature_num(self):
748
+ self._success = test_utils.test_single_train_eval(
749
+ 'samples/model_config/dbmtl_variational_dropout_feature_num.config',
750
+ self._test_dir,
751
+ post_check_func=test_utils.test_feature_selection)
752
+ self.assertTrue(self._success)
753
+
754
+ def test_essm_variational_dropout(self):
755
+ self._success = test_utils.test_single_train_eval(
756
+ 'samples/model_config/esmm_variational_dropout_on_taobao.config',
757
+ self._test_dir,
758
+ post_check_func=test_utils.test_feature_selection)
759
+ self.assertTrue(self._success)
760
+
761
+ def test_fm_variational_dropout(self):
762
+ self._success = test_utils.test_single_train_eval(
763
+ 'samples/model_config/fm_variational_dropout_on_taobao.config',
764
+ self._test_dir,
765
+ post_check_func=test_utils.test_feature_selection)
766
+ self.assertTrue(self._success)
767
+
768
+ def test_deepfm_with_combo_feature_variational_dropout(self):
769
+ self._success = test_utils.test_single_train_eval(
770
+ 'samples/model_config/deepfm_combo_variational_dropout_on_avazu_ctr.config',
771
+ self._test_dir,
772
+ post_check_func=test_utils.test_feature_selection)
773
+ self.assertTrue(self._success)
774
+
775
+ def test_dbmtl_sequence_variational_dropout(self):
776
+ self._success = test_utils.test_single_train_eval(
777
+ 'samples/model_config/dbmtl_variational_dropout_on_sequence_feature_taobao.config',
778
+ self._test_dir,
779
+ post_check_func=test_utils.test_feature_selection)
780
+ self.assertTrue(self._success)
781
+
782
+ def test_din_variational_dropout(self):
783
+ self._success = test_utils.test_single_train_eval(
784
+ 'samples/model_config/din_varitional_dropout_on_taobao.config',
785
+ self._test_dir,
786
+ post_check_func=test_utils.test_feature_selection)
787
+ self.assertTrue(self._success)
788
+
789
+ def test_rocket_launching(self):
790
+ self._success = test_utils.test_single_train_eval(
791
+ 'samples/model_config/rocket_launching.config', self._test_dir)
792
+ self.assertTrue(self._success)
793
+
794
+ def test_rocket_launching_feature_based(self):
795
+ self._success = test_utils.test_single_train_eval(
796
+ 'samples/model_config/rocket_launching_feature_based.config',
797
+ self._test_dir)
798
+ self.assertTrue(self._success)
799
+
800
+ def test_rocket_launching_with_rtp_input(self):
801
+ self._success = test_utils.test_single_train_eval(
802
+ 'samples/model_config/rocket_launching_with_rtp_input.config',
803
+ self._test_dir)
804
+ self.assertTrue(self._success)
805
+
806
+ def test_dbmtl_mmoe(self):
807
+ self._success = test_utils.test_single_train_eval(
808
+ 'samples/model_config/dbmtl_mmoe_on_taobao.config', self._test_dir)
809
+ self.assertTrue(self._success)
810
+
811
+ def test_train_with_ps_worker(self):
812
+ self._success = test_utils.test_distributed_train_eval(
813
+ 'samples/model_config/multi_tower_on_taobao.config', self._test_dir)
814
+ self.assertTrue(self._success)
815
+
816
+ @unittest.skip('Timeout on CI machine')
817
+ def test_fit_on_eval(self):
818
+ self._success = test_utils.test_distributed_train_eval(
819
+ 'samples/model_config/multi_tower_on_taobao.config',
820
+ self._test_dir,
821
+ total_steps=10,
822
+ num_evaluator=1,
823
+ fit_on_eval=True)
824
+ self.assertTrue(self._success)
825
+
826
+ def test_unbalance_data(self):
827
+ self._success = test_utils.test_distributed_train_eval(
828
+ 'samples/model_config/multi_tower_on_taobao_unblanace.config',
829
+ self._test_dir,
830
+ total_steps=0,
831
+ num_epoch=1,
832
+ num_evaluator=1)
833
+ self.assertTrue(self._success)
834
+
835
+ def test_train_with_ps_worker_with_evaluator(self):
836
+ self._success = test_utils.test_distributed_train_eval(
837
+ 'samples/model_config/multi_tower_on_taobao.config',
838
+ self._test_dir,
839
+ num_evaluator=1)
840
+ self.assertTrue(self._success)
841
+ final_export_dir = os.path.join(self._test_dir, 'train/export/final')
842
+ all_saved_files = glob.glob(final_export_dir + '/*/saved_model.pb')
843
+ logging.info('final_export_dir=%s all_saved_files=%s' %
844
+ (final_export_dir, ','.join(all_saved_files)))
845
+ self.assertTrue(len(all_saved_files) == 1)
846
+
847
+ def test_train_with_ps_worker_chief_redundant(self):
848
+ self._success = test_utils.test_distributed_train_eval(
849
+ 'samples/model_config/multi_tower_on_taobao_chief_redundant.config',
850
+ self._test_dir)
851
+ self.assertTrue(self._success)
852
+
853
+ def test_deepfm_embed_input(self):
854
+ self._success = test_utils.test_distributed_train_eval(
855
+ 'samples/model_config/deepfm_with_embed.config', self._test_dir)
856
+ self.assertTrue(self._success)
857
+
858
+ def test_multi_tower_embed_input(self):
859
+ self._success = test_utils.test_distributed_train_eval(
860
+ 'samples/model_config/multi_tower_with_embed.config', self._test_dir)
861
+ self.assertTrue(self._success)
862
+
863
+ def test_tfrecord_input(self):
864
+ self._success = test_utils.test_distributed_train_eval(
865
+ 'samples/model_config/deepfm_on_criteo_tfrecord.config', self._test_dir)
866
+ self.assertTrue(self._success)
867
+
868
+ def test_batch_tfrecord_input(self):
869
+ self._success = test_utils.test_distributed_train_eval(
870
+ 'samples/model_config/deepfm_on_criteo_batch_tfrecord.config',
871
+ self._test_dir)
872
+ self.assertTrue(self._success)
873
+
874
+ def test_autodis_embedding(self):
875
+ self._success = test_utils.test_single_train_eval(
876
+ 'samples/model_config/deepfm_on_criteo_with_autodis.config',
877
+ self._test_dir)
878
+ self.assertTrue(self._success)
879
+
880
+ def test_periodic_embedding(self):
881
+ self._success = test_utils.test_single_train_eval(
882
+ 'samples/model_config/deepfm_on_criteo_with_periodic.config',
883
+ self._test_dir)
884
+ self.assertTrue(self._success)
885
+
886
+ def test_sample_weight(self):
887
+ self._success = test_utils.test_single_train_eval(
888
+ 'samples/model_config/deepfm_with_sample_weight.config', self._test_dir)
889
+ self.assertTrue(self._success)
890
+
891
+ def test_dssm_sample_weight(self):
892
+ self._success = test_utils.test_single_train_eval(
893
+ 'samples/model_config/dssm_with_sample_weight.config', self._test_dir)
894
+ self.assertTrue(self._success)
895
+
896
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
897
+ def test_dssm_neg_sampler_with_sample_weight(self):
898
+ self._success = test_utils.test_single_train_eval(
899
+ 'samples/model_config/dssm_neg_sampler_with_sample_weight.config',
900
+ self._test_dir)
901
+ self.assertTrue(self._success)
902
+
903
+ @unittest.skipIf(
904
+ LooseVersion(tf.__version__) != LooseVersion('2.3.0'),
905
+ 'MultiWorkerMirroredStrategy need tf version == 2.3')
906
+ def test_train_with_multi_worker_mirror(self):
907
+ self._success = test_utils.test_distributed_train_eval(
908
+ 'samples/model_config/multi_tower_multi_worker_mirrored_strategy_on_taobao.config',
909
+ self._test_dir)
910
+ self.assertTrue(self._success)
911
+
912
+ @unittest.skipIf(
913
+ LooseVersion(tf.__version__) != LooseVersion('2.3.0'),
914
+ 'MultiWorkerMirroredStrategy need tf version == 2.3')
915
+ def test_train_mmoe_with_multi_worker_mirror(self):
916
+ self._success = test_utils.test_distributed_train_eval(
917
+ 'samples/model_config/mmoe_mirrored_strategy_on_taobao.config',
918
+ self._test_dir)
919
+ self.assertTrue(self._success)
920
+
921
+ def test_fg_dtype(self):
922
+ self._success = test_utils.test_single_train_eval(
923
+ 'samples/model_config/taobao_fg_test_dtype.config', self._test_dir)
924
+ self.assertTrue(self._success)
925
+
926
+ @unittest.skipIf(six.PY2, 'Only run in python3')
927
+ def test_share_not_used(self):
928
+ self._success = test_utils.test_single_train_eval(
929
+ 'samples/model_config/share_not_used.config', self._test_dir)
930
+ self.assertTrue(self._success)
931
+
932
+ def test_sequence_autoint(self):
933
+ self._success = test_utils.test_single_train_eval(
934
+ 'samples/model_config/autoint_on_sequence_feature_taobao.config',
935
+ self._test_dir)
936
+ self.assertTrue(self._success)
937
+
938
+ def test_sequence_dbmtl(self):
939
+ self._success = test_utils.test_single_train_eval(
940
+ 'samples/model_config/dbmtl_on_sequence_feature_taobao.config',
941
+ self._test_dir)
942
+ self.assertTrue(self._success)
943
+
944
+ def test_sequence_dcn(self):
945
+ self._success = test_utils.test_single_train_eval(
946
+ 'samples/model_config/dcn_on_sequence_feature_taobao.config',
947
+ self._test_dir)
948
+ self.assertTrue(self._success)
949
+
950
+ def test_sequence_dssm(self):
951
+ self._success = test_utils.test_single_train_eval(
952
+ 'samples/model_config/dssm_on_sequence_feature_taobao.config',
953
+ self._test_dir)
954
+ self.assertTrue(self._success)
955
+
956
+ def test_sequence_esmm(self):
957
+ self._success = test_utils.test_single_train_eval(
958
+ 'samples/model_config/esmm_on_sequence_feature_taobao.config',
959
+ self._test_dir)
960
+ self.assertTrue(self._success)
961
+
962
+ def test_sequence_mmoe(self):
963
+ self._success = test_utils.test_single_train_eval(
964
+ 'samples/model_config/mmoe_on_sequence_feature_taobao.config',
965
+ self._test_dir)
966
+ self.assertTrue(self._success)
967
+
968
+ def test_sequence_ple(self):
969
+ self._success = test_utils.test_single_train_eval(
970
+ 'samples/model_config/ple_on_sequence_feature_taobao.config',
971
+ self._test_dir)
972
+ self.assertTrue(self._success)
973
+
974
+ def test_sequence_rocket_launching(self):
975
+ self._success = test_utils.test_single_train_eval(
976
+ 'samples/model_config/rocket_launching_on_sequence_feature_taobao.config',
977
+ self._test_dir)
978
+ self.assertTrue(self._success)
979
+
980
+ def test_sequence_simple_multi_task(self):
981
+ self._success = test_utils.test_single_train_eval(
982
+ 'samples/model_config/simple_multi_task_on_sequence_feature_taobao.config',
983
+ self._test_dir)
984
+ self.assertTrue(self._success)
985
+
986
+ def test_sequence_wide_and_deep(self):
987
+ self._success = test_utils.test_single_train_eval(
988
+ 'samples/model_config/wide_and_deep_on_sequence_feature_taobao.config',
989
+ self._test_dir)
990
+ self.assertTrue(self._success)
991
+
992
+ def test_numeric_boundary_sequence_dbmtl(self):
993
+ self._success = test_utils.test_single_train_eval(
994
+ 'samples/model_config/dbmtl_on_numeric_boundary_sequence_feature_taobao.config',
995
+ self._test_dir)
996
+ self.assertTrue(self._success)
997
+
998
+ def test_numeric_hash_bucket_sequence_dbmtl(self):
999
+ self._success = test_utils.test_single_train_eval(
1000
+ 'samples/model_config/dbmtl_on_numeric_hash_bucket_sequence_feature_taobao.config',
1001
+ self._test_dir)
1002
+ self.assertTrue(self._success)
1003
+
1004
+ def test_numeric_raw_sequence_dbmtl(self):
1005
+ self._success = test_utils.test_single_train_eval(
1006
+ 'samples/model_config/dbmtl_on_numeric_raw_sequence_feature_taobao.config',
1007
+ self._test_dir)
1008
+ self.assertTrue(self._success)
1009
+
1010
+ def test_numeric_num_buckets_sequence_dbmtl(self):
1011
+ self._success = test_utils.test_single_train_eval(
1012
+ 'samples/model_config/dbmtl_on_numeric_num_buckets_sequence_feature_taobao.config',
1013
+ self._test_dir)
1014
+ self.assertTrue(self._success)
1015
+
1016
+ def test_multi_numeric_boundary_sequence_dbmtl(self):
1017
+ self._success = test_utils.test_single_train_eval(
1018
+ 'samples/model_config/dbmtl_on_multi_numeric_boundary_sequence_feature_taobao.config',
1019
+ self._test_dir)
1020
+ self.assertTrue(self._success)
1021
+
1022
+ def test_multi_numeric_hash_bucket_sequence_dbmtl(self):
1023
+ self._success = test_utils.test_single_train_eval(
1024
+ 'samples/model_config/dbmtl_on_multi_numeric_hash_bucket_sequence_feature_taobao.config',
1025
+ self._test_dir)
1026
+ self.assertTrue(self._success)
1027
+
1028
+ def test_multi_numeric_raw_sequence_dbmtl(self):
1029
+ self._success = test_utils.test_single_train_eval(
1030
+ 'samples/model_config/dbmtl_on_multi_numeric_raw_sequence_feature_taobao.config',
1031
+ self._test_dir)
1032
+ self.assertTrue(self._success)
1033
+
1034
+ def test_multi_numeric_num_buckets_sequence_dbmtl(self):
1035
+ self._success = test_utils.test_single_train_eval(
1036
+ 'samples/model_config/dbmtl_on_multi_numeric_num_buckets_sequence_feature_taobao.config',
1037
+ self._test_dir)
1038
+ self.assertTrue(self._success)
1039
+
1040
+ def test_multi_sequence_dbmtl(self):
1041
+ self._success = test_utils.test_single_train_eval(
1042
+ 'samples/model_config/dbmtl_on_multi_sequence_feature_taobao.config',
1043
+ self._test_dir)
1044
+ self.assertTrue(self._success)
1045
+
1046
+ def test_multi_optimizer(self):
1047
+ self._success = test_utils.test_distributed_train_eval(
1048
+ 'samples/model_config/wide_and_deep_two_opti.config', self._test_dir)
1049
+ self.assertTrue(self._success)
1050
+
1051
+ def test_embedding_separate_optimizer(self):
1052
+ self._success = test_utils.test_distributed_train_eval(
1053
+ 'samples/model_config/deepfm_combo_on_avazu_embed_adagrad.config',
1054
+ self._test_dir)
1055
+ self.assertTrue(self._success)
1056
+
1057
+ def test_expr_feature(self):
1058
+ self._success = test_utils.test_single_train_eval(
1059
+ 'samples/model_config/multi_tower_on_taobao_for_expr.config',
1060
+ self._test_dir)
1061
+ self.assertTrue(self._success)
1062
+
1063
+ def test_gzip_data(self):
1064
+ self._success = test_utils.test_single_train_eval(
1065
+ 'samples/model_config/din_on_gzip_data.config', self._test_dir)
1066
+ self.assertTrue(self._success)
1067
+
1068
+ def test_cmd_config_param(self):
1069
+
1070
+ def _post_check_config(pipeline_config):
1071
+ train_saved_config_path = os.path.join(self._test_dir,
1072
+ 'train/pipeline.config')
1073
+ pipeline_config = config_util.get_configs_from_pipeline_file(
1074
+ train_saved_config_path)
1075
+ assert pipeline_config.model_config.deepfm.wide_output_dim == 8,\
1076
+ 'invalid model_config.deepfm.wide_output_dim=%d' % \
1077
+ pipeline_config.model_config.deepfm.wide_output_dim
1078
+
1079
+ self._success = test_utils.test_single_train_eval(
1080
+ 'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config',
1081
+ self._test_dir,
1082
+ post_check_func=_post_check_config,
1083
+ extra_cmd_args='--model_config.deepfm.wide_output_dim 8')
1084
+
1085
+ def test_cmd_config_param_v2(self):
1086
+
1087
+ def _post_check_config(pipeline_config):
1088
+ train_saved_config_path = os.path.join(self._test_dir,
1089
+ 'train/pipeline.config')
1090
+ pipeline_config = config_util.get_configs_from_pipeline_file(
1091
+ train_saved_config_path)
1092
+ assert pipeline_config.model_config.deepfm.wide_output_dim == 1,\
1093
+ 'invalid model_config.deepfm.wide_output_dim=%d' % \
1094
+ pipeline_config.model_config.deepfm.wide_output_dim
1095
+
1096
+ self._success = test_utils.test_single_train_eval(
1097
+ 'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config',
1098
+ self._test_dir,
1099
+ post_check_func=_post_check_config,
1100
+ extra_cmd_args='--model_config.deepfm.wide_output_dim=1')
1101
+
1102
+ def test_cmd_config_param_v3(self):
1103
+
1104
+ def _post_check_config(pipeline_config):
1105
+ train_saved_config_path = os.path.join(self._test_dir,
1106
+ 'train/pipeline.config')
1107
+ pipeline_config = config_util.get_configs_from_pipeline_file(
1108
+ train_saved_config_path)
1109
+ assert pipeline_config.model_config.deepfm.wide_output_dim == 3,\
1110
+ 'invalid model_config.deepfm.wide_output_dim=%d' % \
1111
+ pipeline_config.model_config.deepfm.wide_output_dim
1112
+
1113
+ self._success = test_utils.test_single_train_eval(
1114
+ 'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config',
1115
+ self._test_dir,
1116
+ post_check_func=_post_check_config,
1117
+ extra_cmd_args='--model_config.deepfm.wide_output_dim="3"')
1118
+
1119
+ def test_distribute_eval_deepfm_multi_cls(self):
1120
+ cur_eval_path = 'data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls'
1121
+ self._success = test_utils.test_distributed_eval(
1122
+ 'samples/model_config/deepfm_distribute_eval_multi_cls_on_avazu_ctr.config',
1123
+ cur_eval_path, self._test_dir)
1124
+ self.assertTrue(self._success)
1125
+
1126
+ def test_distribute_eval_deepfm_single_cls(self):
1127
+ cur_eval_path = 'data/test/distribute_eval_test/dwd_distribute_eval_avazu_out_test_combo'
1128
+ self._success = test_utils.test_distributed_eval(
1129
+ 'samples/model_config/deepfm_distribute_eval_combo_on_avazu_ctr.config',
1130
+ cur_eval_path, self._test_dir)
1131
+ self.assertTrue(self._success)
1132
+
1133
+ def test_distribute_eval_dssm_pointwise_classification(self):
1134
+ cur_eval_path = 'data/test/distribute_eval_test/dssm_distribute_eval_pointwise_classification_taobao_ckpt'
1135
+ self._success = test_utils.test_distributed_eval(
1136
+ 'samples/model_config/dssm_distribute_eval_pointwise_classification_on_taobao.config',
1137
+ cur_eval_path, self._test_dir)
1138
+ self.assertTrue(self._success)
1139
+
1140
+ def test_distribute_eval_dssm_reg(self):
1141
+ cur_eval_path = 'data/test/distribute_eval_test/dssm_distribute_eval_reg_taobao_ckpt'
1142
+ self._success = test_utils.test_distributed_eval(
1143
+ 'samples/model_config/dssm_distribute_eval_reg_on_taobao.config',
1144
+ cur_eval_path, self._test_dir)
1145
+ self.assertTrue(self._success)
1146
+
1147
+ def test_distribute_eval_dropout(self):
1148
+ cur_eval_path = 'data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt'
1149
+ self._success = test_utils.test_distributed_eval(
1150
+ 'samples/model_config/dropoutnet_distribute_eval_on_taobao.config',
1151
+ cur_eval_path, self._test_dir)
1152
+ self.assertTrue(self._success)
1153
+
1154
+ def test_distribute_eval_esmm(self):
1155
+ cur_eval_path = 'data/test/distribute_eval_test/esmm_distribute_eval_taobao_ckpt'
1156
+ self._success = test_utils.test_distributed_eval(
1157
+ 'samples/model_config/esmm_distribute_eval_on_taobao.config',
1158
+ cur_eval_path, self._test_dir)
1159
+ self.assertTrue(self._success)
1160
+
1161
+ def test_share_no_used(self):
1162
+ self._success = test_utils.test_single_train_eval(
1163
+ 'samples/model_config/share_embedding_not_used.config', self._test_dir)
1164
+ self.assertTrue(self._success)
1165
+
1166
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
1167
+ def test_dssm_neg_sampler_sequence_feature(self):
1168
+ self._success = test_utils.test_single_train_eval(
1169
+ 'samples/model_config/dssm_neg_sampler_sequence_feature.config',
1170
+ self._test_dir)
1171
+ self.assertTrue(self._success)
1172
+
1173
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
1174
+ def test_dssm_neg_sampler_need_key_feature(self):
1175
+ self._success = test_utils.test_single_train_eval(
1176
+ 'samples/model_config/dssm_neg_sampler_need_key_feature.config',
1177
+ self._test_dir)
1178
+ self.assertTrue(self._success)
1179
+
1180
+ def test_dbmtl_on_multi_numeric_boundary_need_key_feature(self):
1181
+ self._success = test_utils.test_single_train_eval(
1182
+ 'samples/model_config/dbmtl_on_multi_numeric_boundary_need_key_feature_taobao.config',
1183
+ self._test_dir)
1184
+ self.assertTrue(self._success)
1185
+
1186
+ def test_dbmtl_on_multi_numeric_boundary_allow_key_transform(self):
1187
+ self._success = test_utils.test_single_train_eval(
1188
+ 'samples/model_config/dbmtl_on_multi_numeric_boundary_allow_key_transform.config',
1189
+ self._test_dir)
1190
+ self.assertTrue(self._success)
1191
+
1192
+ def test_dbmtl_on_multi_numeric_boundary_aux_hist_seq(self):
1193
+ self._success = test_utils.test_single_train_eval(
1194
+ 'samples/model_config/dbmtl_on_numeric_boundary_sequence_feature_aux_hist_seq_taobao.config',
1195
+ self._test_dir)
1196
+ self.assertTrue(self._success)
1197
+
1198
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
1199
+ def test_multi_tower_recall_neg_sampler_sequence_feature(self):
1200
+ self._success = test_utils.test_single_train_eval(
1201
+ 'samples/model_config/multi_tower_recall_neg_sampler_sequence_feature.config',
1202
+ self._test_dir)
1203
+ self.assertTrue(self._success)
1204
+
1205
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
1206
+ def test_multi_tower_recall_neg_sampler_only_sequence_feature(self):
1207
+ self._success = test_utils.test_single_train_eval(
1208
+ 'samples/model_config/multi_tower_recall_neg_sampler_only_sequence_feature.config',
1209
+ self._test_dir)
1210
+ self.assertTrue(self._success)
1211
+
1212
+ @unittest.skipIf(hvd is None, 'horovod is not installed')
1213
+ def test_horovod(self):
1214
+ self._success = test_utils.test_distributed_train_eval(
1215
+ 'samples/model_config/deepfm_combo_on_avazu_ctr.config',
1216
+ self._test_dir,
1217
+ use_hvd=True)
1218
+ self.assertTrue(self._success)
1219
+
1220
+ @unittest.skipIf(hvd is None or sok is None,
1221
+ 'horovod and sok is not installed')
1222
+ def test_sok(self):
1223
+ self._success = test_utils.test_distributed_train_eval(
1224
+ 'samples/model_config/multi_tower_on_taobao_sok.config',
1225
+ self._test_dir,
1226
+ use_hvd=True)
1227
+ self.assertTrue(self._success)
1228
+
1229
+ @unittest.skipIf(
1230
+ six.PY2 or tf_version.split('.')[0] != '2',
1231
+ 'only run on python3 and tf 2.x')
1232
+ def test_train_parquet(self):
1233
+ os.environ[constant.NO_ARITHMETRIC_OPTI] = '1'
1234
+ self._success = test_utils.test_single_train_eval(
1235
+ 'samples/model_config/dlrm_on_criteo_parquet.config', self._test_dir)
1236
+ self.assertTrue(self._success)
1237
+
1238
+ @unittest.skipIf(hvd is None, 'horovod is not installed')
1239
+ def test_train_parquet_embedding_parallel(self):
1240
+ self._success = test_utils.test_distributed_train_eval(
1241
+ 'samples/model_config/dlrm_on_criteo_parquet_ep.config',
1242
+ self._test_dir,
1243
+ use_hvd=True)
1244
+ self.assertTrue(self._success)
1245
+
1246
+ @unittest.skipIf(hvd is None, 'horovod is not installed')
1247
+ def test_train_parquet_embedding_parallel_v2(self):
1248
+ self._success = test_utils.test_distributed_train_eval(
1249
+ 'samples/model_config/dlrm_on_criteo_parquet_ep_v2.config',
1250
+ self._test_dir,
1251
+ use_hvd=True)
1252
+ self.assertTrue(self._success)
1253
+
1254
+ def test_pdn(self):
1255
+ self._success = test_utils.test_single_train_eval(
1256
+ 'samples/model_config/pdn_on_taobao.config', self._test_dir)
1257
+ self.assertTrue(self._success)
1258
+
1259
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
1260
+ def test_dssm_senet(self):
1261
+ self._success = test_utils.test_single_train_eval(
1262
+ 'samples/model_config/dssm_senet_on_taobao.config', self._test_dir)
1263
+ self.assertTrue(self._success)
1264
+
1265
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
1266
+ def test_dssm_backbone_on_taobao(self):
1267
+ self._success = test_utils.test_single_train_eval(
1268
+ 'samples/model_config/dssm_on_taobao_backbone.config', self._test_dir)
1269
+ self.assertTrue(self._success)
1270
+
1271
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
1272
+ def test_dssm_senet_backbone_on_taobao(self):
1273
+ self._success = test_utils.test_single_train_eval(
1274
+ 'samples/model_config/dssm_senet_on_taobao_backbone.config',
1275
+ self._test_dir)
1276
+ self.assertTrue(self._success)
1277
+
1278
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
1279
+ def test_parallel_dssm_backbone_on_taobao(self):
1280
+ self._success = test_utils.test_single_train_eval(
1281
+ 'samples/model_config/parallel_dssm_on_taobao_backbone.config',
1282
+ self._test_dir)
1283
+ self.assertTrue(self._success)
1284
+
1285
+ def test_xdeefm_backbone_on_taobao(self):
1286
+ self._success = test_utils.test_single_train_eval(
1287
+ 'samples/model_config/xdeepfm_on_taobao_backbone.config',
1288
+ self._test_dir)
1289
+ self.assertTrue(self._success)
1290
+
1291
+ @unittest.skipIf(gl is None, 'graphlearn is not installed')
1292
+ def test_dat_on_taobao(self):
1293
+ self._success = test_utils.test_single_train_eval(
1294
+ 'samples/model_config/dat_on_taobao.config', self._test_dir)
1295
+ self.assertTrue(self._success)
1296
+
1297
+
1298
+ if __name__ == '__main__':
1299
+ tf.test.main()