easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,170 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import argparse
4
+ import ctypes
5
+ import glob
6
+ import json
7
+ import logging
8
+ import os
9
+ import subprocess
10
+ import time
11
+
12
+ import numpy as np
13
+ from google.protobuf import text_format
14
+
15
+ from easy_rec.python.protos import dataset_pb2
16
+ from easy_rec.python.protos import pipeline_pb2
17
+ from easy_rec.python.protos import tf_predict_pb2
18
+
19
+ logging.basicConfig(
20
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
21
+
22
+ PROCESSOR_VERSION = 'LaRec-0.9.5d-b1b1604-TF-2.5.0-Linux'
23
+ PROCESSOR_FILE = PROCESSOR_VERSION + '.tar.gz'
24
+ PROCESSOR_URL = 'http://easyrec.oss-cn-beijing.aliyuncs.com/processor/' + PROCESSOR_FILE
25
+ PROCESSOR_ENTRY_LIB = 'processor/' + PROCESSOR_VERSION + '/larec/libtf_predictor.so'
26
+
27
+
28
+ def build_array_proto(array_proto, data, dtype):
29
+ array_proto.array_shape.dim.append(len(data))
30
+
31
+ if dtype == dataset_pb2.DatasetConfig.STRING:
32
+ array_proto.string_val.extend([x.encode('utf-8') for x in data])
33
+ array_proto.dtype = tf_predict_pb2.DT_STRING
34
+ elif dtype == dataset_pb2.DatasetConfig.FLOAT:
35
+ array_proto.float_val.extend([float(x) for x in data])
36
+ array_proto.dtype = tf_predict_pb2.DT_FLOAT
37
+ elif dtype == dataset_pb2.DatasetConfig.DOUBLE:
38
+ array_proto.double_val.extend([float(x) for x in data])
39
+ array_proto.dtype = tf_predict_pb2.DT_DOUBLE
40
+ elif dtype == dataset_pb2.DatasetConfig.INT32:
41
+ array_proto.int_val.extend([int(x) for x in data])
42
+ array_proto.dtype = tf_predict_pb2.DT_INT32
43
+ elif dtype == dataset_pb2.DatasetConfig.INT64:
44
+ array_proto.int64_val.extend([np.int64(x) for x in data])
45
+ array_proto.dtype = tf_predict_pb2.DT_INT64
46
+ else:
47
+ assert False, 'invalid datatype[%s]' % str(dtype)
48
+ return array_proto
49
+
50
+
51
+ if __name__ == '__main__':
52
+ parser = argparse.ArgumentParser()
53
+ parser.add_argument(
54
+ '--input_path', type=str, default=None, help='input data path')
55
+ parser.add_argument(
56
+ '--output_path', type=str, default=None, help='output data path')
57
+ parser.add_argument(
58
+ '--libc_path',
59
+ type=str,
60
+ default='/lib64/libc.so.6',
61
+ help='libc.so.6 path')
62
+ parser.add_argument(
63
+ '--saved_model_dir', type=str, default=None, help='saved model directory')
64
+ parser.add_argument(
65
+ '--test_dir', type=str, default=None, help='test directory')
66
+ args = parser.parse_args()
67
+
68
+ if not os.path.exists('processor'):
69
+ os.mkdir('processor')
70
+ if not os.path.exists(PROCESSOR_ENTRY_LIB):
71
+ if not os.path.exists('processor/' + PROCESSOR_FILE):
72
+ subprocess.check_output(
73
+ 'wget %s -O processor/%s' % (PROCESSOR_URL, PROCESSOR_FILE),
74
+ shell=True)
75
+ subprocess.check_output(
76
+ 'cd processor && tar -zvxf %s' % PROCESSOR_FILE, shell=True)
77
+ assert os.path.exists(
78
+ PROCESSOR_ENTRY_LIB), 'invalid processor path: %s' % PROCESSOR_ENTRY_LIB
79
+
80
+ assert os.path.exists(args.libc_path), '%s does not exist' % args.libc_path
81
+ assert args.saved_model_dir is not None and os.path.isdir(
82
+ args.saved_model_dir
83
+ ), '%s is not a valid directory' % args.saved_model_dir
84
+ assert args.input_path is not None and os.path.exists(
85
+ args.input_path), '%s does not exist' % args.input_path
86
+ assert args.output_path is not None, 'output_path is not set'
87
+
88
+ pipeline_config = pipeline_pb2.EasyRecConfig()
89
+ pipeline_config_path = os.path.join(args.saved_model_dir,
90
+ 'assets/pipeline.config')
91
+ with open(pipeline_config_path) as fin:
92
+ config_str = fin.read()
93
+ text_format.Merge(config_str, pipeline_config)
94
+
95
+ data_config = pipeline_config.data_config
96
+
97
+ input_fields = [[]
98
+ for x in data_config.input_fields
99
+ if x.input_name not in data_config.label_fields]
100
+
101
+ with open(args.input_path, 'r') as fin:
102
+ for line_str in fin:
103
+ line_str = line_str.strip()
104
+ line_toks = line_str.split(data_config.rtp_separator)[-1].split(chr(2))
105
+ for i, tok in enumerate(line_toks):
106
+ input_fields[i].append(tok)
107
+
108
+ req = tf_predict_pb2.PredictRequest()
109
+ req.signature_name = 'serving_default'
110
+ for i in range(len(input_fields)):
111
+ build_array_proto(req.inputs[data_config.input_fields[i + 1].input_name],
112
+ input_fields[i],
113
+ data_config.input_fields[i + 1].input_type)
114
+
115
+ tf_predictor = ctypes.cdll.LoadLibrary(PROCESSOR_ENTRY_LIB)
116
+ tf_predictor.saved_model_init.restype = ctypes.c_void_p
117
+ handle = tf_predictor.saved_model_init(args.saved_model_dir.encode('utf-8'))
118
+ logging.info('saved_model handle=%d' % handle)
119
+
120
+ num_steps = pipeline_config.train_config.num_steps
121
+ logging.info('num_steps=%d' % num_steps)
122
+
123
+ # last_step could be greater than num_steps for sync_replicas: false
124
+ train_dir = os.path.dirname(args.saved_model_dir.strip('/'))
125
+ all_models = glob.glob(
126
+ os.path.join(args.test_dir, 'train/model.ckpt-*.index'))
127
+ iters = [int(x.split('-')[-1].replace('.index', '')) for x in all_models]
128
+ iters.sort()
129
+ last_step = iters[-1]
130
+ logging.info('last_step=%d' % last_step)
131
+
132
+ sparse_step = ctypes.c_int(0)
133
+ dense_step = ctypes.c_int(0)
134
+ start_ts = time.time()
135
+ while sparse_step.value < last_step or dense_step.value < last_step:
136
+ tf_predictor.saved_model_step(
137
+ ctypes.c_void_p(handle), ctypes.byref(sparse_step),
138
+ ctypes.byref(dense_step))
139
+ time.sleep(1)
140
+ if time.time() - start_ts > 300:
141
+ logging.warning(
142
+ 'could not reach last_step, sparse_step=%d dense_step=%d' %
143
+ (sparse_step.value, dense_step.value))
144
+ break
145
+
146
+ data_bin = req.SerializeToString()
147
+ save_path = os.path.join(args.saved_model_dir, 'req.pb')
148
+ with open(save_path, 'wb') as fout:
149
+ fout.write(data_bin)
150
+ logging.info('save request to %s' % save_path)
151
+
152
+ tf_predictor.saved_model_predict.restype = ctypes.c_void_p
153
+ out_len = ctypes.c_int(0)
154
+ res_p = tf_predictor.saved_model_predict(
155
+ ctypes.c_void_p(handle), data_bin, ctypes.c_int32(len(data_bin)),
156
+ ctypes.byref(out_len))
157
+ res_bytes = bytearray(ctypes.string_at(res_p, out_len))
158
+ res = tf_predict_pb2.PredictResponse()
159
+ res.ParseFromString(res_bytes)
160
+
161
+ with open(args.output_path, 'w') as fout:
162
+ logits = res.outputs['logits'].float_val
163
+ probs = res.outputs['probs'].float_val
164
+ for logit, prob in zip(logits, probs):
165
+ fout.write(json.dumps({'logits': logit, 'probs': prob}) + '\n')
166
+
167
+ # free memory
168
+ tf_predictor.saved_model_release(ctypes.c_void_p(handle))
169
+ libc = ctypes.cdll.LoadLibrary(args.libc_path)
170
+ libc.free(ctypes.c_void_p(res_p))
@@ -0,0 +1,124 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from __future__ import absolute_import
4
+ from __future__ import division
5
+ from __future__ import print_function
6
+
7
+ import logging
8
+ from datetime import datetime
9
+
10
+ import common_io
11
+ import numpy as np
12
+ import tensorflow as tf
13
+
14
+ try:
15
+ import graphlearn as gl
16
+ except: # noqa: E722
17
+ logging.warning(
18
+ 'GraphLearn is not installed. You can install it by "pip install https://easyrec.oss-cn-beijing.aliyuncs.com/3rdparty/graphlearn-0.7-cp27-cp27mu-linux_x86_64.whl.' # noqa: E501
19
+ )
20
+
21
+ if tf.__version__ >= '2.0':
22
+ tf = tf.compat.v1
23
+
24
+
25
+ class VectorRetrieve(object):
26
+
27
+ def __init__(self,
28
+ query_table,
29
+ doc_table,
30
+ out_table,
31
+ ndim,
32
+ delimiter=',',
33
+ batch_size=4,
34
+ index_type='ivfflat',
35
+ nlist=10,
36
+ nprobe=2,
37
+ distance=1,
38
+ m=8):
39
+ """Retrieve top n neighbours by query vector.
40
+
41
+ Args:
42
+ query_table: query vector table
43
+ doc_table: document vector table
44
+ out_table: output table
45
+ ndim: int, number of feature dimensions
46
+ delimiter: delimiter for feature vectors
47
+ batch_size: query batch size
48
+ index_type: search model `flat`, `ivfflat`, `ivfpq`, `gpu_ivfflat`
49
+ nlist: number of split part on each worker
50
+ nprobe: probe part on each worker
51
+ distance: type of distance, 0 is l2 distance(default), 1 is inner product.
52
+ m: number of dimensions for each node after compress
53
+ """
54
+ self.query_table = query_table
55
+ self.doc_table = doc_table
56
+ self.out_table = out_table
57
+ self.ndim = ndim
58
+ self.delimiter = delimiter
59
+ self.batch_size = batch_size
60
+
61
+ gl.set_inter_threadnum(8)
62
+ gl.set_knn_metric(distance)
63
+ knn_option = gl.IndexOption()
64
+ knn_option.name = 'knn'
65
+ knn_option.index_type = index_type
66
+ knn_option.nlist = nlist
67
+ knn_option.nprobe = nprobe
68
+ knn_option.m = m
69
+ self.knn_option = knn_option
70
+
71
+ def __call__(self, top_n, task_index, task_count, *args, **kwargs):
72
+ g = gl.Graph()
73
+ g.node(
74
+ self.doc_table,
75
+ 'doc',
76
+ decoder=gl.Decoder(
77
+ attr_types=['float'] * self.ndim, attr_delimiter=self.delimiter),
78
+ option=self.knn_option)
79
+ g.init(task_index=task_index, task_count=task_count)
80
+
81
+ query_reader = common_io.table.TableReader(
82
+ self.query_table, slice_id=task_index, slice_count=task_count)
83
+ num_records = query_reader.get_row_count()
84
+ total_batch_num = num_records // self.batch_size + 1.0
85
+ batch_num = 0
86
+ print('total input records: {}'.format(query_reader.get_row_count()))
87
+ print('total_batch_num: {}'.format(total_batch_num))
88
+ print('output_table: {}'.format(self.out_table))
89
+
90
+ output_table_writer = common_io.table.TableWriter(self.out_table,
91
+ task_index)
92
+ count = 0
93
+ while True:
94
+ try:
95
+ batch_query_nodes, batch_query_feats = zip(
96
+ *query_reader.read(self.batch_size, allow_smaller_final_batch=True))
97
+ batch_num += 1.0
98
+ print('{} process: {:.2f}'.format(datetime.now().time(),
99
+ batch_num / total_batch_num))
100
+ feats = to_np_array(batch_query_feats, self.delimiter)
101
+ rt_ids, rt_dists = g.search('doc', feats, gl.KnnOption(k=top_n))
102
+
103
+ for query_node, nodes, dists in zip(batch_query_nodes, rt_ids,
104
+ rt_dists):
105
+ query = np.array([query_node] * len(nodes), dtype='int64')
106
+ output_table_writer.write(
107
+ zip(query, nodes, dists), (0, 1, 2), allow_type_cast=False)
108
+ count += 1
109
+ if np.mod(count, 100) == 0:
110
+ print('write ', count, ' query nodes totally')
111
+ except Exception as e:
112
+ print(e)
113
+ break
114
+
115
+ print('==finished==')
116
+ output_table_writer.close()
117
+ query_reader.close()
118
+ g.close()
119
+
120
+
121
+ def to_np_array(batch_query_feats, attr_delimiter):
122
+ return np.array(
123
+ [map(float, feat.split(attr_delimiter)) for feat in batch_query_feats],
124
+ dtype='float32')
File without changes
@@ -0,0 +1,117 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.input.input import Input
8
+ from easy_rec.python.utils.tf_utils import get_tf_type
9
+
10
+ if tf.__version__ >= '2.0':
11
+ tf = tf.compat.v1
12
+
13
+
14
+ class BatchTFRecordInput(Input):
15
+ """BatchTFRecordInput use for batch read from tfrecord.
16
+
17
+ For example, there is a tfrecord which one feature(key)
18
+ correspond to n data(value).
19
+ batch_size needs to be a multiple of n.
20
+ """
21
+
22
+ def __init__(self,
23
+ data_config,
24
+ feature_config,
25
+ input_path,
26
+ task_index=0,
27
+ task_num=1,
28
+ check_mode=False,
29
+ pipeline_config=None):
30
+ super(BatchTFRecordInput,
31
+ self).__init__(data_config, feature_config, input_path, task_index,
32
+ task_num, check_mode, pipeline_config)
33
+ assert data_config.HasField(
34
+ 'n_data_batch_tfrecord'), 'Need to set n_data_batch_tfrecord in config.'
35
+ self._input_shapes = [x.input_shape for x in data_config.input_fields]
36
+ self.feature_desc = {}
37
+ for x, t, d, s in zip(self._input_fields, self._input_field_types,
38
+ self._input_field_defaults, self._input_shapes):
39
+ d = self.get_type_defaults(t, d)
40
+ t = get_tf_type(t)
41
+ self.feature_desc[x] = tf.io.FixedLenSequenceFeature(
42
+ dtype=t, shape=s, allow_missing=True)
43
+
44
+ def _parse_tfrecord(self, example):
45
+ try:
46
+ _, features, _ = tf.parse_sequence_example(
47
+ example, sequence_features=self.feature_desc)
48
+ except AttributeError:
49
+ _, features, _ = tf.io.parse_sequence_example(
50
+ example, sequence_features=self.feature_desc)
51
+ # Below code will reduce one dimension when the data dimension > 2.
52
+ features = dict(
53
+ (key,
54
+ tf.reshape(value, [
55
+ -1,
56
+ ] + [x for i, x in enumerate(value.shape) if i not in (0, 1)])) for (
57
+ key, value) in features.items())
58
+ return features
59
+
60
+ def _build(self, mode, params):
61
+ if type(self._input_path) != list:
62
+ self._input_path = self._input_path.split(',')
63
+ file_paths = []
64
+ for x in self._input_path:
65
+ file_paths.extend(tf.gfile.Glob(x))
66
+ assert len(file_paths) > 0, 'match no files with %s' % self._input_path
67
+
68
+ num_parallel_calls = self._data_config.num_parallel_calls
69
+ data_compression_type = self._data_config.data_compression_type
70
+ if mode == tf.estimator.ModeKeys.TRAIN:
71
+ logging.info('train files[%d]: %s' %
72
+ (len(file_paths), ','.join(file_paths)))
73
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
74
+ if self._data_config.shuffle:
75
+ # shuffle input files
76
+ dataset = dataset.shuffle(len(file_paths))
77
+ # too many readers read the same file will cause performance issues
78
+ # as the same data will be read multiple times
79
+ parallel_num = min(num_parallel_calls, len(file_paths))
80
+ dataset = dataset.interleave(
81
+ lambda x: tf.data.TFRecordDataset(
82
+ x, compression_type=data_compression_type),
83
+ cycle_length=parallel_num,
84
+ num_parallel_calls=parallel_num)
85
+ dataset = dataset.shard(self._task_num, self._task_index)
86
+ if self._data_config.shuffle:
87
+ dataset = dataset.shuffle(
88
+ self._data_config.shuffle_buffer_size,
89
+ seed=2020,
90
+ reshuffle_each_iteration=True)
91
+ dataset = dataset.repeat(self.num_epochs)
92
+ else:
93
+ logging.info('eval files[%d]: %s' %
94
+ (len(file_paths), ','.join(file_paths)))
95
+ dataset = tf.data.TFRecordDataset(
96
+ file_paths, compression_type=data_compression_type)
97
+ dataset = dataset.repeat(1)
98
+
99
+ # We read n data from tfrecord one time.
100
+ cur_batch = self._data_config.batch_size // self._data_config.n_data_batch_tfrecord
101
+ cur_batch = max(1, cur_batch)
102
+ dataset = dataset.batch(cur_batch)
103
+ dataset = dataset.map(
104
+ self._parse_tfrecord, num_parallel_calls=num_parallel_calls)
105
+
106
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
107
+ dataset = dataset.map(
108
+ map_func=self._preprocess, num_parallel_calls=num_parallel_calls)
109
+
110
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
111
+
112
+ if mode != tf.estimator.ModeKeys.PREDICT:
113
+ dataset = dataset.map(lambda x:
114
+ (self._get_features(x), self._get_labels(x)))
115
+ else:
116
+ dataset = dataset.map(lambda x: (self._get_features(x)))
117
+ return dataset
@@ -0,0 +1,259 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import argparse
4
+ import concurrent
5
+ import concurrent.futures
6
+ import glob
7
+ import logging
8
+ import os
9
+ import queue
10
+ import time
11
+
12
+ import numpy as np
13
+
14
+
15
+ class BinaryDataset:
16
+
17
+ def __init__(
18
+ self,
19
+ label_bins,
20
+ dense_bins,
21
+ category_bins,
22
+ batch_size=1,
23
+ drop_last=False,
24
+ prefetch=1,
25
+ global_rank=0,
26
+ global_size=1,
27
+ ):
28
+ total_sample_num = 0
29
+ self._sample_num_arr = []
30
+ for label_bin in label_bins:
31
+ sample_num = os.path.getsize(label_bin) // 4
32
+ total_sample_num += sample_num
33
+ self._sample_num_arr.append(sample_num)
34
+ logging.info('total number samples = %d' % total_sample_num)
35
+ self._total_sample_num = total_sample_num
36
+
37
+ self._batch_size = batch_size
38
+
39
+ self._compute_global_start_pos(total_sample_num, batch_size, global_rank,
40
+ global_size, drop_last)
41
+
42
+ self._label_file_arr = [None for _ in self._sample_num_arr]
43
+ self._dense_file_arr = [None for _ in self._sample_num_arr]
44
+ self._category_file_arr = [None for _ in self._sample_num_arr]
45
+
46
+ for tmp_file_id in range(self._start_file_id, self._end_file_id + 1):
47
+ self._label_file_arr[tmp_file_id] = os.open(label_bins[tmp_file_id],
48
+ os.O_RDONLY)
49
+ self._dense_file_arr[tmp_file_id] = os.open(dense_bins[tmp_file_id],
50
+ os.O_RDONLY)
51
+ self._category_file_arr[tmp_file_id] = os.open(category_bins[tmp_file_id],
52
+ os.O_RDONLY)
53
+
54
+ self._prefetch = min(prefetch, self._num_entries)
55
+ self._prefetch_queue = queue.Queue()
56
+ self._executor = concurrent.futures.ThreadPoolExecutor(
57
+ max_workers=self._prefetch)
58
+
59
+ self._os_close_func = os.close
60
+
61
+ def _compute_global_start_pos(self, total_sample_num, batch_size, global_rank,
62
+ global_size, drop_last):
63
+ # ensure all workers have the same number of samples
64
+ avg_sample_num = (total_sample_num // global_size)
65
+ res_num = (total_sample_num % global_size)
66
+ self._num_samples = avg_sample_num
67
+ if res_num > 0:
68
+ self._num_samples += 1
69
+ if global_rank < res_num:
70
+ global_start_pos = (avg_sample_num + 1) * global_rank
71
+ else:
72
+ global_start_pos = avg_sample_num * global_rank + res_num - 1
73
+ else:
74
+ global_start_pos = avg_sample_num * global_rank
75
+ # global_end_pos = global_start_pos + self._num_samples
76
+
77
+ self._num_entries = self._num_samples // batch_size
78
+ self._last_batch_size = batch_size
79
+ if not drop_last and (self._num_samples % batch_size != 0):
80
+ self._num_entries += 1
81
+ self._last_batch_size = self._num_samples % batch_size
82
+ logging.info('num_batches = %d num_samples = %d' %
83
+ (self._num_entries, self._num_samples))
84
+
85
+ start_file_id = 0
86
+ curr_pos = 0
87
+ while curr_pos + self._sample_num_arr[start_file_id] <= global_start_pos:
88
+ start_file_id += 1
89
+ curr_pos += self._sample_num_arr[start_file_id]
90
+ self._start_file_id = start_file_id
91
+ self._start_file_pos = global_start_pos - curr_pos
92
+
93
+ logging.info('start_file_id = %d start_file_pos = %d' %
94
+ (start_file_id, self._start_file_pos))
95
+
96
+ # find the start of each batch
97
+ self._start_pos_arr = np.zeros([self._num_entries, 2], dtype=np.uint32)
98
+ batch_id = 0
99
+ tmp_start_pos = self._start_file_pos
100
+ while batch_id < self._num_entries:
101
+ self._start_pos_arr[batch_id] = (start_file_id, tmp_start_pos)
102
+ batch_id += 1
103
+ # the last batch
104
+ if batch_id == self._num_entries:
105
+ tmp_start_pos += self._last_batch_size
106
+ while start_file_id < len(
107
+ self._sample_num_arr
108
+ ) and tmp_start_pos > self._sample_num_arr[start_file_id]:
109
+ tmp_start_pos -= self._sample_num_arr[start_file_id]
110
+ start_file_id += 1
111
+ else:
112
+ tmp_start_pos += batch_size
113
+ while start_file_id < len(
114
+ self._sample_num_arr
115
+ ) and tmp_start_pos >= self._sample_num_arr[start_file_id]:
116
+ tmp_start_pos -= self._sample_num_arr[start_file_id]
117
+ start_file_id += 1
118
+
119
+ self._end_file_id = start_file_id
120
+ self._end_file_pos = tmp_start_pos
121
+
122
+ logging.info('end_file_id = %d end_file_pos = %d' %
123
+ (self._end_file_id, self._end_file_pos))
124
+
125
+ def __del__(self):
126
+ for f in self._label_file_arr:
127
+ if f is not None:
128
+ self._os_close_func(f)
129
+ for f in self._dense_file_arr:
130
+ if f is not None:
131
+ self._os_close_func(f)
132
+ for f in self._category_file_arr:
133
+ if f is not None:
134
+ self._os_close_func(f)
135
+
136
+ def __len__(self):
137
+ return self._num_entries
138
+
139
+ def __getitem__(self, idx):
140
+ if idx >= self._num_entries:
141
+ raise IndexError()
142
+
143
+ if self._prefetch <= 1:
144
+ return self._get(idx)
145
+
146
+ if idx == 0:
147
+ for i in range(self._prefetch):
148
+ self._prefetch_queue.put(self._executor.submit(self._get, (i)))
149
+
150
+ if idx < (self._num_entries - self._prefetch):
151
+ self._prefetch_queue.put(
152
+ self._executor.submit(self._get, (idx + self._prefetch)))
153
+
154
+ return self._prefetch_queue.get().result()
155
+
156
+ def _get(self, idx):
157
+ curr_file_id = self._start_pos_arr[idx][0]
158
+ start_read_pos = self._start_pos_arr[idx][1]
159
+
160
+ end_read_pos = start_read_pos + self._batch_size
161
+ total_read_num = 0
162
+
163
+ label_read_arr = []
164
+ dense_read_arr = []
165
+ cate_read_arr = []
166
+ while total_read_num < self._batch_size and curr_file_id < len(
167
+ self._sample_num_arr):
168
+ tmp_read_num = min(end_read_pos,
169
+ self._sample_num_arr[curr_file_id]) - start_read_pos
170
+
171
+ label_raw_data = os.pread(self._label_file_arr[curr_file_id],
172
+ 4 * tmp_read_num, start_read_pos * 4)
173
+ tmp_lbl_np = np.frombuffer(
174
+ label_raw_data, dtype=np.int32).reshape([tmp_read_num, 1])
175
+ label_read_arr.append(tmp_lbl_np)
176
+
177
+ dense_raw_data = os.pread(self._dense_file_arr[curr_file_id],
178
+ 52 * tmp_read_num, start_read_pos * 52)
179
+ part_dense_np = np.frombuffer(
180
+ dense_raw_data, dtype=np.float32).reshape([tmp_read_num, 13])
181
+ # part_dense_np = np.log(part_dense_np + 3, dtype=np.float32)
182
+ dense_read_arr.append(part_dense_np)
183
+
184
+ category_raw_data = os.pread(self._category_file_arr[curr_file_id],
185
+ 104 * tmp_read_num, start_read_pos * 104)
186
+ part_cate_np = np.frombuffer(
187
+ category_raw_data, dtype=np.uint32).reshape([tmp_read_num, 26])
188
+ cate_read_arr.append(part_cate_np)
189
+
190
+ curr_file_id += 1
191
+ start_read_pos = 0
192
+ total_read_num += tmp_read_num
193
+
194
+ if len(label_read_arr) == 1:
195
+ label = label_read_arr[0]
196
+ else:
197
+ label = np.concatenate(label_read_arr, axis=0)
198
+
199
+ if len(cate_read_arr) == 1:
200
+ category = cate_read_arr[0]
201
+ else:
202
+ category = np.concatenate(cate_read_arr, axis=0)
203
+
204
+ if len(dense_read_arr) == 1:
205
+ dense = dense_read_arr[0]
206
+ else:
207
+ dense = np.concatenate(dense_read_arr, axis=0)
208
+
209
+ return dense, category, label
210
+
211
+
212
+ if __name__ == '__main__':
213
+ parser = argparse.ArgumentParser()
214
+ parser.add_argument('--batch_size', type=int, default=1024, help='batch_size')
215
+ parser.add_argument(
216
+ '--dataset_dir', type=str, default='./', help='dataset_dir')
217
+ parser.add_argument('--task_num', type=int, default=1, help='task number')
218
+ parser.add_argument('--task_index', type=int, default=0, help='task index')
219
+ parser.add_argument(
220
+ '--prefetch_size', type=int, default=10, help='prefetch size')
221
+ args = parser.parse_args()
222
+
223
+ batch_size = args.batch_size
224
+ dataset_dir = args.dataset_dir
225
+ logging.info('batch_size = %d' % batch_size)
226
+ logging.info('dataset_dir = %s' % dataset_dir)
227
+
228
+ label_files = glob.glob(os.path.join(dataset_dir, '*_label.bin'))
229
+ dense_files = glob.glob(os.path.join(dataset_dir, '*_dense.bin'))
230
+ category_files = glob.glob(os.path.join(dataset_dir, '*_category.bin'))
231
+
232
+ label_files.sort()
233
+ dense_files.sort()
234
+ category_files.sort()
235
+
236
+ test_dataset = BinaryDataset(
237
+ label_files,
238
+ dense_files,
239
+ category_files,
240
+ batch_size=batch_size,
241
+ drop_last=False,
242
+ prefetch=args.prefetch_size,
243
+ global_rank=args.task_index,
244
+ global_size=args.task_num,
245
+ )
246
+
247
+ for step, (dense, category, labels) in enumerate(test_dataset):
248
+ # if (step % 100 == 0):
249
+ # print(step, dense.shape, category.shape, labels.shape)
250
+ if step == 0:
251
+ logging.info('warmup over!')
252
+ start_time = time.time()
253
+ if step == 1000:
254
+ logging.info('1000 steps time = %.3f' % (time.time() - start_time))
255
+ logging.info('total_steps = %d total_time = %.3f' %
256
+ (step + 1, time.time() - start_time))
257
+ logging.info(
258
+ 'final step[%d] dense_shape=%s category_shape=%s labels_shape=%s' %
259
+ (step, dense.shape, category.shape, labels.shape))