easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,101 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ from easy_rec.python.input.input import Input
6
+ from easy_rec.python.utils import odps_util
7
+
8
+ try:
9
+ import pai
10
+ except Exception:
11
+ pass
12
+
13
+
14
+ class OdpsInput(Input):
15
+
16
+ def __init__(self,
17
+ data_config,
18
+ feature_config,
19
+ input_path,
20
+ task_index=0,
21
+ task_num=1,
22
+ check_mode=False,
23
+ pipeline_config=None):
24
+ super(OdpsInput,
25
+ self).__init__(data_config, feature_config, input_path, task_index,
26
+ task_num, check_mode, pipeline_config)
27
+
28
+ def _build(self, mode, params):
29
+ # check data_config are consistent with odps tables
30
+ odps_util.check_input_field_and_types(self._data_config)
31
+
32
+ selected_cols = ','.join(self._input_fields)
33
+ if self._data_config.chief_redundant and \
34
+ mode == tf.estimator.ModeKeys.TRAIN:
35
+ reader = tf.TableRecordReader(
36
+ csv_delimiter=self._data_config.separator,
37
+ selected_cols=selected_cols,
38
+ slice_count=max(self._task_num - 1, 1),
39
+ slice_id=max(self._task_index - 1, 0))
40
+ else:
41
+ reader = tf.TableRecordReader(
42
+ csv_delimiter=self._data_config.separator,
43
+ selected_cols=selected_cols,
44
+ slice_count=self._task_num,
45
+ slice_id=self._task_index)
46
+
47
+ if type(self._input_path) != list:
48
+ self._input_path = self._input_path.split(',')
49
+ assert len(
50
+ self._input_path) > 0, 'match no files with %s' % self._input_path
51
+
52
+ if mode == tf.estimator.ModeKeys.TRAIN:
53
+ if self._data_config.pai_worker_queue:
54
+ work_queue = pai.data.WorkQueue(
55
+ self._input_path,
56
+ num_epochs=self.num_epochs,
57
+ shuffle=self._data_config.shuffle,
58
+ num_slices=self._data_config.pai_worker_slice_num * self._task_num)
59
+ work_queue.add_summary()
60
+ file_queue = work_queue.input_producer()
61
+ reader = tf.TableRecordReader()
62
+ else:
63
+ file_queue = tf.train.string_input_producer(
64
+ self._input_path,
65
+ num_epochs=self.num_epochs,
66
+ capacity=1000,
67
+ shuffle=self._data_config.shuffle)
68
+ else:
69
+ file_queue = tf.train.string_input_producer(
70
+ self._input_path, num_epochs=1, capacity=1000, shuffle=False)
71
+ key, value = reader.read_up_to(file_queue, self._batch_size)
72
+
73
+ record_defaults = [
74
+ self.get_type_defaults(t, v)
75
+ for t, v in zip(self._input_field_types, self._input_field_defaults)
76
+ ]
77
+ fields = tf.decode_csv(
78
+ value,
79
+ record_defaults=record_defaults,
80
+ field_delim=self._data_config.separator,
81
+ name='decode_csv')
82
+
83
+ inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
84
+ for x in self._label_fids:
85
+ inputs[self._input_fields[x]] = fields[x]
86
+
87
+ fields = self._preprocess(inputs)
88
+
89
+ features = self._get_features(fields)
90
+ # import pai
91
+ if mode != tf.estimator.ModeKeys.PREDICT:
92
+ labels = self._get_labels(fields)
93
+ # features, labels = pai.data.prefetch(features=(features, labels),
94
+ # capacity=self._prefetch_size, num_threads=2,
95
+ # closed_exception_types=(tuple([tf.errors.InternalError])))
96
+ return features, labels
97
+ else:
98
+ # features = pai.data.prefetch(features=(features,),
99
+ # capacity=self._prefetch_size, num_threads=2,
100
+ # closed_exception_types=(tuple([tf.errors.InternalError])))
101
+ return features
@@ -0,0 +1,110 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.input.input import Input
8
+ from easy_rec.python.utils import odps_util
9
+
10
+ try:
11
+ import pai
12
+ except Exception:
13
+ pass
14
+
15
+
16
+ class OdpsInputV2(Input):
17
+
18
+ def __init__(self,
19
+ data_config,
20
+ feature_config,
21
+ input_path,
22
+ task_index=0,
23
+ task_num=1,
24
+ check_mode=False,
25
+ pipeline_config=None):
26
+ super(OdpsInputV2,
27
+ self).__init__(data_config, feature_config, input_path, task_index,
28
+ task_num, check_mode, pipeline_config)
29
+
30
+ def _parse_table(self, *fields):
31
+ fields = list(fields)
32
+ inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
33
+ for x in self._label_fids:
34
+ inputs[self._input_fields[x]] = fields[x]
35
+ return inputs
36
+
37
+ def _build(self, mode, params):
38
+ if type(self._input_path) != list:
39
+ self._input_path = self._input_path.split(',')
40
+ assert len(
41
+ self._input_path) > 0, 'match no files with %s' % self._input_path
42
+ # check data_config are consistent with odps tables
43
+ odps_util.check_input_field_and_types(self._data_config)
44
+
45
+ selected_cols = ','.join(self._input_fields)
46
+ record_defaults = [
47
+ self.get_type_defaults(x, v)
48
+ for x, v in zip(self._input_field_types, self._input_field_defaults)
49
+ ]
50
+
51
+ if self._data_config.pai_worker_queue and \
52
+ mode == tf.estimator.ModeKeys.TRAIN:
53
+ logging.info('pai_worker_slice_num = %d' %
54
+ self._data_config.pai_worker_slice_num)
55
+ work_queue = pai.data.WorkQueue(
56
+ self._input_path,
57
+ num_epochs=self.num_epochs,
58
+ shuffle=self._data_config.shuffle,
59
+ num_slices=self._data_config.pai_worker_slice_num * self._task_num)
60
+ que_paths = work_queue.input_dataset()
61
+ dataset = tf.data.TableRecordDataset(
62
+ que_paths,
63
+ record_defaults=record_defaults,
64
+ selected_cols=selected_cols)
65
+ elif self._data_config.chief_redundant and \
66
+ mode == tf.estimator.ModeKeys.TRAIN:
67
+ dataset = tf.data.TableRecordDataset(
68
+ self._input_path,
69
+ record_defaults=record_defaults,
70
+ selected_cols=selected_cols,
71
+ slice_id=max(self._task_index - 1, 0),
72
+ slice_count=max(self._task_num - 1, 1))
73
+ else:
74
+ dataset = tf.data.TableRecordDataset(
75
+ self._input_path,
76
+ record_defaults=record_defaults,
77
+ selected_cols=selected_cols,
78
+ slice_id=self._task_index,
79
+ slice_count=self._task_num)
80
+
81
+ if mode == tf.estimator.ModeKeys.TRAIN:
82
+ if self._data_config.shuffle:
83
+ dataset = dataset.shuffle(
84
+ self._data_config.shuffle_buffer_size,
85
+ seed=2020,
86
+ reshuffle_each_iteration=True)
87
+ dataset = dataset.repeat(self.num_epochs)
88
+ else:
89
+ dataset = dataset.repeat(1)
90
+
91
+ dataset = dataset.batch(batch_size=self._data_config.batch_size)
92
+
93
+ dataset = dataset.map(
94
+ self._parse_table,
95
+ num_parallel_calls=self._data_config.num_parallel_calls)
96
+
97
+ # preprocess is necessary to transform data
98
+ # so that they could be feed into FeatureColumns
99
+ dataset = dataset.map(
100
+ map_func=self._preprocess,
101
+ num_parallel_calls=self._data_config.num_parallel_calls)
102
+
103
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
104
+
105
+ if mode != tf.estimator.ModeKeys.PREDICT:
106
+ dataset = dataset.map(lambda x:
107
+ (self._get_features(x), self._get_labels(x)))
108
+ else:
109
+ dataset = dataset.map(lambda x: (self._get_features(x)))
110
+ return dataset
@@ -0,0 +1,132 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import logging
5
+ import sys
6
+
7
+ import tensorflow as tf
8
+
9
+ from easy_rec.python.input.input import Input
10
+ from easy_rec.python.utils import odps_util
11
+ from easy_rec.python.utils.tf_utils import get_tf_type
12
+
13
+ try:
14
+ import common_io
15
+ except Exception:
16
+ common_io = None
17
+
18
+
19
+ class OdpsInputV3(Input):
20
+ """Common IO based interface, could run at local or on data science."""
21
+
22
+ def __init__(self,
23
+ data_config,
24
+ feature_config,
25
+ input_path,
26
+ task_index=0,
27
+ task_num=1,
28
+ check_mode=False,
29
+ pipeline_config=None):
30
+ super(OdpsInputV3,
31
+ self).__init__(data_config, feature_config, input_path, task_index,
32
+ task_num, check_mode, pipeline_config)
33
+ self._num_epoch = 0
34
+ if common_io is None:
35
+ logging.error('''
36
+ please install common_io pip install
37
+ https://easyrec.oss-cn-beijing.aliyuncs.com/3rdparty/common_io-0.4.2%2Btunnel-py2.py3-none-any.whl'''
38
+ )
39
+ sys.exit(1)
40
+
41
+ def _parse_table(self, *fields):
42
+ fields = list(fields)
43
+ inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
44
+ for x in self._label_fids:
45
+ inputs[self._input_fields[x]] = fields[x]
46
+ return inputs
47
+
48
+ def _odps_read(self):
49
+ logging.info('start epoch[%d]' % self._num_epoch)
50
+ self._num_epoch += 1
51
+ if type(self._input_path) != list:
52
+ self._input_path = self._input_path.split(',')
53
+ assert len(
54
+ self._input_path) > 0, 'match no files with %s' % self._input_path
55
+
56
+ # check data_config are consistent with odps tables
57
+ odps_util.check_input_field_and_types(self._data_config)
58
+
59
+ record_defaults = [
60
+ self.get_type_defaults(x, v)
61
+ for x, v in zip(self._input_field_types, self._input_field_defaults)
62
+ ]
63
+
64
+ selected_cols = ','.join(self._input_fields)
65
+ for table_path in self._input_path:
66
+ reader = common_io.table.TableReader(
67
+ table_path,
68
+ selected_cols=selected_cols,
69
+ slice_id=self._task_index,
70
+ slice_count=self._task_num)
71
+ total_records_num = reader.get_row_count()
72
+ batch_num = int(total_records_num / self._data_config.batch_size)
73
+ res_num = total_records_num - batch_num * self._data_config.batch_size
74
+ batch_defaults = [
75
+ [x] * self._data_config.batch_size for x in record_defaults
76
+ ]
77
+ for batch_id in range(batch_num):
78
+ batch_data_np = [x.copy() for x in batch_defaults]
79
+ for row_id, one_data in enumerate(
80
+ reader.read(self._data_config.batch_size)):
81
+ for col_id in range(len(record_defaults)):
82
+ if one_data[col_id] not in ['', 'NULL', None]:
83
+ batch_data_np[col_id][row_id] = one_data[col_id]
84
+ yield tuple(batch_data_np)
85
+ if res_num > 0:
86
+ batch_data_np = [x[:res_num] for x in batch_defaults]
87
+ for row_id, one_data in enumerate(reader.read(res_num)):
88
+ for col_id in range(len(record_defaults)):
89
+ if one_data[col_id] not in ['', 'NULL', None]:
90
+ batch_data_np[col_id][row_id] = one_data[col_id]
91
+ yield tuple(batch_data_np)
92
+ reader.close()
93
+ logging.info('finish epoch[%d]' % self._num_epoch)
94
+
95
+ def _build(self, mode, params):
96
+ # get input type
97
+ list_type = [get_tf_type(x) for x in self._input_field_types]
98
+ list_type = tuple(list_type)
99
+ list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))]
100
+ list_shapes = tuple(list_shapes)
101
+
102
+ # read odps tables
103
+ dataset = tf.data.Dataset.from_generator(
104
+ self._odps_read, output_types=list_type, output_shapes=list_shapes)
105
+
106
+ if mode == tf.estimator.ModeKeys.TRAIN:
107
+ dataset = dataset.shuffle(
108
+ self._data_config.shuffle_buffer_size,
109
+ seed=2020,
110
+ reshuffle_each_iteration=True)
111
+ dataset = dataset.repeat(self.num_epochs)
112
+ else:
113
+ dataset = dataset.repeat(1)
114
+
115
+ dataset = dataset.map(
116
+ self._parse_table,
117
+ num_parallel_calls=self._data_config.num_parallel_calls)
118
+
119
+ # preprocess is necessary to transform data
120
+ # so that they could be feed into FeatureColumns
121
+ dataset = dataset.map(
122
+ map_func=self._preprocess,
123
+ num_parallel_calls=self._data_config.num_parallel_calls)
124
+
125
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
126
+
127
+ if mode != tf.estimator.ModeKeys.PREDICT:
128
+ dataset = dataset.map(lambda x:
129
+ (self._get_features(x), self._get_labels(x)))
130
+ else:
131
+ dataset = dataset.map(lambda x: (self._get_features(x)))
132
+ return dataset
@@ -0,0 +1,187 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.input.input import Input
8
+ from easy_rec.python.ops.gen_str_avx_op import str_split_by_chr
9
+ from easy_rec.python.utils.check_utils import check_split
10
+ from easy_rec.python.utils.input_utils import string_to_number
11
+
12
+ try:
13
+ import pai
14
+ except Exception:
15
+ pass
16
+
17
+
18
+ class OdpsRTPInput(Input):
19
+ """RTPInput for parsing rtp fg new input format on odps.
20
+
21
+ Our new format(csv in table) of rtp output:
22
+ label0, item_id, ..., user_id, features
23
+ For the feature column, features are separated by ,
24
+ multiple values of one feature are separated by , such as:
25
+ ...20beautysmartParis...
26
+ The features column and labels are specified by data_config.selected_cols,
27
+ columns are selected by names in the table
28
+ such as: clk,features, the last selected column is features, the first
29
+ selected columns are labels
30
+ """
31
+
32
+ def __init__(self,
33
+ data_config,
34
+ feature_config,
35
+ input_path,
36
+ task_index=0,
37
+ task_num=1,
38
+ check_mode=False,
39
+ pipeline_config=None):
40
+ super(OdpsRTPInput,
41
+ self).__init__(data_config, feature_config, input_path, task_index,
42
+ task_num, check_mode, pipeline_config)
43
+ logging.info('input_fields: %s label_fields: %s' %
44
+ (','.join(self._input_fields), ','.join(self._label_fields)))
45
+
46
+ def _parse_table(self, *fields):
47
+ fields = list(fields)
48
+ labels = fields[:-1]
49
+
50
+ selected_cols = self._data_config.selected_cols \
51
+ if self._data_config.selected_cols else None
52
+ non_feature_cols = self._label_fields
53
+ if selected_cols:
54
+ cols = [c.strip() for c in selected_cols.split(',')]
55
+ non_feature_cols = cols[:-1]
56
+ # only for features, labels and sample_weight excluded
57
+ record_types = [
58
+ t for x, t in zip(self._input_fields, self._input_field_types)
59
+ if x not in non_feature_cols
60
+ ]
61
+ record_defaults = [
62
+ self.get_type_defaults(t, v)
63
+ for x, t, v in zip(self._input_fields, self._input_field_types,
64
+ self._input_field_defaults)
65
+ if x not in non_feature_cols
66
+ ]
67
+
68
+ feature_num = len(record_types)
69
+ # assume that the last field is the generated feature column
70
+ print('field_delim = %s, feature_num = %d' %
71
+ (self._data_config.separator, feature_num))
72
+ logging.info('field_delim = %s, input_field_name = %d' %
73
+ (self._data_config.separator, len(record_types)))
74
+
75
+ check_list = [
76
+ tf.py_func(
77
+ check_split,
78
+ [fields[-1], self._data_config.separator,
79
+ len(record_types)],
80
+ Tout=tf.bool)
81
+ ] if self._check_mode else []
82
+ with tf.control_dependencies(check_list):
83
+ fields = str_split_by_chr(
84
+ fields[-1], self._data_config.separator, skip_empty=False)
85
+ tmp_fields = tf.reshape(fields.values, [-1, feature_num])
86
+ fields = labels[len(self._label_fields):]
87
+ for i in range(feature_num):
88
+ field = string_to_number(tmp_fields[:, i], record_types[i],
89
+ record_defaults[i], i)
90
+ fields.append(field)
91
+
92
+ field_keys = [x for x in self._input_fields if x not in self._label_fields]
93
+ effective_fids = [field_keys.index(x) for x in self._effective_fields]
94
+ inputs = {field_keys[x]: fields[x] for x in effective_fids}
95
+
96
+ for x in range(len(self._label_fields)):
97
+ inputs[self._label_fields[x]] = labels[x]
98
+ print('effective field num = %d, input_num = %d' %
99
+ (len(fields), len(inputs)))
100
+ return inputs
101
+
102
+ def _build(self, mode, params):
103
+ if type(self._input_path) != list:
104
+ self._input_path = self._input_path.split(',')
105
+ assert len(
106
+ self._input_path) > 0, 'match no files with %s' % self._input_path
107
+
108
+ selected_cols = self._data_config.selected_cols \
109
+ if self._data_config.selected_cols else None
110
+ if selected_cols:
111
+ cols = [c.strip() for c in selected_cols.split(',')]
112
+ record_defaults = [
113
+ self.get_type_defaults(t, v)
114
+ for x, t, v in zip(self._input_fields, self._input_field_types,
115
+ self._input_field_defaults)
116
+ if x in cols[:-1]
117
+ ]
118
+ print('selected_cols: %s; defaults num: %d' %
119
+ (','.join(cols), len(record_defaults)))
120
+ else:
121
+ record_defaults = [
122
+ self.get_type_defaults(t, v)
123
+ for x, t, v in zip(self._input_fields, self._input_field_types,
124
+ self._input_field_defaults)
125
+ if x in self._label_fields
126
+ ]
127
+ # the actual features are in one single column
128
+ record_defaults.append(
129
+ self._data_config.separator.join([
130
+ str(self.get_type_defaults(t, v))
131
+ for x, t, v in zip(self._input_fields, self._input_field_types,
132
+ self._input_field_defaults)
133
+ if x not in self._label_fields
134
+ ]))
135
+
136
+ if self._data_config.pai_worker_queue and \
137
+ mode == tf.estimator.ModeKeys.TRAIN:
138
+ logging.info('pai_worker_slice_num = %d' %
139
+ self._data_config.pai_worker_slice_num)
140
+ work_queue = pai.data.WorkQueue(
141
+ self._input_path,
142
+ num_epochs=self.num_epochs,
143
+ shuffle=self._data_config.shuffle,
144
+ num_slices=self._data_config.pai_worker_slice_num * self._task_num)
145
+ que_paths = work_queue.input_dataset()
146
+ dataset = tf.data.TableRecordDataset(
147
+ que_paths,
148
+ record_defaults=record_defaults,
149
+ selected_cols=selected_cols)
150
+ else:
151
+ dataset = tf.data.TableRecordDataset(
152
+ self._input_path,
153
+ record_defaults=record_defaults,
154
+ selected_cols=selected_cols,
155
+ slice_id=self._task_index,
156
+ slice_count=self._task_num)
157
+
158
+ if mode == tf.estimator.ModeKeys.TRAIN:
159
+ if self._data_config.shuffle:
160
+ dataset = dataset.shuffle(
161
+ self._data_config.shuffle_buffer_size,
162
+ seed=2020,
163
+ reshuffle_each_iteration=True)
164
+ dataset = dataset.repeat(self.num_epochs)
165
+ else:
166
+ dataset = dataset.repeat(1)
167
+
168
+ dataset = dataset.batch(batch_size=self._data_config.batch_size)
169
+
170
+ dataset = dataset.map(
171
+ self._parse_table,
172
+ num_parallel_calls=self._data_config.num_parallel_calls)
173
+
174
+ # preprocess is necessary to transform data
175
+ # so that they could be feed into FeatureColumns
176
+ dataset = dataset.map(
177
+ map_func=self._preprocess,
178
+ num_parallel_calls=self._data_config.num_parallel_calls)
179
+
180
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
181
+
182
+ if mode != tf.estimator.ModeKeys.PREDICT:
183
+ dataset = dataset.map(lambda x:
184
+ (self._get_features(x), self._get_labels(x)))
185
+ else:
186
+ dataset = dataset.map(lambda x: (self._get_features(x)))
187
+ return dataset
@@ -0,0 +1,104 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import json
4
+ import logging
5
+
6
+ import tensorflow as tf
7
+
8
+ from easy_rec.python.input.odps_rtp_input import OdpsRTPInput
9
+
10
+ if tf.__version__.startswith('1.'):
11
+ from tensorflow.python.platform import gfile
12
+ else:
13
+ import tensorflow.io.gfile as gfile
14
+ try:
15
+ import pai
16
+ import rtp_fg
17
+ except Exception:
18
+ pai = None
19
+ rtp_fg = None
20
+
21
+
22
+ class OdpsRTPInputV2(OdpsRTPInput):
23
+ """RTPInput for parsing rtp fg new input format on odps.
24
+
25
+ Our new format(csv in table) of rtp output:
26
+ label0, item_id, ..., user_id, features
27
+ Where features is in default RTP-tensorflow format.
28
+ The features column and labels are specified by data_config.selected_cols,
29
+ columns are selected by names in the table
30
+ such as: clk,features, the last selected column is features, the first
31
+ selected columns are labels
32
+ """
33
+
34
+ def __init__(self,
35
+ data_config,
36
+ feature_config,
37
+ input_path,
38
+ task_index=0,
39
+ task_num=1,
40
+ check_mode=False,
41
+ fg_json_path=None,
42
+ pipeline_config=None):
43
+ super(OdpsRTPInputV2,
44
+ self).__init__(data_config, feature_config, input_path, task_index,
45
+ task_num, check_mode, pipeline_config)
46
+ if fg_json_path.startswith('!'):
47
+ fg_json_path = fg_json_path[1:]
48
+ self._fg_config_path = fg_json_path
49
+ logging.info('fg config path: {}'.format(self._fg_config_path))
50
+ if self._fg_config_path is None:
51
+ raise ValueError('fg_json_path is not set')
52
+ with gfile.GFile(self._fg_config_path, 'r') as f:
53
+ self._fg_config = json.load(f)
54
+
55
+ def _parse_table(self, *fields):
56
+ self.check_rtp()
57
+
58
+ fields = list(fields)
59
+ labels = fields[:-1]
60
+
61
+ # assume that the last field is the generated feature column
62
+ features = rtp_fg.parse_genreated_fg(self._fg_config, fields[-1])
63
+
64
+ field_keys = [x for x in self._input_fields if x not in self._label_fields]
65
+ for feature_key in features:
66
+ if feature_key not in field_keys or feature_key not in self._effective_fields:
67
+ del features[feature_key]
68
+ inputs = {x: features[x] for x in features.keys()}
69
+
70
+ for x in range(len(self._label_fields)):
71
+ inputs[self._label_fields[x]] = labels[x]
72
+ return inputs
73
+
74
+ def create_placeholders(self, *args, **kwargs):
75
+ """Create serving placeholders with rtp_fg."""
76
+ self.check_rtp()
77
+ self._mode = tf.estimator.ModeKeys.PREDICT
78
+ inputs_placeholder = tf.placeholder(tf.string, [None], name='features')
79
+ print('[OdpsRTPInputV2] building placeholders.')
80
+ print('[OdpsRTPInputV2] fg_config: {}'.format(self._fg_config))
81
+ features = rtp_fg.parse_genreated_fg(self._fg_config, inputs_placeholder)
82
+ print('[OdpsRTPInputV2] built features: {}'.format(features.keys()))
83
+ features = self._preprocess(features)
84
+ print('[OdpsRTPInputV2] processed features: {}'.format(features.keys()))
85
+ return {'features': inputs_placeholder}, features['feature']
86
+
87
+ def create_multi_placeholders(self, *args, **kwargs):
88
+ """Create serving multi-placeholders with rtp_fg."""
89
+ raise NotImplementedError(
90
+ 'create_multi_placeholders is not supported for OdpsRTPInputV2')
91
+
92
+ def check_rtp(self):
93
+ if rtp_fg is None:
94
+ raise NotImplementedError(
95
+ 'OdpsRTPInputV2 cannot run without rtp_fg, which is not installed')
96
+
97
+ def _pre_build(self, mode, params):
98
+ try:
99
+ # Prevent TF from replacing the shape tensor to a constant tensor. This will
100
+ # cause the batch size being fixed. And RTP will be not able to recognize
101
+ # the input shape.
102
+ tf.get_default_graph().set_shape_optimize(False)
103
+ except AttributeError as e:
104
+ logging.warning('failed to disable shape optimization:', e)