easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,223 @@
1
+ # Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # =============================================================================
15
+ # """Evaluation of Top k hitrate."""
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+ from __future__ import print_function
19
+
20
+ import json
21
+ import logging
22
+ import os
23
+ import sys
24
+
25
+ import graphlearn as gl
26
+ import tensorflow as tf
27
+
28
+ from easy_rec.python.protos.dataset_pb2 import DatasetConfig
29
+ from easy_rec.python.utils import config_util
30
+ from easy_rec.python.utils import io_util
31
+ from easy_rec.python.utils.config_util import process_multi_file_input_path
32
+ from easy_rec.python.utils.hit_rate_utils import compute_hitrate_batch
33
+ from easy_rec.python.utils.hit_rate_utils import load_graph
34
+ from easy_rec.python.utils.hit_rate_utils import reduce_hitrate
35
+ from easy_rec.python.utils.hive_utils import HiveUtils
36
+
37
+ if tf.__version__ >= '2.0':
38
+ tf = tf.compat.v1
39
+
40
+ from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_train_worker_num_on_ds # NOQA
41
+
42
+ logging.basicConfig(
43
+ format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
44
+ level=logging.INFO)
45
+
46
+ tf.app.flags.DEFINE_string('item_emb_table', '', 'item embedding table name')
47
+ tf.app.flags.DEFINE_string('gt_table', '', 'ground truth table name')
48
+ tf.app.flags.DEFINE_string('hitrate_details_result', '',
49
+ 'hitrate detail file path')
50
+ tf.app.flags.DEFINE_string('total_hitrate_result', '',
51
+ 'total hitrate result file path')
52
+
53
+ tf.app.flags.DEFINE_string('pipeline_config_path', '', 'pipeline config path')
54
+ tf.app.flags.DEFINE_integer('batch_size', 512, 'batch size')
55
+ tf.app.flags.DEFINE_integer('emb_dim', 128, 'embedding dimension')
56
+ tf.app.flags.DEFINE_string('recall_type', 'i2i', 'i2i or u2i')
57
+ tf.app.flags.DEFINE_integer('top_k', '5', 'top_k hitrate.')
58
+ tf.app.flags.DEFINE_integer('knn_metric', '0', '0(l2) or 1(ip).')
59
+ tf.app.flags.DEFINE_bool('knn_strict', False, 'use exact search.')
60
+ tf.app.flags.DEFINE_integer('timeout', '60', 'timeout')
61
+ tf.app.flags.DEFINE_integer('num_interests', 1, 'max number of interests')
62
+ tf.app.flags.DEFINE_string('gt_table_field_sep', '\t', 'gt_table_field_sep')
63
+ tf.app.flags.DEFINE_string('item_emb_table_field_sep', '\t',
64
+ 'item_emb_table_field_sep')
65
+ tf.app.flags.DEFINE_bool('is_on_ds', False, help='is on ds')
66
+
67
+ FLAGS = tf.app.flags.FLAGS
68
+
69
+
70
+ def compute_hitrate(g, gt_all, hitrate_writer, gt_table=None):
71
+ """Compute hitrate of each worker.
72
+
73
+ Args:
74
+ g: a GL Graph instance.
75
+ gt_reader: reader of input trigger_items_table.
76
+ hitrate_writer: writer of hitrate table.
77
+ gt_table: ground truth table.
78
+
79
+ Returns:
80
+ total_hits: total hits of this worker.
81
+ total_gt_count: total count of ground truth items of this worker.
82
+ """
83
+ total_hits = 0.0
84
+ total_gt_count = 0.0
85
+
86
+ for gt_record in gt_all:
87
+ gt_record = list(gt_record)
88
+ hits, gt_count, src_ids, recall_ids, recall_distances, hitrates, bad_cases, bad_dists = \
89
+ compute_hitrate_batch(g, gt_record, FLAGS.emb_dim, FLAGS.num_interests, FLAGS.top_k)
90
+ total_hits += hits
91
+ total_gt_count += gt_count
92
+
93
+ src_ids = [str(ids) for ids in src_ids]
94
+ hitrates = [str(hitrate) for hitrate in hitrates]
95
+ topk_recalls = [','.join(str(x) for x in ids) for ids in recall_ids]
96
+ topk_dists = [
97
+ ','.join('|'.join(str(x)
98
+ for x in dist)
99
+ for dist in dists)
100
+ for dists in recall_distances
101
+ ]
102
+ bad_cases = [','.join(str(x) for x in bad_case) for bad_case in bad_cases]
103
+ bad_dists = [','.join(str(x) for x in dist) for dist in bad_dists]
104
+
105
+ hitrate_writer.write('\n'.join([
106
+ '\t'.join(line) for line in zip(src_ids, topk_recalls, topk_dists,
107
+ hitrates, bad_cases, bad_dists)
108
+ ]))
109
+ print('total_hits: ', total_hits)
110
+ print('total_gt_count: ', total_gt_count)
111
+ return total_hits, total_gt_count
112
+
113
+
114
+ def gt_hdfs(gt_table, batch_size, gt_file_sep):
115
+
116
+ if '*' in gt_table or ',' in gt_table:
117
+ file_paths = tf.gfile.Glob(gt_table.split(','))
118
+ elif tf.gfile.IsDirectory(gt_table):
119
+ file_paths = tf.gfile.Glob(os.path.join(gt_table, '*'))
120
+ else:
121
+ file_paths = tf.gfile.Glob(gt_table)
122
+
123
+ batch_list, i = [], 0
124
+ for file_path in file_paths:
125
+ with tf.gfile.GFile(file_path, 'r') as fin:
126
+ for gt in fin:
127
+ i += 1
128
+ gt_list = gt.strip().split(gt_file_sep)
129
+ # make id , emb_num to int
130
+ gt_list[0], gt_list[3] = int(gt_list[0]), int(gt_list[3])
131
+ batch_list.append(tuple(i for i in gt_list))
132
+ if i >= batch_size:
133
+ yield batch_list
134
+ batch_list, i = [], 0
135
+ if i != 0:
136
+ yield batch_list
137
+
138
+
139
+ def main():
140
+ tf_config = json.loads(os.environ['TF_CONFIG'])
141
+ worker_count = len(tf_config['cluster']['worker'])
142
+ task_index = tf_config['task']['index']
143
+ job_name = tf_config['task']['type']
144
+
145
+ hitrate_details_result = FLAGS.hitrate_details_result
146
+ total_hitrate_result = FLAGS.total_hitrate_result
147
+ i_emb_table = FLAGS.item_emb_table
148
+ gt_table = FLAGS.gt_table
149
+
150
+ pipeline_config = config_util.get_configs_from_pipeline_file(
151
+ FLAGS.pipeline_config_path)
152
+ logging.info('i_emb_table %s', i_emb_table)
153
+
154
+ input_type = pipeline_config.data_config.input_type
155
+ input_type_name = DatasetConfig.InputType.Name(input_type)
156
+ if input_type_name == 'CSVInput':
157
+ i_emb_table = process_multi_file_input_path(i_emb_table)
158
+ else:
159
+ hive_utils = HiveUtils(
160
+ data_config=pipeline_config.data_config,
161
+ hive_config=pipeline_config.hive_train_input)
162
+ i_emb_table = hive_utils.get_table_location(i_emb_table)
163
+
164
+ g = load_graph(i_emb_table, FLAGS.emb_dim, FLAGS.knn_metric, FLAGS.timeout,
165
+ FLAGS.knn_strict)
166
+ gl.set_tracker_mode(0)
167
+ gl.set_field_delimiter(FLAGS.item_emb_table_field_sep)
168
+
169
+ cluster = tf.train.ClusterSpec({
170
+ 'ps': tf_config['cluster']['ps'],
171
+ 'worker': tf_config['cluster']['worker']
172
+ })
173
+ server = tf.train.Server(cluster, job_name=job_name, task_index=task_index)
174
+
175
+ if job_name == 'ps':
176
+ server.join()
177
+ else:
178
+ worker_hosts = [
179
+ str(host.split(':')[0]) + ':888' + str(i)
180
+ for i, host in enumerate(tf_config['cluster']['worker'])
181
+ ]
182
+ worker_hosts = ','.join(worker_hosts)
183
+ g.init(task_index=task_index, task_count=worker_count, hosts=worker_hosts)
184
+ # Your model, use g to do some operation, such as sampling
185
+
186
+ if input_type_name == 'CSVInput':
187
+ gt_all = gt_hdfs(gt_table, FLAGS.batch_size, FLAGS.gt_table_field_sep)
188
+ else:
189
+ gt_reader = HiveUtils(
190
+ data_config=pipeline_config.data_config,
191
+ hive_config=pipeline_config.hive_train_input,
192
+ selected_cols='*')
193
+ gt_all = gt_reader.hive_read_lines(gt_table, FLAGS.batch_size)
194
+ if not tf.gfile.IsDirectory(hitrate_details_result):
195
+ tf.gfile.MakeDirs(hitrate_details_result)
196
+ hitrate_details_result = os.path.join(hitrate_details_result,
197
+ 'part-%s' % task_index)
198
+ details_writer = tf.gfile.GFile(hitrate_details_result, 'w')
199
+ print('Start compute hitrate...')
200
+ total_hits, total_gt_count = compute_hitrate(g, gt_all, details_writer,
201
+ gt_table)
202
+ var_total_hitrate, var_worker_count = reduce_hitrate(
203
+ cluster, total_hits, total_gt_count, task_index)
204
+
205
+ with tf.train.MonitoredTrainingSession(
206
+ master=server.target, is_chief=(task_index == 0)) as sess:
207
+ outs = sess.run([var_total_hitrate, var_worker_count])
208
+
209
+ # write after all workers have completed the calculation of hitrate.
210
+ print('outs: ', outs)
211
+ if outs[1] == worker_count:
212
+ logging.info(outs)
213
+ with tf.gfile.GFile(total_hitrate_result, 'w') as total_writer:
214
+ total_writer.write(str(outs[0]))
215
+
216
+ details_writer.close()
217
+ g.close()
218
+ print('Compute hitrate done.')
219
+
220
+
221
+ if __name__ == '__main__':
222
+ sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
223
+ main()
@@ -0,0 +1,138 @@
1
+ # Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # =============================================================================
15
+ """Evaluation of Top k hitrate."""
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+ from __future__ import print_function
19
+
20
+ import sys
21
+
22
+ import tensorflow as tf
23
+
24
+ from easy_rec.python.utils import io_util
25
+ from easy_rec.python.utils.hit_rate_utils import compute_hitrate_batch
26
+ from easy_rec.python.utils.hit_rate_utils import load_graph
27
+ from easy_rec.python.utils.hit_rate_utils import reduce_hitrate
28
+
29
+ flags = tf.app.flags
30
+ FLAGS = flags.FLAGS
31
+ flags.DEFINE_integer('task_index', None, 'Task index')
32
+ flags.DEFINE_integer('task_count', None, 'Task count')
33
+ flags.DEFINE_string('job_name', None, 'worker or ps or aligraph')
34
+ flags.DEFINE_string('ps_hosts', '', 'ps hosts')
35
+ flags.DEFINE_string('worker_hosts', '', 'worker hosts')
36
+ flags.DEFINE_string('tables', '', 'input odps tables name')
37
+ flags.DEFINE_string('outputs', '', 'ouput odps tables name')
38
+ flags.DEFINE_integer('batch_size', 512, 'batch size')
39
+ flags.DEFINE_integer('emb_dim', 128, 'embedding dimension')
40
+ flags.DEFINE_string('recall_type', 'i2i', 'i2i or u2i')
41
+ flags.DEFINE_integer('top_k', '5', 'top_k hitrate.')
42
+ flags.DEFINE_integer('knn_metric', '0', '0(l2) or 1(ip).')
43
+ flags.DEFINE_bool('knn_strict', False, 'use exact search.')
44
+ flags.DEFINE_integer('timeout', '60', 'timeout')
45
+ flags.DEFINE_integer('num_interests', 1, 'max number of interests')
46
+
47
+
48
+ def compute_hitrate(g, gt_reader, hitrate_writer):
49
+ """Compute hitrate of each worker.
50
+
51
+ Args:
52
+ g: a GL Graph instance.
53
+ gt_reader: odps reader of input trigger_items_table.
54
+ hitrate_writer: odps writer of hitrate table.
55
+
56
+ Returns:
57
+ total_hits: total hits of this worker.
58
+ total_gt_count: total count of ground truth items of this worker.
59
+ """
60
+ total_hits = 0.0
61
+ total_gt_count = 0.0
62
+ while True:
63
+ try:
64
+ gt_record = gt_reader.read(FLAGS.batch_size)
65
+ hits, gt_count, src_ids, recall_ids, recall_distances, hitrates, bad_cases, bad_dists = \
66
+ compute_hitrate_batch(g, gt_record, FLAGS.emb_dim, FLAGS.num_interests, FLAGS.top_k)
67
+ total_hits += hits
68
+ total_gt_count += gt_count
69
+ topk_recalls = [','.join(str(x) for x in ids) for ids in recall_ids]
70
+ topk_dists = [
71
+ ','.join(str(x) for x in dists) for dists in recall_distances
72
+ ]
73
+ bad_cases = [','.join(str(x) for x in case) for case in bad_cases]
74
+ bad_dists = [','.join(str(x) for x in dist) for dist in bad_dists]
75
+
76
+ hitrate_writer.write(
77
+ list(
78
+ zip(src_ids, topk_recalls, topk_dists, hitrates, bad_cases,
79
+ bad_dists)),
80
+ indices=[0, 1, 2, 3, 4, 5])
81
+ except tf.python_io.OutOfRangeException:
82
+ break
83
+ return total_hits, total_gt_count
84
+
85
+
86
+ def main():
87
+ worker_count = len(FLAGS.worker_hosts.split(','))
88
+ input_tables = FLAGS.tables.split(',')
89
+ if FLAGS.recall_type == 'u2i':
90
+ i_emb_table, gt_table = input_tables
91
+ g = load_graph(i_emb_table, FLAGS.emb_dim, FLAGS.knn_metric, FLAGS.timeout,
92
+ FLAGS.knn_strict)
93
+ else:
94
+ i_emb_table, gt_table = input_tables[-2], input_tables[-1]
95
+ g = load_graph(i_emb_table, FLAGS.emb_dim, FLAGS.knn_metric, FLAGS.timeout,
96
+ FLAGS.knn_strict)
97
+ hitrate_details_table, total_hitrate_table = FLAGS.outputs.split(',')
98
+
99
+ cluster = tf.train.ClusterSpec({
100
+ 'ps': FLAGS.ps_hosts.split(','),
101
+ 'worker': FLAGS.worker_hosts.split(',')
102
+ })
103
+ server = tf.train.Server(
104
+ cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
105
+ if FLAGS.job_name == 'ps':
106
+ server.join()
107
+ else:
108
+ g.init(task_index=FLAGS.task_index, task_count=worker_count)
109
+ gt_reader = tf.python_io.TableReader(
110
+ gt_table,
111
+ slice_id=FLAGS.task_index,
112
+ slice_count=worker_count,
113
+ capacity=2048)
114
+ details_writer = tf.python_io.TableWriter(
115
+ hitrate_details_table, slice_id=FLAGS.task_index)
116
+ print('Start compute hitrate...')
117
+ total_hits, total_gt_count = compute_hitrate(g, gt_reader, details_writer)
118
+ var_total_hitrate, var_worker_count = reduce_hitrate(
119
+ cluster, total_hits, total_gt_count, FLAGS.task_index)
120
+
121
+ with tf.train.MonitoredTrainingSession(
122
+ master=server.target, is_chief=(FLAGS.task_index == 0)) as sess:
123
+ outs = sess.run([var_total_hitrate, var_worker_count])
124
+
125
+ # write after all workers have completed the calculation of hitrate.
126
+ if outs[1] == worker_count:
127
+ with tf.python_io.TableWriter(total_hitrate_table) as total_writer:
128
+ total_writer.write([outs[0]], indices=[0])
129
+
130
+ gt_reader.close()
131
+ details_writer.close()
132
+ g.close()
133
+ print('Compute hitrate done.')
134
+
135
+
136
+ if __name__ == '__main__':
137
+ sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
138
+ main()
@@ -0,0 +1,120 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import json
4
+ import logging
5
+ import os
6
+ import sys
7
+
8
+ import tensorflow as tf
9
+
10
+ from easy_rec.python.input.input import Input
11
+ from easy_rec.python.utils import config_util
12
+ from easy_rec.python.utils import fg_util
13
+ from easy_rec.python.utils import io_util
14
+ from easy_rec.python.utils.check_utils import check_env_and_input_path
15
+ from easy_rec.python.utils.check_utils import check_sequence
16
+
17
+ if tf.__version__ >= '2.0':
18
+ tf = tf.compat.v1
19
+
20
+ logging.basicConfig(
21
+ format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
22
+ level=logging.INFO)
23
+ tf.app.flags.DEFINE_string('pipeline_config_path', None,
24
+ 'Path to pipeline config '
25
+ 'file.')
26
+ tf.app.flags.DEFINE_multi_string(
27
+ 'data_input_path', None, help='data input path')
28
+
29
+ FLAGS = tf.app.flags.FLAGS
30
+
31
+
32
+ def _get_input_fn(data_config,
33
+ feature_configs,
34
+ data_path=None,
35
+ export_config=None):
36
+ """Build estimator input function.
37
+
38
+ Args:
39
+ data_config: dataset config
40
+ feature_configs: FeatureConfig
41
+ data_path: input_data_path
42
+ export_config: configuration for exporting models,
43
+ only used to build input_fn when exporting models
44
+
45
+ Returns:
46
+ subclass of Input
47
+ """
48
+ input_class_map = {y: x for x, y in data_config.InputType.items()}
49
+ input_cls_name = input_class_map[data_config.input_type]
50
+
51
+ input_class = Input.create_class(input_cls_name)
52
+ if 'TF_CONFIG' in os.environ:
53
+ tf_config = json.loads(os.environ['TF_CONFIG'])
54
+ worker_num = len(tf_config['cluster']['worker'])
55
+ task_index = tf_config['task']['index']
56
+ else:
57
+ worker_num = 1
58
+ task_index = 0
59
+
60
+ input_obj = input_class(
61
+ data_config,
62
+ feature_configs,
63
+ data_path,
64
+ task_index=task_index,
65
+ task_num=worker_num,
66
+ check_mode=True)
67
+ input_fn = input_obj.create_input(export_config)
68
+ return input_fn
69
+
70
+
71
+ def loda_pipeline_config(pipeline_config_path):
72
+ pipeline_config = config_util.get_configs_from_pipeline_file(
73
+ pipeline_config_path, False)
74
+ if pipeline_config.fg_json_path:
75
+ fg_util.load_fg_json_to_config(pipeline_config)
76
+ config_util.auto_expand_share_feature_configs(pipeline_config)
77
+ return pipeline_config
78
+
79
+
80
+ def run_check(pipeline_config, input_path):
81
+ logging.info('data_input_path: %s' % input_path)
82
+ check_env_and_input_path(pipeline_config, input_path)
83
+ feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
84
+ eval_input_fn = _get_input_fn(pipeline_config.data_config, feature_configs,
85
+ input_path)
86
+ eval_spec = tf.estimator.EvalSpec(
87
+ name='val',
88
+ input_fn=eval_input_fn,
89
+ steps=None,
90
+ throttle_secs=10,
91
+ exporters=[])
92
+ input_iter = eval_spec.input_fn(
93
+ mode=tf.estimator.ModeKeys.EVAL).make_one_shot_iterator()
94
+ with tf.Session() as sess:
95
+ try:
96
+ while (True):
97
+ input_feas, input_lbls = input_iter.get_next()
98
+ features = sess.run(input_feas)
99
+ check_sequence(pipeline_config, features)
100
+ except tf.errors.OutOfRangeError:
101
+ logging.info('pre-check finish...')
102
+
103
+
104
+ def main(argv):
105
+ assert FLAGS.pipeline_config_path, 'pipeline_config_path should not be empty when checking!'
106
+ pipeline_config = loda_pipeline_config(FLAGS.pipeline_config_path)
107
+
108
+ if FLAGS.data_input_path:
109
+ input_path = ','.join(FLAGS.data_input_path)
110
+ else:
111
+ assert pipeline_config.train_input_path or pipeline_config.eval_input_path, \
112
+ 'input_path should not be empty when checking!'
113
+ input_path = pipeline_config.train_input_path + ',' + pipeline_config.eval_input_path
114
+
115
+ run_check(pipeline_config, input_path)
116
+
117
+
118
+ if __name__ == '__main__':
119
+ sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
120
+ tf.app.run()
@@ -0,0 +1,111 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import argparse
4
+ import json
5
+ import logging
6
+ import os
7
+ import sys
8
+
9
+ import numpy as np
10
+
11
+ import easy_rec
12
+ from easy_rec.python.inference.predictor import Predictor
13
+
14
+ try:
15
+ import tensorflow as tf
16
+ tf.load_op_library(os.path.join(easy_rec.ops_dir, 'libembed_op.so'))
17
+ except Exception as ex:
18
+ logging.warning('exception: %s' % str(ex))
19
+
20
+ logging.basicConfig(
21
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
22
+
23
+ if __name__ == '__main__':
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument(
26
+ '--saved_model_dir', type=str, default=None, help='saved model directory')
27
+ parser.add_argument(
28
+ '--input_path', type=str, default=None, help='input feature path')
29
+ parser.add_argument('--save_path', type=str, default=None, help='save path')
30
+ parser.add_argument(
31
+ '--cmp_res_path', type=str, default=None, help='compare result path')
32
+ parser.add_argument(
33
+ '--cmp_key', type=str, default='probs', help='compare key')
34
+ parser.add_argument(
35
+ '--rtp_fea_id',
36
+ type=int,
37
+ default=-1,
38
+ help='rtp feature column index, default to the last column')
39
+ parser.add_argument('--tol', type=float, default=1e-5, help='tolerance')
40
+ parser.add_argument(
41
+ '--label_id',
42
+ nargs='*',
43
+ type=int,
44
+ help='the label column, which is to be excluded')
45
+ parser.add_argument(
46
+ '--separator',
47
+ type=str,
48
+ default='',
49
+ help='separator between features, default to \\u0002')
50
+ parser.add_argument(
51
+ '--rtp_separator',
52
+ type=str,
53
+ default='',
54
+ help='separator, default to \\u0001')
55
+ args = parser.parse_args()
56
+
57
+ if not args.saved_model_dir:
58
+ logging.error('saved_model_dir is not set')
59
+ sys.exit(1)
60
+
61
+ if not args.input_path:
62
+ logging.error('input_path is not set')
63
+ sys.exit(1)
64
+
65
+ if args.label_id is None:
66
+ args.label_id = []
67
+
68
+ logging.info('input_path: ' + args.input_path)
69
+ logging.info('save_path: ' + args.save_path)
70
+ logging.info('separator: ' + args.separator)
71
+
72
+ predictor = Predictor(args.saved_model_dir)
73
+ if len(predictor.input_names) == 1:
74
+ assert len(
75
+ args.label_id
76
+ ) == 0, 'label_id should not be set if rtp feature format is used.'
77
+
78
+ with open(args.input_path, 'r') as fin:
79
+ batch_input = []
80
+ for line_str in fin:
81
+ line_str = line_str.strip()
82
+ line_tok = line_str.split(args.rtp_separator)
83
+ feature = line_tok[args.rtp_fea_id]
84
+ feature = [
85
+ x for fid, x in enumerate(feature.split(args.separator))
86
+ if fid not in args.label_id
87
+ ]
88
+ if 'features' in predictor.input_names:
89
+ feature = args.separator.join(feature)
90
+ batch_input.append(feature)
91
+ output = predictor.predict(batch_input)
92
+
93
+ if args.save_path:
94
+ fout = open(args.save_path, 'w')
95
+ for one in output:
96
+ fout.write(str(one) + '\n')
97
+ fout.close()
98
+
99
+ if args.cmp_res_path:
100
+ logging.info('compare result path: ' + args.cmp_res_path)
101
+ logging.info('compare key: ' + args.cmp_key)
102
+ logging.info('tolerance: ' + str(args.tol))
103
+ with open(args.cmp_res_path, 'r') as fin:
104
+ for line_id, line_str in enumerate(fin):
105
+ line_str = line_str.strip()
106
+ line_pred = json.loads(line_str)
107
+ assert np.abs(
108
+ line_pred[args.cmp_key] -
109
+ output[line_id][args.cmp_key]) < args.tol, 'line[%d]: %.8f' % (
110
+ line_id,
111
+ np.abs(line_pred[args.cmp_key] - output[line_id][args.cmp_key]))
@@ -0,0 +1,55 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import argparse
4
+ import logging
5
+ import os
6
+ import sys
7
+
8
+ from kafka import KafkaConsumer
9
+ from kafka.structs import TopicPartition
10
+
11
+ logging.basicConfig(
12
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
13
+
14
+ if __name__ == '__main__':
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--servers', type=str, default='localhost:9092')
17
+ parser.add_argument('--topic', type=str, default=None)
18
+ parser.add_argument('--group', type=str, default='consumer')
19
+ parser.add_argument('--partitions', type=str, default=None)
20
+ parser.add_argument('--timeout', type=float, default=float('inf'))
21
+ parser.add_argument('--save_dir', type=str, default=None)
22
+ args = parser.parse_args()
23
+
24
+ if args.topic is None:
25
+ logging.error('--topic is not set')
26
+ sys.exit(1)
27
+
28
+ servers = args.servers.split(',')
29
+ consumer = KafkaConsumer(
30
+ group_id=args.group,
31
+ bootstrap_servers=servers,
32
+ consumer_timeout_ms=args.timeout * 1000)
33
+
34
+ if args.partitions is not None:
35
+ partitions = [int(x) for x in args.partitions.split(',')]
36
+ else:
37
+ partitions = consumer.partitions_for_topic(args.topic)
38
+ logging.info('partitions: %s' % partitions)
39
+
40
+ topics = [
41
+ TopicPartition(topic=args.topic, partition=part_id)
42
+ for part_id in partitions
43
+ ]
44
+ consumer.assign(topics)
45
+ consumer.seek_to_beginning()
46
+
47
+ record_id = 0
48
+ for x in consumer:
49
+ logging.info('%d: key=%s\toffset=%d\ttimestamp=%d\tlen=%d' %
50
+ (record_id, x.key, x.offset, x.timestamp, len(x.value)))
51
+ if args.save_dir is not None:
52
+ save_path = os.path.join(args.save_dir, x.key)
53
+ with open(save_path, 'wb') as fout:
54
+ fout.write(x.value)
55
+ record_id += 1