easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,147 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from __future__ import absolute_import
4
+ from __future__ import division
5
+ from __future__ import print_function
6
+
7
+ import logging
8
+ import os
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import tensorflow as tf
13
+ from tensorflow.python.platform import gfile
14
+
15
+ from easy_rec.python.inference.predictor import Predictor
16
+ from easy_rec.python.input.parquet_input import ParquetInput
17
+ from easy_rec.python.protos.dataset_pb2 import DatasetConfig
18
+ from easy_rec.python.utils import config_util
19
+ from easy_rec.python.utils import input_utils
20
+
21
+ try:
22
+ from tensorflow.python.framework.load_library import load_op_library
23
+ import easy_rec
24
+ load_embed_lib_path = os.path.join(easy_rec.ops_dir, 'libload_embed.so')
25
+ load_embed_lib = load_op_library(load_embed_lib_path)
26
+ except Exception as ex:
27
+ logging.warning('load libload_embed.so failed: %s' % str(ex))
28
+
29
+
30
+ class ParquetPredictor(Predictor):
31
+
32
+ def __init__(self,
33
+ model_path,
34
+ data_config,
35
+ ds_vector_recall=False,
36
+ fg_json_path=None,
37
+ profiling_file=None,
38
+ selected_cols=None,
39
+ output_sep=chr(1),
40
+ pipeline_config=None):
41
+ super(ParquetPredictor, self).__init__(model_path, profiling_file,
42
+ fg_json_path)
43
+ self._output_sep = output_sep
44
+ self._ds_vector_recall = ds_vector_recall
45
+ input_type = DatasetConfig.InputType.Name(data_config.input_type).lower()
46
+ self.pipeline_config = pipeline_config
47
+
48
+ if 'rtp' in input_type:
49
+ self._is_rtp = True
50
+ self._input_sep = data_config.rtp_separator
51
+ else:
52
+ self._is_rtp = False
53
+ self._input_sep = data_config.separator
54
+
55
+ if selected_cols and not ds_vector_recall:
56
+ self._selected_cols = [int(x) for x in selected_cols.split(',')]
57
+ elif ds_vector_recall:
58
+ self._selected_cols = selected_cols.split(',')
59
+ else:
60
+ self._selected_cols = None
61
+
62
+ def _parse_line(self, line):
63
+ out_dict = {}
64
+ for key in line['feature']:
65
+ out_dict[key] = line['feature'][key]
66
+ if 'reserve' in line:
67
+ out_dict['reserve'] = line['reserve']
68
+ # for key in line['reserve']:
69
+ # out_dict[key] = line['reserve'][key]
70
+ return out_dict
71
+
72
+ def _get_reserved_cols(self, reserved_cols):
73
+ # already parsed in _get_dataset
74
+ return self._reserved_cols
75
+
76
+ def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
77
+ slice_id):
78
+ feature_configs = config_util.get_compatible_feature_configs(
79
+ self.pipeline_config)
80
+
81
+ kwargs = {}
82
+ if self._reserved_args is not None and len(self._reserved_args) > 0:
83
+ if self._reserved_args == 'ALL_COLUMNS':
84
+ parquet_file = gfile.Glob(input_path.split(',')[0])[0]
85
+ # gfile not supported, read_parquet requires random access
86
+ all_data = pd.read_parquet(parquet_file)
87
+ all_cols = list(all_data.columns)
88
+ kwargs['reserve_fields'] = all_cols
89
+ self._all_fields = all_cols
90
+ self._reserved_cols = all_cols
91
+ kwargs['reserve_types'] = input_utils.get_tf_type_from_parquet_file(
92
+ all_cols, parquet_file)
93
+ else:
94
+ self._reserved_cols = [
95
+ x.strip() for x in self._reserved_args.split(',') if x.strip() != ''
96
+ ]
97
+ kwargs['reserve_fields'] = self._reserved_cols
98
+ parquet_file = gfile.Glob(input_path.split(',')[0])[0]
99
+ kwargs['reserve_types'] = input_utils.get_tf_type_from_parquet_file(
100
+ self._reserved_cols, parquet_file)
101
+ logging.info('reserve_fields=%s reserve_types=%s' %
102
+ (','.join(self._reserved_cols), ','.join(
103
+ [str(x) for x in kwargs['reserve_types']])))
104
+ else:
105
+ self._reserved_cols = []
106
+ self.pipeline_config.data_config.batch_size = batch_size
107
+
108
+ kwargs['is_predictor'] = True
109
+ parquet_input = ParquetInput(
110
+ self.pipeline_config.data_config,
111
+ feature_configs,
112
+ input_path,
113
+ task_index=slice_id,
114
+ task_num=slice_num,
115
+ pipeline_config=self.pipeline_config,
116
+ **kwargs)
117
+ return parquet_input._build(tf.estimator.ModeKeys.PREDICT, {})
118
+
119
+ def _get_writer(self, output_path, slice_id):
120
+ if not gfile.Exists(output_path):
121
+ gfile.MakeDirs(output_path)
122
+ res_path = os.path.join(output_path, 'part-%d.csv' % slice_id)
123
+ table_writer = gfile.GFile(res_path, 'w')
124
+ table_writer.write(
125
+ self._output_sep.join(self._output_cols + self._reserved_cols) + '\n')
126
+ return table_writer
127
+
128
+ def _write_lines(self, table_writer, outputs):
129
+ outputs = '\n'.join(
130
+ [self._output_sep.join([str(i) for i in output]) for output in outputs])
131
+ table_writer.write(outputs + '\n')
132
+
133
+ def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
134
+ reserve_vals = []
135
+ for x in outputs:
136
+ tmp_val = outputs[x]
137
+ reserve_vals.append(tmp_val)
138
+ for k in reserved_cols:
139
+ tmp_val = all_vals['reserve'][k]
140
+ if tmp_val.dtype == np.object:
141
+ tmp_val = [x.decode('utf-8') for x in tmp_val]
142
+ reserve_vals.append(tmp_val)
143
+ return reserve_vals
144
+
145
+ @property
146
+ def out_of_range_exception(self):
147
+ return (tf.errors.OutOfRangeError)
@@ -0,0 +1,147 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from __future__ import absolute_import
4
+ from __future__ import division
5
+ from __future__ import print_function
6
+
7
+ import logging
8
+ import os
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import tensorflow as tf
13
+ from tensorflow.python.platform import gfile
14
+
15
+ from easy_rec.python.inference.predictor import Predictor
16
+ from easy_rec.python.input.parquet_input_v2 import ParquetInputV2
17
+ from easy_rec.python.protos.dataset_pb2 import DatasetConfig
18
+ from easy_rec.python.utils import config_util
19
+ from easy_rec.python.utils import input_utils
20
+
21
+ try:
22
+ from tensorflow.python.framework.load_library import load_op_library
23
+ import easy_rec
24
+ load_embed_lib_path = os.path.join(easy_rec.ops_dir, 'libload_embed.so')
25
+ load_embed_lib = load_op_library(load_embed_lib_path)
26
+ except Exception as ex:
27
+ logging.warning('load libload_embed.so failed: %s' % str(ex))
28
+
29
+
30
+ class ParquetPredictorV2(Predictor):
31
+
32
+ def __init__(self,
33
+ model_path,
34
+ data_config,
35
+ ds_vector_recall=False,
36
+ fg_json_path=None,
37
+ profiling_file=None,
38
+ selected_cols=None,
39
+ output_sep=chr(1),
40
+ pipeline_config=None):
41
+ super(ParquetPredictorV2, self).__init__(model_path, profiling_file,
42
+ fg_json_path)
43
+ self._output_sep = output_sep
44
+ self._ds_vector_recall = ds_vector_recall
45
+ input_type = DatasetConfig.InputType.Name(data_config.input_type).lower()
46
+ self.pipeline_config = pipeline_config
47
+
48
+ if 'rtp' in input_type:
49
+ self._is_rtp = True
50
+ self._input_sep = data_config.rtp_separator
51
+ else:
52
+ self._is_rtp = False
53
+ self._input_sep = data_config.separator
54
+
55
+ if selected_cols and not ds_vector_recall:
56
+ self._selected_cols = [int(x) for x in selected_cols.split(',')]
57
+ elif ds_vector_recall:
58
+ self._selected_cols = selected_cols.split(',')
59
+ else:
60
+ self._selected_cols = None
61
+
62
+ def _parse_line(self, line):
63
+ out_dict = {}
64
+ for key in line['feature']:
65
+ out_dict[key] = line['feature'][key]
66
+ if 'reserve' in line:
67
+ out_dict['reserve'] = line['reserve']
68
+ # for key in line['reserve']:
69
+ # out_dict[key] = line['reserve'][key]
70
+ return out_dict
71
+
72
+ def _get_reserved_cols(self, reserved_cols):
73
+ # already parsed in _get_dataset
74
+ return self._reserved_cols
75
+
76
+ def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
77
+ slice_id):
78
+ feature_configs = config_util.get_compatible_feature_configs(
79
+ self.pipeline_config)
80
+
81
+ kwargs = {}
82
+ if self._reserved_args is not None and len(self._reserved_args) > 0:
83
+ if self._reserved_args == 'ALL_COLUMNS':
84
+ parquet_file = gfile.Glob(input_path.split(',')[0])[0]
85
+ # gfile not supported, read_parquet requires random access
86
+ all_data = pd.read_parquet(parquet_file)
87
+ all_cols = list(all_data.columns)
88
+ kwargs['reserve_fields'] = all_cols
89
+ self._all_fields = all_cols
90
+ self._reserved_cols = all_cols
91
+ kwargs['reserve_types'] = input_utils.get_tf_type_from_parquet_file(
92
+ all_cols, parquet_file)
93
+ else:
94
+ self._reserved_cols = [
95
+ x.strip() for x in self._reserved_args.split(',') if x.strip() != ''
96
+ ]
97
+ kwargs['reserve_fields'] = self._reserved_cols
98
+ parquet_file = gfile.Glob(input_path.split(',')[0])[0]
99
+ kwargs['reserve_types'] = input_utils.get_tf_type_from_parquet_file(
100
+ self._reserved_cols, parquet_file)
101
+ logging.info('reserve_fields=%s reserve_types=%s' %
102
+ (','.join(self._reserved_cols), ','.join(
103
+ [str(x) for x in kwargs['reserve_types']])))
104
+ else:
105
+ self._reserved_cols = []
106
+ self.pipeline_config.data_config.batch_size = batch_size
107
+
108
+ kwargs['is_predictor'] = True
109
+ parquet_input = ParquetInputV2(
110
+ self.pipeline_config.data_config,
111
+ feature_configs,
112
+ input_path,
113
+ task_index=slice_id,
114
+ task_num=slice_num,
115
+ pipeline_config=self.pipeline_config,
116
+ **kwargs)
117
+ return parquet_input._build(tf.estimator.ModeKeys.PREDICT, {})
118
+
119
+ def _get_writer(self, output_path, slice_id):
120
+ if not gfile.Exists(output_path):
121
+ gfile.MakeDirs(output_path)
122
+ res_path = os.path.join(output_path, 'part-%d.csv' % slice_id)
123
+ table_writer = gfile.GFile(res_path, 'w')
124
+ table_writer.write(
125
+ self._output_sep.join(self._output_cols + self._reserved_cols) + '\n')
126
+ return table_writer
127
+
128
+ def _write_lines(self, table_writer, outputs):
129
+ outputs = '\n'.join(
130
+ [self._output_sep.join([str(i) for i in output]) for output in outputs])
131
+ table_writer.write(outputs + '\n')
132
+
133
+ def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
134
+ reserve_vals = []
135
+ for x in outputs:
136
+ tmp_val = outputs[x]
137
+ reserve_vals.append(tmp_val)
138
+ for k in reserved_cols:
139
+ tmp_val = all_vals['reserve'][k]
140
+ if tmp_val.dtype == np.object:
141
+ tmp_val = [x.decode('utf-8') for x in tmp_val]
142
+ reserve_vals.append(tmp_val)
143
+ return reserve_vals
144
+
145
+ @property
146
+ def out_of_range_exception(self):
147
+ return (tf.errors.OutOfRangeError)