easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,307 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+ from tensorflow.python.ops.losses.losses_impl import compute_weighted_loss
7
+
8
+ from easy_rec.python.loss.focal_loss import sigmoid_focal_loss_with_logits
9
+ from easy_rec.python.utils.shape_utils import get_shape_list
10
+
11
+ if tf.__version__ >= '2.0':
12
+ tf = tf.compat.v1
13
+
14
+
15
+ def pairwise_loss(labels,
16
+ logits,
17
+ session_ids=None,
18
+ margin=0,
19
+ temperature=1.0,
20
+ weights=1.0,
21
+ name=''):
22
+ """Deprecated Pairwise loss. Also see `pairwise_logistic_loss` below.
23
+
24
+ Args:
25
+ labels: a `Tensor` with shape [batch_size]. e.g. click or not click in the session.
26
+ logits: a `Tensor` with shape [batch_size]. e.g. the value of last neuron before activation.
27
+ session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
28
+ margin: the margin between positive and negative sample pair
29
+ temperature: (Optional) The temperature to use for scaling the logits.
30
+ weights: sample weights
31
+ name: the name of loss
32
+ """
33
+ logging.warning(
34
+ 'The old `pairwise_loss` is being deprecated. '
35
+ 'Please use the new `pairwise_logistic_loss` or `pairwise_focal_loss`')
36
+ loss_name = name if name else 'pairwise_loss'
37
+ logging.info('[{}] margin: {}, temperature: {}'.format(
38
+ loss_name, margin, temperature))
39
+
40
+ if temperature != 1.0:
41
+ logits /= temperature
42
+ pairwise_logits = tf.math.subtract(
43
+ tf.expand_dims(logits, -1), tf.expand_dims(logits, 0)) - margin
44
+ pairwise_mask = tf.greater(
45
+ tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
46
+ if session_ids is not None:
47
+ logging.info('[%s] use session ids' % loss_name)
48
+ group_equal = tf.equal(
49
+ tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0))
50
+ pairwise_mask = tf.logical_and(pairwise_mask, group_equal)
51
+
52
+ pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
53
+ num_pair = tf.size(pairwise_logits)
54
+ tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair)
55
+
56
+ if tf.is_numeric_tensor(weights):
57
+ logging.info('[%s] use sample weight' % loss_name)
58
+ weights = tf.expand_dims(tf.cast(weights, tf.float32), -1)
59
+ batch_size, _ = get_shape_list(weights, 2)
60
+ pairwise_weights = tf.tile(weights, tf.stack([1, batch_size]))
61
+ pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask)
62
+ else:
63
+ pairwise_weights = weights
64
+
65
+ pairwise_pseudo_labels = tf.ones_like(pairwise_logits)
66
+ loss = tf.losses.sigmoid_cross_entropy(
67
+ pairwise_pseudo_labels, pairwise_logits, weights=pairwise_weights)
68
+ # set rank loss to zero if a batch has no positive sample.
69
+ # loss = tf.where(tf.is_nan(loss), tf.zeros_like(loss), loss)
70
+ return loss
71
+
72
+
73
+ def pairwise_focal_loss(labels,
74
+ logits,
75
+ session_ids=None,
76
+ hinge_margin=None,
77
+ gamma=2,
78
+ alpha=None,
79
+ ohem_ratio=1.0,
80
+ temperature=1.0,
81
+ weights=1.0,
82
+ name=''):
83
+ loss_name = name if name else 'pairwise_focal_loss'
84
+ assert 0 < ohem_ratio <= 1.0, loss_name + ' ohem_ratio must be in (0, 1]'
85
+ logging.info(
86
+ '[{}] hinge margin: {}, gamma: {}, alpha: {}, ohem_ratio: {}, temperature: {}'
87
+ .format(loss_name, hinge_margin, gamma, alpha, ohem_ratio, temperature))
88
+
89
+ if temperature != 1.0:
90
+ logits /= temperature
91
+ pairwise_logits = tf.expand_dims(logits, -1) - tf.expand_dims(logits, 0)
92
+
93
+ pairwise_mask = tf.greater(
94
+ tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
95
+ if hinge_margin is not None:
96
+ hinge_mask = tf.less(pairwise_logits, hinge_margin)
97
+ pairwise_mask = tf.logical_and(pairwise_mask, hinge_mask)
98
+ if session_ids is not None:
99
+ logging.info('[%s] use session ids' % loss_name)
100
+ group_equal = tf.equal(
101
+ tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0))
102
+ pairwise_mask = tf.logical_and(pairwise_mask, group_equal)
103
+
104
+ pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
105
+ num_pair = tf.size(pairwise_logits)
106
+ tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair)
107
+
108
+ if tf.is_numeric_tensor(weights):
109
+ logging.info('[%s] use sample weight' % loss_name)
110
+ weights = tf.expand_dims(tf.cast(weights, tf.float32), -1)
111
+ batch_size, _ = get_shape_list(weights, 2)
112
+ pairwise_weights = tf.tile(weights, tf.stack([1, batch_size]))
113
+ pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask)
114
+ else:
115
+ pairwise_weights = weights
116
+
117
+ pairwise_pseudo_labels = tf.ones_like(pairwise_logits)
118
+ loss = sigmoid_focal_loss_with_logits(
119
+ pairwise_pseudo_labels,
120
+ pairwise_logits,
121
+ gamma=gamma,
122
+ alpha=alpha,
123
+ ohem_ratio=ohem_ratio,
124
+ sample_weights=pairwise_weights)
125
+ return loss
126
+
127
+
128
+ def pairwise_logistic_loss(labels,
129
+ logits,
130
+ session_ids=None,
131
+ temperature=1.0,
132
+ hinge_margin=None,
133
+ weights=1.0,
134
+ ohem_ratio=1.0,
135
+ use_label_margin=False,
136
+ name=''):
137
+ r"""Computes pairwise logistic loss between `labels` and `logits`, equivalent to RankNet loss.
138
+
139
+ Definition:
140
+ $$
141
+ \mathcal{L}(\{y\}, \{s\}) =
142
+ \sum_i \sum_j I[y_i > y_j] \log(1 + \exp(-(s_i - s_j)))
143
+ $$
144
+
145
+ Args:
146
+ labels: A `Tensor` of the same shape as `logits` representing graded
147
+ relevance.
148
+ logits: A `Tensor` with shape [batch_size].
149
+ session_ids: a `Tensor` with shape [batch_size]. Session ids of each
150
+ sample, used to max GAUC metric. e.g. user_id
151
+ temperature: (Optional) The temperature to use for scaling the logits.
152
+ hinge_margin: the margin between positive and negative logits
153
+ weights: A scalar, a `Tensor` with shape [batch_size] for each sample
154
+ ohem_ratio: the percent of hard examples to be mined
155
+ use_label_margin: whether to use the diff `label[i]-label[j]` as margin
156
+ name: the name of loss
157
+ """
158
+ loss_name = name if name else 'pairwise_logistic_loss'
159
+ assert 0 < ohem_ratio <= 1.0, loss_name + ' ohem_ratio must be in (0, 1]'
160
+ logging.info('[{}] hinge margin: {}, ohem_ratio: {}, temperature: {}'.format(
161
+ loss_name, hinge_margin, ohem_ratio, temperature))
162
+
163
+ if temperature != 1.0:
164
+ logits /= temperature
165
+ if use_label_margin:
166
+ labels /= temperature
167
+
168
+ pairwise_logits = tf.math.subtract(
169
+ tf.expand_dims(logits, -1), tf.expand_dims(logits, 0))
170
+ if use_label_margin:
171
+ pairwise_logits -= tf.math.subtract(
172
+ tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
173
+ elif hinge_margin is not None:
174
+ pairwise_logits -= hinge_margin
175
+
176
+ pairwise_mask = tf.greater(
177
+ tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
178
+ if session_ids is not None:
179
+ logging.info('[%s] use session ids' % loss_name)
180
+ group_equal = tf.equal(
181
+ tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0))
182
+ pairwise_mask = tf.logical_and(pairwise_mask, group_equal)
183
+
184
+ pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
185
+ num_pair = tf.size(pairwise_logits)
186
+ tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair)
187
+
188
+ # The following is the same as log(1 + exp(-pairwise_logits)).
189
+ losses = tf.nn.relu(-pairwise_logits) + tf.math.log1p(
190
+ tf.exp(-tf.abs(pairwise_logits)))
191
+
192
+ if tf.is_numeric_tensor(weights):
193
+ logging.info('[%s] use sample weight' % loss_name)
194
+ weights = tf.expand_dims(tf.cast(weights, tf.float32), -1)
195
+ batch_size, _ = get_shape_list(weights, 2)
196
+ pairwise_weights = tf.tile(weights, tf.stack([1, batch_size]))
197
+ pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask)
198
+ else:
199
+ pairwise_weights = weights
200
+
201
+ if ohem_ratio == 1.0:
202
+ return compute_weighted_loss(losses, pairwise_weights)
203
+
204
+ losses = compute_weighted_loss(
205
+ losses, pairwise_weights, reduction=tf.losses.Reduction.NONE)
206
+ k = tf.to_float(tf.size(losses)) * tf.convert_to_tensor(ohem_ratio)
207
+ k = tf.to_int32(tf.math.rint(k))
208
+ topk = tf.nn.top_k(losses, k)
209
+ losses = tf.boolean_mask(topk.values, topk.values > 0)
210
+ return tf.reduce_mean(losses)
211
+
212
+
213
+ def pairwise_hinge_loss(labels,
214
+ logits,
215
+ session_ids=None,
216
+ temperature=1.0,
217
+ margin=1.0,
218
+ weights=1.0,
219
+ ohem_ratio=1.0,
220
+ label_is_logits=True,
221
+ use_label_margin=True,
222
+ use_exponent=False,
223
+ name=''):
224
+ r"""Computes pairwise hinge loss between `labels` and `logits`.
225
+
226
+ Definition:
227
+ $$
228
+ \mathcal{L}(\{y\}, \{s\}) =
229
+ \sum_i \sum_j I[y_i > y_j] \max(0, 1 - (s_i - s_j))
230
+ $$
231
+
232
+ Args:
233
+ labels: A `Tensor` of the same shape as `logits` representing graded
234
+ relevance.
235
+ logits: A `Tensor` with shape [batch_size].
236
+ session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
237
+ temperature: (Optional) The temperature to use for scaling the logits.
238
+ margin: the margin between positive and negative logits
239
+ weights: A scalar, a `Tensor` with shape [batch_size] for each sample
240
+ ohem_ratio: the percent of hard examples to be mined
241
+ label_is_logits: Whether `labels` is expected to be a logits tensor.
242
+ use_label_margin: whether to use the diff `label[i]-label[j]` as margin
243
+ use_exponent: whether to use exponential difference
244
+ name: the name of loss
245
+ """
246
+ loss_name = name if name else 'pairwise_hinge_loss'
247
+ assert 0 < ohem_ratio <= 1.0, loss_name + ' ohem_ratio must be in (0, 1]'
248
+ logging.info(
249
+ '[{}] margin: {}, ohem_ratio: {}, temperature: {}, use_exponent: {}, label_is_logits: {}, use_label_margin: {}'
250
+ .format(loss_name, margin, ohem_ratio, temperature, use_exponent,
251
+ label_is_logits, use_label_margin))
252
+
253
+ if temperature != 1.0:
254
+ logits /= temperature
255
+ if label_is_logits:
256
+ labels /= temperature
257
+ if use_exponent:
258
+ labels = tf.nn.sigmoid(labels)
259
+ logits = tf.nn.sigmoid(labels)
260
+
261
+ pairwise_logits = tf.math.subtract(
262
+ tf.expand_dims(logits, -1), tf.expand_dims(logits, 0))
263
+ pairwise_labels = tf.math.subtract(
264
+ tf.expand_dims(labels, -1), tf.expand_dims(labels, 0))
265
+
266
+ pairwise_mask = tf.greater(pairwise_labels, 0)
267
+ if session_ids is not None:
268
+ logging.info('[%s] use session ids' % loss_name)
269
+ group_equal = tf.equal(
270
+ tf.expand_dims(session_ids, -1), tf.expand_dims(session_ids, 0))
271
+ pairwise_mask = tf.logical_and(pairwise_mask, group_equal)
272
+
273
+ pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
274
+ pairwise_labels = tf.boolean_mask(pairwise_labels, pairwise_mask)
275
+ num_pair = tf.size(pairwise_logits)
276
+ tf.summary.scalar('loss/%s_num_of_pairs' % loss_name, num_pair)
277
+
278
+ if use_label_margin:
279
+ diff = pairwise_labels - pairwise_logits
280
+ else:
281
+ diff = margin - pairwise_logits
282
+ if use_exponent:
283
+ threshold = 88.0 # the max value of float32 is 3.4028235e+38
284
+ safe_diff = tf.clip_by_value(diff, -threshold, threshold)
285
+ losses = tf.nn.relu(tf.exp(safe_diff) - 1.0)
286
+ else:
287
+ losses = tf.nn.relu(diff)
288
+
289
+ if tf.is_numeric_tensor(weights):
290
+ logging.info('[%s] use sample weight' % loss_name)
291
+ weights = tf.expand_dims(tf.cast(weights, tf.float32), -1)
292
+ batch_size, _ = get_shape_list(weights, 2)
293
+ pairwise_weights = tf.tile(weights, tf.stack([1, batch_size]))
294
+ pairwise_weights = tf.boolean_mask(pairwise_weights, pairwise_mask)
295
+ else:
296
+ pairwise_weights = weights
297
+
298
+ if ohem_ratio == 1.0:
299
+ return compute_weighted_loss(losses, pairwise_weights)
300
+
301
+ losses = compute_weighted_loss(
302
+ losses, pairwise_weights, reduction=tf.losses.Reduction.NONE)
303
+ k = tf.to_float(tf.size(losses)) * tf.convert_to_tensor(ohem_ratio)
304
+ k = tf.to_int32(tf.math.rint(k))
305
+ topk = tf.nn.top_k(losses, k)
306
+ losses = tf.boolean_mask(topk.values, topk.values > 0)
307
+ return tf.reduce_mean(losses)
@@ -0,0 +1,110 @@
1
+ # coding=utf-8
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ if tf.__version__ >= '2.0':
6
+ tf = tf.compat.v1
7
+
8
+
9
+ def support_vector_guided_softmax_loss(pos_score,
10
+ neg_scores,
11
+ margin=0,
12
+ t=1,
13
+ smooth=1.0,
14
+ threshold=0,
15
+ weights=1.0):
16
+ """Refer paper: Support Vector Guided Softmax Loss for Face Recognition (https://128.84.21.199/abs/1812.11317)."""
17
+ new_pos_score = pos_score - margin
18
+ cond = tf.greater_equal(new_pos_score - neg_scores, threshold)
19
+ mask = tf.where(cond, tf.zeros_like(cond, tf.float32),
20
+ tf.ones_like(cond, tf.float32)) # I_k
21
+ new_neg_scores = mask * (neg_scores * t + t - 1) + (1 - mask) * neg_scores
22
+ logits = tf.concat([new_pos_score, new_neg_scores], axis=1)
23
+ if 1.0 != smooth:
24
+ logits *= smooth
25
+
26
+ loss = tf.losses.sparse_softmax_cross_entropy(
27
+ tf.zeros_like(pos_score, dtype=tf.int32), logits, weights=weights)
28
+ # set rank loss to zero if a batch has no positive sample.
29
+ loss = tf.where(tf.is_nan(loss), tf.zeros_like(loss), loss)
30
+ return loss
31
+
32
+
33
+ def softmax_loss_with_negative_mining(user_emb,
34
+ item_emb,
35
+ labels,
36
+ num_negative_samples=4,
37
+ embed_normed=False,
38
+ weights=1.0,
39
+ gamma=1.0,
40
+ margin=0,
41
+ t=1,
42
+ seed=None):
43
+ """Compute the softmax loss based on the cosine distance explained below.
44
+
45
+ Given mini batches for `user_emb` and `item_emb`, this function computes for each element in `user_emb`
46
+ the cosine distance between it and the corresponding `item_emb`,
47
+ and additionally the cosine distance between `user_emb` and some other elements of `item_emb`
48
+ (referred to a negative samples).
49
+ The negative samples are formed on the fly by shifting the right side (`item_emb`).
50
+ Then the softmax loss will be computed based on these cosine distance.
51
+
52
+ Args:
53
+ user_emb: A `Tensor` with shape [batch_size, embedding_size]. The embedding of user.
54
+ item_emb: A `Tensor` with shape [batch_size, embedding_size]. The embedding of item.
55
+ labels: a `Tensor` with shape [batch_size]. e.g. click or not click in the session. It's values must be 0 or 1.
56
+ num_negative_samples: the num of negative samples, should be in range [1, batch_size).
57
+ embed_normed: bool, whether input embeddings l2 normalized
58
+ weights: `weights` acts as a coefficient for the loss. If a scalar is provided,
59
+ then the loss is simply scaled by the given value. If `weights` is a
60
+ tensor of shape `[batch_size]`, then the loss weights apply to each corresponding sample.
61
+ gamma: smooth coefficient of softmax
62
+ margin: the margin between positive pair and negative pair
63
+ t: coefficient of support vector guided softmax loss
64
+ seed: A Python integer. Used to create a random seed for the distribution.
65
+ See `tf.set_random_seed`
66
+ for behavior.
67
+
68
+ Return:
69
+ support vector guided softmax loss of positive labels
70
+ """
71
+ assert 0 < num_negative_samples, '`num_negative_samples` should be greater than 0'
72
+
73
+ batch_size = tf.shape(item_emb)[0]
74
+ is_valid = tf.assert_less(
75
+ num_negative_samples,
76
+ batch_size,
77
+ message='`num_negative_samples` should be less than batch_size')
78
+ with tf.control_dependencies([is_valid]):
79
+ if not embed_normed:
80
+ user_emb = tf.nn.l2_normalize(user_emb, axis=-1)
81
+ item_emb = tf.nn.l2_normalize(item_emb, axis=-1)
82
+
83
+ vectors = [item_emb]
84
+ for i in range(num_negative_samples):
85
+ shift = tf.random_uniform([], 1, batch_size, dtype=tf.int32, seed=seed)
86
+ neg_item_emb = tf.roll(item_emb, shift, axis=0)
87
+ vectors.append(neg_item_emb)
88
+ # all_embeddings's shape: (batch_size, num_negative_samples + 1, vec_dim)
89
+ all_embeddings = tf.stack(vectors, axis=1)
90
+
91
+ mask = tf.greater(labels, 0)
92
+ mask_user_emb = tf.boolean_mask(user_emb, mask)
93
+ mask_item_emb = tf.boolean_mask(all_embeddings, mask)
94
+ if isinstance(weights, tf.Tensor):
95
+ weights = tf.boolean_mask(weights, mask)
96
+
97
+ # sim_scores's shape: (num_of_pos_label_in_batch_size, num_negative_samples + 1)
98
+ sim_scores = tf.keras.backend.batch_dot(
99
+ mask_user_emb, mask_item_emb, axes=(1, 2))
100
+ pos_score = tf.slice(sim_scores, [0, 0], [-1, 1])
101
+ neg_scores = tf.slice(sim_scores, [0, 1], [-1, -1])
102
+
103
+ loss = support_vector_guided_softmax_loss(
104
+ pos_score,
105
+ neg_scores,
106
+ margin=margin,
107
+ t=t,
108
+ smooth=gamma,
109
+ weights=weights)
110
+ return loss
@@ -0,0 +1,76 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """Zero-inflated lognormal loss for lifetime value prediction."""
4
+ import tensorflow as tf
5
+ import tensorflow_probability as tfp
6
+
7
+ tfd = tfp.distributions
8
+
9
+ if tf.__version__ >= '2.0':
10
+ tf = tf.compat.v1
11
+
12
+
13
+ def zero_inflated_lognormal_pred(logits):
14
+ """Calculates predicted mean of zero inflated lognormal logits.
15
+
16
+ Arguments:
17
+ logits: [batch_size, 3] tensor of logits.
18
+
19
+ Returns:
20
+ positive_probs: [batch_size, 1] tensor of positive probability.
21
+ preds: [batch_size, 1] tensor of predicted mean.
22
+ """
23
+ logits = tf.convert_to_tensor(logits, dtype=tf.float32)
24
+ positive_probs = tf.keras.backend.sigmoid(logits[..., :1])
25
+ loc = logits[..., 1:2]
26
+ scale = tf.keras.backend.softplus(logits[..., 2:])
27
+ preds = (
28
+ positive_probs *
29
+ tf.keras.backend.exp(loc + 0.5 * tf.keras.backend.square(scale)))
30
+ return positive_probs, preds
31
+
32
+
33
+ def zero_inflated_lognormal_loss(labels, logits, name=''):
34
+ """Computes the zero inflated lognormal loss.
35
+
36
+ Usage with tf.keras API:
37
+
38
+ ```python
39
+ model = tf.keras.Model(inputs, outputs)
40
+ model.compile('sgd', loss=zero_inflated_lognormal)
41
+ ```
42
+
43
+ Arguments:
44
+ labels: True targets, tensor of shape [batch_size, 1].
45
+ logits: Logits of output layer, tensor of shape [batch_size, 3].
46
+ name: the name of loss
47
+
48
+ Returns:
49
+ Zero inflated lognormal loss value.
50
+ """
51
+ loss_name = name if name else 'ziln_loss'
52
+ labels = tf.cast(labels, dtype=tf.float32)
53
+ if labels.shape.ndims == 1:
54
+ labels = tf.expand_dims(labels, 1) # [B, 1]
55
+ positive = tf.cast(labels > 0, tf.float32)
56
+
57
+ logits = tf.convert_to_tensor(logits, dtype=tf.float32)
58
+ logits.shape.assert_is_compatible_with(
59
+ tf.TensorShape(labels.shape[:-1].as_list() + [3]))
60
+
61
+ positive_logits = logits[..., :1]
62
+ classification_loss = tf.keras.backend.binary_crossentropy(
63
+ positive, positive_logits, from_logits=True)
64
+ classification_loss = tf.keras.backend.mean(classification_loss)
65
+ tf.summary.scalar('loss/%s_classify' % loss_name, classification_loss)
66
+
67
+ loc = logits[..., 1:2]
68
+ scale = tf.math.maximum(
69
+ tf.keras.backend.softplus(logits[..., 2:]),
70
+ tf.math.sqrt(tf.keras.backend.epsilon()))
71
+ safe_labels = positive * labels + (
72
+ 1 - positive) * tf.keras.backend.ones_like(labels)
73
+ regression_loss = -tf.keras.backend.mean(
74
+ positive * tfd.LogNormal(loc=loc, scale=scale).log_prob(safe_labels))
75
+ tf.summary.scalar('loss/%s_regression' % loss_name, regression_loss)
76
+ return classification_loss + regression_loss