easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,203 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.input.input import Input
8
+ from easy_rec.python.utils.input_utils import get_type_defaults
9
+
10
+ try:
11
+ from tensorflow.python.data.experimental.ops import parquet_dataset_ops
12
+ from tensorflow.python.data.experimental.ops import parquet_pybind
13
+ from tensorflow.python.data.experimental.ops import dataframe
14
+ from tensorflow.python.ops import gen_ragged_conversion_ops
15
+ from tensorflow.python.ops.work_queue import WorkQueue
16
+ _has_deep_rec = True
17
+ except Exception:
18
+ _has_deep_rec = False
19
+ pass
20
+
21
+ if tf.__version__ >= '2.0':
22
+ tf = tf.compat.v1
23
+
24
+
25
+ class ParquetInputV3(Input):
26
+
27
+ def __init__(self,
28
+ data_config,
29
+ feature_config,
30
+ input_path,
31
+ task_index=0,
32
+ task_num=1,
33
+ check_mode=False,
34
+ pipeline_config=None,
35
+ **kwargs):
36
+ if not _has_deep_rec:
37
+ raise RuntimeError('You should install DeepRec first.')
38
+ super(ParquetInputV3,
39
+ self).__init__(data_config, feature_config, input_path, task_index,
40
+ task_num, check_mode, pipeline_config)
41
+
42
+ self._ignore_val_dict = {}
43
+ for f in data_config.input_fields:
44
+ if f.HasField('ignore_val'):
45
+ self._ignore_val_dict[f.input_name] = get_type_defaults(
46
+ f.input_type, f.ignore_val)
47
+
48
+ self._true_type_dict = {}
49
+ for fc in self._feature_configs:
50
+ if fc.feature_type in [fc.IdFeature, fc.TagFeature, fc.SequenceFeature]:
51
+ if fc.hash_bucket_size > 0 or len(
52
+ fc.vocab_list) > 0 or fc.HasField('vocab_file'):
53
+ self._true_type_dict[fc.input_names[0]] = tf.string
54
+ else:
55
+ self._true_type_dict[fc.input_names[0]] = tf.int64
56
+ if len(fc.input_names) > 1:
57
+ self._true_type_dict[fc.input_names[1]] = tf.float32
58
+ if fc.feature_type == fc.RawFeature:
59
+ self._true_type_dict[fc.input_names[0]] = tf.float32
60
+
61
+ self._reserve_fields = None
62
+ self._reserve_types = None
63
+ if 'reserve_fields' in kwargs and 'reserve_types' in kwargs:
64
+ self._reserve_fields = kwargs['reserve_fields']
65
+ self._reserve_types = kwargs['reserve_types']
66
+
67
+ # In ParquetDataset multi_value use input type
68
+ self._multi_value_types = {}
69
+
70
+ def _ignore_and_cast(self, name, value):
71
+ ignore_value = self._ignore_val_dict.get(name, None)
72
+ if ignore_value:
73
+ if isinstance(value, tf.SparseTensor):
74
+ indices = tf.where(tf.equal(value.values, ignore_value))
75
+ value = tf.SparseTensor(
76
+ tf.gather_nd(value.indices, indices),
77
+ tf.gather_nd(value.values, indices), value.dense_shape)
78
+ elif isinstance(value, tf.Tensor):
79
+ indices = tf.where(tf.not_equal(value, ignore_value), name='indices')
80
+ value = tf.SparseTensor(
81
+ indices=indices,
82
+ values=tf.gather_nd(value, indices),
83
+ dense_shape=tf.shape(value, out_type=tf.int64))
84
+ dtype = self._true_type_dict.get(name, None)
85
+ if dtype:
86
+ value = tf.cast(value, dtype)
87
+ return value
88
+
89
+ def _parse_dataframe_value(self, value):
90
+ if len(value.nested_row_splits) == 0:
91
+ return value.values
92
+ value.values.set_shape([None])
93
+ sparse_value = gen_ragged_conversion_ops.ragged_tensor_to_sparse(
94
+ value.nested_row_splits, value.values)
95
+ return tf.SparseTensor(sparse_value.sparse_indices,
96
+ sparse_value.sparse_values,
97
+ sparse_value.sparse_dense_shape)
98
+
99
+ def _parse_dataframe(self, df):
100
+ inputs = {}
101
+ for k, v in df.items():
102
+ if k in self._effective_fields:
103
+ if isinstance(v, dataframe.DataFrame.Value):
104
+ v = self._parse_dataframe_value(v)
105
+ elif k in self._label_fields:
106
+ if isinstance(v, dataframe.DataFrame.Value):
107
+ v = v.values
108
+ elif k in self._reserve_fields:
109
+ if isinstance(v, dataframe.DataFrame.Value):
110
+ v = v.values
111
+ else:
112
+ continue
113
+ inputs[k] = v
114
+ return inputs
115
+
116
+ def _build(self, mode, params):
117
+ input_files = []
118
+ for sub_path in self._input_path.strip().split(','):
119
+ input_files.extend(tf.gfile.Glob(sub_path))
120
+ file_num = len(input_files)
121
+ logging.info('[task_index=%d] total_file_num=%d task_num=%d' %
122
+ (self._task_index, file_num, self._task_num))
123
+
124
+ task_index = self._task_index
125
+ task_num = self._task_num
126
+ if self._data_config.chief_redundant:
127
+ task_index = max(self._task_index - 1, 0)
128
+ task_num = max(self._task_num - 1, 1)
129
+
130
+ if self._data_config.pai_worker_queue and \
131
+ mode == tf.estimator.ModeKeys.TRAIN:
132
+ work_queue = WorkQueue(
133
+ input_files,
134
+ num_epochs=self.num_epochs,
135
+ shuffle=self._data_config.shuffle)
136
+ my_files = work_queue.input_dataset()
137
+ else:
138
+ my_files = []
139
+ for file_id in range(file_num):
140
+ if (file_id % task_num) == task_index:
141
+ my_files.append(input_files[file_id])
142
+
143
+ parquet_fields = parquet_pybind.parquet_fields(input_files[0])
144
+ parquet_input_fields = []
145
+ for f in parquet_fields:
146
+ if f.name in self._input_fields:
147
+ parquet_input_fields.append(f)
148
+
149
+ all_fields = set(self._effective_fields)
150
+ if mode != tf.estimator.ModeKeys.PREDICT:
151
+ all_fields |= set(self._label_fields)
152
+ if self._reserve_fields:
153
+ all_fields |= set(self._reserve_fields)
154
+
155
+ selected_fields = []
156
+ for f in parquet_input_fields:
157
+ if f.name in all_fields:
158
+ selected_fields.append(f)
159
+
160
+ num_parallel_reads = min(self._data_config.num_parallel_calls,
161
+ len(input_files) // task_num)
162
+ dataset = parquet_dataset_ops.ParquetDataset(
163
+ my_files,
164
+ batch_size=self._batch_size,
165
+ fields=selected_fields,
166
+ drop_remainder=self._data_config.drop_remainder,
167
+ num_parallel_reads=num_parallel_reads)
168
+ # partition_count=task_num,
169
+ # partition_index=task_index)
170
+
171
+ if mode == tf.estimator.ModeKeys.TRAIN:
172
+ if self._data_config.shuffle:
173
+ dataset = dataset.shuffle(
174
+ self._data_config.shuffle_buffer_size,
175
+ seed=2020,
176
+ reshuffle_each_iteration=True)
177
+ dataset = dataset.repeat(self.num_epochs)
178
+ else:
179
+ dataset = dataset.repeat(1)
180
+
181
+ dataset = dataset.map(
182
+ self._parse_dataframe,
183
+ num_parallel_calls=self._data_config.num_parallel_calls)
184
+
185
+ # preprocess is necessary to transform data
186
+ # so that they could be feed into FeatureColumns
187
+ dataset = dataset.map(
188
+ map_func=self._preprocess,
189
+ num_parallel_calls=self._data_config.num_parallel_calls)
190
+
191
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
192
+
193
+ if mode != tf.estimator.ModeKeys.PREDICT:
194
+ dataset = dataset.map(lambda x:
195
+ (self._get_features(x), self._get_labels(x)))
196
+ else:
197
+ dataset = dataset.map(lambda x: (self._get_features(x)))
198
+ return dataset
199
+
200
+ def _preprocess(self, field_dict):
201
+ for k, v in field_dict.items():
202
+ field_dict[k] = self._ignore_and_cast(k, v)
203
+ return super(ParquetInputV3, self)._preprocess(field_dict)
@@ -0,0 +1,225 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.input.input import Input
8
+ from easy_rec.python.ops.gen_str_avx_op import str_split_by_chr
9
+ from easy_rec.python.utils.check_utils import check_split
10
+ from easy_rec.python.utils.check_utils import check_string_to_number
11
+ from easy_rec.python.utils.input_utils import string_to_number
12
+ from easy_rec.python.utils.tf_utils import get_tf_type
13
+
14
+ if tf.__version__ >= '2.0':
15
+ tf = tf.compat.v1
16
+
17
+
18
+ class RTPInput(Input):
19
+ """RTPInput for parsing rtp fg new input format.
20
+
21
+ Our new format(csv in csv) of rtp output:
22
+ label0, item_id, ..., user_id, features
23
+ here the separator(,) could be specified by data_config.rtp_separator
24
+ For the feature column, features are separated by ,
25
+ multiple values of one feature are separated by , such as:
26
+ ...20beautysmartParis...
27
+ The features column and labels are specified by data_config.selected_cols,
28
+ columns are selected by indices as our csv file has no header,
29
+ such as: 0,1,4, means the 4th column is features, the 1st and 2nd
30
+ columns are labels
31
+ """
32
+
33
+ def __init__(self,
34
+ data_config,
35
+ feature_config,
36
+ input_path,
37
+ task_index=0,
38
+ task_num=1,
39
+ check_mode=False,
40
+ pipeline_config=None):
41
+ super(RTPInput,
42
+ self).__init__(data_config, feature_config, input_path, task_index,
43
+ task_num, check_mode, pipeline_config)
44
+ logging.info('input_fields: %s label_fields: %s' %
45
+ (','.join(self._input_fields), ','.join(self._label_fields)))
46
+ self._rtp_separator = self._data_config.rtp_separator
47
+ if not isinstance(self._rtp_separator, str):
48
+ self._rtp_separator = self._rtp_separator.encode('utf-8')
49
+ self._selected_cols = [
50
+ int(x) for x in self._data_config.selected_cols.split(',')
51
+ ]
52
+ self._num_cols = -1
53
+ self._feature_col_id = self._selected_cols[-1]
54
+ logging.info('rtp separator = %s' % self._rtp_separator)
55
+
56
+ def _parse_csv(self, line):
57
+ record_defaults = ['' for i in range(self._num_cols)]
58
+
59
+ # the actual features are in one single column
60
+ record_defaults[self._feature_col_id] = self._data_config.separator.join([
61
+ str(self.get_type_defaults(t, v))
62
+ for x, t, v in zip(self._input_fields, self._input_field_types,
63
+ self._input_field_defaults)
64
+ if x not in self._label_fields
65
+ ])
66
+
67
+ check_list = [
68
+ tf.py_func(
69
+ check_split, [line, self._rtp_separator,
70
+ len(record_defaults)],
71
+ Tout=tf.bool)
72
+ ] if self._check_mode else []
73
+ with tf.control_dependencies(check_list):
74
+ fields = tf.string_split(line, self._rtp_separator, skip_empty=False)
75
+
76
+ fields = tf.reshape(fields.values, [-1, len(record_defaults)])
77
+
78
+ labels = []
79
+ for idx, x in enumerate(self._selected_cols[:-1]):
80
+ field = fields[:, x]
81
+ fname = self._input_fields[idx]
82
+ ftype = self._input_field_types[idx]
83
+ tf_type = get_tf_type(ftype)
84
+ if field.dtype in [tf.string]:
85
+ check_list = [
86
+ tf.py_func(check_string_to_number, [field, fname], Tout=tf.bool)
87
+ ] if self._check_mode else []
88
+ with tf.control_dependencies(check_list):
89
+ field = tf.string_to_number(field, tf_type)
90
+ labels.append(field)
91
+
92
+ # only for features, labels excluded
93
+ record_types = [
94
+ t for x, t in zip(self._input_fields, self._input_field_types)
95
+ if x not in self._label_fields
96
+ ]
97
+ # assume that the last field is the generated feature column
98
+ print('field_delim = %s' % self._data_config.separator)
99
+ feature_str = fields[:, self._feature_col_id]
100
+ check_list = [
101
+ tf.py_func(
102
+ check_split,
103
+ [feature_str, self._data_config.separator,
104
+ len(record_types)],
105
+ Tout=tf.bool)
106
+ ] if self._check_mode else []
107
+ with tf.control_dependencies(check_list):
108
+ fields = str_split_by_chr(
109
+ feature_str, self._data_config.separator, skip_empty=False)
110
+ tmp_fields = tf.reshape(fields.values, [-1, len(record_types)])
111
+ rtp_record_defaults = [
112
+ str(self.get_type_defaults(t, v))
113
+ for x, t, v in zip(self._input_fields, self._input_field_types,
114
+ self._input_field_defaults)
115
+ if x not in self._label_fields
116
+ ]
117
+ fields = []
118
+ for i in range(len(record_types)):
119
+ field = string_to_number(tmp_fields[:, i], record_types[i],
120
+ rtp_record_defaults[i], i)
121
+ fields.append(field)
122
+
123
+ field_keys = [x for x in self._input_fields if x not in self._label_fields]
124
+ effective_fids = [field_keys.index(x) for x in self._effective_fields]
125
+ inputs = {field_keys[x]: fields[x] for x in effective_fids}
126
+
127
+ for x in range(len(self._label_fields)):
128
+ inputs[self._label_fields[x]] = labels[x]
129
+ return inputs
130
+
131
+ def _build(self, mode, params):
132
+ if type(self._input_path) != list:
133
+ self._input_path = self._input_path.split(',')
134
+ file_paths = []
135
+ for x in self._input_path:
136
+ file_paths.extend(tf.gfile.Glob(x))
137
+ assert len(file_paths) > 0, 'match no files with %s' % self._input_path
138
+
139
+ # try to figure out number of fields from one file
140
+ with tf.gfile.GFile(file_paths[0], 'r') as fin:
141
+ num_lines = 0
142
+ for line_str in fin:
143
+ line_tok = line_str.strip().split(self._rtp_separator)
144
+ if self._num_cols != -1:
145
+ assert self._num_cols == len(line_tok), \
146
+ 'num selected cols is %d, not equal to %d, current line is: %s, please check rtp_separator and data.' % \
147
+ (self._num_cols, len(line_tok), line_str)
148
+ self._num_cols = len(line_tok)
149
+ num_lines += 1
150
+ if num_lines > 10:
151
+ break
152
+ logging.info('num selected cols = %d' % self._num_cols)
153
+
154
+ record_defaults = [
155
+ self.get_type_defaults(t, v)
156
+ for x, t, v in zip(self._input_fields, self._input_field_types,
157
+ self._input_field_defaults)
158
+ if x in self._label_fields
159
+ ]
160
+
161
+ # the features are in one single column
162
+ record_defaults.append(
163
+ self._data_config.separator.join([
164
+ str(self.get_type_defaults(t, v))
165
+ for x, t, v in zip(self._input_fields, self._input_field_types,
166
+ self._input_field_defaults)
167
+ if x not in self._label_fields
168
+ ]))
169
+
170
+ num_parallel_calls = self._data_config.num_parallel_calls
171
+ if mode == tf.estimator.ModeKeys.TRAIN:
172
+ logging.info('train files[%d]: %s' %
173
+ (len(file_paths), ','.join(file_paths)))
174
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
175
+
176
+ if self._data_config.file_shard:
177
+ dataset = self._safe_shard(dataset)
178
+
179
+ if self._data_config.shuffle:
180
+ # shuffle input files
181
+ dataset = dataset.shuffle(len(file_paths))
182
+
183
+ # too many readers read the same file will cause performance issues
184
+ # as the same data will be read multiple times
185
+ parallel_num = min(num_parallel_calls, len(file_paths))
186
+ dataset = dataset.interleave(
187
+ tf.data.TextLineDataset,
188
+ cycle_length=parallel_num,
189
+ num_parallel_calls=parallel_num)
190
+
191
+ if not self._data_config.file_shard:
192
+ dataset = self._safe_shard(dataset)
193
+
194
+ if self._data_config.shuffle:
195
+ dataset = dataset.shuffle(
196
+ self._data_config.shuffle_buffer_size,
197
+ seed=2020,
198
+ reshuffle_each_iteration=True)
199
+ dataset = dataset.repeat(self.num_epochs)
200
+ else:
201
+ logging.info('eval files[%d]: %s' %
202
+ (len(file_paths), ','.join(file_paths)))
203
+ dataset = tf.data.TextLineDataset(file_paths)
204
+ dataset = dataset.repeat(1)
205
+
206
+ dataset = dataset.batch(batch_size=self._data_config.batch_size)
207
+
208
+ dataset = dataset.map(
209
+ self._parse_csv,
210
+ num_parallel_calls=self._data_config.num_parallel_calls)
211
+
212
+ # preprocess is necessary to transform data
213
+ # so that they could be feed into FeatureColumns
214
+ dataset = dataset.map(
215
+ map_func=self._preprocess,
216
+ num_parallel_calls=self._data_config.num_parallel_calls)
217
+
218
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
219
+
220
+ if mode != tf.estimator.ModeKeys.PREDICT:
221
+ dataset = dataset.map(lambda x:
222
+ (self._get_features(x), self._get_labels(x)))
223
+ else:
224
+ dataset = dataset.map(lambda x: (self._get_features(x)))
225
+ return dataset
@@ -0,0 +1,145 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.input.input import Input
8
+ from easy_rec.python.protos.dataset_pb2 import DatasetConfig
9
+
10
+ if tf.__version__ >= '2.0':
11
+ tf = tf.compat.v1
12
+
13
+
14
+ class RTPInputV2(Input):
15
+ """RTPInput for parsing rtp fg input format.
16
+
17
+ the original rtp format, it is not efficient for training, the performance have to be tuned.
18
+ """
19
+
20
+ def __init__(self,
21
+ data_config,
22
+ feature_config,
23
+ input_path,
24
+ task_index=0,
25
+ task_num=1,
26
+ check_mode=False,
27
+ pipeline_config=None):
28
+ super(RTPInputV2,
29
+ self).__init__(data_config, feature_config, input_path, task_index,
30
+ task_num, check_mode, pipeline_config)
31
+
32
+ def _parse_rtp(self, lines):
33
+ tf_types = [tf.string for x in self._input_field_types]
34
+
35
+ def _parse_one_line_tf(line):
36
+ line = tf.expand_dims(line, axis=0)
37
+ field_toks = tf.string_split(line, '\002').values
38
+ field_vals = tf.string_split(field_toks, '\003').values
39
+ field_vals = tf.reshape(field_vals, [-1, 2])
40
+ keys = field_vals[:, 0]
41
+ vals = field_vals[:, 1]
42
+ temp_vals = [
43
+ str(
44
+ self.get_type_defaults(self._input_field_types[i],
45
+ self._input_field_defaults[i]))
46
+ for i in range(len(self._input_fields))
47
+ ]
48
+ for i, key in enumerate(self._input_fields):
49
+ msk = tf.equal(key, keys)
50
+ val = tf.boolean_mask(vals, msk)
51
+ def_val = self.get_type_defaults(self._input_field_types[i],
52
+ self._input_field_defaults[i])
53
+ temp_vals[i] = tf.cond(
54
+ tf.reduce_any(msk), lambda: tf.reduce_join(val, separator=','),
55
+ lambda: tf.constant(str(def_val)))
56
+ return temp_vals
57
+
58
+ fields = tf.map_fn(
59
+ _parse_one_line_tf,
60
+ lines,
61
+ tf_types,
62
+ parallel_iterations=64,
63
+ name='parse_one_line_tf_map_fn')
64
+
65
+ def _convert(x, target_type, name):
66
+ if target_type in [DatasetConfig.FLOAT, DatasetConfig.DOUBLE]:
67
+ return tf.string_to_number(
68
+ x, tf.float32, name='convert_input_flt32/%s' % name)
69
+ elif target_type == DatasetConfig.INT32:
70
+ return tf.string_to_number(
71
+ x, tf.int32, name='convert_input_int32/%s' % name)
72
+ elif target_type == DatasetConfig.INT64:
73
+ return tf.string_to_number(
74
+ x, tf.int64, name='convert_input_int64/%s' % name)
75
+ return x
76
+
77
+ inputs = {
78
+ self._input_fields[x]: _convert(fields[x], self._input_field_types[x],
79
+ self._input_fields[x])
80
+ for x in self._effective_fids
81
+ }
82
+
83
+ for x in self._label_fids:
84
+ inputs[self._input_fields[x]] = fields[x]
85
+ return inputs
86
+
87
+ def _build(self, mode, params):
88
+ if type(self._input_path) != list:
89
+ self._input_path = self._input_path.split(',')
90
+ file_paths = []
91
+ for x in self._input_path:
92
+ file_paths.extend(tf.gfile.Glob(x))
93
+ assert len(file_paths) > 0, 'match no files with %s' % self._input_path
94
+
95
+ num_parallel_calls = self._data_config.num_parallel_calls
96
+ if mode == tf.estimator.ModeKeys.TRAIN:
97
+ logging.info('train files[%d]: %s' %
98
+ (len(file_paths), ','.join(file_paths)))
99
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
100
+
101
+ if self._data_config.file_shard:
102
+ dataset = self._safe_shard(dataset)
103
+
104
+ if self._data_config.shuffle:
105
+ # shuffle input files
106
+ dataset = dataset.shuffle(len(file_paths))
107
+
108
+ # too many readers read the same file will cause performance issues
109
+ # as the same data will be read multiple times
110
+ parallel_num = min(num_parallel_calls, len(file_paths))
111
+ dataset = dataset.interleave(
112
+ tf.data.TextLineDataset,
113
+ cycle_length=parallel_num,
114
+ num_parallel_calls=parallel_num)
115
+
116
+ if not self._data_config.file_shard:
117
+ dataset = self._safe_shard(dataset)
118
+
119
+ if self._data_config.shuffle:
120
+ dataset = dataset.shuffle(
121
+ self._data_config.shuffle_buffer_size,
122
+ seed=2020,
123
+ reshuffle_each_iteration=True)
124
+ dataset = dataset.repeat(self.num_epochs)
125
+ else:
126
+ logging.info('eval files[%d]: %s' %
127
+ (len(file_paths), ','.join(file_paths)))
128
+ dataset = tf.data.TextLineDataset(file_paths)
129
+ dataset = dataset.repeat(1)
130
+
131
+ dataset = dataset.batch(self._data_config.batch_size)
132
+ dataset = dataset.map(
133
+ self._parse_rtp, num_parallel_calls=num_parallel_calls)
134
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
135
+ dataset = dataset.map(
136
+ map_func=self._preprocess, num_parallel_calls=num_parallel_calls)
137
+
138
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
139
+
140
+ if mode != tf.estimator.ModeKeys.PREDICT:
141
+ dataset = dataset.map(lambda x:
142
+ (self._get_features(x), self._get_labels(x)))
143
+ else:
144
+ dataset = dataset.map(lambda x: (self._get_features(x)))
145
+ return dataset
@@ -0,0 +1,100 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.input.input import Input
8
+ from easy_rec.python.utils.tf_utils import get_tf_type
9
+
10
+ if tf.__version__ >= '2.0':
11
+ tf = tf.compat.v1
12
+
13
+
14
+ class TFRecordInput(Input):
15
+
16
+ def __init__(self,
17
+ data_config,
18
+ feature_config,
19
+ input_path,
20
+ task_index=0,
21
+ task_num=1,
22
+ check_mode=False,
23
+ pipeline_config=None):
24
+ super(TFRecordInput,
25
+ self).__init__(data_config, feature_config, input_path, task_index,
26
+ task_num, check_mode, pipeline_config)
27
+
28
+ self.feature_desc = {}
29
+ for x, t, d, s in zip(self._input_fields, self._input_field_types,
30
+ self._input_field_defaults, self._input_dims):
31
+ d = self.get_type_defaults(t, d)
32
+ t = get_tf_type(t)
33
+ if s == 1:
34
+ self.feature_desc[x] = tf.FixedLenFeature(
35
+ dtype=t, shape=[s], default_value=d)
36
+ else:
37
+ self.feature_desc[x] = tf.FixedLenFeature(
38
+ dtype=t, shape=[s], default_value=[d] * s)
39
+
40
+ def _parse_tfrecord(self, example):
41
+ try:
42
+ inputs = tf.parse_single_example(example, features=self.feature_desc)
43
+ except AttributeError:
44
+ inputs = tf.io.parse_single_example(example, features=self.feature_desc)
45
+ return inputs
46
+
47
+ def _build(self, mode, params):
48
+ if type(self._input_path) != list:
49
+ self._input_path = self._input_path.split(',')
50
+ file_paths = []
51
+ for x in self._input_path:
52
+ file_paths.extend(tf.gfile.Glob(x))
53
+ assert len(file_paths) > 0, 'match no files with %s' % self._input_path
54
+
55
+ num_parallel_calls = self._data_config.num_parallel_calls
56
+ data_compression_type = self._data_config.data_compression_type
57
+ if mode == tf.estimator.ModeKeys.TRAIN:
58
+ logging.info('train files[%d]: %s' %
59
+ (len(file_paths), ','.join(file_paths)))
60
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
61
+ if self._data_config.shuffle:
62
+ # shuffle input files
63
+ dataset = dataset.shuffle(len(file_paths))
64
+ # too many readers read the same file will cause performance issues
65
+ # as the same data will be read multiple times
66
+ parallel_num = min(num_parallel_calls, len(file_paths))
67
+ dataset = dataset.interleave(
68
+ lambda x: tf.data.TFRecordDataset(
69
+ x, compression_type=data_compression_type),
70
+ cycle_length=parallel_num,
71
+ num_parallel_calls=parallel_num)
72
+ dataset = dataset.shard(self._task_num, self._task_index)
73
+ if self._data_config.shuffle:
74
+ dataset = dataset.shuffle(
75
+ self._data_config.shuffle_buffer_size,
76
+ seed=2020,
77
+ reshuffle_each_iteration=True)
78
+ dataset = dataset.repeat(self.num_epochs)
79
+ else:
80
+ logging.info('eval files[%d]: %s' %
81
+ (len(file_paths), ','.join(file_paths)))
82
+ dataset = tf.data.TFRecordDataset(
83
+ file_paths, compression_type=data_compression_type)
84
+ dataset = dataset.repeat(1)
85
+
86
+ dataset = dataset.map(
87
+ self._parse_tfrecord, num_parallel_calls=num_parallel_calls)
88
+ dataset = dataset.batch(self._data_config.batch_size)
89
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
90
+ dataset = dataset.map(
91
+ map_func=self._preprocess, num_parallel_calls=num_parallel_calls)
92
+
93
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
94
+
95
+ if mode != tf.estimator.ModeKeys.PREDICT:
96
+ dataset = dataset.map(lambda x:
97
+ (self._get_features(x), self._get_labels(x)))
98
+ else:
99
+ dataset = dataset.map(lambda x: (self._get_features(x)))
100
+ return dataset
File without changes