easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,237 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import logging
5
+ import os
6
+ import time
7
+ import traceback
8
+
9
+ import oss2
10
+
11
+ try:
12
+ from datahub import DataHub
13
+ from datahub.exceptions import InvalidOperationException
14
+ from datahub.exceptions import ResourceExistException
15
+ # from datahub.exceptions import LimitExceededException
16
+ # from datahub.exceptions import ResourceNotFoundException
17
+ # from datahub.models import BlobRecord
18
+ # from datahub.models import CursorType
19
+ from datahub.models import FieldType
20
+ from datahub.models import RecordSchema
21
+ from datahub.models import RecordType
22
+ from datahub.models import TupleRecord
23
+ except Exception:
24
+ logging.error(
25
+ 'DataHub is not installed, please installed it by: pip install pydatahub')
26
+ DataHub = None
27
+
28
+ try:
29
+ from odps import ODPS
30
+ from odps.df import DataFrame
31
+ except Exception:
32
+ ODPS = None
33
+ DataFrame = None
34
+
35
+
36
+ class OdpsOSSConfig:
37
+
38
+ def __init__(self, script_path='./samples/odps_script'):
39
+ self.time_stamp = int(time.time())
40
+ temp_dir = os.environ.get('TMPDIR', '/tmp')
41
+ self.exp_dir = 'easy_rec_odps_test_%d' % self.time_stamp
42
+ self.temp_dir = os.path.join(temp_dir, self.exp_dir)
43
+ self.log_dir = os.path.join(self.temp_dir, 'logs/')
44
+
45
+ # public buckets with readyonly access
46
+ self.ali_bucket_endpoint = 'http://oss-cn-beijing.aliyuncs.com'
47
+ self.ali_bucket_name = 'easyrec'
48
+ self.script_path = script_path
49
+ # read only access
50
+ self.ali_oss_key = os.environ['ALI_OSS_KEY']
51
+ self.ali_oss_secret = os.environ['ALI_OSS_SEC']
52
+
53
+ self.oss_key = ''
54
+ self.oss_secret = ''
55
+ self.endpoint = ''
56
+ self.arn = 'acs:ram::xxx:role/aliyunodpspaidefaultrole'
57
+ self.bucket_name = ''
58
+
59
+ self.odpscmd_path = os.environ.get('ODPS_CMD_PATH', 'odpscmd')
60
+ self.odps_config_path = ''
61
+
62
+ self.project_name = ''
63
+
64
+ self.dh_id = ''
65
+ self.dh_key = ''
66
+
67
+ self.dh_endpoint = 'https://dh-cn-beijing.aliyuncs.com'
68
+ self.dh_topic = 'easy_rec_test'
69
+ self.dh_project = 'easy_rec_test'
70
+
71
+ self.odps_endpoint = ''
72
+
73
+ self.dh = None
74
+ self.odps = None
75
+
76
+ # default to algo_public
77
+ self.algo_project = None
78
+ self.algo_res_project = None
79
+ self.algo_version = None
80
+ self.algo_name = 'easy_rec_ext'
81
+
82
+ # default to outer environment
83
+ # the difference are ossHost buckets arn settings
84
+ self.is_outer = True
85
+
86
+ def load_oss_config(self, config_path):
87
+ with open(config_path, 'r') as fin:
88
+ for line_str in fin:
89
+ line_str = line_str.strip()
90
+ line_str = line_str.replace(' ', '')
91
+ if line_str.startswith('accessKeyID='):
92
+ self.oss_key = line_str[len('accessKeyID='):].strip()
93
+ elif line_str.startswith('accessKeySecret='):
94
+ self.oss_secret = line_str[len('accessKeySecret='):].strip()
95
+ elif line_str.startswith('endpoint='):
96
+ self.endpoint = line_str[len('endpoint='):].strip()
97
+
98
+ def load_odps_config(self, config_path):
99
+ self.odps_config_path = config_path
100
+ with open(config_path, 'r') as fin:
101
+ for line_str in fin:
102
+ line_str = line_str.strip()
103
+ line_str = line_str.replace(' ', '')
104
+ key_str = 'project_name='
105
+ if line_str.startswith(key_str):
106
+ self.project_name = line_str[len(key_str):]
107
+ key_str = 'end_point='
108
+ if line_str.startswith(key_str):
109
+ self.odps_endpoint = line_str[len(key_str):]
110
+ key_str = 'access_id='
111
+ if line_str.startswith(key_str):
112
+ self.dh_id = line_str[len(key_str):]
113
+ key_str = 'access_key='
114
+ if line_str.startswith(key_str):
115
+ self.dh_key = line_str[len(key_str):]
116
+
117
+ def clean_topic(self, dh_project):
118
+ if not dh_project:
119
+ logging.error('project is empty .')
120
+ topic_names = self.dh.list_topic(dh_project).topic_names
121
+ for topic_name in topic_names:
122
+ self.clean_subscription(topic_name)
123
+ self.dh.delete_topic(dh_project, topic_name)
124
+
125
+ def clean_project(self):
126
+ project_names = self.dh.list_project().project_names
127
+ for dh_project in project_names:
128
+ if dh_project == self.dh_project:
129
+ self.clean_topic(dh_project)
130
+ try:
131
+ self.dh.delete_project(dh_project)
132
+ except InvalidOperationException:
133
+ pass
134
+
135
+ def clean_subscription(self, topic_name):
136
+ subscriptions = self.dh.list_subscription(self.dh_project, topic_name, '',
137
+ 1, 100).subscriptions
138
+ for subscription in subscriptions:
139
+ self.dh.delete_subscription(self.dh_project, topic_name, subscription)
140
+
141
+ def get_input_type(self, input_type):
142
+ DhDict = {
143
+ 'INT64': FieldType.BIGINT,
144
+ 'INT32': FieldType.BIGINT,
145
+ 'STRING': FieldType.STRING,
146
+ 'BOOLEAN': FieldType.BOOLEAN,
147
+ 'FLOAT32': FieldType.DOUBLE,
148
+ 'FLOAT64': FieldType.DOUBLE
149
+ }
150
+
151
+ return DhDict.get(input_type)
152
+
153
+ def init_dh_and_odps(self):
154
+ self.dh = DataHub(self.dh_id, self.dh_key, self.dh_endpoint)
155
+ self.odps = ODPS(self.dh_id, self.dh_key, self.project_name,
156
+ self.odps_endpoint)
157
+ self.odpsTable = 'deepfm_train_%s' % self.time_stamp
158
+ self.clean_project()
159
+ read_odps = DataFrame(self.odps.get_table(self.odpsTable))
160
+ col_name = read_odps.schema.names
161
+ col_type = [self.get_input_type(str(i)) for i in read_odps.schema.types]
162
+ try:
163
+ self.dh.create_project(self.dh_project, comment='EasyRecTest')
164
+ logging.info('create project success!')
165
+ except ResourceExistException:
166
+ logging.warning('project %s already exist!' % self.dh_project)
167
+ except Exception:
168
+ logging.error(traceback.format_exc())
169
+ record_schema = RecordSchema.from_lists(col_name, col_type)
170
+ try:
171
+ # project_name, topic_name, shard_count, life_cycle, record_schema, comment
172
+ self.dh.create_tuple_topic(
173
+ self.dh_project,
174
+ self.dh_topic,
175
+ 7,
176
+ 3,
177
+ record_schema,
178
+ comment='EasyRecTest')
179
+ logging.info('create tuple topic %s success!' % self.dh_topic)
180
+ except ResourceExistException:
181
+ logging.info('topic %s already exist!' % self.dh_topic)
182
+ except Exception as ex:
183
+ logging.error('exception:%s' % str(ex))
184
+ logging.error(traceback.format_exc())
185
+ try:
186
+ self.dh.wait_shards_ready(self.dh_project, self.dh_topic)
187
+ logging.info('datahub[%s,%s] shards all ready' %
188
+ (self.dh_project, self.dh_topic))
189
+ topic_result = self.dh.get_topic(self.dh_project, self.dh_topic)
190
+ if topic_result.record_type != RecordType.TUPLE:
191
+ logging.error('invalid topic type: %s' % str(topic_result.record_type))
192
+ record_schema = topic_result.record_schema
193
+ t = self.odps.get_table(self.odpsTable)
194
+ with t.open_reader() as reader:
195
+ record_list = []
196
+ for data in reader:
197
+ record = TupleRecord(values=data.values, schema=record_schema)
198
+ record_list.append(record)
199
+ for i in range(10):
200
+ self.dh.put_records(self.dh_project, self.dh_topic, record_list)
201
+ except Exception as ex:
202
+ logging.error('exception: %s' % str(ex))
203
+ logging.error(traceback.format_exc())
204
+
205
+
206
+ def get_oss_bucket(oss_key, oss_secret, endpoint, bucket_name):
207
+ """Build oss2.Bucket instance.
208
+
209
+ Args:
210
+ oss_key: oss access_key
211
+ oss_secret: oss access_secret
212
+ endpoint: oss endpoint
213
+ bucket_name: oss bucket name
214
+ Return:
215
+ oss2.Bucket instance
216
+ """
217
+ if oss_key is None or oss_secret is None:
218
+ logging.info('oss_key or oss_secret is None')
219
+ return None
220
+ auth = oss2.Auth(oss_key, oss_secret)
221
+ bucket = oss2.Bucket(auth, endpoint, bucket_name)
222
+ return bucket
223
+
224
+
225
+ def delete_oss_path(bucket, in_prefix, bucket_name):
226
+ """Delete oss path.
227
+
228
+ Args:
229
+ bucket: oss2.Bucket instance
230
+ in_prefix: oss path prefix to be removed
231
+ bucket_name: bucket_name
232
+ """
233
+ prefix = in_prefix.replace('oss://' + bucket_name + '/', '')
234
+ for obj in oss2.ObjectIterator(bucket, prefix=prefix):
235
+ bucket.delete_object(obj.key)
236
+ bucket.delete_object(prefix)
237
+ logging.info('delete oss path: %s, completed.' % in_prefix)
@@ -0,0 +1,54 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import logging
5
+
6
+ import tensorflow as tf
7
+
8
+ from easy_rec.python.utils import test_utils
9
+
10
+ if tf.__version__ >= '2.0':
11
+ tf = tf.compat.v1
12
+ gfile = tf.gfile
13
+
14
+
15
+ class CheckTest(tf.test.TestCase):
16
+
17
+ def setUp(self):
18
+ self._test_dir = test_utils.get_tmp_dir()
19
+ self._success = True
20
+ logging.info('Testing %s.%s' % (type(self).__name__, self._testMethodName))
21
+ logging.info('test dir: %s' % self._test_dir)
22
+
23
+ def tearDown(self):
24
+ test_utils.set_gpu_id(None)
25
+ if self._success:
26
+ test_utils.clean_up(self._test_dir)
27
+
28
+ def test_csv_input_train_with_check(self):
29
+ self._success = test_utils.test_single_train_eval(
30
+ 'samples/model_config/dbmtl_on_taobao.config',
31
+ self._test_dir,
32
+ check_mode=True)
33
+ self.assertTrue(self._success)
34
+
35
+ def test_rtp_input_train_with_check(self):
36
+ self._success = test_utils.test_single_train_eval(
37
+ 'samples/model_config/taobao_fg.config',
38
+ self._test_dir,
39
+ check_mode=True)
40
+ self.assertTrue(self._success)
41
+
42
+ def test_csv_input_with_pre_check(self):
43
+ self._success = test_utils.test_single_pre_check(
44
+ 'samples/model_config/dbmtl_on_taobao.config', self._test_dir)
45
+ self.assertTrue(self._success)
46
+
47
+ def test_rtp_input_with_pre_check(self):
48
+ self._success = test_utils.test_single_pre_check(
49
+ 'samples/model_config/dbmtl_on_taobao.config', self._test_dir)
50
+ self.assertTrue(self._success)
51
+
52
+
53
+ if __name__ == '__main__':
54
+ tf.test.main()
@@ -0,0 +1,394 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import csv
4
+ import json
5
+ import logging
6
+ import os
7
+ import shutil
8
+
9
+ import numpy as np
10
+ import tensorflow as tf
11
+
12
+ from easy_rec.python.inference.csv_predictor import CSVPredictor
13
+ from easy_rec.python.inference.predictor import Predictor
14
+ from easy_rec.python.utils import config_util
15
+ from easy_rec.python.utils import test_utils
16
+ from easy_rec.python.utils.test_utils import RunAsSubprocess
17
+
18
+
19
+ class PredictorTest(tf.test.TestCase):
20
+
21
+ def setUp(self):
22
+ self.gpus = test_utils.get_available_gpus()
23
+ self.assertTrue(len(self.gpus) > 0, 'no available gpu on this machine')
24
+ logging.info('available gpus %s' % self.gpus)
25
+ test_utils.set_gpu_id(self.gpus[0])
26
+ logging.info('Testing %s.%s' % (type(self).__name__, self._testMethodName))
27
+ self._test_path = 'data/test/inference/taobao_infer_data.txt'
28
+
29
+ def tearDown(self):
30
+ test_utils.set_gpu_id(None)
31
+
32
+ @RunAsSubprocess
33
+ def test_pred_list(self):
34
+ predictor = Predictor('data/test/inference/tb_multitower_export/')
35
+ with open(self._test_path, 'r') as fin:
36
+ reader = csv.reader(fin)
37
+ inputs = []
38
+ for row in reader:
39
+ inputs.append(row[2:])
40
+ output_res = predictor.predict(inputs, batch_size=32)
41
+ self.assertTrue(len(output_res) == 100)
42
+
43
+ @RunAsSubprocess
44
+ def test_lookup_pred(self):
45
+ predictor = Predictor('data/test/inference/lookup_export')
46
+ lookup_pred_path = 'data/test/inference/lookup_data_test80.csv'
47
+ with open(lookup_pred_path, 'r') as fin:
48
+ reader = csv.reader(fin)
49
+ inputs = []
50
+ for row in reader:
51
+ inputs.append(row[1:])
52
+ output_res = predictor.predict(inputs, batch_size=32)
53
+ self.assertTrue(len(output_res) == 80)
54
+
55
+ @RunAsSubprocess
56
+ def test_pred_dict(self):
57
+ predictor = Predictor('data/test/inference/tb_multitower_export/')
58
+ field_keys = [
59
+ 'pid', 'adgroup_id', 'cate_id', 'campaign_id', 'customer', 'brand',
60
+ 'user_id', 'cms_segid', 'cms_group_id', 'final_gender_code',
61
+ 'age_level', 'pvalue_level', 'shopping_level', 'occupation',
62
+ 'new_user_class_level', 'tag_category_list', 'tag_brand_list', 'price'
63
+ ]
64
+ with open(self._test_path, 'r') as fin:
65
+ reader = csv.reader(fin)
66
+ inputs = []
67
+ for row in reader:
68
+ inputs.append({f: row[fid + 2] for fid, f in enumerate(field_keys)})
69
+ output_res = predictor.predict(inputs, batch_size=32)
70
+ self.assertTrue(len(output_res) == 100)
71
+
72
+ @RunAsSubprocess
73
+ def test_pred_placeholder_named_by_input(self):
74
+ predictor = Predictor(
75
+ 'data/test/inference/tb_multitower_placeholder_rename_export/')
76
+ field_keys = [
77
+ 'pid', 'adgroup_id', 'cate_id', 'campaign_id', 'customer', 'brand',
78
+ 'user_id', 'cms_segid', 'cms_group_id', 'final_gender_code',
79
+ 'age_level', 'pvalue_level', 'shopping_level', 'occupation',
80
+ 'new_user_class_level', 'tag_category_list', 'tag_brand_list', 'price'
81
+ ]
82
+ with open(self._test_path, 'r') as fin:
83
+ reader = csv.reader(fin)
84
+ inputs = []
85
+ for row in reader:
86
+ line_input = {}
87
+ for fid, f in enumerate(field_keys):
88
+ if f in ['tag_category_list', 'tag_brand_list']:
89
+ line_input[f] = ['12', '23']
90
+ else:
91
+ line_input[f] = row[fid + 2]
92
+ inputs.append(line_input)
93
+ output_res = predictor.predict(inputs, batch_size=32)
94
+ self.assertTrue(len(output_res) == 100)
95
+
96
+ @RunAsSubprocess
97
+ def test_fm_pred_list(self):
98
+ predictor = Predictor('data/test/inference/fm_export/')
99
+ with open(self._test_path, 'r') as fin:
100
+ reader = csv.reader(fin)
101
+ inputs = []
102
+ for row in reader:
103
+ inputs.append(row[2:])
104
+ output_res = predictor.predict(inputs, batch_size=32)
105
+ self.assertTrue(len(output_res) == 100)
106
+
107
+ @RunAsSubprocess
108
+ def test_fm_pred_dict(self):
109
+ predictor = Predictor('data/test/inference/fm_export/')
110
+ field_keys = [
111
+ 'pid', 'adgroup_id', 'cate_id', 'campaign_id', 'customer', 'brand',
112
+ 'user_id', 'cms_segid', 'cms_group_id', 'final_gender_code',
113
+ 'age_level', 'pvalue_level', 'shopping_level', 'occupation',
114
+ 'new_user_class_level', 'tag_category_list', 'tag_brand_list', 'price'
115
+ ]
116
+ with open(self._test_path, 'r') as fin:
117
+ reader = csv.reader(fin)
118
+ inputs = []
119
+ for row in reader:
120
+ inputs.append({f: row[fid + 2] for fid, f in enumerate(field_keys)})
121
+ output_res = predictor.predict(inputs, batch_size=32)
122
+ self.assertTrue(len(output_res) == 100)
123
+
124
+
125
+ class PredictorTestOnDS(tf.test.TestCase):
126
+
127
+ def setUp(self):
128
+
129
+ self._test_dir = test_utils.get_tmp_dir()
130
+ self._test_output_path = None
131
+ logging.info('Testing %s.%s' % (type(self).__name__, self._testMethodName))
132
+
133
+ def tearDown(self):
134
+ if self._test_output_path and (os.path.exists(self._test_output_path)):
135
+ shutil.rmtree(self._test_output_path)
136
+ test_utils.set_gpu_id(None)
137
+
138
+ @RunAsSubprocess
139
+ def test_local_pred(self):
140
+ test_input_path = 'data/test/inference/taobao_infer_data.txt'
141
+ self._test_output_path = os.path.join(self._test_dir, 'taobao_infer_result')
142
+ saved_model_dir = 'data/test/inference/tb_multitower_export/'
143
+ pipeline_config_path = os.path.join(saved_model_dir,
144
+ 'assets/pipeline.config')
145
+ pipeline_config = config_util.get_configs_from_pipeline_file(
146
+ pipeline_config_path, False)
147
+ predictor = CSVPredictor(
148
+ saved_model_dir,
149
+ pipeline_config.data_config,
150
+ output_sep=';',
151
+ selected_cols='')
152
+
153
+ predictor.predict_impl(
154
+ test_input_path,
155
+ self._test_output_path,
156
+ reserved_cols='ALL_COLUMNS',
157
+ output_cols='ALL_COLUMNS',
158
+ slice_id=0,
159
+ slice_num=1)
160
+ header_truth = 'logits;probs;clk;buy;pid;adgroup_id;cate_id;campaign_id;customer;'\
161
+ 'brand;user_id;cms_segid;cms_group_id;final_gender_code;age_level;pvalue_level;' \
162
+ 'shopping_level;occupation;new_user_class_level;tag_category_list;tag_brand_list;price'
163
+
164
+ with open(self._test_output_path + '/part-0.csv', 'r') as f:
165
+ output_res = f.readlines()
166
+ self.assertTrue(len(output_res) == 101)
167
+ self.assertEqual(output_res[0].strip(), header_truth)
168
+
169
+ @RunAsSubprocess
170
+ def test_local_pred_with_header(self):
171
+ test_input_path = 'data/test/inference/taobao_infer_data_with_header.txt'
172
+ self._test_output_path = os.path.join(self._test_dir, 'taobao_infer_result')
173
+ saved_model_dir = 'data/test/inference/tb_multitower_export/'
174
+ pipeline_config_path = os.path.join(saved_model_dir,
175
+ 'assets/pipeline.config')
176
+ pipeline_config = config_util.get_configs_from_pipeline_file(
177
+ pipeline_config_path, False)
178
+ pipeline_config.data_config.with_header = True
179
+
180
+ predictor = CSVPredictor(
181
+ saved_model_dir,
182
+ pipeline_config.data_config,
183
+ with_header=True,
184
+ output_sep=';',
185
+ selected_cols='')
186
+
187
+ predictor.predict_impl(
188
+ test_input_path,
189
+ self._test_output_path,
190
+ reserved_cols='ALL_COLUMNS',
191
+ output_cols='ALL_COLUMNS',
192
+ slice_id=0,
193
+ slice_num=1)
194
+ header_truth = 'logits;probs;clk;buy;pid;adgroup_id;cate_id;campaign_id;customer;'\
195
+ 'brand;user_id;cms_segid;cms_group_id;final_gender_code;age_level;pvalue_level;' \
196
+ 'shopping_level;occupation;new_user_class_level;tag_category_list;tag_brand_list;price'
197
+
198
+ with open(self._test_output_path + '/part-0.csv', 'r') as f:
199
+ output_res = f.readlines()
200
+ self.assertTrue(len(output_res) == 101)
201
+ self.assertEqual(output_res[0].strip(), header_truth)
202
+
203
+ @RunAsSubprocess
204
+ def test_local_pred_without_config(self):
205
+ test_input_path = 'data/test/inference/taobao_infer_data.txt'
206
+ self._test_output_path = os.path.join(self._test_dir, 'taobao_infer_result')
207
+ saved_model_dir = 'data/test/inference/tb_multitower_export/'
208
+ self._success = test_utils.test_single_predict(self._test_dir,
209
+ test_input_path,
210
+ self._test_output_path,
211
+ saved_model_dir)
212
+ self.assertTrue(self._success)
213
+ with open(self._test_output_path + '/part-0.csv', 'r') as f:
214
+ output_res = f.readlines()
215
+ self.assertTrue(len(output_res) == 101)
216
+
217
+ @RunAsSubprocess
218
+ def test_local_pred_with_part_col(self):
219
+ test_input_path = 'data/test/inference/taobao_infer_data.txt'
220
+ self._test_output_path = os.path.join(self._test_dir, 'taobao_infer_result')
221
+ saved_model_dir = 'data/test/inference/tb_multitower_export/'
222
+ pipeline_config_path = os.path.join(saved_model_dir,
223
+ 'assets/pipeline.config')
224
+ pipeline_config = config_util.get_configs_from_pipeline_file(
225
+ pipeline_config_path, False)
226
+
227
+ predictor = CSVPredictor(
228
+ saved_model_dir,
229
+ pipeline_config.data_config,
230
+ output_sep=';',
231
+ selected_cols='')
232
+
233
+ predictor.predict_impl(
234
+ test_input_path,
235
+ self._test_output_path,
236
+ reserved_cols='clk,buy,user_id,adgroup_id',
237
+ output_cols='probs',
238
+ slice_id=0,
239
+ slice_num=1)
240
+ header_truth = 'probs;clk;buy;user_id;adgroup_id'
241
+
242
+ with open(self._test_output_path + '/part-0.csv', 'r') as f:
243
+ output_res = f.readlines()
244
+ self.assertTrue(len(output_res) == 101)
245
+ self.assertEqual(output_res[0].strip(), header_truth)
246
+
247
+ @RunAsSubprocess
248
+ def test_local_pred_rtp(self):
249
+ test_input_path = 'data/test/inference/taobao_infer_rtp_data.txt'
250
+ self._test_output_path = os.path.join(self._test_dir,
251
+ 'taobao_test_feature_result')
252
+ saved_model_dir = 'data/test/inference/tb_multitower_rtp_export/'
253
+ pipeline_config_path = os.path.join(saved_model_dir,
254
+ 'assets/pipeline.config')
255
+ pipeline_config = config_util.get_configs_from_pipeline_file(
256
+ pipeline_config_path, False)
257
+
258
+ predictor = CSVPredictor(
259
+ saved_model_dir,
260
+ pipeline_config.data_config,
261
+ output_sep=';',
262
+ selected_cols='0,3')
263
+ predictor.predict_impl(
264
+ test_input_path,
265
+ self._test_output_path,
266
+ reserved_cols='ALL_COLUMNS',
267
+ output_cols='ALL_COLUMNS',
268
+ slice_id=0,
269
+ slice_num=1)
270
+ header_truth = 'logits;probs;clk;no_used_1;no_used_2;features'
271
+ with open(self._test_output_path + '/part-0.csv', 'r') as f:
272
+ output_res = f.readlines()
273
+ self.assertTrue(len(output_res) == 101)
274
+ self.assertEqual(output_res[0].strip(), header_truth)
275
+
276
+ @RunAsSubprocess
277
+ def test_local_pred_rtp_with_part_col(self):
278
+ test_input_path = 'data/test/inference/taobao_infer_rtp_data.txt'
279
+ self._test_output_path = os.path.join(self._test_dir,
280
+ 'taobao_test_feature_result')
281
+ saved_model_dir = 'data/test/inference/tb_multitower_rtp_export/'
282
+ pipeline_config_path = os.path.join(saved_model_dir,
283
+ 'assets/pipeline.config')
284
+ pipeline_config = config_util.get_configs_from_pipeline_file(
285
+ pipeline_config_path, False)
286
+
287
+ predictor = CSVPredictor(
288
+ saved_model_dir,
289
+ pipeline_config.data_config,
290
+ output_sep=';',
291
+ selected_cols='0,3')
292
+ predictor.predict_impl(
293
+ test_input_path,
294
+ self._test_output_path,
295
+ reserved_cols='clk,features,no_used_1',
296
+ output_cols='ALL_COLUMNS',
297
+ slice_id=0,
298
+ slice_num=1)
299
+ header_truth = 'logits;probs;clk;features;no_used_1'
300
+ with open(self._test_output_path + '/part-0.csv', 'r') as f:
301
+ output_res = f.readlines()
302
+ self.assertTrue(len(output_res) == 101)
303
+ self.assertEqual(output_res[0].strip(), header_truth)
304
+
305
+ @RunAsSubprocess
306
+ def test_local_pred_embedding(self):
307
+ test_input_path = 'data/test/inference/taobao_item_feature_data.csv'
308
+ self._test_output_path = os.path.join(self._test_dir, 'taobao_item_feature')
309
+ saved_model_dir = 'data/test/inference/dssm_item_model/'
310
+ pipeline_config_path = os.path.join(saved_model_dir,
311
+ 'assets/pipeline.config')
312
+ pipeline_config = config_util.get_configs_from_pipeline_file(
313
+ pipeline_config_path, False)
314
+ predictor = CSVPredictor(
315
+ saved_model_dir,
316
+ pipeline_config.data_config,
317
+ ds_vector_recall=True,
318
+ output_sep=';',
319
+ selected_cols='pid,adgroup_id,cate_id,campaign_id,customer,brand,price')
320
+
321
+ predictor.predict_impl(
322
+ test_input_path,
323
+ self._test_output_path,
324
+ reserved_cols='adgroup_id',
325
+ output_cols='item_emb',
326
+ slice_id=0,
327
+ slice_num=1)
328
+
329
+ with open(self._test_output_path + '/part-0.csv', 'r') as f:
330
+ output_res = f.readlines()
331
+ self.assertTrue(
332
+ output_res[1] ==
333
+ '-0.187066,-0.027638,-0.117294,0.115318,-0.273561,0.035698,-0.055832,'
334
+ '0.226849,-0.105808,-0.152751,0.081528,-0.183329,0.134619,0.185392,'
335
+ '0.096774,0.104428,0.161868,0.269710,-0.268538,0.138760,-0.170105,'
336
+ '0.232625,-0.121130,0.198466,-0.078941,0.017774,0.268834,-0.238553,0.084058,'
337
+ '-0.269466,-0.289651,0.179517;620392\n')
338
+
339
+
340
+ class PredictorTestV2(tf.test.TestCase):
341
+
342
+ def setUp(self):
343
+ self.gpus = test_utils.get_available_gpus()
344
+ self.assertTrue(len(self.gpus) > 0, 'no available gpu on this machine')
345
+ logging.info('available gpus %s' % self.gpus)
346
+ test_utils.set_gpu_id(self.gpus[0])
347
+ logging.info('Testing %s.%s' % (type(self).__name__, self._testMethodName))
348
+
349
+ def tearDown(self):
350
+ test_utils.set_gpu_id(None)
351
+
352
+ @RunAsSubprocess
353
+ def test_pred_multi(self):
354
+ predictor = Predictor('data/test/inference/fg_export_multi')
355
+ test_path = 'data/test/rtp/taobao_test_feature.txt'
356
+ with open(test_path, 'r') as fin:
357
+ inputs = []
358
+ for line_str in fin:
359
+ line_str = line_str.strip()
360
+ line_toks = line_str.split(';')
361
+ feature = line_toks[-1]
362
+ feature = feature.split('\002')
363
+ inputs.append(feature)
364
+ output_res = predictor.predict(inputs, batch_size=32)
365
+ self.assertTrue(len(output_res) == 10000)
366
+ with open('data/test/rtp/taobao_fg_pred.out', 'r') as fin:
367
+ for line_id, line_str in enumerate(fin):
368
+ line_str = line_str.strip()
369
+ line_pred = json.loads(line_str)
370
+ self.assertTrue(
371
+ np.abs(line_pred['probs'] - output_res[line_id]['probs']) < 5e-6)
372
+
373
+ @RunAsSubprocess
374
+ def test_pred_single(self):
375
+ predictor = Predictor('data/test/inference/fg_export_single')
376
+ test_path = 'data/test/rtp/taobao_test_feature.txt'
377
+ with open(test_path, 'r') as fin:
378
+ inputs = []
379
+ for line_str in fin:
380
+ line_str = line_str.strip()
381
+ line_toks = line_str.split(';')
382
+ feature = line_toks[-1]
383
+ inputs.append(feature)
384
+ output_res = predictor.predict(inputs, batch_size=32)
385
+ with open('data/test/rtp/taobao_fg_pred.out', 'r') as fin:
386
+ for line_id, line_str in enumerate(fin):
387
+ line_str = line_str.strip()
388
+ line_pred = json.loads(line_str)
389
+ self.assertTrue(
390
+ np.abs(line_pred['probs'] - output_res[line_id]['probs']) < 5e-5)
391
+
392
+
393
+ if __name__ == '__main__':
394
+ tf.test.main()