easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,189 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from __future__ import absolute_import
4
+ from __future__ import division
5
+ from __future__ import print_function
6
+
7
+ import logging
8
+ import os
9
+
10
+ import tensorflow as tf
11
+ from tensorflow.python.platform import gfile
12
+
13
+ from easy_rec.python.inference.predictor import SINGLE_PLACEHOLDER_FEATURE_KEY
14
+ from easy_rec.python.inference.predictor import Predictor
15
+ from easy_rec.python.protos.dataset_pb2 import DatasetConfig
16
+ from easy_rec.python.utils.check_utils import check_split
17
+
18
+ if tf.__version__ >= '2.0':
19
+ tf = tf.compat.v1
20
+
21
+
22
+ class CSVPredictor(Predictor):
23
+
24
+ def __init__(self,
25
+ model_path,
26
+ data_config,
27
+ with_header=False,
28
+ ds_vector_recall=False,
29
+ fg_json_path=None,
30
+ profiling_file=None,
31
+ selected_cols=None,
32
+ output_sep=chr(1)):
33
+ super(CSVPredictor, self).__init__(model_path, profiling_file, fg_json_path)
34
+ self._output_sep = output_sep
35
+ self._ds_vector_recall = ds_vector_recall
36
+ input_type = DatasetConfig.InputType.Name(data_config.input_type).lower()
37
+ self._with_header = with_header
38
+
39
+ if 'rtp' in input_type:
40
+ self._is_rtp = True
41
+ self._input_sep = data_config.rtp_separator
42
+ else:
43
+ self._is_rtp = False
44
+ self._input_sep = data_config.separator
45
+
46
+ if selected_cols and not ds_vector_recall:
47
+ self._selected_cols = [int(x) for x in selected_cols.split(',')]
48
+ elif ds_vector_recall:
49
+ self._selected_cols = selected_cols.split(',')
50
+ else:
51
+ self._selected_cols = None
52
+
53
+ def _get_reserved_cols(self, reserved_cols):
54
+ if reserved_cols == 'ALL_COLUMNS':
55
+ if self._is_rtp:
56
+ if self._with_header:
57
+ reserved_cols = self._all_fields
58
+ else:
59
+ idx = 0
60
+ reserved_cols = []
61
+ for x in range(len(self._record_defaults) - 1):
62
+ if not self._selected_cols or x in self._selected_cols[:-1]:
63
+ reserved_cols.append(self._input_fields[idx])
64
+ idx += 1
65
+ else:
66
+ reserved_cols.append('no_used_%d' % x)
67
+ reserved_cols.append(SINGLE_PLACEHOLDER_FEATURE_KEY)
68
+ else:
69
+ reserved_cols = self._all_fields
70
+ else:
71
+ reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
72
+ return reserved_cols
73
+
74
+ def _parse_line(self, line):
75
+ check_list = [
76
+ tf.py_func(
77
+ check_split, [line, self._input_sep,
78
+ len(self._record_defaults)],
79
+ Tout=tf.bool)
80
+ ]
81
+ with tf.control_dependencies(check_list):
82
+ fields = tf.decode_csv(
83
+ line,
84
+ field_delim=self._input_sep,
85
+ record_defaults=self._record_defaults,
86
+ name='decode_csv')
87
+ if self._is_rtp:
88
+ if self._with_header:
89
+ inputs = dict(zip(self._all_fields, fields))
90
+ else:
91
+ inputs = {}
92
+ idx = 0
93
+ for x in range(len(self._record_defaults) - 1):
94
+ if not self._selected_cols or x in self._selected_cols[:-1]:
95
+ inputs[self._input_fields[idx]] = fields[x]
96
+ idx += 1
97
+ else:
98
+ inputs['no_used_%d' % x] = fields[x]
99
+ inputs[SINGLE_PLACEHOLDER_FEATURE_KEY] = fields[-1]
100
+ else:
101
+ inputs = {self._all_fields[x]: fields[x] for x in range(len(fields))}
102
+ return inputs
103
+
104
+ def _get_num_cols(self, file_paths):
105
+ # try to figure out number of fields from one file
106
+ num_cols = -1
107
+ with gfile.GFile(file_paths[0], 'r') as fin:
108
+ num_lines = 0
109
+ for line_str in fin:
110
+ line_tok = line_str.strip().split(self._input_sep)
111
+ if num_cols != -1:
112
+ assert num_cols == len(line_tok), (
113
+ 'num selected cols is %d, not equal to %d, current line is: %s, please check input_sep and data.'
114
+ % (num_cols, len(line_tok), line_str))
115
+ num_cols = len(line_tok)
116
+ num_lines += 1
117
+ if num_lines > 10:
118
+ break
119
+ logging.info('num selected cols = %d' % num_cols)
120
+ return num_cols
121
+
122
+ def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
123
+ slice_id):
124
+ file_paths = []
125
+ for path in input_path.split(','):
126
+ for x in gfile.Glob(path):
127
+ if not x.endswith('_SUCCESS'):
128
+ file_paths.append(x)
129
+ assert len(file_paths) > 0, 'match no files with %s' % input_path
130
+
131
+ if self._with_header:
132
+ with gfile.GFile(file_paths[0], 'r') as fin:
133
+ for line_str in fin:
134
+ line_str = line_str.strip()
135
+ self._field_names = line_str.split(self._input_sep)
136
+ break
137
+ print('field_names: %s' % ','.join(self._field_names))
138
+ self._all_fields = self._field_names
139
+ elif self._ds_vector_recall:
140
+ self._all_fields = self._selected_cols
141
+ else:
142
+ self._all_fields = self._input_fields
143
+ if self._is_rtp:
144
+ num_cols = self._get_num_cols(file_paths)
145
+ self._record_defaults = ['' for _ in range(num_cols)]
146
+ if not self._selected_cols:
147
+ self._selected_cols = list(range(num_cols))
148
+ for col_idx in self._selected_cols[:-1]:
149
+ col_name = self._input_fields[col_idx]
150
+ default_val = self._get_defaults(col_name)
151
+ self._record_defaults[col_idx] = default_val
152
+ else:
153
+ self._record_defaults = [
154
+ self._get_defaults(col_name) for col_name in self._all_fields
155
+ ]
156
+
157
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
158
+ parallel_num = min(num_parallel_calls, len(file_paths))
159
+ dataset = dataset.interleave(
160
+ lambda x: tf.data.TextLineDataset(x).skip(int(self._with_header)),
161
+ cycle_length=parallel_num,
162
+ num_parallel_calls=parallel_num)
163
+ dataset = dataset.shard(slice_num, slice_id)
164
+ dataset = dataset.batch(batch_size)
165
+ dataset = dataset.prefetch(buffer_size=64)
166
+ return dataset
167
+
168
+ def _get_writer(self, output_path, slice_id):
169
+ if not gfile.Exists(output_path):
170
+ gfile.MakeDirs(output_path)
171
+ res_path = os.path.join(output_path, 'part-%d.csv' % slice_id)
172
+ table_writer = gfile.GFile(res_path, 'w')
173
+ table_writer.write(
174
+ self._output_sep.join(self._output_cols + self._reserved_cols) + '\n')
175
+ return table_writer
176
+
177
+ def _write_lines(self, table_writer, outputs):
178
+ outputs = '\n'.join(
179
+ [self._output_sep.join([str(i) for i in output]) for output in outputs])
180
+ table_writer.write(outputs + '\n')
181
+
182
+ def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
183
+ reserve_vals = [outputs[x] for x in output_cols] + \
184
+ [all_vals[k] for k in reserved_cols]
185
+ return reserve_vals
186
+
187
+ @property
188
+ def out_of_range_exception(self):
189
+ return (tf.errors.OutOfRangeError)
@@ -0,0 +1,200 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from __future__ import absolute_import
4
+ from __future__ import division
5
+ from __future__ import print_function
6
+
7
+ import os
8
+ import time
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import tensorflow as tf
13
+ from tensorflow.python.platform import gfile
14
+
15
+ from easy_rec.python.inference.predictor import Predictor
16
+ from easy_rec.python.protos.dataset_pb2 import DatasetConfig
17
+ from easy_rec.python.utils import tf_utils
18
+ from easy_rec.python.utils.hive_utils import HiveUtils
19
+ from easy_rec.python.utils.tf_utils import get_tf_type
20
+
21
+ if tf.__version__ >= '2.0':
22
+ tf = tf.compat.v1
23
+
24
+
25
+ class HiveParquetPredictor(Predictor):
26
+
27
+ def __init__(self,
28
+ model_path,
29
+ data_config,
30
+ hive_config,
31
+ fg_json_path=None,
32
+ profiling_file=None,
33
+ output_sep=chr(1),
34
+ all_cols=None,
35
+ all_col_types=None):
36
+ super(HiveParquetPredictor, self).__init__(model_path, profiling_file,
37
+ fg_json_path)
38
+
39
+ self._data_config = data_config
40
+ self._hive_config = hive_config
41
+ self._output_sep = output_sep
42
+ input_type = DatasetConfig.InputType.Name(data_config.input_type).lower()
43
+ if 'rtp' in input_type:
44
+ self._is_rtp = True
45
+ else:
46
+ self._is_rtp = False
47
+ self._all_cols = [x.strip() for x in all_cols if x != '']
48
+ self._all_col_types = [x.strip() for x in all_col_types if x != '']
49
+ self._record_defaults = [
50
+ self._get_defaults(col_name, col_type)
51
+ for col_name, col_type in zip(self._all_cols, self._all_col_types)
52
+ ]
53
+
54
+ def _get_reserved_cols(self, reserved_cols):
55
+ if reserved_cols == 'ALL_COLUMNS':
56
+ reserved_cols = self._all_cols
57
+ else:
58
+ reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
59
+ return reserved_cols
60
+
61
+ def _parse_line(self, *fields):
62
+ fields = list(fields)
63
+ field_dict = {self._all_cols[i]: fields[i] for i in range(len(fields))}
64
+ return field_dict
65
+
66
+ def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
67
+ slice_id):
68
+ self._hive_util = HiveUtils(
69
+ data_config=self._data_config, hive_config=self._hive_config)
70
+ hdfs_path = self._hive_util.get_table_location(input_path)
71
+ self._input_hdfs_path = gfile.Glob(os.path.join(hdfs_path, '*'))
72
+ assert len(self._input_hdfs_path) > 0, 'match no files with %s' % input_path
73
+
74
+ list_type = []
75
+ input_field_type_map = {
76
+ x.input_name: x.input_type for x in self._data_config.input_fields
77
+ }
78
+ type_2_tftype = {
79
+ 'string': tf.string,
80
+ 'double': tf.double,
81
+ 'float': tf.float32,
82
+ 'bigint': tf.int32,
83
+ 'boolean': tf.bool
84
+ }
85
+ for col_name, col_type in zip(self._all_cols, self._all_col_types):
86
+ if col_name in input_field_type_map:
87
+ list_type.append(get_tf_type(input_field_type_map[col_name]))
88
+ else:
89
+ list_type.append(type_2_tftype[col_type.lower()])
90
+ list_type = tuple(list_type)
91
+ list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))]
92
+ list_shapes = tuple(list_shapes)
93
+
94
+ def parquet_read():
95
+ for input_path in self._input_hdfs_path:
96
+ if input_path.endswith('SUCCESS'):
97
+ continue
98
+ df = pd.read_parquet(input_path, engine='pyarrow')
99
+
100
+ df.replace('', np.nan, inplace=True)
101
+ df.replace('NULL', np.nan, inplace=True)
102
+ total_records_num = len(df)
103
+
104
+ for k, v in zip(self._all_cols, self._record_defaults):
105
+ df[k].fillna(v, inplace=True)
106
+
107
+ for start_idx in range(0, total_records_num, batch_size):
108
+ end_idx = min(total_records_num, start_idx + batch_size)
109
+ batch_data = df[start_idx:end_idx]
110
+ inputs = []
111
+ for k in self._all_cols:
112
+ inputs.append(batch_data[k].to_numpy())
113
+ yield tuple(inputs)
114
+
115
+ dataset = tf.data.Dataset.from_generator(
116
+ parquet_read, output_types=list_type, output_shapes=list_shapes)
117
+ dataset = dataset.shard(slice_num, slice_id)
118
+ dataset = dataset.prefetch(buffer_size=64)
119
+ return dataset
120
+
121
+ def get_table_info(self, output_path):
122
+ partition_name, partition_val = None, None
123
+ if len(output_path.split('/')) == 2:
124
+ table_name, partition = output_path.split('/')
125
+ partition_name, partition_val = partition.split('=')
126
+ else:
127
+ table_name = output_path
128
+ return table_name, partition_name, partition_val
129
+
130
+ def _get_writer(self, output_path, slice_id):
131
+ table_name, partition_name, partition_val = self.get_table_info(output_path)
132
+ is_exist = self._hive_util.is_table_or_partition_exist(
133
+ table_name, partition_name, partition_val)
134
+ assert not is_exist, '%s is already exists. Please drop it.' % output_path
135
+
136
+ output_path = output_path.replace('.', '/')
137
+ self._hdfs_path = 'hdfs://%s:9000/user/easy_rec/%s_tmp' % (
138
+ self._hive_config.host, output_path)
139
+ if not gfile.Exists(self._hdfs_path):
140
+ gfile.MakeDirs(self._hdfs_path)
141
+ res_path = os.path.join(self._hdfs_path, 'part-%d.csv' % slice_id)
142
+ table_writer = gfile.GFile(res_path, 'w')
143
+ return table_writer
144
+
145
+ def _write_lines(self, table_writer, outputs):
146
+ outputs = '\n'.join(
147
+ [self._output_sep.join([str(i) for i in output]) for output in outputs])
148
+ table_writer.write(outputs + '\n')
149
+
150
+ def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
151
+ reserve_vals = [outputs[x] for x in output_cols] + \
152
+ [all_vals[k] for k in reserved_cols]
153
+ return reserve_vals
154
+
155
+ def load_to_table(self, output_path, slice_num, slice_id):
156
+ res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % slice_id)
157
+ success_writer = gfile.GFile(res_path, 'w')
158
+ success_writer.write('')
159
+ success_writer.close()
160
+
161
+ if slice_id != 0:
162
+ return
163
+
164
+ for id in range(slice_num):
165
+ res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % id)
166
+ while not gfile.Exists(res_path):
167
+ time.sleep(10)
168
+
169
+ table_name, partition_name, partition_val = self.get_table_info(output_path)
170
+ schema = ''
171
+ for output_col_name in self._output_cols:
172
+ tf_type = self._predictor_impl._outputs_map[output_col_name].dtype
173
+ col_type = tf_utils.get_col_type(tf_type)
174
+ schema += output_col_name + ' ' + col_type + ','
175
+
176
+ for output_col_name in self._reserved_cols:
177
+ assert output_col_name in self._all_cols, 'Column: %s not exists.' % output_col_name
178
+ idx = self._all_cols.index(output_col_name)
179
+ output_col_types = self._all_col_types[idx]
180
+ schema += output_col_name + ' ' + output_col_types + ','
181
+ schema = schema.rstrip(',')
182
+
183
+ if partition_name and partition_val:
184
+ sql = 'create table if not exists %s (%s) PARTITIONED BY (%s string)' % \
185
+ (table_name, schema, partition_name)
186
+ self._hive_util.run_sql(sql)
187
+ sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s PARTITION (%s=%s)" % \
188
+ (self._hdfs_path, table_name, partition_name, partition_val)
189
+ self._hive_util.run_sql(sql)
190
+ else:
191
+ sql = 'create table if not exists %s (%s)' % \
192
+ (table_name, schema)
193
+ self._hive_util.run_sql(sql)
194
+ sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s" % \
195
+ (self._hdfs_path, table_name)
196
+ self._hive_util.run_sql(sql)
197
+
198
+ @property
199
+ def out_of_range_exception(self):
200
+ return (tf.errors.OutOfRangeError)
@@ -0,0 +1,166 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from __future__ import absolute_import
4
+ from __future__ import division
5
+ from __future__ import print_function
6
+
7
+ import os
8
+ import time
9
+
10
+ import tensorflow as tf
11
+ from tensorflow.python.platform import gfile
12
+
13
+ from easy_rec.python.inference.predictor import Predictor
14
+ from easy_rec.python.protos.dataset_pb2 import DatasetConfig
15
+ from easy_rec.python.utils import tf_utils
16
+ from easy_rec.python.utils.hive_utils import HiveUtils
17
+
18
+ if tf.__version__ >= '2.0':
19
+ tf = tf.compat.v1
20
+
21
+
22
+ class HivePredictor(Predictor):
23
+
24
+ def __init__(self,
25
+ model_path,
26
+ data_config,
27
+ hive_config,
28
+ fg_json_path=None,
29
+ profiling_file=None,
30
+ output_sep=chr(1),
31
+ all_cols=None,
32
+ all_col_types=None):
33
+ super(HivePredictor, self).__init__(model_path, profiling_file,
34
+ fg_json_path)
35
+
36
+ self._data_config = data_config
37
+ self._hive_config = hive_config
38
+ self._output_sep = output_sep
39
+ input_type = DatasetConfig.InputType.Name(data_config.input_type).lower()
40
+ if 'rtp' in input_type:
41
+ self._is_rtp = True
42
+ else:
43
+ self._is_rtp = False
44
+ self._all_cols = [x.strip() for x in all_cols if x != '']
45
+ self._all_col_types = [x.strip() for x in all_col_types if x != '']
46
+ self._record_defaults = [
47
+ self._get_defaults(col_name, col_type)
48
+ for col_name, col_type in zip(self._all_cols, self._all_col_types)
49
+ ]
50
+
51
+ def _get_reserved_cols(self, reserved_cols):
52
+ if reserved_cols == 'ALL_COLUMNS':
53
+ reserved_cols = self._all_cols
54
+ else:
55
+ reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
56
+ return reserved_cols
57
+
58
+ def _parse_line(self, line):
59
+ field_delim = self._data_config.rtp_separator if self._is_rtp else self._data_config.separator
60
+ fields = tf.decode_csv(
61
+ line,
62
+ field_delim=field_delim,
63
+ record_defaults=self._record_defaults,
64
+ name='decode_csv')
65
+ inputs = {self._all_cols[x]: fields[x] for x in range(len(fields))}
66
+ return inputs
67
+
68
+ def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
69
+ slice_id):
70
+ self._hive_util = HiveUtils(
71
+ data_config=self._data_config, hive_config=self._hive_config)
72
+ self._input_hdfs_path = self._hive_util.get_table_location(input_path)
73
+ file_paths = tf.gfile.Glob(os.path.join(self._input_hdfs_path, '*'))
74
+ assert len(file_paths) > 0, 'match no files with %s' % input_path
75
+
76
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
77
+ parallel_num = min(num_parallel_calls, len(file_paths))
78
+ dataset = dataset.interleave(
79
+ tf.data.TextLineDataset,
80
+ cycle_length=parallel_num,
81
+ num_parallel_calls=parallel_num)
82
+ dataset = dataset.shard(slice_num, slice_id)
83
+ dataset = dataset.batch(batch_size)
84
+ dataset = dataset.prefetch(buffer_size=64)
85
+ return dataset
86
+
87
+ def get_table_info(self, output_path):
88
+ partition_name, partition_val = None, None
89
+ if len(output_path.split('/')) == 2:
90
+ table_name, partition = output_path.split('/')
91
+ partition_name, partition_val = partition.split('=')
92
+ else:
93
+ table_name = output_path
94
+ return table_name, partition_name, partition_val
95
+
96
+ def _get_writer(self, output_path, slice_id):
97
+ table_name, partition_name, partition_val = self.get_table_info(output_path)
98
+ is_exist = self._hive_util.is_table_or_partition_exist(
99
+ table_name, partition_name, partition_val)
100
+ assert not is_exist, '%s is already exists. Please drop it.' % output_path
101
+
102
+ output_path = output_path.replace('.', '/')
103
+ self._hdfs_path = 'hdfs://%s:9000/user/easy_rec/%s_tmp' % (
104
+ self._hive_config.host, output_path)
105
+ if not gfile.Exists(self._hdfs_path):
106
+ gfile.MakeDirs(self._hdfs_path)
107
+ res_path = os.path.join(self._hdfs_path, 'part-%d.csv' % slice_id)
108
+ table_writer = gfile.GFile(res_path, 'w')
109
+ return table_writer
110
+
111
+ def _write_lines(self, table_writer, outputs):
112
+ outputs = '\n'.join(
113
+ [self._output_sep.join([str(i) for i in output]) for output in outputs])
114
+ table_writer.write(outputs + '\n')
115
+
116
+ def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
117
+ reserve_vals = [outputs[x] for x in output_cols] + \
118
+ [all_vals[k] for k in reserved_cols]
119
+ return reserve_vals
120
+
121
+ def load_to_table(self, output_path, slice_num, slice_id):
122
+ res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % slice_id)
123
+ success_writer = gfile.GFile(res_path, 'w')
124
+ success_writer.write('')
125
+ success_writer.close()
126
+
127
+ if slice_id != 0:
128
+ return
129
+
130
+ for id in range(slice_num):
131
+ res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % id)
132
+ while not gfile.Exists(res_path):
133
+ time.sleep(10)
134
+
135
+ table_name, partition_name, partition_val = self.get_table_info(output_path)
136
+ schema = ''
137
+ for output_col_name in self._output_cols:
138
+ tf_type = self._predictor_impl._outputs_map[output_col_name].dtype
139
+ col_type = tf_utils.get_col_type(tf_type)
140
+ schema += output_col_name + ' ' + col_type + ','
141
+
142
+ for output_col_name in self._reserved_cols:
143
+ assert output_col_name in self._all_cols, 'Column: %s not exists.' % output_col_name
144
+ idx = self._all_cols.index(output_col_name)
145
+ output_col_types = self._all_col_types[idx]
146
+ schema += output_col_name + ' ' + output_col_types + ','
147
+ schema = schema.rstrip(',')
148
+
149
+ if partition_name and partition_val:
150
+ sql = 'create table if not exists %s (%s) PARTITIONED BY (%s string)' % \
151
+ (table_name, schema, partition_name)
152
+ self._hive_util.run_sql(sql)
153
+ sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s PARTITION (%s=%s)" % \
154
+ (self._hdfs_path, table_name, partition_name, partition_val)
155
+ self._hive_util.run_sql(sql)
156
+ else:
157
+ sql = 'create table if not exists %s (%s)' % \
158
+ (table_name, schema)
159
+ self._hive_util.run_sql(sql)
160
+ sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s" % \
161
+ (self._hdfs_path, table_name)
162
+ self._hive_util.run_sql(sql)
163
+
164
+ @property
165
+ def out_of_range_exception(self):
166
+ return (tf.errors.OutOfRangeError)
@@ -0,0 +1,70 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from __future__ import absolute_import
4
+ from __future__ import division
5
+ from __future__ import print_function
6
+
7
+ import tensorflow as tf
8
+
9
+ from easy_rec.python.inference.predictor import Predictor
10
+
11
+
12
+ class ODPSPredictor(Predictor):
13
+
14
+ def __init__(self,
15
+ model_path,
16
+ fg_json_path=None,
17
+ profiling_file=None,
18
+ all_cols='',
19
+ all_col_types=''):
20
+ super(ODPSPredictor, self).__init__(model_path, profiling_file,
21
+ fg_json_path)
22
+ self._all_cols = [x.strip() for x in all_cols.split(',') if x != '']
23
+ self._all_col_types = [
24
+ x.strip() for x in all_col_types.split(',') if x != ''
25
+ ]
26
+ self._record_defaults = [
27
+ self._get_defaults(col_name, col_type)
28
+ for col_name, col_type in zip(self._all_cols, self._all_col_types)
29
+ ]
30
+
31
+ def _get_reserved_cols(self, reserved_cols):
32
+ reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
33
+ return reserved_cols
34
+
35
+ def _parse_line(self, *fields):
36
+ fields = list(fields)
37
+ field_dict = {self._all_cols[i]: fields[i] for i in range(len(fields))}
38
+ return field_dict
39
+
40
+ def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
41
+ slice_id):
42
+ input_list = input_path.split(',')
43
+ dataset = tf.data.TableRecordDataset(
44
+ input_list,
45
+ record_defaults=self._record_defaults,
46
+ slice_id=slice_id,
47
+ slice_count=slice_num,
48
+ selected_cols=','.join(self._all_cols))
49
+ dataset = dataset.batch(batch_size)
50
+ dataset = dataset.prefetch(buffer_size=64)
51
+ return dataset
52
+
53
+ def _get_writer(self, output_path, slice_id):
54
+ import common_io
55
+ table_writer = common_io.table.TableWriter(output_path, slice_id=slice_id)
56
+ return table_writer
57
+
58
+ def _write_lines(self, table_writer, outputs):
59
+ assert len(outputs) > 0
60
+ indices = list(range(0, len(outputs[0])))
61
+ table_writer.write(outputs, indices, allow_type_cast=False)
62
+
63
+ @property
64
+ def out_of_range_exception(self):
65
+ return (tf.python_io.OutOfRangeException, tf.errors.OutOfRangeError)
66
+
67
+ def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
68
+ reserve_vals = [all_vals[k] for k in reserved_cols] + \
69
+ [outputs[x] for x in output_cols]
70
+ return reserve_vals