easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,513 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ # Date: 2020-10-06
4
+ # Filename:export_test.py
5
+ import functools
6
+ import json
7
+ import logging
8
+ import os
9
+ import unittest
10
+
11
+ import numpy as np
12
+ import tensorflow as tf
13
+ from tensorflow.python.platform import gfile
14
+
15
+ import easy_rec
16
+ from easy_rec.python.inference.predictor import Predictor
17
+ from easy_rec.python.utils import config_util
18
+ from easy_rec.python.utils import test_utils
19
+ from easy_rec.python.utils.test_utils import RunAsSubprocess
20
+
21
+
22
+ class ExportTest(tf.test.TestCase):
23
+
24
+ def setUp(self):
25
+ logging.info('Testing %s.%s' % (type(self).__name__, self._testMethodName))
26
+
27
+ def tearDown(self):
28
+ test_utils.set_gpu_id(None)
29
+
30
+ @RunAsSubprocess
31
+ def _predict_and_check(self,
32
+ data_path,
33
+ saved_model_dir,
34
+ cmp_result,
35
+ keys=['probs'],
36
+ separator=',',
37
+ tol=1e-4):
38
+ predictor = Predictor(saved_model_dir)
39
+ with open(data_path, 'r') as fin:
40
+ inputs = []
41
+ for line_str in fin:
42
+ line_str = line_str.strip()
43
+ if len(predictor.input_names) > 1:
44
+ inputs.append(line_str.split(separator))
45
+ else:
46
+ inputs.append(line_str)
47
+ output_res = predictor.predict(inputs, batch_size=32)
48
+
49
+ for i in range(len(output_res)):
50
+ for key in keys:
51
+ val0 = output_res[i][key]
52
+ val1 = cmp_result[i][key]
53
+ diff = np.max(np.abs(val0 - val1))
54
+ assert diff < tol, \
55
+ 'too much difference: %.6f for %s, tol=%.6f' \
56
+ % (diff, key, tol)
57
+
58
+ def _extract_data(self, input_path, output_path, offset=1, separator=','):
59
+ with open(input_path, 'r') as fin:
60
+ with open(output_path, 'w') as fout:
61
+ for line_str in fin:
62
+ line_str = line_str.strip()
63
+ line_toks = line_str.split(separator)
64
+ if offset > 0:
65
+ line_toks = line_toks[offset:]
66
+ fout.write('%s\n' % (separator.join(line_toks)))
67
+
68
+ def _extract_rtp_data(self, input_path, output_path, separator=';'):
69
+ with open(input_path, 'r') as fin:
70
+ with open(output_path, 'w') as fout:
71
+ for line_str in fin:
72
+ line_str = line_str.strip()
73
+ line_toks = line_str.split(separator)
74
+ fout.write('%s\n' % line_toks[-1])
75
+
76
+ def test_multi_tower(self):
77
+ self._export_test('samples/model_config/multi_tower_export.config',
78
+ self._extract_data)
79
+
80
+ def test_filter_input(self):
81
+ self._export_test('samples/model_config/export_filter_input.config',
82
+ self._extract_data)
83
+
84
+ def test_mmoe(self):
85
+ self._export_test(
86
+ 'samples/model_config/mmoe_on_taobao.config',
87
+ functools.partial(self._extract_data, offset=2),
88
+ keys=['probs_ctr', 'probs_cvr'])
89
+
90
+ def test_fg(self):
91
+ self._export_test(
92
+ 'samples/model_config/taobao_fg.config',
93
+ self._extract_rtp_data,
94
+ separator='')
95
+
96
+ def test_fg_export(self):
97
+ self._export_test(
98
+ 'samples/model_config/taobao_fg_export.config',
99
+ self._extract_rtp_data,
100
+ separator='',
101
+ test_multi=False)
102
+
103
+ def test_export_with_asset(self):
104
+ pipeline_config_path = 'samples/model_config/taobao_fg.config'
105
+ test_dir = test_utils.get_tmp_dir()
106
+ # prepare model
107
+ self.assertTrue(
108
+ test_utils.test_single_train_eval(
109
+ pipeline_config_path, test_dir=test_dir))
110
+ test_utils.set_gpu_id(None)
111
+ config_path = os.path.join(test_dir, 'pipeline.config')
112
+ export_dir = os.path.join(test_dir, 'export/')
113
+ export_cmd = """
114
+ python -m easy_rec.python.export
115
+ --pipeline_config_path %s
116
+ --export_dir %s
117
+ --asset_files fg.json:samples/model_config/taobao_fg.json
118
+ --export_done_file ExportDone
119
+ """ % (
120
+ config_path,
121
+ export_dir,
122
+ )
123
+ proc = test_utils.run_cmd(export_cmd,
124
+ '%s/log_%s.txt' % (test_dir, 'export'))
125
+ proc.wait()
126
+ self.assertTrue(proc.returncode == 0)
127
+ files = gfile.Glob(export_dir + '*')
128
+ export_dir = files[0]
129
+ assert gfile.Exists(export_dir + '/assets/taobao_fg.json')
130
+ assert gfile.Exists(export_dir + '/assets/pipeline.config')
131
+ assert gfile.Exists(export_dir + '/ExportDone')
132
+
133
+ def test_export_with_out_in_ckpt_config(self):
134
+ test_dir = test_utils.get_tmp_dir()
135
+ logging.info('test dir: %s' % test_dir)
136
+
137
+ pipeline_config_path = 'samples/model_config/mmoe_on_taobao.config'
138
+
139
+ def _post_check_func(pipeline_config):
140
+ ckpt_path = tf.train.latest_checkpoint(pipeline_config.model_dir)
141
+ export_dir = os.path.join(test_dir, 'train/export/no_config')
142
+ export_cmd = """
143
+ python -m easy_rec.python.export
144
+ --pipeline_config_path %s
145
+ --checkpoint_path %s
146
+ --export_dir %s
147
+ """ % (pipeline_config_path, ckpt_path, export_dir)
148
+ proc = test_utils.run_cmd(export_cmd,
149
+ '%s/log_%s.txt' % (test_dir, 'export'))
150
+ proc.wait()
151
+ return proc.returncode == 0
152
+
153
+ # prepare model
154
+ self.assertTrue(
155
+ test_utils.test_single_train_eval(
156
+ pipeline_config_path,
157
+ test_dir=test_dir,
158
+ post_check_func=_post_check_func))
159
+
160
+ def test_multi_class_predict(self):
161
+ self._export_test(
162
+ 'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config',
163
+ extract_data_func=self._extract_data,
164
+ keys=['probs', 'logits', 'probs_y', 'logits_y', 'y'])
165
+
166
+ def _export_test(self,
167
+ pipeline_config_path,
168
+ extract_data_func=None,
169
+ separator=',',
170
+ keys=['probs'],
171
+ test_multi=True):
172
+ test_dir = test_utils.get_tmp_dir()
173
+ logging.info('test dir: %s' % test_dir)
174
+
175
+ # prepare model
176
+ self.assertTrue(
177
+ test_utils.test_single_train_eval(
178
+ pipeline_config_path, test_dir=test_dir))
179
+ test_utils.set_gpu_id(None)
180
+
181
+ # prepare two version config
182
+ config_path_single = os.path.join(test_dir, 'pipeline.config')
183
+ config_path_multi = os.path.join(test_dir, 'pipeline_v2.config')
184
+ pipeline_config = config_util.get_configs_from_pipeline_file(
185
+ config_path_single)
186
+ if pipeline_config.export_config.multi_placeholder:
187
+ config_path_single, config_path_multi = config_path_multi, config_path_single
188
+ pipeline_config.export_config.multi_placeholder =\
189
+ not pipeline_config.export_config.multi_placeholder
190
+ config_util.save_pipeline_config(pipeline_config, test_dir,
191
+ 'pipeline_v2.config')
192
+
193
+ # prepare two version export dir
194
+ export_dir_single = os.path.join(test_dir, 'train/export/final')
195
+ export_dir_multi = os.path.join(test_dir, 'train/export/multi')
196
+ export_cmd = """
197
+ python -m easy_rec.python.export
198
+ --pipeline_config_path %s
199
+ --export_dir %s
200
+ """ % (config_path_multi, export_dir_multi)
201
+ proc = test_utils.run_cmd(export_cmd,
202
+ '%s/log_%s.txt' % (test_dir, 'export'))
203
+ proc.wait()
204
+ self.assertTrue(proc.returncode == 0)
205
+
206
+ # use checkpoint to prepare result
207
+ result_path = os.path.join(test_dir, 'result.txt')
208
+ predict_cmd = """
209
+ python -m easy_rec.python.predict
210
+ --pipeline_config_path %s
211
+ --output_path %s
212
+ """ % (config_path_single, result_path)
213
+ proc = test_utils.run_cmd(predict_cmd % (),
214
+ '%s/log_%s.txt' % (test_dir, 'predict'))
215
+ proc.wait()
216
+ self.assertTrue(proc.returncode == 0)
217
+ with open(result_path, 'r') as fin:
218
+ cmp_result = []
219
+ for line_str in fin:
220
+ line_str = line_str.strip()
221
+ cmp_result.append(json.loads(line_str))
222
+
223
+ test_data_path = pipeline_config.eval_input_path
224
+ if extract_data_func is not None:
225
+ tmp_data_path = os.path.join(test_dir, 'pred_input_data')
226
+ extract_data_func(test_data_path, tmp_data_path)
227
+ test_data_path = tmp_data_path
228
+ self._predict_and_check(
229
+ test_data_path,
230
+ export_dir_single,
231
+ cmp_result,
232
+ keys=keys,
233
+ separator=separator)
234
+ if test_multi:
235
+ self._predict_and_check(
236
+ test_data_path,
237
+ export_dir_multi,
238
+ cmp_result,
239
+ keys=keys,
240
+ separator=separator)
241
+ test_utils.clean_up(test_dir)
242
+
243
+ def _test_big_model_export(self,
244
+ pipeline_config_path,
245
+ test_data_path,
246
+ extract_data_func=None,
247
+ total_steps=50):
248
+ test_dir = test_utils.get_tmp_dir()
249
+ logging.info('test dir: %s' % test_dir)
250
+
251
+ lookup_op_path = os.path.join(easy_rec.ops_dir, 'libembed_op.so')
252
+ tf.load_op_library(lookup_op_path)
253
+
254
+ # prepare model
255
+ self.assertTrue(
256
+ test_utils.test_single_train_eval(
257
+ pipeline_config_path, test_dir=test_dir, total_steps=total_steps))
258
+
259
+ test_utils.set_gpu_id(None)
260
+ # the pipeline.config is produced by the prepare model cmd
261
+ config_path = os.path.join(test_dir, 'pipeline.config')
262
+ export_dir = os.path.join(test_dir, 'export/')
263
+ export_cmd = """
264
+ python -m easy_rec.python.export
265
+ --pipeline_config_path %s
266
+ --export_dir %s
267
+ --asset_files %s
268
+ --redis_url %s
269
+ --redis_passwd %s
270
+ --redis_threads 1
271
+ --redis_write_kv 1
272
+ --verbose 1
273
+ """ % (config_path, export_dir, test_data_path, os.environ['redis_url'],
274
+ os.environ['redis_passwd'])
275
+ proc = test_utils.run_cmd(export_cmd,
276
+ '%s/log_%s.txt' % (test_dir, 'export'))
277
+ proc.wait()
278
+ self.assertTrue(proc.returncode == 0)
279
+
280
+ export_dir = gfile.Glob(export_dir + '[0-9][0-9][0-9]*')[0]
281
+ _, test_data_name = os.path.split(test_data_path)
282
+ assert gfile.Exists(export_dir + '/assets/' + test_data_name)
283
+
284
+ # use checkpoint to prepare result
285
+ result_path = os.path.join(test_dir, 'result.txt')
286
+ predict_cmd = """
287
+ python -m easy_rec.python.predict
288
+ --pipeline_config_path %s
289
+ --input_path %s
290
+ --output_path %s
291
+ """ % (config_path, test_data_path, result_path)
292
+ proc = test_utils.run_cmd(predict_cmd % (),
293
+ '%s/log_%s.txt' % (test_dir, 'predict'))
294
+ proc.wait()
295
+ self.assertTrue(proc.returncode == 0)
296
+ with open(result_path, 'r') as fin:
297
+ cmp_result = []
298
+ for line_str in fin:
299
+ line_str = line_str.strip()
300
+ cmp_result.append(json.loads(line_str))
301
+
302
+ if extract_data_func is not None:
303
+ tmp_data_path = os.path.join(test_dir, 'pred_input_data')
304
+ extract_data_func(test_data_path, tmp_data_path)
305
+ test_data_path = tmp_data_path
306
+ self._predict_and_check(test_data_path, export_dir, cmp_result)
307
+
308
+ @unittest.skipIf(
309
+ 'redis_url' not in os.environ,
310
+ 'Only execute when redis is available: redis_url, redis_passwd')
311
+ def test_big_model_export(self):
312
+ pipeline_config_path = 'samples/model_config/multi_tower_export.config'
313
+ test_data_path = 'data/test/export/data.csv'
314
+ self._test_big_model_export(
315
+ pipeline_config_path,
316
+ test_data_path,
317
+ extract_data_func=self._extract_data)
318
+
319
+ @unittest.skipIf(
320
+ 'redis_url' not in os.environ,
321
+ 'Only execute when redis is available: redis_url, redis_passwd')
322
+ def test_big_model_deepfm_export(self):
323
+ pipeline_config_path = 'samples/model_config/deepfm_combo_on_avazu_ctr.config'
324
+ test_data_path = 'data/test/dwd_avazu_ctr_deepmodel_10w.csv'
325
+ self._test_big_model_export(
326
+ pipeline_config_path,
327
+ test_data_path,
328
+ extract_data_func=self._extract_data)
329
+
330
+ @unittest.skipIf(
331
+ 'redis_url' not in os.environ,
332
+ 'Only execute when redis is available: redis_url, redis_passwd')
333
+ def test_big_model_din_export(self):
334
+ pipeline_config_path = 'samples/model_config/din_on_taobao.config'
335
+ test_data_path = 'data/test/tb_data/taobao_test_data'
336
+ self._test_big_model_export(
337
+ pipeline_config_path,
338
+ test_data_path,
339
+ extract_data_func=functools.partial(self._extract_data, offset=2))
340
+
341
+ @unittest.skipIf(
342
+ 'redis_url' not in os.environ,
343
+ 'Only execute when redis is available: redis_url, redis_passwd')
344
+ def test_big_model_wide_and_deep_export(self):
345
+ pipeline_config_path = 'samples/model_config/wide_and_deep_two_opti.config'
346
+ test_data_path = 'data/test/dwd_avazu_ctr_deepmodel_10w.csv'
347
+ self._test_big_model_export(
348
+ pipeline_config_path,
349
+ test_data_path,
350
+ extract_data_func=functools.partial(self._extract_data))
351
+
352
+ @unittest.skipIf(
353
+ 'redis_url' not in os.environ or '-PAI' not in tf.__version__,
354
+ 'Only execute when pai-tf and redis is available: redis_url, redis_passwd'
355
+ )
356
+ def test_big_model_embedding_variable_export(self):
357
+ pipeline_config_path = 'samples/model_config/taobao_fg_ev.config'
358
+ test_data_path = 'data/test/rtp/taobao_valid_feature.txt'
359
+ self._test_big_model_export(
360
+ pipeline_config_path,
361
+ test_data_path,
362
+ self._extract_rtp_data,
363
+ total_steps=1000)
364
+
365
+ @unittest.skipIf(
366
+ 'oss_endpoint' not in os.environ or 'oss_ak' not in os.environ or
367
+ 'oss_sk' not in os.environ or 'oss_path' not in os.environ or
368
+ '-PAI' not in tf.__version__,
369
+ 'Only execute oss params(oss_endpoint,oss_ak,oss_sk) are specified,'
370
+ 'and pai-tf is available.')
371
+ def test_big_model_embedding_variable_oss_export(self):
372
+ pipeline_config_path = 'samples/model_config/taobao_fg_ev.config'
373
+ test_data_path = 'data/test/rtp/taobao_valid_feature.txt'
374
+ self._test_big_model_export_to_oss(
375
+ pipeline_config_path,
376
+ test_data_path,
377
+ self._extract_rtp_data,
378
+ total_steps=100)
379
+
380
+ @unittest.skipIf(
381
+ 'oss_endpoint' not in os.environ or 'oss_ak' not in os.environ or
382
+ 'oss_sk' not in os.environ or 'oss_path' not in os.environ or
383
+ '-PAI' not in tf.__version__,
384
+ 'Only execute oss params(oss_endpoint,oss_ak,oss_sk) are specified,'
385
+ 'and pai-tf is available.')
386
+ def test_big_model_embedding_variable_v2_oss_export(self):
387
+ pipeline_config_path = 'samples/model_config/taobao_fg_ev_v2.config'
388
+ test_data_path = 'data/test/rtp/taobao_valid_feature.txt'
389
+ self._test_big_model_export_to_oss(
390
+ pipeline_config_path,
391
+ test_data_path,
392
+ self._extract_rtp_data,
393
+ total_steps=100)
394
+
395
+ def _test_big_model_export_to_oss(self,
396
+ pipeline_config_path,
397
+ test_data_path,
398
+ extract_data_func=None,
399
+ total_steps=50):
400
+ test_dir = test_utils.get_tmp_dir()
401
+ logging.info('test dir: %s' % test_dir)
402
+
403
+ lookup_op_path = os.path.join(easy_rec.ops_dir, 'libembed_op.so')
404
+ tf.load_op_library(lookup_op_path)
405
+
406
+ # prepare model
407
+ self.assertTrue(
408
+ test_utils.test_single_train_eval(
409
+ pipeline_config_path, test_dir=test_dir, total_steps=total_steps))
410
+
411
+ test_utils.set_gpu_id(None)
412
+ # the pipeline.config is produced by the prepare model cmd
413
+ config_path = os.path.join(test_dir, 'pipeline.config')
414
+ export_dir = os.path.join(test_dir, 'export/')
415
+ export_cmd = """
416
+ python -m easy_rec.python.export
417
+ --pipeline_config_path %s
418
+ --export_dir %s
419
+ --asset_files %s
420
+ --oss_path %s
421
+ --oss_endpoint %s
422
+ --oss_ak %s --oss_sk %s
423
+ --oss_threads 5
424
+ --oss_timeout 10
425
+ --oss_write_kv 1
426
+ --verbose 1
427
+ """ % (config_path, export_dir, test_data_path, os.environ['oss_path'],
428
+ os.environ['oss_endpoint'], os.environ['oss_ak'],
429
+ os.environ['oss_sk'])
430
+ proc = test_utils.run_cmd(export_cmd,
431
+ '%s/log_%s.txt' % (test_dir, 'export'))
432
+ proc.wait()
433
+ self.assertTrue(proc.returncode == 0)
434
+
435
+ export_dir = gfile.Glob(export_dir + '[0-9][0-9][0-9]*')[0]
436
+ _, test_data_name = os.path.split(test_data_path)
437
+ assert gfile.Exists(export_dir + '/assets/' + test_data_name)
438
+
439
+ # use checkpoint to prepare result
440
+ result_path = os.path.join(test_dir, 'result.txt')
441
+ predict_cmd = """
442
+ python -m easy_rec.python.predict
443
+ --pipeline_config_path %s
444
+ --input_path %s
445
+ --output_path %s
446
+ """ % (config_path, test_data_path, result_path)
447
+ proc = test_utils.run_cmd(predict_cmd,
448
+ '%s/log_%s.txt' % (test_dir, 'predict'))
449
+ proc.wait()
450
+ self.assertTrue(proc.returncode == 0)
451
+ with open(result_path, 'r') as fin:
452
+ cmp_result = []
453
+ for line_str in fin:
454
+ line_str = line_str.strip()
455
+ cmp_result.append(json.loads(line_str))
456
+
457
+ if extract_data_func is not None:
458
+ tmp_data_path = os.path.join(test_dir, 'pred_input_data')
459
+ extract_data_func(test_data_path, tmp_data_path)
460
+ test_data_path = tmp_data_path
461
+ self._predict_and_check(test_data_path, export_dir, cmp_result)
462
+
463
+ @unittest.skipIf(
464
+ 'oss_path' not in os.environ,
465
+ 'Only execute when oss is available: oss_path, oss_endpoint, oss_ak, oss_sk'
466
+ )
467
+ def test_big_model_export_to_oss(self):
468
+ pipeline_config_path = 'samples/model_config/multi_tower_export.config'
469
+ test_data_path = 'data/test/export/data.csv'
470
+ self._test_big_model_export_to_oss(
471
+ pipeline_config_path,
472
+ test_data_path,
473
+ extract_data_func=self._extract_data)
474
+
475
+ @unittest.skipIf(
476
+ 'oss_path' not in os.environ,
477
+ 'Only execute when oss is available: oss_path, oss_endpoint, oss_ak, oss_sk'
478
+ )
479
+ def test_big_model_deepfm_export_to_oss(self):
480
+ pipeline_config_path = 'samples/model_config/deepfm_combo_on_avazu_ctr.config'
481
+ test_data_path = 'data/test/dwd_avazu_ctr_deepmodel_10w.csv'
482
+ self._test_big_model_export_to_oss(
483
+ pipeline_config_path,
484
+ test_data_path,
485
+ extract_data_func=self._extract_data)
486
+
487
+ @unittest.skipIf(
488
+ 'oss_path' not in os.environ,
489
+ 'Only execute when oss is available: oss_path, oss_endpoint, oss_ak, oss_sk'
490
+ )
491
+ def test_big_model_din_export_to_oss(self):
492
+ pipeline_config_path = 'samples/model_config/din_on_taobao.config'
493
+ test_data_path = 'data/test/tb_data/taobao_test_data'
494
+ self._test_big_model_export_to_oss(
495
+ pipeline_config_path,
496
+ test_data_path,
497
+ extract_data_func=functools.partial(self._extract_data, offset=2))
498
+
499
+ @unittest.skipIf(
500
+ 'oss_path' not in os.environ,
501
+ 'Only execute when oss is available: oss_path, oss_endpoint, oss_ak, oss_sk'
502
+ )
503
+ def test_big_model_wide_and_deep_export_to_oss(self):
504
+ pipeline_config_path = 'samples/model_config/wide_and_deep_two_opti.config'
505
+ test_data_path = 'data/test/dwd_avazu_ctr_deepmodel_10w.csv'
506
+ self._test_big_model_export_to_oss(
507
+ pipeline_config_path,
508
+ test_data_path,
509
+ extract_data_func=functools.partial(self._extract_data))
510
+
511
+
512
+ if __name__ == '__main__':
513
+ tf.test.main()
@@ -0,0 +1,70 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+ import unittest
5
+
6
+ import tensorflow as tf
7
+ from google.protobuf import text_format
8
+
9
+ from easy_rec.python.utils import config_util
10
+ from easy_rec.python.utils import fg_util
11
+ from easy_rec.python.utils import test_utils
12
+
13
+ if tf.__version__ >= '2.0':
14
+ tf = tf.compat.v1
15
+
16
+
17
+ class FGTest(tf.test.TestCase):
18
+
19
+ def __init__(self, methodName='FGTest'):
20
+ super(FGTest, self).__init__(methodName=methodName)
21
+
22
+ def setUp(self):
23
+ logging.info('Testing %s.%s' % (type(self).__name__, self._testMethodName))
24
+ self._test_dir = test_utils.get_tmp_dir()
25
+ self._success = True
26
+ logging.info('test dir: %s' % self._test_dir)
27
+
28
+ def tearDown(self):
29
+ test_utils.set_gpu_id(None)
30
+ if self._success:
31
+ test_utils.clean_up(self._test_dir)
32
+
33
+ def test_fg_json_to_config(self):
34
+ pipeline_config_path = 'samples/rtp_fg/fg_test_extensions.config'
35
+ final_pipeline_config_path = 'samples/rtp_fg/fg_test_extensions_final.config'
36
+ fg_path = 'samples/rtp_fg/fg_test_extensions.json'
37
+
38
+ pipeline_config = config_util.get_configs_from_pipeline_file(
39
+ pipeline_config_path)
40
+ pipeline_config.fg_json_path = fg_path
41
+ fg_util.load_fg_json_to_config(pipeline_config)
42
+ pipeline_config_str = text_format.MessageToString(
43
+ pipeline_config, as_utf8=True)
44
+
45
+ final_pipeline_config = config_util.get_configs_from_pipeline_file(
46
+ final_pipeline_config_path)
47
+ final_pipeline_config_str = text_format.MessageToString(
48
+ final_pipeline_config, as_utf8=True)
49
+ self.assertEqual(pipeline_config_str, final_pipeline_config_str)
50
+
51
+ def test_fg_dtype(self):
52
+ self._success = test_utils.test_single_train_eval(
53
+ 'samples/model_config/taobao_fg_test_dtype.config', self._test_dir)
54
+ self.assertTrue(self._success)
55
+
56
+ def test_fg_train(self):
57
+ self._success = test_utils.test_single_train_eval(
58
+ 'samples/model_config/fg_train.config', self._test_dir)
59
+ self.assertTrue(self._success)
60
+
61
+ @unittest.skipIf('-PAI' not in tf.__version__,
62
+ 'Only test when pai-tf is used.')
63
+ def test_fg_train_ev(self):
64
+ self._success = test_utils.test_single_train_eval(
65
+ 'samples/model_config/fg_train_ev.config', self._test_dir)
66
+ self.assertTrue(self._success)
67
+
68
+
69
+ if __name__ == '__main__':
70
+ tf.test.main()