easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,844 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import json
7
+ import logging
8
+ import math
9
+ import os
10
+ import sys
11
+ import threading
12
+
13
+ import numpy as np
14
+ import six
15
+ import tensorflow as tf
16
+
17
+ from easy_rec.python.protos.dataset_pb2 import DatasetConfig
18
+ from easy_rec.python.utils import ds_util
19
+ from easy_rec.python.utils.config_util import process_multi_file_input_path
20
+ from easy_rec.python.utils.tf_utils import get_tf_type
21
+
22
+ if tf.__version__.startswith('1.'):
23
+ from tensorflow.python.platform import gfile
24
+ else:
25
+ import tensorflow.io.gfile as gfile
26
+
27
+
28
+ # patch graph-learn string_attrs for utf-8
29
+ @property
30
+ def string_attrs(self): # NOQA
31
+ self._init()
32
+ return self._string_attrs
33
+
34
+
35
+ # pyre-ignore [56]
36
+ @string_attrs.setter
37
+ # pyre-ignore [2, 3]
38
+ def string_attrs(self, string_attrs): # NOQA
39
+ self._string_attrs = self._reshape(string_attrs, expand_shape=True)
40
+ self._inited = True
41
+
42
+
43
+ try:
44
+ import graphlearn as gl
45
+ from graphlearn.python.data.values import Values
46
+ Values.string_attrs = string_attrs
47
+ except Exception:
48
+ logging.info(
49
+ 'GraphLearn is not installed. You can install it by "pip install https://easyrec.oss-cn-beijing.aliyuncs.com/3rdparty/graphlearn-0.7-cp27-cp27mu-linux_x86_64.whl"' # noqa: E501
50
+ )
51
+
52
+ if tf.__version__ >= '2.0':
53
+ tf = tf.compat.v1
54
+
55
+
56
+ def _get_gl_type(field_type):
57
+ type_map = {
58
+ DatasetConfig.INT32: 'int',
59
+ DatasetConfig.INT64: 'int',
60
+ DatasetConfig.STRING: 'string',
61
+ DatasetConfig.BOOL: 'int',
62
+ DatasetConfig.FLOAT: 'float',
63
+ DatasetConfig.DOUBLE: 'float'
64
+ }
65
+ assert field_type in type_map, 'invalid type: %s' % field_type
66
+ return type_map[field_type]
67
+
68
+
69
+ def _get_np_type(field_type):
70
+ type_map = {
71
+ DatasetConfig.INT32: np.int32,
72
+ DatasetConfig.INT64: np.int64,
73
+ DatasetConfig.STRING: str,
74
+ DatasetConfig.BOOL: bool,
75
+ DatasetConfig.FLOAT: np.float32,
76
+ DatasetConfig.DOUBLE: np.double
77
+ }
78
+ assert field_type in type_map, 'invalid type: %s' % field_type
79
+ return type_map[field_type]
80
+
81
+
82
+ class BaseSampler(object):
83
+ _instance_lock = threading.Lock()
84
+
85
+ def __init__(self, fields, num_sample, num_eval_sample=None):
86
+ self._g = None
87
+ self._sampler = None
88
+ self._num_sample = num_sample
89
+ self._num_eval_sample = num_eval_sample if num_eval_sample is not None else num_sample
90
+ self._build_field_types(fields)
91
+ self._log_first_n = 5
92
+ self._is_on_ds = ds_util.is_on_ds()
93
+
94
+ def set_eval_num_sample(self):
95
+ print('set_eval_num_sample: %d %d' %
96
+ (self._num_sample, self._num_eval_sample))
97
+ self._num_sample = self._num_eval_sample
98
+
99
+ def _init_graph(self):
100
+ if 'TF_CONFIG' in os.environ:
101
+ tf_config = json.loads(os.environ['TF_CONFIG'])
102
+ if 'ps' in tf_config['cluster']:
103
+ # ps mode
104
+ tf_config = json.loads(os.environ['TF_CONFIG'])
105
+ if 'worker' in tf_config['cluster']:
106
+ task_count = len(tf_config['cluster']['worker']) + 2
107
+ else:
108
+ task_count = 2
109
+ if self._is_on_ds:
110
+ gl.set_tracker_mode(0)
111
+ server_hosts = [
112
+ host.split(':')[0] + ':888' + str(i)
113
+ for i, host in enumerate(tf_config['cluster']['ps'])
114
+ ]
115
+ cluster = {
116
+ 'server': ','.join(server_hosts),
117
+ 'client_count': task_count
118
+ }
119
+ else:
120
+ ps_count = len(tf_config['cluster']['ps'])
121
+ cluster = {'server_count': ps_count, 'client_count': task_count}
122
+ if tf_config['task']['type'] in ['chief', 'master']:
123
+ self._g.init(cluster=cluster, job_name='client', task_index=0)
124
+ elif tf_config['task']['type'] == 'worker':
125
+ self._g.init(
126
+ cluster=cluster,
127
+ job_name='client',
128
+ task_index=tf_config['task']['index'] + 2)
129
+ # TODO(hongsheng.jhs): check cluster has evaluator or not?
130
+ elif tf_config['task']['type'] == 'evaluator':
131
+ self._g.init(
132
+ cluster=cluster,
133
+ job_name='client',
134
+ task_index=tf_config['task']['index'] + 1)
135
+ if self._num_eval_sample is not None and self._num_eval_sample > 0:
136
+ self._num_sample = self._num_eval_sample
137
+ elif tf_config['task']['type'] == 'ps':
138
+ self._g.init(
139
+ cluster=cluster,
140
+ job_name='server',
141
+ task_index=tf_config['task']['index'])
142
+ else:
143
+ # worker mode
144
+ task_count = len(tf_config['cluster']['worker']) + 1
145
+ if not self._is_on_ds:
146
+ if tf_config['task']['type'] in ['chief', 'master']:
147
+ self._g.init(task_index=0, task_count=task_count)
148
+ elif tf_config['task']['type'] == 'worker':
149
+ self._g.init(
150
+ task_index=tf_config['task']['index'] + 1,
151
+ task_count=task_count)
152
+ else:
153
+ gl.set_tracker_mode(0)
154
+ if tf_config['cluster'].get('chief', ''):
155
+ chief_host = tf_config['cluster']['chief'][0].split(
156
+ ':')[0] + ':8880'
157
+ else:
158
+ chief_host = tf_config['cluster']['master'][0].split(
159
+ ':')[0] + ':8880'
160
+ worker_hosts = chief_host + [
161
+ host.split(':')[0] + ':888' + str(i)
162
+ for i, host in enumerate(tf_config['cluster']['worker'])
163
+ ]
164
+
165
+ if tf_config['task']['type'] in ['chief', 'master']:
166
+ self._g.init(
167
+ task_index=0,
168
+ task_count=task_count,
169
+ hosts=','.join(worker_hosts))
170
+ elif tf_config['task']['type'] == 'worker':
171
+ self._g.init(
172
+ task_index=tf_config['task']['index'] + 1,
173
+ task_count=task_count,
174
+ hosts=worker_hosts)
175
+
176
+ # TODO(hongsheng.jhs): check cluster has evaluator or not?
177
+ else:
178
+ # local mode
179
+ self._g.init()
180
+
181
+ def _build_field_types(self, fields):
182
+ self._attr_names = []
183
+ self._attr_types = []
184
+ self._attr_gl_types = []
185
+ self._attr_np_types = []
186
+ self._attr_tf_types = []
187
+ for i, field in enumerate(fields):
188
+ self._attr_names.append(field.input_name)
189
+ self._attr_types.append(field.input_type)
190
+ self._attr_gl_types.append(_get_gl_type(field.input_type))
191
+ self._attr_np_types.append(_get_np_type(field.input_type))
192
+ self._attr_tf_types.append(get_tf_type(field.input_type))
193
+
194
+ @classmethod
195
+ def instance(cls, *args, **kwargs):
196
+ with cls._instance_lock:
197
+ if not hasattr(cls, '_instance'):
198
+ cls._instance = cls(*args, **kwargs)
199
+ return cls._instance
200
+
201
+ def __del__(self):
202
+ if self._g is not None:
203
+ self._g.close()
204
+
205
+ def _parse_nodes(self, nodes):
206
+ if self._log_first_n > 0:
207
+ logging.info('num_example=%d num_eval_example=%d node_num=%d' %
208
+ (self._num_sample, self._num_eval_sample, len(nodes.ids)))
209
+ self._log_first_n -= 1
210
+ features = []
211
+ int_idx = 0
212
+ float_idx = 0
213
+ string_idx = 0
214
+ for attr_gl_type, attr_np_type in zip(self._attr_gl_types,
215
+ self._attr_np_types):
216
+ if attr_gl_type == 'int':
217
+ feature = nodes.int_attrs[:, :, int_idx]
218
+ int_idx += 1
219
+ elif attr_gl_type == 'float':
220
+ feature = nodes.float_attrs[:, :, float_idx]
221
+ float_idx += 1
222
+ elif attr_gl_type == 'string':
223
+ feature = nodes.string_attrs[:, :, string_idx]
224
+ if int(sys.version_info[0]) == 3:
225
+ feature = np.char.decode(feature.astype(np.string_), 'utf-8')
226
+ string_idx += 1
227
+ else:
228
+ raise ValueError('Unknown attr type %s' % attr_gl_type)
229
+ feature = np.reshape(feature,
230
+ [-1])[:self._num_sample].astype(attr_np_type)
231
+ if attr_gl_type == 'string':
232
+ feature = feature.tolist()
233
+ features.append(feature)
234
+ return features
235
+
236
+ def _parse_sparse_nodes(self, nodes):
237
+ features = []
238
+ int_idx = 0
239
+ float_idx = 0
240
+ string_idx = 0
241
+ for attr_gl_type, attr_np_type in zip(self._attr_gl_types,
242
+ self._attr_np_types):
243
+ if attr_gl_type == 'int':
244
+ feature = nodes.int_attrs[:, int_idx]
245
+ int_idx += 1
246
+ elif attr_gl_type == 'float':
247
+ feature = nodes.float_attrs[:, float_idx]
248
+ float_idx += 1
249
+ elif attr_gl_type == 'string':
250
+ feature = nodes.string_attrs[:, string_idx]
251
+ string_idx += 1
252
+ else:
253
+ raise ValueError('Unknown attr type %s' % attr_gl_type)
254
+ feature = feature.astype(attr_np_type)
255
+ if attr_gl_type == 'string':
256
+ feature = feature.tolist()
257
+ features.append(feature)
258
+ return features, nodes.indices
259
+
260
+
261
+ class NegativeSampler(BaseSampler):
262
+ """Negative Sampler.
263
+
264
+ Weighted random sampling items not in batch.
265
+
266
+ Args:
267
+ data_path: item feature data path. id:int64 | weight:float | attrs:string.
268
+ fields: item input fields.
269
+ num_sample: number of negative samples.
270
+ batch_size: mini-batch size.
271
+ attr_delimiter: delimiter of feature string.
272
+ num_eval_sample: number of negative samples for evaluator.
273
+ """
274
+
275
+ def __init__(self,
276
+ data_path,
277
+ fields,
278
+ num_sample,
279
+ batch_size,
280
+ attr_delimiter=':',
281
+ num_eval_sample=None):
282
+ super(NegativeSampler, self).__init__(fields, num_sample, num_eval_sample)
283
+ self._batch_size = batch_size
284
+ self._g = gl.Graph().node(
285
+ tf.compat.as_str(data_path),
286
+ node_type='item',
287
+ decoder=gl.Decoder(
288
+ attr_types=self._attr_gl_types,
289
+ weighted=True,
290
+ attr_delimiter=attr_delimiter))
291
+ self._init_graph()
292
+
293
+ expand_factor = int(math.ceil(self._num_sample / batch_size))
294
+ self._sampler = self._g.negative_sampler(
295
+ 'item', expand_factor, strategy='node_weight')
296
+
297
+ def _get_impl(self, ids):
298
+ ids = np.array(ids, dtype=np.int64)
299
+ ids = np.pad(ids, (0, self._batch_size - len(ids)), 'edge')
300
+ nodes = self._sampler.get(ids)
301
+ features = self._parse_nodes(nodes)
302
+ return features
303
+
304
+ def get(self, ids):
305
+ """Sampling method.
306
+
307
+ Args:
308
+ ids: item id tensor.
309
+
310
+ Returns:
311
+ Negative sampled feature dict.
312
+ """
313
+ sampled_values = tf.py_func(self._get_impl, [ids], self._attr_tf_types)
314
+ result_dict = {}
315
+ for k, t, v in zip(self._attr_names, self._attr_tf_types, sampled_values):
316
+ v.set_shape([self._num_sample])
317
+ result_dict[k] = v
318
+ return result_dict
319
+
320
+
321
+ class NegativeSamplerInMemory(BaseSampler):
322
+ """Negative Sampler.
323
+
324
+ Weighted random sampling items not in batch.
325
+
326
+ Args:
327
+ data_path: item feature data path. id:int64 | weight:float | attrs:string.
328
+ fields: item input fields.
329
+ num_sample: number of negative samples.
330
+ batch_size: mini-batch size.
331
+ attr_delimiter: delimiter of feature string.
332
+ num_eval_sample: number of negative samples for evaluator.
333
+ """
334
+
335
+ def __init__(self,
336
+ data_path,
337
+ fields,
338
+ num_sample,
339
+ batch_size,
340
+ attr_delimiter=':',
341
+ num_eval_sample=None):
342
+ super(NegativeSamplerInMemory, self).__init__(fields, num_sample,
343
+ num_eval_sample)
344
+ self._batch_size = batch_size
345
+
346
+ self._item_ids = []
347
+ self._cols = [[] for x in fields]
348
+
349
+ if six.PY2 and isinstance(attr_delimiter, type(u'')):
350
+ attr_delimiter = attr_delimiter.encode('utf-8')
351
+ if data_path.startswith('odps://'):
352
+ self._load_table(data_path, attr_delimiter)
353
+ else:
354
+ self._load_data(data_path, attr_delimiter)
355
+
356
+ print('NegativeSamplerInMemory: total_row_num = %d' % len(self._cols[0]))
357
+ for col_id in range(len(self._attr_np_types)):
358
+ np_type = self._attr_np_types[col_id]
359
+ print('\tcol_id[%d], dtype=%s' % (col_id, self._attr_gl_types[col_id]))
360
+ if np_type != str:
361
+ self._cols[col_id] = np.array(self._cols[col_id], dtype=np_type)
362
+ else:
363
+ self._cols[col_id] = np.asarray(
364
+ self._cols[col_id], order='C', dtype=object)
365
+
366
+ def _load_table(self, data_path, attr_delimiter):
367
+ import common_io
368
+ reader = common_io.table.TableReader(data_path)
369
+ schema = reader.get_schema()
370
+ item_id_col = 0
371
+ fea_id_col = 2
372
+ for tid in range(len(schema)):
373
+ if schema[tid][0].startswith('feature'):
374
+ fea_id_col = tid
375
+ break
376
+ for tid in range(len(schema)):
377
+ if schema[tid][0].startswith('id'):
378
+ item_id_col = tid
379
+ break
380
+ print('NegativeSamplerInMemory: feature_id_col = %d, item_id_col = %d' %
381
+ (fea_id_col, item_id_col))
382
+ while True:
383
+ try:
384
+ row_arr = reader.read(num_records=1024, allow_smaller_final_batch=True)
385
+ for row in row_arr:
386
+ # item_id, weight, feature
387
+ self._item_ids.append(int(row[item_id_col]))
388
+ col_vals = row[fea_id_col].split(attr_delimiter)
389
+ assert len(col_vals) == len(
390
+ self._cols), 'invalid row[%d %d]: %s %s' % (len(
391
+ col_vals), len(self._cols), row[item_id_col], row[fea_id_col])
392
+ for col_id in range(len(col_vals)):
393
+ self._cols[col_id].append(col_vals[col_id])
394
+ except common_io.exception.OutOfRangeException:
395
+ reader.close()
396
+ break
397
+
398
+ def _load_data(self, data_path, attr_delimiter):
399
+ item_id_col = 0
400
+ fea_id_col = 2
401
+ print('NegativeSamplerInMemory: load sample feature from %s' % data_path)
402
+ with gfile.GFile(data_path, 'r') as fin:
403
+ for line_id, line_str in enumerate(fin):
404
+ line_str = line_str.strip()
405
+ cols = line_str.split('\t')
406
+ if line_id == 0:
407
+ schema = [x.split(':') for x in cols]
408
+ for tid in range(len(schema)):
409
+ if schema[tid][0].startswith('id'):
410
+ item_id_col = tid
411
+ if schema[tid][0].startswith('feature'):
412
+ fea_id_col = tid
413
+ print('feature_id_col = %d, item_id_col = %d' %
414
+ (fea_id_col, item_id_col))
415
+ else:
416
+ self._item_ids.append(int(cols[item_id_col]))
417
+ fea_vals = cols[fea_id_col].split(attr_delimiter)
418
+ assert len(fea_vals) == len(
419
+ self._cols), 'invalid row[%d][%d %d]:%s %s' % (
420
+ line_id, len(fea_vals), len(
421
+ self._cols), cols[item_id_col], cols[fea_id_col])
422
+ for col_id in range(len(fea_vals)):
423
+ self._cols[col_id].append(fea_vals[col_id])
424
+
425
+ def _get_impl(self, ids):
426
+ features = []
427
+ if type(ids[0]) != int:
428
+ ids = [int(x) for x in ids]
429
+ assert self._num_sample > 0, 'invalid num_sample: %d' % self._num_sample
430
+
431
+ indices = np.random.choice(
432
+ len(self._item_ids),
433
+ size=self._num_sample + self._batch_size,
434
+ replace=False)
435
+
436
+ sel_ids = []
437
+ for tid in indices:
438
+ rid = self._item_ids[tid]
439
+ if rid not in ids:
440
+ sel_ids.append(tid)
441
+ if len(sel_ids) >= self._num_sample and self._num_sample > 0:
442
+ break
443
+
444
+ features = []
445
+ for col_id in range(len(self._cols)):
446
+ tmp_col = self._cols[col_id]
447
+ np_type = self._attr_np_types[col_id]
448
+ if np_type != str:
449
+ sel_feas = tmp_col[sel_ids]
450
+ features.append(sel_feas)
451
+ else:
452
+ features.append(
453
+ np.asarray([tmp_col[x] for x in sel_ids], order='C', dtype=object))
454
+ return features
455
+
456
+ def get(self, ids):
457
+ """Sampling method.
458
+
459
+ Args:
460
+ ids: item id tensor.
461
+
462
+ Returns:
463
+ Negative sampled feature dict.
464
+ """
465
+ all_attr_types = list(self._attr_tf_types)
466
+ if self._num_sample <= 0:
467
+ all_attr_types.append(tf.float32)
468
+ sampled_values = tf.py_func(self._get_impl, [ids], all_attr_types)
469
+ result_dict = {}
470
+ for k, v in zip(self._attr_names, sampled_values):
471
+ result_dict[k] = v
472
+ return result_dict
473
+
474
+
475
+ class NegativeSamplerV2(BaseSampler):
476
+ """Negative Sampler V2.
477
+
478
+ Weighted random sampling items which do not have positive edge with the user.
479
+
480
+ Args:
481
+ user_data_path: user node data path. id:int64 | weight:float.
482
+ item_data_path: item feature data path. id:int64 | weight:float | attrs:string.
483
+ edge_data_path: positive edge data path. userid:int64 | itemid:int64 | weight:float
484
+ fields: item input fields.
485
+ num_sample: number of negative samples.
486
+ batch_size: mini-batch size.
487
+ attr_delimiter: delimiter of feature string.
488
+ num_eval_sample: number of negative samples for evaluator.
489
+ """
490
+
491
+ def __init__(self,
492
+ user_data_path,
493
+ item_data_path,
494
+ edge_data_path,
495
+ fields,
496
+ num_sample,
497
+ batch_size,
498
+ attr_delimiter=':',
499
+ num_eval_sample=None):
500
+ super(NegativeSamplerV2, self).__init__(fields, num_sample, num_eval_sample)
501
+ self._batch_size = batch_size
502
+ self._g = gl.Graph() \
503
+ .node(tf.compat.as_str(user_data_path),
504
+ node_type='user',
505
+ decoder=gl.Decoder(weighted=True)) \
506
+ .node(tf.compat.as_str(item_data_path),
507
+ node_type='item',
508
+ decoder=gl.Decoder(
509
+ attr_types=self._attr_gl_types,
510
+ weighted=True,
511
+ attr_delimiter=attr_delimiter)) \
512
+ .edge(tf.compat.as_str(edge_data_path),
513
+ edge_type=('user', 'item', 'edge'),
514
+ decoder=gl.Decoder(weighted=True))
515
+ self._init_graph()
516
+
517
+ expand_factor = int(math.ceil(self._num_sample / batch_size))
518
+ self._sampler = self._g.negative_sampler(
519
+ 'edge', expand_factor, strategy='random', conditional=True)
520
+
521
+ def _get_impl(self, src_ids, dst_ids):
522
+ src_ids = np.array(src_ids, dtype=np.int64)
523
+ src_ids = np.pad(src_ids, (0, self._batch_size - len(src_ids)), 'edge')
524
+ dst_ids = np.array(dst_ids, dtype=np.int64)
525
+ dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), 'edge')
526
+ nodes = self._sampler.get(src_ids, dst_ids)
527
+ features = self._parse_nodes(nodes)
528
+ return features
529
+
530
+ def get(self, src_ids, dst_ids):
531
+ """Sampling method.
532
+
533
+ Args:
534
+ src_ids: user id tensor.
535
+ dst_ids: item id tensor.
536
+
537
+ Returns:
538
+ Negative sampled feature dict.
539
+ """
540
+ sampled_values = tf.py_func(self._get_impl, [src_ids, dst_ids],
541
+ self._attr_tf_types)
542
+ result_dict = {}
543
+ for k, t, v in zip(self._attr_names, self._attr_tf_types, sampled_values):
544
+ v.set_shape([self._num_sample])
545
+ result_dict[k] = v
546
+ return result_dict
547
+
548
+
549
+ class HardNegativeSampler(BaseSampler):
550
+ """HardNegativeSampler.
551
+
552
+ Weighted random sampling items not in batch as negative samples, and sampling
553
+ destination nodes in hard_neg_edge as hard negative samples
554
+
555
+ Args:
556
+ user_data_path: user node data path. id:int64 | weight:float.
557
+ item_data_path: item feature data path. id:int64 | weight:float | attrs:string.
558
+ hard_neg_edge_data_path: hard negative edge data path. userid:int64 | itemid:int64 | weight:float
559
+ fields: item input fields.
560
+ num_sample: number of negative samples.
561
+ num_hard_sample: maximum number of hard negative samples.
562
+ batch_size: mini-batch size.
563
+ attr_delimiter: delimiter of feature string.
564
+ num_eval_sample: number of negative samples for evaluator.
565
+ """
566
+
567
+ def __init__(self,
568
+ user_data_path,
569
+ item_data_path,
570
+ hard_neg_edge_data_path,
571
+ fields,
572
+ num_sample,
573
+ num_hard_sample,
574
+ batch_size,
575
+ attr_delimiter=':',
576
+ num_eval_sample=None):
577
+ super(HardNegativeSampler, self).__init__(fields, num_sample,
578
+ num_eval_sample)
579
+ self._batch_size = batch_size
580
+ self._g = gl.Graph() \
581
+ .node(tf.compat.as_str(user_data_path),
582
+ node_type='user',
583
+ decoder=gl.Decoder(weighted=True)) \
584
+ .node(tf.compat.as_str(item_data_path),
585
+ node_type='item',
586
+ decoder=gl.Decoder(
587
+ attr_types=self._attr_gl_types,
588
+ weighted=True,
589
+ attr_delimiter=attr_delimiter)) \
590
+ .edge(tf.compat.as_str(hard_neg_edge_data_path),
591
+ edge_type=('user', 'item', 'hard_neg_edge'),
592
+ decoder=gl.Decoder(weighted=True))
593
+ self._init_graph()
594
+
595
+ expand_factor = int(math.ceil(self._num_sample / batch_size))
596
+ self._neg_sampler = self._g.negative_sampler(
597
+ 'item', expand_factor, strategy='node_weight')
598
+ self._hard_neg_sampler = self._g.neighbor_sampler(['hard_neg_edge'],
599
+ num_hard_sample,
600
+ strategy='full')
601
+
602
+ def _get_impl(self, src_ids, dst_ids):
603
+ src_ids = np.array(src_ids, dtype=np.int64)
604
+ dst_ids = np.array(dst_ids, dtype=np.int64)
605
+ dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), 'edge')
606
+ nodes = self._neg_sampler.get(dst_ids)
607
+ neg_features = self._parse_nodes(nodes)
608
+ sparse_nodes = self._hard_neg_sampler.get(src_ids).layer_nodes(1)
609
+ hard_neg_features, hard_neg_indices = self._parse_sparse_nodes(sparse_nodes)
610
+
611
+ results = []
612
+ for i, v in enumerate(hard_neg_features):
613
+ if type(v) == list:
614
+ results.append(np.asarray(neg_features[i] + v, order='C', dtype=object))
615
+ else:
616
+ results.append(np.concatenate([neg_features[i], v], axis=0))
617
+ results.append(hard_neg_indices)
618
+ return results
619
+
620
+ def get(self, src_ids, dst_ids):
621
+ """Sampling method.
622
+
623
+ Args:
624
+ src_ids: user id tensor.
625
+ dst_ids: item id tensor.
626
+
627
+ Returns:
628
+ Sampled feature dict. The first batch_size is negative samples, remainder is hard negative samples
629
+ """
630
+ output_types = self._attr_tf_types + [tf.int64]
631
+ output_values = tf.py_func(self._get_impl, [src_ids, dst_ids], output_types)
632
+ result_dict = {}
633
+ for k, t, v in zip(self._attr_names, self._attr_tf_types,
634
+ output_values[:-1]):
635
+ v.set_shape([None])
636
+ result_dict[k] = v
637
+
638
+ hard_neg_indices = output_values[-1]
639
+ hard_neg_indices.set_shape([None, 2])
640
+ result_dict['hard_neg_indices'] = hard_neg_indices
641
+ return result_dict
642
+
643
+
644
+ class HardNegativeSamplerV2(BaseSampler):
645
+ """HardNegativeSampler.
646
+
647
+ Weighted random sampling items which do not have positive edge with the user., and sampling
648
+ destination nodes in hard_neg_edge as hard negative samples
649
+
650
+ Args:
651
+ user_data_path: user node data path. id:int64 | weight:float.
652
+ item_data_path: item feature data path. id:int64 | weight:float | attrs:string.
653
+ edge_data_path: positive edge data path. userid:int64 | itemid:int64 | weight:float
654
+ hard_neg_edge_data_path: hard negative edge data path. userid:int64 | itemid:int64 | weight:float
655
+ fields: item input fields.
656
+ num_sample: number of negative samples.
657
+ num_hard_sample: maximum number of hard negative samples.
658
+ batch_size: mini-batch size.
659
+ attr_delimiter: delimiter of feature string.
660
+ num_eval_sample: number of negative samples for evaluator.
661
+ """
662
+
663
+ def __init__(self,
664
+ user_data_path,
665
+ item_data_path,
666
+ edge_data_path,
667
+ hard_neg_edge_data_path,
668
+ fields,
669
+ num_sample,
670
+ num_hard_sample,
671
+ batch_size,
672
+ attr_delimiter=':',
673
+ num_eval_sample=None):
674
+ super(HardNegativeSamplerV2, self).__init__(fields, num_sample,
675
+ num_eval_sample)
676
+ self._batch_size = batch_size
677
+ self._g = gl.Graph() \
678
+ .node(tf.compat.as_str(user_data_path),
679
+ node_type='user',
680
+ decoder=gl.Decoder(weighted=True)) \
681
+ .node(tf.compat.as_str(item_data_path),
682
+ node_type='item',
683
+ decoder=gl.Decoder(
684
+ attr_types=self._attr_gl_types,
685
+ weighted=True,
686
+ attr_delimiter=attr_delimiter)) \
687
+ .edge(tf.compat.as_str(edge_data_path),
688
+ edge_type=('user', 'item', 'edge'),
689
+ decoder=gl.Decoder(weighted=True)) \
690
+ .edge(tf.compat.as_str(hard_neg_edge_data_path),
691
+ edge_type=('user', 'item', 'hard_neg_edge'),
692
+ decoder=gl.Decoder(weighted=True))
693
+ self._init_graph()
694
+
695
+ expand_factor = int(math.ceil(self._num_sample / batch_size))
696
+ self._neg_sampler = self._g.negative_sampler(
697
+ 'edge', expand_factor, strategy='random', conditional=True)
698
+ self._hard_neg_sampler = self._g.neighbor_sampler(['hard_neg_edge'],
699
+ num_hard_sample,
700
+ strategy='full')
701
+
702
+ def _get_impl(self, src_ids, dst_ids):
703
+ src_ids = np.array(src_ids, dtype=np.int64)
704
+ src_ids_padded = np.pad(src_ids, (0, self._batch_size - len(src_ids)),
705
+ 'edge')
706
+ dst_ids = np.array(dst_ids, dtype=np.int64)
707
+ dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), 'edge')
708
+ nodes = self._neg_sampler.get(src_ids_padded, dst_ids)
709
+ neg_features = self._parse_nodes(nodes)
710
+ sparse_nodes = self._hard_neg_sampler.get(src_ids).layer_nodes(1)
711
+ hard_neg_features, hard_neg_indices = self._parse_sparse_nodes(sparse_nodes)
712
+
713
+ results = []
714
+ for i, v in enumerate(hard_neg_features):
715
+ if type(v) == list:
716
+ results.append(np.asarray(neg_features[i] + v, order='C', dtype=object))
717
+ else:
718
+ results.append(np.concatenate([neg_features[i], v], axis=0))
719
+ results.append(hard_neg_indices)
720
+ return results
721
+
722
+ def get(self, src_ids, dst_ids):
723
+ """Sampling method.
724
+
725
+ Args:
726
+ src_ids: user id tensor.
727
+ dst_ids: item id tensor.
728
+
729
+ Returns:
730
+ Sampled feature dict. The first batch_size is negative samples, remainder is hard negative samples
731
+ """
732
+ output_types = self._attr_tf_types + [tf.int64]
733
+ output_values = tf.py_func(self._get_impl, [src_ids, dst_ids], output_types)
734
+ result_dict = {}
735
+ for k, t, v in zip(self._attr_names, self._attr_tf_types,
736
+ output_values[:-1]):
737
+ v.set_shape([None])
738
+ result_dict[k] = v
739
+
740
+ hard_neg_indices = output_values[-1]
741
+ hard_neg_indices.set_shape([None, 2])
742
+ result_dict['hard_neg_indices'] = hard_neg_indices
743
+ return result_dict
744
+
745
+
746
+ def build(data_config):
747
+
748
+ if not data_config.HasField('sampler'):
749
+ return None
750
+ sampler_type = data_config.WhichOneof('sampler')
751
+ print('sampler_type = %s' % sampler_type)
752
+ sampler_config = getattr(data_config, sampler_type)
753
+
754
+ if ds_util.is_on_ds():
755
+ gl.set_field_delimiter(sampler_config.field_delimiter)
756
+
757
+ if sampler_type == 'negative_sampler':
758
+ input_fields = {f.input_name: f for f in data_config.input_fields}
759
+ attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
760
+
761
+ input_path = process_multi_file_input_path(sampler_config.input_path)
762
+ return NegativeSampler.instance(
763
+ data_path=input_path,
764
+ fields=attr_fields,
765
+ num_sample=sampler_config.num_sample,
766
+ batch_size=data_config.batch_size,
767
+ attr_delimiter=sampler_config.attr_delimiter,
768
+ num_eval_sample=sampler_config.num_eval_sample)
769
+ elif sampler_type == 'negative_sampler_in_memory':
770
+ input_fields = {f.input_name: f for f in data_config.input_fields}
771
+ attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
772
+
773
+ input_path = process_multi_file_input_path(sampler_config.input_path)
774
+ return NegativeSamplerInMemory.instance(
775
+ data_path=input_path,
776
+ fields=attr_fields,
777
+ num_sample=sampler_config.num_sample,
778
+ batch_size=data_config.batch_size,
779
+ attr_delimiter=sampler_config.attr_delimiter,
780
+ num_eval_sample=sampler_config.num_eval_sample)
781
+ elif sampler_type == 'negative_sampler_v2':
782
+ input_fields = {f.input_name: f for f in data_config.input_fields}
783
+ attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
784
+
785
+ user_input_path = process_multi_file_input_path(
786
+ sampler_config.user_input_path)
787
+ item_input_path = process_multi_file_input_path(
788
+ sampler_config.item_input_path)
789
+ pos_edge_input_path = process_multi_file_input_path(
790
+ sampler_config.pos_edge_input_path)
791
+ return NegativeSamplerV2.instance(
792
+ user_data_path=user_input_path,
793
+ item_data_path=item_input_path,
794
+ edge_data_path=pos_edge_input_path,
795
+ fields=attr_fields,
796
+ num_sample=sampler_config.num_sample,
797
+ batch_size=data_config.batch_size,
798
+ attr_delimiter=sampler_config.attr_delimiter,
799
+ num_eval_sample=sampler_config.num_eval_sample)
800
+ elif sampler_type == 'hard_negative_sampler':
801
+ input_fields = {f.input_name: f for f in data_config.input_fields}
802
+ attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
803
+
804
+ user_input_path = process_multi_file_input_path(
805
+ sampler_config.user_input_path)
806
+ item_input_path = process_multi_file_input_path(
807
+ sampler_config.item_input_path)
808
+ hard_neg_edge_input_path = process_multi_file_input_path(
809
+ sampler_config.hard_neg_edge_input_path)
810
+ return HardNegativeSampler.instance(
811
+ user_data_path=user_input_path,
812
+ item_data_path=item_input_path,
813
+ hard_neg_edge_data_path=hard_neg_edge_input_path,
814
+ fields=attr_fields,
815
+ num_sample=sampler_config.num_sample,
816
+ num_hard_sample=sampler_config.num_hard_sample,
817
+ batch_size=data_config.batch_size,
818
+ attr_delimiter=sampler_config.attr_delimiter,
819
+ num_eval_sample=sampler_config.num_eval_sample)
820
+ elif sampler_type == 'hard_negative_sampler_v2':
821
+ input_fields = {f.input_name: f for f in data_config.input_fields}
822
+ attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
823
+
824
+ user_input_path = process_multi_file_input_path(
825
+ sampler_config.user_input_path)
826
+ item_input_path = process_multi_file_input_path(
827
+ sampler_config.item_input_path)
828
+ pos_edge_input_path = process_multi_file_input_path(
829
+ sampler_config.pos_edge_input_path)
830
+ hard_neg_edge_input_path = process_multi_file_input_path(
831
+ sampler_config.hard_neg_edge_input_path)
832
+ return HardNegativeSamplerV2.instance(
833
+ user_data_path=user_input_path,
834
+ item_data_path=item_input_path,
835
+ edge_data_path=pos_edge_input_path,
836
+ hard_neg_edge_data_path=hard_neg_edge_input_path,
837
+ fields=attr_fields,
838
+ num_sample=sampler_config.num_sample,
839
+ num_hard_sample=sampler_config.num_hard_sample,
840
+ batch_size=data_config.batch_size,
841
+ attr_delimiter=sampler_config.attr_delimiter,
842
+ num_eval_sample=sampler_config.num_eval_sample)
843
+ else:
844
+ raise ValueError('Unknown sampler %s' % sampler_type)