easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,118 @@
1
+ from collections import deque
2
+
3
+
4
+ def _process_multi_expr(expr):
5
+ expr = expr.strip()
6
+ size = len(expr)
7
+ idx = 0
8
+ two_expr = ['>=', '<=', '==']
9
+ expr_list = []
10
+ while (idx < size):
11
+ if idx + 2 <= size and expr[idx:idx + 2] in two_expr:
12
+ expr_list.append(expr[idx:idx + 2])
13
+ idx += 2
14
+ else:
15
+ expr_list.append(expr[idx])
16
+ idx += 1
17
+ return expr_list
18
+
19
+
20
+ def _process_enum(enum, input_names, prefix=''):
21
+ enum = enum.strip()
22
+ if enum in input_names:
23
+ enum = "parsed_dict['%s']" % (prefix + enum)
24
+ return enum
25
+
26
+
27
+ def _get_expression_list(expression, input_names, prefix=''):
28
+ ops = [
29
+ '+', '-', '*', '/', '(', ')', '>', '>=', '<', '<=', '==', '=', '&', '|'
30
+ ]
31
+ expression_list = []
32
+ eunm = ''
33
+ pre_expr = ''
34
+
35
+ for i in expression:
36
+ if i in ops:
37
+ if eunm:
38
+ expression_list.append(_process_enum(eunm, input_names, prefix=prefix))
39
+ eunm = ''
40
+ pre_expr += i
41
+ else:
42
+ eunm += i
43
+ if pre_expr:
44
+ expression_list.extend(_process_multi_expr(pre_expr))
45
+ pre_expr = ''
46
+ if eunm:
47
+ expression_list.append(_process_enum(eunm, input_names, prefix=prefix))
48
+ if pre_expr:
49
+ expression_list.extend(_process_multi_expr(pre_expr))
50
+
51
+ final_expression_list = ['']
52
+ ops = ['(', ')', '>=', '<=', '==', '>', '<', '&', '|']
53
+ for expr in expression_list:
54
+ if expr in ops:
55
+ final_expression_list.append(expr)
56
+ elif final_expression_list[-1] not in ops:
57
+ final_expression_list[-1] += expr
58
+ else:
59
+ final_expression_list.append(expr)
60
+ final_expression_list = [expr for expr in final_expression_list if expr]
61
+ return final_expression_list
62
+
63
+
64
+ def _solve(enum, sign, stack):
65
+ if len(stack) == 0 or enum == '' or sign == '':
66
+ return enum
67
+ op1 = stack.pop()
68
+ op2 = enum
69
+ if sign == '>':
70
+ result = 'tf.greater(%s, %s)' % (op1, op2)
71
+ elif sign == '>=':
72
+ result = 'tf.greater_equal(%s, %s)' % (op1, op2)
73
+ elif sign == '<':
74
+ result = 'tf.less(%s, %s)' % (op1, op2)
75
+ elif sign == '<=':
76
+ result = 'tf.less_equal(%s, %s)' % (op1, op2)
77
+ elif sign == '==':
78
+ result = 'tf.equal(%s, %s)' % (op1, op2)
79
+ elif sign == '&':
80
+ result = '%s & %s' % (op1, op2)
81
+ elif sign == '|':
82
+ result = '%s | %s' % (op1, op2)
83
+ else:
84
+ assert False
85
+ return result
86
+
87
+
88
+ def _expression_eval(expr_list):
89
+ ops = ['>', '>=', '<', '<=', '==', '&', '|', '(', ')']
90
+ stack = deque()
91
+ sign = ''
92
+ operand = ''
93
+ for c in expr_list:
94
+ if c == ' ':
95
+ continue
96
+ elif c not in ops:
97
+ operand = c
98
+ elif c == '(':
99
+ stack.append(sign)
100
+ sign = ''
101
+ else:
102
+ result = _solve(operand, sign, stack)
103
+ operand = ''
104
+ if c == ')':
105
+ sign = stack.pop()
106
+ operand = _solve(result, sign, stack)
107
+ sign = ''
108
+ else:
109
+ sign = c
110
+ stack.append(result)
111
+ expr_str = _solve(operand, sign, stack)
112
+ return expr_str
113
+
114
+
115
+ def get_expression(expression, input_names, prefix=''):
116
+ expression_list = _get_expression_list(expression, input_names, prefix=prefix)
117
+ expression = _expression_eval(expression_list)
118
+ return expression
@@ -0,0 +1,53 @@
1
+ import json
2
+ import logging
3
+
4
+ import tensorflow as tf
5
+
6
+ from easy_rec.python.protos.dataset_pb2 import DatasetConfig
7
+ from easy_rec.python.protos.feature_config_pb2 import FeatureConfig
8
+ from easy_rec.python.utils.config_util import get_compatible_feature_configs
9
+
10
+ from easy_rec.python.utils.convert_rtp_fg import load_input_field_and_feature_config # NOQA
11
+
12
+ if tf.__version__ >= '2.0':
13
+ tf = tf.compat.v1
14
+
15
+
16
+ def load_fg_json_to_config(pipeline_config):
17
+ fg_json_path = pipeline_config.fg_json_path
18
+ if not fg_json_path:
19
+ return
20
+
21
+ if fg_json_path.startswith('!'):
22
+ # already loaded
23
+ return
24
+
25
+ label_fields = pipeline_config.data_config.label_fields
26
+ with tf.gfile.GFile(fg_json_path, 'r') as fin:
27
+ rtp_fg = json.load(fin)
28
+
29
+ fg_config = load_input_field_and_feature_config(
30
+ rtp_fg, label_fields=label_fields)
31
+
32
+ pipeline_config.data_config.ClearField('input_fields')
33
+ pipeline_config.ClearField('feature_configs')
34
+
35
+ # not clear features so that we could define extra features
36
+ # which is not defined in fg.json
37
+ # pipeline_config.feature_config.ClearField('features')
38
+
39
+ for input_config in fg_config.data_config.input_fields:
40
+ in_config = DatasetConfig.Field()
41
+ in_config.CopyFrom(input_config)
42
+ pipeline_config.data_config.input_fields.append(in_config)
43
+ fg_fea_config = get_compatible_feature_configs(fg_config)
44
+ for fc in fg_fea_config:
45
+ fea_config = FeatureConfig()
46
+ fea_config.CopyFrom(fc)
47
+ pipeline_config.feature_config.features.append(fea_config)
48
+ logging.info('data_config and feature_config has been replaced by fg_json.')
49
+
50
+ # signal that it is already loaded
51
+ pipeline_config.fg_json_path = '!' + pipeline_config.fg_json_path
52
+
53
+ return pipeline_config
@@ -0,0 +1,220 @@
1
+ import logging
2
+
3
+ import graphlearn as gl
4
+ import numpy as np
5
+ import tensorflow as tf
6
+
7
+ if tf.__version__ >= '2.0':
8
+ tf = tf.compat.v1
9
+
10
+
11
+ def load_graph(i_emb_table, emb_dim, knn_metric, timeout, knn_strict):
12
+ """Load embedding tables in GL.
13
+
14
+ that used to lookup embedding and do knn search.
15
+ """
16
+ gl.set_knn_metric(knn_metric)
17
+ gl.set_timeout(timeout)
18
+ option = gl.IndexOption()
19
+ option.name = 'knn'
20
+ if knn_strict:
21
+ # option.index_type = "flat"
22
+ option.index_type = 'ivfflat'
23
+ option.nlist = 5
24
+ option.nprobe = 5
25
+ else:
26
+ option.index_type = 'ivfflat'
27
+ option.nlist = 5
28
+ option.nprobe = 2
29
+ g = gl.Graph().node(
30
+ i_emb_table,
31
+ node_type='i',
32
+ decoder=gl.Decoder(attr_types=['float'] * emb_dim, attr_delimiter=','),
33
+ option=option)
34
+ return g
35
+
36
+
37
+ def batch_hitrate(src_ids,
38
+ recall_ids,
39
+ recall_distances,
40
+ gt_items,
41
+ num_interests,
42
+ mask=None):
43
+ """Compute hitrate of a batch of src ids.
44
+
45
+ Args:
46
+ src_ids: trigger id, a numpy array.
47
+ recall_ids: recalled ids by src_ids, a numpy array.
48
+ recall_distances: corresponding distances of recalled ids, a numpy array.
49
+ gt_items: batch of ground truth item ids list, a list of list.
50
+ num_interests: max number of interests.
51
+ mask: some models have different number of interests.
52
+
53
+ Returns:
54
+ hitrates: hitrate of src_ids, a list.
55
+ bad_cases: bad cases, a list of list.
56
+ bad_dsts: distances of bad cases, a list of list.
57
+ hits: total hit counts of a batch of src ids, a scalar.
58
+ gt_count: total ground truth items num of a batch of src ids, a scalar.
59
+ """
60
+ hitrates = []
61
+ bad_cases = []
62
+ bad_dists = []
63
+ hits = 0.0
64
+ gt_count = 0.0
65
+ for idx, src_id in enumerate(src_ids):
66
+ recall_id = recall_ids[idx]
67
+ recall_distance = recall_distances[idx]
68
+
69
+ bad_case = {}
70
+ gt_items_size = len(gt_items[idx])
71
+ hit_ids = []
72
+ if gt_items_size == 0: # just skip invalid record.
73
+ print('Id {:d} has no related items sequence, just skip.'.format(src_id))
74
+ continue
75
+ for interest_id in range(num_interests):
76
+ if not mask[idx, interest_id]:
77
+ continue
78
+ for k, id in enumerate(recall_id[interest_id]):
79
+ if id in gt_items[idx]:
80
+ if id not in hit_ids:
81
+ hit_ids.append(id)
82
+ else:
83
+ dis = recall_distance[interest_id][k]
84
+ if id not in bad_case:
85
+ bad_case[id] = dis
86
+ elif dis < bad_case[id]:
87
+ bad_case[id] = dis
88
+ hit_count = float(len(hit_ids))
89
+ hitrates.append(hit_count / gt_items_size)
90
+ hits += hit_count
91
+ gt_count += gt_items_size
92
+ bad_cases.append([x for x in bad_case])
93
+ bad_dists.append([bad_case[x] for x in bad_case])
94
+ return hitrates, bad_cases, bad_dists, hits, gt_count
95
+
96
+
97
+ def reduce_hitrate(cluster, hits, count, task_index):
98
+ """Reduce hitrate of all workers.
99
+
100
+ Args:
101
+ cluster: tf cluster.
102
+ hits: total_hits of each worker.
103
+ count: total count of ground truth items of each worker.
104
+ task_index: worker index.
105
+
106
+ Returns:
107
+ var_total_hitrate: variable of total hitrate.
108
+ var_worker_count: variable used to mark the number of worker that
109
+ have completed the calculation of hitrate.
110
+ """
111
+ with tf.device(
112
+ tf.train.replica_device_setter(
113
+ worker_device='/job:worker/task:%d' % task_index, cluster=cluster)):
114
+ with tf.variable_scope('hitrate_var', reuse=tf.AUTO_REUSE):
115
+ var_worker_count = tf.get_variable(
116
+ 'worker_count',
117
+ shape=(),
118
+ dtype=tf.int32,
119
+ initializer=tf.zeros_initializer())
120
+ var_hits = tf.get_variable(
121
+ 'hits',
122
+ shape=(),
123
+ dtype=tf.float32,
124
+ initializer=tf.zeros_initializer())
125
+ var_gt_count = tf.get_variable(
126
+ 'gt_count',
127
+ shape=(),
128
+ dtype=tf.float32,
129
+ initializer=tf.zeros_initializer())
130
+ var_total_hitrate = tf.get_variable(
131
+ 'total_hitate',
132
+ shape=(),
133
+ dtype=tf.float32,
134
+ initializer=tf.zeros_initializer())
135
+
136
+ var_hits = tf.assign_add(var_hits, hits, use_locking=True)
137
+ var_gt_count = tf.assign_add(var_gt_count, count, use_locking=True)
138
+ var_gt_count = tf.Print(
139
+ var_gt_count, [var_gt_count, var_hits],
140
+ message='var_gt_count/var_hits')
141
+ var_total_hitrate = tf.assign(
142
+ var_total_hitrate, var_hits / var_gt_count, use_locking=True)
143
+ with tf.control_dependencies([var_total_hitrate]):
144
+ var_worker_count = tf.assign_add(var_worker_count, 1, use_locking=True)
145
+ return var_total_hitrate, var_worker_count
146
+
147
+
148
+ def compute_hitrate_batch(g, gt_record, emb_dim, num_interests, top_k):
149
+ """Reduce hitrate of one batch.
150
+
151
+ Args:
152
+ g: a GL Graph instance.
153
+ gt_record: record list of groung truth.
154
+ emb_dim: embedding dim.
155
+ num_interests: max number of interests.
156
+ top_k: top_k hitrate.
157
+
158
+ Returns:
159
+ hits: total hit counts of a batch of src ids, a scalar.
160
+ gt_count: total ground truth items num of a batch of src ids, a scalar.
161
+ src_ids: src ids, a list.
162
+ recall_ids: recall ids, a list.
163
+ recall_distances: recall distances, a list.
164
+ hitrates: hitrate of a batch of src_ids, a list.
165
+ bad_cases: bad cases, a list of list.
166
+ bad_dsts: distances of bad cases, a list of list.
167
+ """
168
+
169
+ def _to_float_attrs(x):
170
+ # incase user embedding is not present
171
+ if x == '':
172
+ return np.zeros([emb_dim], dtype=np.float32)
173
+ embed = np.array(x.split(','), dtype=np.float32)
174
+ assert len(embed) == emb_dim, 'invalid embed len=%d, x=%s' % (len(embed), x)
175
+ return embed
176
+
177
+ def _to_multi_float_attrs(x, userid):
178
+ if x == '':
179
+ arr = [_to_float_attrs(x) for i in range(num_interests)]
180
+ else:
181
+ arr = [_to_float_attrs(sub_x) for sub_x in x.split('|')]
182
+ assert len(arr) == num_interests, 'invalid arr len=%d, x=%s, userid=%s' % (
183
+ len(arr), x, userid)
184
+ return arr
185
+
186
+ src_ids = np.array([src_items[0] for src_items in gt_record])
187
+ user_embedding = np.array([
188
+ _to_multi_float_attrs(src_items[2], src_items[0])
189
+ for src_items in gt_record
190
+ ])
191
+ user_emb_num = [src_items[3] for src_items in gt_record]
192
+
193
+ print('max(user_emb_num) = %d len(src_ids) = %d' %
194
+ (np.max(user_emb_num), len(src_ids)))
195
+
196
+ # a list of list.
197
+ gt_items = [
198
+ list(map(int, src_items[1].split(','))) for src_items in gt_record
199
+ ]
200
+
201
+ logging.info('src_nodes.float_attrs.shape=%s' % str(user_embedding.shape))
202
+ user_embedding = user_embedding.reshape([-1, user_embedding.shape[-1]])
203
+ # numpy array
204
+ recall_ids, recall_distances = g.search('i', user_embedding,
205
+ gl.KnnOption(k=top_k))
206
+ logging.info('recall_ids.shape=%s' % str(recall_ids.shape))
207
+
208
+ def _make_mask(lens):
209
+ mask = np.ones([len(lens), num_interests], dtype=np.float32)
210
+ for tmp_id, tmp_len in enumerate(lens):
211
+ mask[tmp_id, int(tmp_len):] = 0
212
+ return mask
213
+
214
+ mask = _make_mask(user_emb_num)
215
+ recall_ids = recall_ids.reshape([-1, num_interests, recall_ids.shape[-1]])
216
+ recall_distances = recall_distances.reshape(
217
+ [-1, num_interests, recall_distances.shape[-1]])
218
+ hitrates, bad_cases, bad_dists, hits, gt_count = batch_hitrate(
219
+ src_ids, recall_ids, recall_distances, gt_items, num_interests, mask)
220
+ return hits, gt_count, src_ids, recall_ids, recall_distances, hitrates, bad_cases, bad_dists
@@ -0,0 +1,183 @@
1
+ # -*- coding: utf-8 -*-
2
+ import logging
3
+
4
+ try:
5
+ from pyhive import hive
6
+ from pyhive.exc import ProgrammingError
7
+ except ImportError:
8
+ logging.warning('pyhive is not installed.')
9
+
10
+
11
+ class TableInfo(object):
12
+
13
+ def __init__(self, tablename, selected_cols, partition_kv, limit_num):
14
+ self.tablename = tablename
15
+ self.selected_cols = selected_cols
16
+ self.partition_kv = partition_kv
17
+ self.limit_num = limit_num
18
+
19
+ def gen_sql(self):
20
+ part = ''
21
+ if self.partition_kv and len(self.partition_kv) > 0:
22
+ res = []
23
+ for k, v in self.partition_kv.items():
24
+ res.append('{}={}'.format(k, v))
25
+ part = ' '.join(res)
26
+ sql = """select {}
27
+ from {}""".format(self.selected_cols, self.tablename)
28
+
29
+ if part:
30
+ sql += """
31
+ where {}
32
+ """.format(part)
33
+ if self.limit_num is not None and self.limit_num > 0:
34
+ sql += ' limit {}'.format(self.limit_num)
35
+ return sql
36
+
37
+
38
+ class HiveUtils(object):
39
+ """Common IO based interface, could run at local or on data science."""
40
+
41
+ def __init__(self,
42
+ data_config,
43
+ hive_config,
44
+ selected_cols='',
45
+ record_defaults=[],
46
+ task_index=0,
47
+ task_num=1):
48
+
49
+ self._data_config = data_config
50
+ self._hive_config = hive_config
51
+
52
+ self._num_epoch = data_config.num_epochs
53
+ self._num_epoch_record = 0
54
+ self._task_index = task_index
55
+ self._task_num = task_num
56
+ self._selected_cols = selected_cols
57
+ self._record_defaults = record_defaults
58
+
59
+ def _construct_table_info(self, table_name, limit_num):
60
+ # sample_table/dt=2014-11-23/name=a
61
+ segs = table_name.split('/')
62
+ table_name = segs[0].strip()
63
+ if len(segs) > 0:
64
+ partition_kv = {i.split('=')[0]: i.split('=')[1] for i in segs[1:]}
65
+ else:
66
+ partition_kv = None
67
+
68
+ table_info = TableInfo(table_name, self._selected_cols, partition_kv,
69
+ limit_num)
70
+ return table_info
71
+
72
+ def _construct_hive_connect(self):
73
+ conn = hive.Connection(
74
+ host=self._hive_config.host,
75
+ port=self._hive_config.port,
76
+ username=self._hive_config.username,
77
+ database=self._hive_config.database)
78
+ return conn
79
+
80
+ def hive_read_line(self, input_path, limit_num=None):
81
+ table_info = self._construct_table_info(input_path, limit_num)
82
+ conn = self._construct_hive_connect()
83
+ cursor = conn.cursor()
84
+ sql = table_info.gen_sql()
85
+ cursor.execute(sql)
86
+
87
+ while True:
88
+ data = cursor.fetchmany(size=1)
89
+ if len(data) == 0:
90
+ break
91
+ yield data
92
+
93
+ cursor.close()
94
+ conn.close()
95
+
96
+ def hive_read_lines(self, input_path, batch_size, limit_num=None):
97
+ table_info = self._construct_table_info(input_path, limit_num)
98
+ conn = self._construct_hive_connect()
99
+ cursor = conn.cursor()
100
+ sql = table_info.gen_sql()
101
+ cursor.execute(sql)
102
+
103
+ while True:
104
+ data = cursor.fetchmany(size=batch_size)
105
+ if len(data) == 0:
106
+ break
107
+ yield data
108
+
109
+ cursor.close()
110
+ conn.close()
111
+
112
+ def run_sql(self, sql):
113
+ conn = self._construct_hive_connect()
114
+ cursor = conn.cursor()
115
+ cursor.execute(sql)
116
+ try:
117
+ data = cursor.fetchall()
118
+ except ProgrammingError:
119
+ data = []
120
+ return data
121
+
122
+ def is_table_or_partition_exist(self,
123
+ table_name,
124
+ partition_name=None,
125
+ partition_val=None):
126
+ if partition_name and partition_val:
127
+ sql = 'show partitions %s partition(%s=%s)' % (table_name, partition_name,
128
+ partition_val)
129
+ try:
130
+ res = self.run_sql(sql)
131
+ if not res:
132
+ return False
133
+ else:
134
+ return True
135
+ except: # noqa: E722
136
+ return False
137
+
138
+ else:
139
+ sql = 'desc %s' % table_name
140
+ try:
141
+ self.run_sql(sql)
142
+ return True
143
+ except: # noqa: E722
144
+ return False
145
+
146
+ def get_table_location(self, input_path):
147
+ conn = self._construct_hive_connect()
148
+ cursor = conn.cursor()
149
+ partition = ''
150
+ if len(input_path.split('/')) == 2:
151
+ table_name, partition = input_path.split('/')
152
+ partition += '/'
153
+ else:
154
+ table_name = input_path
155
+ sql = 'desc formatted %s' % table_name
156
+ cursor.execute(sql)
157
+ data = cursor.fetchmany()
158
+ for line in data:
159
+ if line[0].startswith('Location'):
160
+ return line[1].strip() + '/' + partition
161
+ return None
162
+
163
+ def get_all_cols(self, input_path):
164
+ conn = self._construct_hive_connect()
165
+ cursor = conn.cursor()
166
+ sql = 'desc %s' % input_path.split('/')[0]
167
+ cursor.execute(sql)
168
+ data = cursor.fetchmany()
169
+ col_names = []
170
+ cols_types = []
171
+ pt_name = ''
172
+ if len(input_path.split('/')) == 2:
173
+ pt_name = input_path.split('/')[1].split('=')[0]
174
+
175
+ for col in data:
176
+ col_name = col[0].strip()
177
+ if col_name and (not col_name.startswith('#')) and (col_name
178
+ not in col_names):
179
+ if col_name != pt_name:
180
+ col_names.append(col_name)
181
+ cols_types.append(col[1].strip())
182
+
183
+ return col_names, cols_types