easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,320 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import json
4
+ import logging
5
+ import traceback
6
+
7
+ import tensorflow as tf
8
+ from tensorflow.python.framework import dtypes
9
+
10
+ from easy_rec.python.input.input import Input
11
+ from easy_rec.python.utils import odps_util
12
+ from easy_rec.python.utils.config_util import parse_time
13
+
14
+ if tf.__version__.startswith('1.'):
15
+ from tensorflow.python.platform import gfile
16
+ else:
17
+ import tensorflow.io.gfile as gfile
18
+
19
+ try:
20
+ import common_io
21
+ except Exception:
22
+ common_io = None
23
+
24
+ try:
25
+ from datahub import DataHub
26
+ from datahub.exceptions import DatahubException
27
+ from datahub.models import RecordType
28
+ from datahub.models import CursorType
29
+ import urllib3
30
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
31
+ logging.getLogger('datahub.account').setLevel(logging.INFO)
32
+ except Exception:
33
+ logging.warning(
34
+ 'DataHub is not installed[%s]. You can install it by: pip install pydatahub'
35
+ % traceback.format_exc())
36
+ DataHub = None
37
+
38
+
39
+ class DataHubInput(Input):
40
+ """DataHubInput is used for online train."""
41
+
42
+ def __init__(self,
43
+ data_config,
44
+ feature_config,
45
+ datahub_config,
46
+ task_index=0,
47
+ task_num=1,
48
+ check_mode=False,
49
+ pipeline_config=None):
50
+ super(DataHubInput,
51
+ self).__init__(data_config, feature_config, '', task_index, task_num,
52
+ check_mode, pipeline_config)
53
+ if DataHub is None:
54
+ logging.error('please install datahub: ',
55
+ 'pip install pydatahub ;Python 3.6 recommended')
56
+ try:
57
+ self._num_epoch = 0
58
+ self._datahub_config = datahub_config
59
+ if self._datahub_config is not None:
60
+ akId = self._datahub_config.akId
61
+ akSecret = self._datahub_config.akSecret
62
+ endpoint = self._datahub_config.endpoint
63
+ if not isinstance(akId, str):
64
+ akId = akId.encode('utf-8')
65
+ akSecret = akSecret.encode('utf-8')
66
+ endpoint = endpoint.encode('utf-8')
67
+ self._datahub = DataHub(akId, akSecret, endpoint)
68
+ else:
69
+ self._datahub = None
70
+ except Exception as ex:
71
+ logging.info('exception in init datahub: %s' % str(ex))
72
+ pass
73
+ self._offset_dict = {}
74
+ if datahub_config:
75
+ shard_result = self._datahub.list_shard(self._datahub_config.project,
76
+ self._datahub_config.topic)
77
+ shards = shard_result.shards
78
+ self._all_shards = shards
79
+ self._shards = [
80
+ shards[i] for i in range(len(shards)) if (i % task_num) == task_index
81
+ ]
82
+ logging.info('all shards: %s' % str(self._shards))
83
+
84
+ offset_type = datahub_config.WhichOneof('offset')
85
+ if offset_type == 'offset_time':
86
+ ts = parse_time(datahub_config.offset_time) * 1000
87
+ for x in self._all_shards:
88
+ ks = str(x.shard_id)
89
+ cursor_result = self._datahub.get_cursor(self._datahub_config.project,
90
+ self._datahub_config.topic,
91
+ ks, CursorType.SYSTEM_TIME,
92
+ ts)
93
+ logging.info('shard[%s] cursor = %s' % (ks, cursor_result))
94
+ self._offset_dict[ks] = cursor_result.cursor
95
+ elif offset_type == 'offset_info':
96
+ self._offset_dict = json.loads(self._datahub_config.offset_info)
97
+ else:
98
+ self._offset_dict = {}
99
+
100
+ self._dh_field_names = []
101
+ self._dh_field_types = []
102
+ topic_info = self._datahub.get_topic(
103
+ project_name=self._datahub_config.project,
104
+ topic_name=self._datahub_config.topic)
105
+ for field in topic_info.record_schema.field_list:
106
+ self._dh_field_names.append(field.name)
107
+ self._dh_field_types.append(field.type.value)
108
+
109
+ assert len(
110
+ self._feature_fields) > 0, 'data_config.feature_fields are not set.'
111
+
112
+ for x in self._feature_fields:
113
+ assert x in self._dh_field_names, 'feature_field[%s] is not in datahub' % x
114
+
115
+ # feature column ids in datahub schema
116
+ self._dh_fea_ids = [
117
+ self._dh_field_names.index(x) for x in self._feature_fields
118
+ ]
119
+
120
+ for x in self._label_fields:
121
+ assert x in self._dh_field_names, 'label_field[%s] is not in datahub' % x
122
+
123
+ if self._data_config.HasField('sample_weight'):
124
+ x = self._data_config.sample_weight
125
+ assert x in self._dh_field_names, 'sample_weight[%s] is not in datahub' % x
126
+
127
+ self._read_cnt = 32
128
+
129
+ if len(self._dh_fea_ids) > 1:
130
+ self._filter_fea_func = lambda record: ''.join(
131
+ [record.values[x]
132
+ for x in self._dh_fea_ids]).split(chr(2))[1] == '-1024'
133
+ else:
134
+ dh_fea_id = self._dh_fea_ids[0]
135
+ self._filter_fea_func = lambda record: record.values[dh_fea_id].split(
136
+ self._data_config.separator)[1] == '-1024'
137
+
138
+ def _parse_record(self, *fields):
139
+ field_dict = {}
140
+ fields = list(fields)
141
+
142
+ def _dump_offsets():
143
+ all_offsets = {
144
+ x.shard_id: self._offset_dict[x.shard_id]
145
+ for x in self._shards
146
+ if x.shard_id in self._offset_dict
147
+ }
148
+ return json.dumps(all_offsets)
149
+
150
+ field_dict[Input.DATA_OFFSET] = tf.py_func(_dump_offsets, [], dtypes.string)
151
+
152
+ for x in self._label_fields:
153
+ dh_id = self._dh_field_names.index(x)
154
+ field_dict[x] = fields[dh_id]
155
+
156
+ feature_inputs = self.get_feature_input_fields()
157
+ # only for features, labels and sample_weight excluded
158
+ record_types = [
159
+ t for x, t in zip(self._input_fields, self._input_field_types)
160
+ if x in feature_inputs
161
+ ]
162
+ feature_num = len(record_types)
163
+
164
+ feature_fields = [
165
+ fields[self._dh_field_names.index(x)] for x in self._feature_fields
166
+ ]
167
+ feature = feature_fields[0]
168
+ for fea_id in range(1, len(feature_fields)):
169
+ feature = feature + self._data_config.separator + feature_fields[fea_id]
170
+
171
+ feature = tf.string_split(
172
+ feature, self._data_config.separator, skip_empty=False)
173
+
174
+ fields = tf.reshape(feature.values, [-1, feature_num])
175
+
176
+ for fid in range(feature_num):
177
+ field_dict[feature_inputs[fid]] = fields[:, fid]
178
+ return field_dict
179
+
180
+ def _preprocess(self, field_dict):
181
+ output_dict = super(DataHubInput, self)._preprocess(field_dict)
182
+
183
+ # append offset fields
184
+ if Input.DATA_OFFSET in field_dict:
185
+ output_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]
186
+
187
+ # for _get_features to include DATA_OFFSET
188
+ if Input.DATA_OFFSET not in self._appended_fields:
189
+ self._appended_fields.append(Input.DATA_OFFSET)
190
+
191
+ return output_dict
192
+
193
+ def restore(self, checkpoint_path):
194
+ if checkpoint_path is None:
195
+ return
196
+
197
+ offset_path = checkpoint_path + '.offset'
198
+ if not gfile.Exists(offset_path):
199
+ return
200
+
201
+ logging.info('will restore datahub offset from %s' % offset_path)
202
+ with gfile.GFile(offset_path, 'r') as fin:
203
+ offset_dict = json.load(fin)
204
+ for k in offset_dict:
205
+ v = offset_dict[k]
206
+ ks = str(k)
207
+ if ks not in self._offset_dict or v > self._offset_dict[ks]:
208
+ self._offset_dict[ks] = v
209
+
210
+ def _is_data_empty(self, record):
211
+ is_empty = True
212
+ for fid in self._dh_fea_ids:
213
+ if record.values[fid] is not None and len(record.values[fid]) > 0:
214
+ is_empty = False
215
+ break
216
+ return is_empty
217
+
218
+ def _dump_record(self, record):
219
+ feas = []
220
+ for fid in range(len(record.values)):
221
+ if fid not in self._dh_fea_ids:
222
+ feas.append(self._dh_field_names[fid] + ':' + str(record.values[fid]))
223
+ return ';'.join(feas)
224
+
225
+ def _datahub_generator(self):
226
+ logging.info('start epoch[%d]' % self._num_epoch)
227
+ self._num_epoch += 1
228
+
229
+ try:
230
+ self._datahub.wait_shards_ready(self._datahub_config.project,
231
+ self._datahub_config.topic)
232
+ topic_result = self._datahub.get_topic(self._datahub_config.project,
233
+ self._datahub_config.topic)
234
+ if topic_result.record_type != RecordType.TUPLE:
235
+ logging.error('datahub topic type(%s) illegal' %
236
+ str(topic_result.record_type))
237
+ record_schema = topic_result.record_schema
238
+
239
+ tid = 0
240
+ while True:
241
+ shard_id = self._shards[tid].shard_id
242
+ tid += 1
243
+ if tid >= len(self._shards):
244
+ tid = 0
245
+
246
+ if shard_id not in self._offset_dict:
247
+ cursor_result = self._datahub.get_cursor(self._datahub_config.project,
248
+ self._datahub_config.topic,
249
+ shard_id, CursorType.OLDEST)
250
+ cursor = cursor_result.cursor
251
+ else:
252
+ cursor = self._offset_dict[shard_id]
253
+
254
+ get_result = self._datahub.get_tuple_records(
255
+ self._datahub_config.project, self._datahub_config.topic, shard_id,
256
+ record_schema, cursor, self._read_cnt)
257
+ count = get_result.record_count
258
+ if count == 0:
259
+ continue
260
+ for row_id, record in enumerate(get_result.records):
261
+ if self._is_data_empty(record):
262
+ logging.warning('skip empty data record: %s' %
263
+ self._dump_record(record))
264
+ continue
265
+ if self._filter_fea_func is not None:
266
+ if self._filter_fea_func(record):
267
+ logging.warning('filter data record: %s' %
268
+ self._dump_record(record))
269
+ continue
270
+ yield tuple(list(record.values))
271
+ if shard_id not in self._offset_dict or get_result.next_cursor > self._offset_dict[
272
+ shard_id]:
273
+ self._offset_dict[shard_id] = get_result.next_cursor
274
+ except DatahubException as ex:
275
+ logging.error('DatahubException: %s' % str(ex))
276
+
277
+ def _build(self, mode, params):
278
+ if mode == tf.estimator.ModeKeys.TRAIN:
279
+ assert self._datahub is not None, 'datahub_train_input is not set'
280
+ elif mode == tf.estimator.ModeKeys.EVAL:
281
+ assert self._datahub is not None, 'datahub_eval_input is not set'
282
+
283
+ # get input types
284
+ list_types = [
285
+ odps_util.odps_type_2_tf_type(x) for x in self._dh_field_types
286
+ ]
287
+ list_types = tuple(list_types)
288
+ list_shapes = [
289
+ tf.TensorShape([]) for x in range(0, len(self._dh_field_types))
290
+ ]
291
+ list_shapes = tuple(list_shapes)
292
+ # read datahub
293
+ dataset = tf.data.Dataset.from_generator(
294
+ self._datahub_generator,
295
+ output_types=list_types,
296
+ output_shapes=list_shapes)
297
+ if mode == tf.estimator.ModeKeys.TRAIN:
298
+ if self._data_config.shuffle:
299
+ dataset = dataset.shuffle(
300
+ self._data_config.shuffle_buffer_size,
301
+ seed=2020,
302
+ reshuffle_each_iteration=True)
303
+
304
+ dataset = dataset.batch(self._data_config.batch_size)
305
+
306
+ dataset = dataset.map(
307
+ self._parse_record,
308
+ num_parallel_calls=self._data_config.num_parallel_calls)
309
+ # preprocess is necessary to transform data
310
+ # so that they could be feed into FeatureColumns
311
+ dataset = dataset.map(
312
+ map_func=self._preprocess,
313
+ num_parallel_calls=self._data_config.num_parallel_calls)
314
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
315
+ if mode != tf.estimator.ModeKeys.PREDICT:
316
+ dataset = dataset.map(lambda x:
317
+ (self._get_features(x), self._get_labels(x)))
318
+ else:
319
+ dataset = dataset.map(lambda x: (self._get_features(x)))
320
+ return dataset
@@ -0,0 +1,58 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import tensorflow as tf
5
+
6
+ from easy_rec.python.input.input import Input
7
+ from easy_rec.python.utils.tf_utils import get_tf_type
8
+
9
+ if tf.__version__ >= '2.0':
10
+ tf = tf.compat.v1
11
+
12
+
13
+ class DummyInput(Input):
14
+ """Dummy memory input.
15
+
16
+ Dummy Input is used to debug the performance bottleneck of data pipeline.
17
+ """
18
+
19
+ def __init__(self,
20
+ data_config,
21
+ feature_config,
22
+ input_path,
23
+ task_index=0,
24
+ task_num=1,
25
+ check_mode=False,
26
+ pipeline_config=None,
27
+ input_vals={}):
28
+ super(DummyInput,
29
+ self).__init__(data_config, feature_config, input_path, task_index,
30
+ task_num, check_mode, pipeline_config)
31
+ self._input_vals = input_vals
32
+
33
+ def _build(self, mode, params):
34
+ """Build fake constant input.
35
+
36
+ Args:
37
+ mode: tf.estimator.ModeKeys.TRAIN / tf.estimator.ModeKeys.EVAL / tf.estimator.ModeKeys.PREDICT
38
+ params: parameters passed by estimator, currently not used
39
+
40
+ Returns:
41
+ features tensor dict
42
+ label tensor dict
43
+ """
44
+ features = {}
45
+ for field, field_type, def_val in zip(self._input_fields,
46
+ self._input_field_types,
47
+ self._input_field_defaults):
48
+ tf_type = get_tf_type(field_type)
49
+ def_val = self.get_type_defaults(field_type, default_val=def_val)
50
+
51
+ if field in self._input_vals:
52
+ tensor = self._input_vals[field]
53
+ else:
54
+ tensor = tf.constant([def_val] * self._batch_size, dtype=tf_type)
55
+
56
+ features[field] = tensor
57
+ parse_dict = self._preprocess(features)
58
+ return self._get_features(parse_dict), self._get_labels(parse_dict)
@@ -0,0 +1,123 @@
1
+ # -*- coding: utf-8 -*-
2
+ import logging
3
+ import os
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.input.input import Input
8
+ from easy_rec.python.utils.hive_utils import HiveUtils
9
+
10
+
11
+ class HiveInput(Input):
12
+ """Common IO based interface, could run at local or on data science."""
13
+
14
+ def __init__(self,
15
+ data_config,
16
+ feature_config,
17
+ input_path,
18
+ task_index=0,
19
+ task_num=1,
20
+ check_mode=False,
21
+ pipeline_config=None):
22
+ super(HiveInput,
23
+ self).__init__(data_config, feature_config, input_path, task_index,
24
+ task_num, check_mode, pipeline_config)
25
+ if input_path is None:
26
+ return
27
+ self._data_config = data_config
28
+ self._feature_config = feature_config
29
+ self._hive_config = input_path
30
+
31
+ hive_util = HiveUtils(
32
+ data_config=self._data_config, hive_config=self._hive_config)
33
+ self._input_hdfs_path = hive_util.get_table_location(
34
+ self._hive_config.table_name)
35
+ self._input_table_col_names, self._input_table_col_types = hive_util.get_all_cols(
36
+ self._hive_config.table_name)
37
+
38
+ def _parse_csv(self, line):
39
+ record_defaults = []
40
+ for field_name in self._input_table_col_names:
41
+ if field_name in self._input_fields:
42
+ tid = self._input_fields.index(field_name)
43
+ record_defaults.append(
44
+ self.get_type_defaults(self._input_field_types[tid],
45
+ self._input_field_defaults[tid]))
46
+ else:
47
+ record_defaults.append('')
48
+
49
+ tmp_fields = tf.decode_csv(
50
+ line,
51
+ field_delim=self._data_config.separator,
52
+ record_defaults=record_defaults,
53
+ name='decode_csv')
54
+
55
+ fields = []
56
+ for x in self._input_fields:
57
+ assert x in self._input_table_col_names, 'Column %s not in Table %s.' % (
58
+ x, self._hive_config.table_name)
59
+ fields.append(tmp_fields[self._input_table_col_names.index(x)])
60
+
61
+ # filter only valid fields
62
+ inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
63
+ for x in self._label_fids:
64
+ inputs[self._input_fields[x]] = fields[x]
65
+ return inputs
66
+
67
+ def _build(self, mode, params):
68
+ file_paths = tf.gfile.Glob(os.path.join(self._input_hdfs_path, '*'))
69
+ assert len(
70
+ file_paths) > 0, 'match no files with %s' % self._hive_config.table_name
71
+
72
+ num_parallel_calls = self._data_config.num_parallel_calls
73
+ if mode == tf.estimator.ModeKeys.TRAIN:
74
+ logging.info('train files[%d]: %s' %
75
+ (len(file_paths), ','.join(file_paths)))
76
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
77
+
78
+ if self._data_config.file_shard:
79
+ dataset = self._safe_shard(dataset)
80
+
81
+ if self._data_config.shuffle:
82
+ # shuffle input files
83
+ dataset = dataset.shuffle(len(file_paths))
84
+
85
+ # too many readers read the same file will cause performance issues
86
+ # as the same data will be read multiple times
87
+ parallel_num = min(num_parallel_calls, len(file_paths))
88
+ dataset = dataset.interleave(
89
+ lambda x: tf.data.TextLineDataset(x),
90
+ cycle_length=parallel_num,
91
+ num_parallel_calls=parallel_num)
92
+
93
+ if not self._data_config.file_shard:
94
+ dataset = self._safe_shard(dataset)
95
+
96
+ if self._data_config.shuffle:
97
+ dataset = dataset.shuffle(
98
+ self._data_config.shuffle_buffer_size,
99
+ seed=2020,
100
+ reshuffle_each_iteration=True)
101
+ dataset = dataset.repeat(self.num_epochs)
102
+ else:
103
+ logging.info('eval files[%d]: %s' %
104
+ (len(file_paths), ','.join(file_paths)))
105
+ dataset = tf.data.TextLineDataset(file_paths)
106
+ dataset = dataset.repeat(1)
107
+
108
+ dataset = dataset.batch(self._data_config.batch_size)
109
+ dataset = dataset.map(
110
+ self._parse_csv, num_parallel_calls=num_parallel_calls)
111
+
112
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
113
+ dataset = dataset.map(
114
+ map_func=self._preprocess, num_parallel_calls=num_parallel_calls)
115
+
116
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
117
+
118
+ if mode != tf.estimator.ModeKeys.PREDICT:
119
+ dataset = dataset.map(lambda x:
120
+ (self._get_features(x), self._get_labels(x)))
121
+ else:
122
+ dataset = dataset.map(lambda x: (self._get_features(x)))
123
+ return dataset
@@ -0,0 +1,140 @@
1
+ # -*- coding: utf-8 -*-
2
+ import logging
3
+ import os
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import tensorflow as tf
8
+
9
+ from easy_rec.python.input.input import Input
10
+ from easy_rec.python.utils.hive_utils import HiveUtils
11
+ from easy_rec.python.utils.tf_utils import get_tf_type
12
+
13
+
14
+ class HiveParquetInput(Input):
15
+ """Common IO based interface, could run at local or on data science."""
16
+
17
+ def __init__(self,
18
+ data_config,
19
+ feature_config,
20
+ input_path,
21
+ task_index=0,
22
+ task_num=1,
23
+ check_mode=False,
24
+ pipeline_config=None):
25
+ super(HiveParquetInput,
26
+ self).__init__(data_config, feature_config, input_path, task_index,
27
+ task_num, check_mode, pipeline_config)
28
+ if input_path is None:
29
+ return
30
+ self._data_config = data_config
31
+ self._feature_config = feature_config
32
+ self._hive_config = input_path
33
+
34
+ hive_util = HiveUtils(
35
+ data_config=self._data_config, hive_config=self._hive_config)
36
+ input_hdfs_path = hive_util.get_table_location(self._hive_config.table_name)
37
+ self._input_table_col_names, self._input_table_col_types = hive_util.get_all_cols(
38
+ self._hive_config.table_name)
39
+ self._all_hdfs_path = tf.gfile.Glob(os.path.join(input_hdfs_path, '*'))
40
+
41
+ for x in self._input_fields:
42
+ assert x in self._input_table_col_names, 'Column %s not in Table %s.' % (
43
+ x, self._hive_config.table_name)
44
+
45
+ self._record_defaults = [
46
+ self.get_type_defaults(t, v)
47
+ for t, v in zip(self._input_field_types, self._input_field_defaults)
48
+ ]
49
+
50
+ def _file_shard(self, file_paths, task_num, task_index):
51
+ if self._data_config.chief_redundant:
52
+ task_num = max(task_num - 1, 1)
53
+ task_index = max(task_index - 1, 0)
54
+ task_file_paths = []
55
+ for idx in range(task_index, len(file_paths), task_num):
56
+ task_file_paths.append(file_paths[idx])
57
+ return task_file_paths
58
+
59
+ def _parquet_read(self):
60
+ for input_path in self._input_hdfs_path:
61
+ if input_path.endswith('SUCCESS'):
62
+ continue
63
+ df = pd.read_parquet(input_path, engine='pyarrow')
64
+ df = df[self._input_fields]
65
+ df.replace('', np.nan, inplace=True)
66
+ df.replace('NULL', np.nan, inplace=True)
67
+ total_records_num = len(df)
68
+
69
+ for k, v in zip(self._input_fields, self._record_defaults):
70
+ df[k].fillna(v, inplace=True)
71
+
72
+ for start_idx in range(0, total_records_num,
73
+ self._data_config.batch_size):
74
+ end_idx = min(total_records_num,
75
+ start_idx + self._data_config.batch_size)
76
+ batch_data = df[start_idx:end_idx]
77
+ inputs = []
78
+ for k in self._input_fields:
79
+ inputs.append(batch_data[k].to_numpy())
80
+ yield tuple(inputs)
81
+
82
+ def _parse_csv(self, *fields):
83
+ # filter only valid fields
84
+ inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
85
+ # filter only valid labels
86
+ for x in self._label_fids:
87
+ inputs[self._input_fields[x]] = fields[x]
88
+ return inputs
89
+
90
+ def _build(self, mode, params):
91
+ # get input type
92
+ list_type = [get_tf_type(x) for x in self._input_field_types]
93
+ list_type = tuple(list_type)
94
+ list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))]
95
+ list_shapes = tuple(list_shapes)
96
+
97
+ if len(self._all_hdfs_path) >= 2 * self._task_num:
98
+ file_shard = True
99
+ self._input_hdfs_path = self._file_shard(self._all_hdfs_path,
100
+ self._task_num, self._task_index)
101
+ else:
102
+ file_shard = False
103
+ self._input_hdfs_path = self._all_hdfs_path
104
+ logging.info('input path: %s' % self._input_hdfs_path)
105
+ assert len(self._input_hdfs_path
106
+ ) > 0, 'match no files with %s' % self._hive_config.table_name
107
+
108
+ dataset = tf.data.Dataset.from_generator(
109
+ self._parquet_read, output_types=list_type, output_shapes=list_shapes)
110
+
111
+ if not file_shard:
112
+ dataset = self._safe_shard(dataset)
113
+
114
+ if mode == tf.estimator.ModeKeys.TRAIN:
115
+ dataset = dataset.shuffle(
116
+ self._data_config.shuffle_buffer_size,
117
+ seed=2020,
118
+ reshuffle_each_iteration=True)
119
+ dataset = dataset.repeat(self.num_epochs)
120
+ else:
121
+ dataset = dataset.repeat(1)
122
+
123
+ dataset = dataset.map(
124
+ self._parse_csv,
125
+ num_parallel_calls=self._data_config.num_parallel_calls)
126
+
127
+ # preprocess is necessary to transform data
128
+ # so that they could be feed into FeatureColumns
129
+ dataset = dataset.map(
130
+ map_func=self._preprocess,
131
+ num_parallel_calls=self._data_config.num_parallel_calls)
132
+
133
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
134
+
135
+ if mode != tf.estimator.ModeKeys.PREDICT:
136
+ dataset = dataset.map(lambda x:
137
+ (self._get_features(x), self._get_labels(x)))
138
+ else:
139
+ dataset = dataset.map(lambda x: (self._get_features(x)))
140
+ return dataset