easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,82 @@
1
+ # coding=utf-8
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ if tf.__version__ >= '2.0':
6
+ tf = tf.compat.v1
7
+
8
+
9
+ def circle_loss(embeddings,
10
+ labels,
11
+ sessions=None,
12
+ margin=0.25,
13
+ gamma=32,
14
+ embed_normed=False):
15
+ """Paper: Circle Loss: A Unified Perspective of Pair Similarity Optimization.
16
+
17
+ Link: http://arxiv.org/pdf/2002.10857.pdf
18
+
19
+ Args:
20
+ embeddings: A `Tensor` with shape [batch_size, embedding_size]. The embedding of each sample.
21
+ labels: a `Tensor` with shape [batch_size]. e.g. click or not click in the session.
22
+ sessions: a `Tensor` with shape [batch_size]. session ids of each sample.
23
+ margin: the margin between positive similarity and negative similarity
24
+ gamma: parameter of circle loss
25
+ embed_normed: bool, whether input embeddings l2 normalized
26
+ """
27
+ norm_embeddings = embeddings if embed_normed else tf.nn.l2_normalize(
28
+ embeddings, axis=-1)
29
+ pair_wise_cosine_matrix = tf.matmul(
30
+ norm_embeddings, norm_embeddings, transpose_b=True)
31
+
32
+ positive_mask = get_anchor_positive_triplet_mask(labels, sessions)
33
+ negative_mask = 1 - positive_mask - tf.eye(tf.shape(labels)[0])
34
+
35
+ delta_p = 1 - margin
36
+ delta_n = margin
37
+
38
+ ap = tf.nn.relu(-tf.stop_gradient(pair_wise_cosine_matrix * positive_mask) +
39
+ 1 + margin)
40
+ an = tf.nn.relu(
41
+ tf.stop_gradient(pair_wise_cosine_matrix * negative_mask) + margin)
42
+
43
+ logit_p = -ap * (pair_wise_cosine_matrix -
44
+ delta_p) * gamma * positive_mask - (1 - positive_mask) * 1e12
45
+ logit_n = an * (pair_wise_cosine_matrix -
46
+ delta_n) * gamma * negative_mask - (1 - negative_mask) * 1e12
47
+
48
+ joint_neg_loss = tf.reduce_logsumexp(logit_n, axis=-1)
49
+ joint_pos_loss = tf.reduce_logsumexp(logit_p, axis=-1)
50
+ loss = tf.nn.softplus(joint_neg_loss + joint_pos_loss)
51
+ return tf.reduce_mean(loss)
52
+
53
+
54
+ def get_anchor_positive_triplet_mask(labels, sessions=None):
55
+ """Return a 2D mask where mask[a, p] is 1.0 iff a and p are distinct and have same session and label.
56
+
57
+ Args:
58
+ labels: a `Tensor` with shape [batch_size]
59
+ sessions: a `Tensor` with shape [batch_size]
60
+
61
+ Returns:
62
+ mask: tf.float32 `Tensor` with shape [batch_size, batch_size]
63
+ """
64
+ # Check that i and j are distinct
65
+ indices_equal = tf.cast(tf.eye(tf.shape(labels)[0]), tf.bool)
66
+ indices_not_equal = tf.logical_not(indices_equal)
67
+
68
+ # Check if labels[i] == labels[j]
69
+ # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
70
+ labels_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
71
+
72
+ # Check if labels[i] == labels[j]
73
+ if sessions is None or sessions is labels:
74
+ class_equal = labels_equal
75
+ else:
76
+ sessions_equal = tf.equal(
77
+ tf.expand_dims(sessions, 0), tf.expand_dims(sessions, 1))
78
+ class_equal = tf.logical_and(sessions_equal, labels_equal)
79
+
80
+ # Combine the three masks
81
+ mask = tf.logical_and(indices_not_equal, class_equal)
82
+ return tf.cast(mask, tf.float32)
@@ -0,0 +1,79 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ if tf.__version__ >= '2.0':
6
+ tf = tf.compat.v1
7
+
8
+
9
+ def l2_loss(x1, x2):
10
+ """Compute euclidean distance of two embeddings."""
11
+ distance = tf.reduce_sum(tf.square(x1 - x2), axis=-1)
12
+ return tf.reduce_mean(distance)
13
+
14
+
15
+ def info_nce_loss(query, positive, temperature=0.1):
16
+ """Calculates the InfoNCE loss for self-supervised learning.
17
+
18
+ This contrastive loss enforces the embeddings of similar (positive) samples to be close
19
+ and those of different (negative) samples to be distant.
20
+ A query embedding is compared with one positive key and with one or more negative keys.
21
+
22
+ References:
23
+ https://arxiv.org/abs/1807.03748v2
24
+ https://arxiv.org/abs/2010.05113
25
+ """
26
+ # Check input dimensionality.
27
+ if query.shape.ndims != 2:
28
+ raise ValueError('<query> must have 2 dimensions.')
29
+ if positive.shape.ndims != 2:
30
+ raise ValueError('<positive> must have 2 dimensions.')
31
+ # Embedding vectors should have same number of components.
32
+ if query.shape[-1] != positive.shape[-1]:
33
+ raise ValueError(
34
+ 'Vectors of <query> and <positive> should have the same number of components.'
35
+ )
36
+
37
+ # Negative keys are implicitly off-diagonal positive keys.
38
+
39
+ # Cosine between all combinations
40
+ logits = tf.matmul(query, positive, transpose_b=True)
41
+ logits /= temperature
42
+
43
+ # Positive keys are the entries on the diagonal
44
+ batch_size = tf.shape(query)[0]
45
+ labels = tf.range(batch_size)
46
+
47
+ return tf.losses.sparse_softmax_cross_entropy(labels, logits)
48
+
49
+
50
+ def get_mask_matrix(batch_size):
51
+ mat = tf.ones((batch_size, batch_size), dtype=tf.bool)
52
+ diag = tf.zeros([batch_size], dtype=tf.bool)
53
+ mask = tf.linalg.set_diag(mat, diag)
54
+ mask = tf.tile(mask, [2, 2])
55
+ return mask
56
+
57
+
58
+ def nce_loss(z_i, z_j, temperature=1.0):
59
+ """Contrastive nce loss for homogeneous embeddings.
60
+
61
+ Refer paper: Contrastive Learning for Sequential Recommendation
62
+ """
63
+ batch_size = tf.shape(z_i)[0]
64
+ N = 2 * batch_size
65
+ z = tf.concat((z_i, z_j), axis=0)
66
+ sim = tf.matmul(z, tf.transpose(z)) / temperature
67
+ sim_i_j = tf.matrix_diag_part(
68
+ tf.slice(sim, [batch_size, 0], [batch_size, batch_size]))
69
+ sim_j_i = tf.matrix_diag_part(
70
+ tf.slice(sim, [0, batch_size], [batch_size, batch_size]))
71
+ positive_samples = tf.reshape(tf.concat((sim_i_j, sim_j_i), axis=0), (N, 1))
72
+ mask = get_mask_matrix(batch_size)
73
+ negative_samples = tf.reshape(tf.boolean_mask(sim, mask), (N, -1))
74
+
75
+ labels = tf.zeros(N, dtype=tf.int32)
76
+ logits = tf.concat((positive_samples, negative_samples), axis=1)
77
+
78
+ loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)
79
+ return loss
@@ -0,0 +1,38 @@
1
+ # coding=utf-8
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import tensorflow as tf
5
+
6
+ if tf.__version__ >= '2.0':
7
+ tf = tf.compat.v1
8
+
9
+
10
+ def f1_reweight_sigmoid_cross_entropy(labels,
11
+ logits,
12
+ beta_square,
13
+ label_smoothing=0,
14
+ weights=None):
15
+ """Refer paper: Adaptive Scaling for Sparse Detection in Information Extraction."""
16
+ probs = tf.nn.sigmoid(logits)
17
+ if len(logits.shape.as_list()) == 1:
18
+ logits = tf.expand_dims(logits, -1)
19
+ if len(labels.shape.as_list()) == 1:
20
+ labels = tf.expand_dims(labels, -1)
21
+ labels = tf.to_float(labels)
22
+ batch_size = tf.shape(labels)[0]
23
+ batch_size_float = tf.to_float(batch_size)
24
+ num_pos = tf.reduce_sum(labels, axis=0)
25
+ num_neg = batch_size_float - num_pos
26
+ tp = tf.reduce_sum(probs, axis=0)
27
+ tn = batch_size_float - tp
28
+ neg_weight = tp / (beta_square * num_pos + num_neg - tn + 1e-8)
29
+ neg_weight_tile = tf.tile(tf.expand_dims(neg_weight, 0), [batch_size, 1])
30
+ final_weights = tf.where(
31
+ tf.equal(labels, 1.0), tf.ones_like(labels), neg_weight_tile)
32
+ if weights is not None:
33
+ weights = tf.cast(weights, tf.float32)
34
+ if len(weights.shape.as_list()) == 1:
35
+ weights = tf.expand_dims(weights, -1)
36
+ final_weights *= weights
37
+ return tf.losses.sigmoid_cross_entropy(
38
+ labels, logits, final_weights, label_smoothing=label_smoothing)
@@ -0,0 +1,93 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ if tf.__version__ >= '2.0':
8
+ tf = tf.compat.v1
9
+
10
+
11
+ def sigmoid_focal_loss_with_logits(labels,
12
+ logits,
13
+ gamma=2.0,
14
+ alpha=None,
15
+ ohem_ratio=1.0,
16
+ sample_weights=None,
17
+ label_smoothing=0,
18
+ name=''):
19
+ """Implements the focal loss function.
20
+
21
+ Focal loss was first introduced in the RetinaNet paper
22
+ (https://arxiv.org/pdf/1708.02002.pdf). Focal loss is extremely useful for
23
+ classification when you have highly imbalanced classes. It down-weights
24
+ well-classified examples and focuses on hard examples. The loss value is
25
+ much high for a sample which is misclassified by the classifier as compared
26
+ to the loss value corresponding to a well-classified example. One of the
27
+ best use-cases of focal loss is its usage in object detection where the
28
+ imbalance between the background class and other classes is extremely high.
29
+
30
+ Args
31
+ labels: `[batch_size]` target integer labels in `{0, 1}`.
32
+ logits: Float `[batch_size]` logits outputs of the network.
33
+ alpha: balancing factor.
34
+ gamma: modulating factor.
35
+ ohem_ratio: the percent of hard examples to be mined
36
+ sample_weights: Optional `Tensor` whose rank is either 0, or the same rank as
37
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
38
+ be either `1`, or the same as the corresponding `losses` dimension).
39
+ label_smoothing: If greater than `0` then smooth the labels.
40
+ name: the name of loss
41
+
42
+ Returns:
43
+ Weighted loss float `Tensor`. If `reduction` is `NONE`,this has the
44
+ same shape as `y_true`; otherwise, it is scalar.
45
+
46
+ Raises:
47
+ ValueError: If the shape of `sample_weight` is invalid or value of
48
+ `gamma` is less than zero
49
+ """
50
+ loss_name = name if name else 'focal_loss'
51
+ assert 0 < ohem_ratio <= 1.0, loss_name + ' ohem_ratio must be in (0, 1]'
52
+ if gamma and gamma < 0:
53
+ raise ValueError('Value of gamma should be greater than or equal to zero')
54
+ logging.info(
55
+ '[{}] gamma: {}, alpha: {}, ohem_ratho: {}, label smoothing: {}'.format(
56
+ loss_name, gamma, alpha, ohem_ratio, label_smoothing))
57
+
58
+ y_true = tf.cast(labels, logits.dtype)
59
+
60
+ # convert the predictions into probabilities
61
+ y_pred = tf.nn.sigmoid(logits)
62
+ epsilon = 1e-7
63
+ y_pred = tf.clip_by_value(y_pred, epsilon, 1 - epsilon)
64
+ p_t = (y_true * y_pred) + ((1 - y_true) * (1 - y_pred))
65
+ weights = tf.pow((1 - p_t), gamma)
66
+
67
+ if alpha is not None:
68
+ alpha_factor = y_true * alpha + ((1 - alpha) * (1 - y_true))
69
+ weights *= alpha_factor
70
+
71
+ if sample_weights is not None:
72
+ if tf.is_numeric_tensor(sample_weights):
73
+ logging.info('[%s] use sample weight' % loss_name)
74
+ weights *= tf.cast(sample_weights, tf.float32)
75
+ elif sample_weights != 1.0:
76
+ logging.info('[%s] use sample weight: %f' % (loss_name, sample_weights))
77
+ weights *= sample_weights
78
+
79
+ if ohem_ratio == 1.0:
80
+ return tf.losses.sigmoid_cross_entropy(
81
+ y_true, logits, weights=weights, label_smoothing=label_smoothing)
82
+
83
+ losses = tf.losses.sigmoid_cross_entropy(
84
+ y_true,
85
+ logits,
86
+ weights=weights,
87
+ label_smoothing=label_smoothing,
88
+ reduction=tf.losses.Reduction.NONE)
89
+ k = tf.to_float(tf.size(losses)) * tf.convert_to_tensor(ohem_ratio)
90
+ k = tf.to_int32(tf.math.rint(k))
91
+ topk = tf.nn.top_k(losses, k)
92
+ losses = tf.boolean_mask(topk.values, topk.values > 0)
93
+ return tf.reduce_mean(losses)
@@ -0,0 +1,128 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import numpy as np
6
+ import tensorflow as tf
7
+
8
+ if tf.__version__ >= '2.0':
9
+ tf = tf.compat.v1
10
+
11
+
12
+ def jrc_loss(labels,
13
+ logits,
14
+ session_ids,
15
+ alpha=0.5,
16
+ loss_weight_strategy='fixed',
17
+ sample_weights=1.0,
18
+ same_label_loss=True,
19
+ name=''):
20
+ """Joint Optimization of Ranking and Calibration with Contextualized Hybrid Model.
21
+
22
+ https://arxiv.org/abs/2208.06164
23
+
24
+ Args:
25
+ labels: a `Tensor` with shape [batch_size]. e.g. click or not click in the session.
26
+ logits: a `Tensor` with shape [batch_size, 2]. e.g. the value of last neuron before activation.
27
+ session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
28
+ alpha: the weight to balance ranking loss and calibration loss
29
+ loss_weight_strategy: str, the loss weight strategy to balancing between ce_loss and ge_loss
30
+ sample_weights: Coefficients for the loss. This must be scalar or broadcastable to
31
+ `labels` (i.e. same rank and each dimension is either 1 or the same).
32
+ same_label_loss: enable ge_loss for sample with same label in a session or not.
33
+ name: the name of loss
34
+ """
35
+ loss_name = name if name else 'jrc_loss'
36
+ logging.info('[{}] alpha: {}, loss_weight_strategy: {}'.format(
37
+ loss_name, alpha, loss_weight_strategy))
38
+
39
+ ce_loss = tf.losses.sparse_softmax_cross_entropy(
40
+ labels, logits, weights=sample_weights)
41
+
42
+ labels = tf.expand_dims(labels, 1) # [B, 1]
43
+ labels = tf.concat([1 - labels, labels], axis=1) # [B, 2]
44
+
45
+ batch_size = tf.shape(logits)[0]
46
+
47
+ # Mask: shape [B, B], mask[i,j]=1 indicates the i-th sample
48
+ # and j-th sample are in the same context
49
+ mask = tf.equal(
50
+ tf.expand_dims(session_ids, 1), tf.expand_dims(session_ids, 0))
51
+ mask = tf.to_float(mask)
52
+
53
+ # Tile logits and label: [B, 2]->[B, B, 2]
54
+ logits = tf.tile(tf.expand_dims(logits, 1), [1, batch_size, 1])
55
+ y = tf.tile(tf.expand_dims(labels, 1), [1, batch_size, 1])
56
+
57
+ # Set logits that are not in the same context to -inf
58
+ mask3d = tf.expand_dims(mask, 2)
59
+ y = tf.to_float(y) * mask3d
60
+ logits = logits + (1 - mask3d) * -1e9
61
+ y_neg, y_pos = y[:, :, 0], y[:, :, 1]
62
+ l_neg, l_pos = logits[:, :, 0], logits[:, :, 1]
63
+
64
+ if tf.is_numeric_tensor(sample_weights):
65
+ logging.info('[%s] use sample weight' % loss_name)
66
+ weights = tf.expand_dims(tf.cast(sample_weights, tf.float32), 0)
67
+ pairwise_weights = tf.tile(weights, tf.stack([batch_size, 1]))
68
+ y_pos *= pairwise_weights
69
+ y_neg *= pairwise_weights
70
+
71
+ # Compute list-wise generative loss -log p(x|y, z)
72
+ if same_label_loss:
73
+ logging.info('[%s] enable same_label_loss' % loss_name)
74
+ loss_pos = -tf.reduce_sum(y_pos * tf.nn.log_softmax(l_pos, axis=0), axis=0)
75
+ loss_neg = -tf.reduce_sum(y_neg * tf.nn.log_softmax(l_neg, axis=0), axis=0)
76
+ ge_loss = tf.reduce_mean(
77
+ (loss_pos + loss_neg) / tf.reduce_sum(mask, axis=0))
78
+ else:
79
+ logging.info('[%s] disable same_label_loss' % loss_name)
80
+ diag = tf.one_hot(tf.range(batch_size), batch_size)
81
+ l_pos = l_pos + (1 - diag) * y_pos * -1e9
82
+ l_neg = l_neg + (1 - diag) * y_neg * -1e9
83
+ loss_pos = -tf.linalg.diag_part(y_pos * tf.nn.log_softmax(l_pos, axis=0))
84
+ loss_neg = -tf.linalg.diag_part(y_neg * tf.nn.log_softmax(l_neg, axis=0))
85
+ ge_loss = tf.reduce_mean(loss_pos + loss_neg)
86
+
87
+ tf.summary.scalar('loss/%s_ce' % loss_name, ce_loss)
88
+ tf.summary.scalar('loss/%s_ge' % loss_name, ge_loss)
89
+
90
+ # The final JRC model
91
+ if loss_weight_strategy == 'fixed':
92
+ loss = alpha * ce_loss + (1 - alpha) * ge_loss
93
+ elif loss_weight_strategy == 'random_uniform':
94
+ weight = tf.random_uniform([])
95
+ loss = weight * ce_loss + (1 - weight) * ge_loss
96
+ tf.summary.scalar('loss/%s_ce_weight' % loss_name, weight)
97
+ tf.summary.scalar('loss/%s_ge_weight' % loss_name, 1 - weight)
98
+ elif loss_weight_strategy == 'random_normal':
99
+ weights = tf.random_normal([2])
100
+ loss_weight = tf.nn.softmax(weights)
101
+ loss = loss_weight[0] * ce_loss + loss_weight[1] * ge_loss
102
+ tf.summary.scalar('loss/%s_ce_weight' % loss_name, loss_weight[0])
103
+ tf.summary.scalar('loss/%s_ge_weight' % loss_name, loss_weight[1])
104
+ elif loss_weight_strategy == 'random_bernoulli':
105
+ bern = tf.distributions.Bernoulli(probs=0.5, dtype=tf.float32)
106
+ weights = bern.sample(2)
107
+ loss_weight = tf.cond(
108
+ tf.equal(tf.reduce_sum(weights), 1), lambda: weights,
109
+ lambda: tf.convert_to_tensor([0.5, 0.5]))
110
+ loss = loss_weight[0] * ce_loss + loss_weight[1] * ge_loss
111
+ tf.summary.scalar('loss/%s_ce_weight' % loss_name, loss_weight[0])
112
+ tf.summary.scalar('loss/%s_ge_weight' % loss_name, loss_weight[1])
113
+ elif loss_weight_strategy == 'uncertainty':
114
+ uncertainty1 = tf.Variable(
115
+ 0, name='%s_ranking_loss_weight' % loss_name, dtype=tf.float32)
116
+ tf.summary.scalar('loss/%s_ranking_uncertainty' % loss_name, uncertainty1)
117
+ uncertainty2 = tf.Variable(
118
+ 0, name='%s_calibration_loss_weight' % loss_name, dtype=tf.float32)
119
+ tf.summary.scalar('loss/%s_calibration_uncertainty' % loss_name,
120
+ uncertainty2)
121
+ loss = tf.exp(-uncertainty1) * ce_loss + 0.5 * uncertainty1
122
+ loss += tf.exp(-uncertainty2) * ge_loss + 0.5 * uncertainty2
123
+ else:
124
+ raise ValueError('Unsupported loss weight strategy `%s` for jrc loss' %
125
+ loss_weight_strategy)
126
+ if np.isscalar(sample_weights) and sample_weights != 1.0:
127
+ return loss * sample_weights
128
+ return loss
@@ -0,0 +1,161 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.utils.load_class import load_by_path
8
+
9
+
10
+ def _list_wise_loss(x, labels, logits, session_ids, label_is_logits):
11
+ mask = tf.equal(x, session_ids)
12
+ logits = tf.boolean_mask(logits, mask)
13
+ labels = tf.boolean_mask(labels, mask)
14
+ y = tf.nn.softmax(labels) if label_is_logits else labels
15
+ y_hat = tf.nn.log_softmax(logits)
16
+ return -tf.reduce_sum(y * y_hat)
17
+
18
+
19
+ def _list_prob_loss(x, labels, logits, session_ids):
20
+ mask = tf.equal(x, session_ids)
21
+ logits = tf.boolean_mask(logits, mask)
22
+ labels = tf.boolean_mask(labels, mask)
23
+ y = labels / tf.reduce_sum(labels)
24
+ y_hat = tf.nn.log_softmax(logits)
25
+ return -tf.reduce_sum(y * y_hat)
26
+
27
+
28
+ def listwise_rank_loss(labels,
29
+ logits,
30
+ session_ids,
31
+ transform_fn=None,
32
+ temperature=1.0,
33
+ label_is_logits=False,
34
+ scale_logits=False,
35
+ weights=1.0,
36
+ name='listwise_loss'):
37
+ r"""Computes listwise softmax cross entropy loss between `labels` and `logits`.
38
+
39
+ Definition:
40
+ $$
41
+ \mathcal{L}(\{y\}, \{s\}) =
42
+ \sum_i y_j \log( \frac{\exp(s_i)}{\sum_j exp(s_j)} )
43
+ $$
44
+
45
+ Args:
46
+ labels: A `Tensor` of the same shape as `logits` representing graded
47
+ relevance.
48
+ logits: A `Tensor` with shape [batch_size].
49
+ session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
50
+ transform_fn: an affine transformation function of labels
51
+ temperature: (Optional) The temperature to use for scaling the logits.
52
+ label_is_logits: Whether `labels` is expected to be a logits tensor.
53
+ By default, we consider that `labels` encodes a probability distribution.
54
+ scale_logits: Whether to scale the logits.
55
+ weights: sample weights
56
+ name: the name of loss
57
+ """
58
+ loss_name = name if name else 'listwise_rank_loss'
59
+ logging.info('[{}] temperature: {}, scale logits: {}'.format(
60
+ loss_name, temperature, scale_logits))
61
+ labels = tf.to_float(labels)
62
+ if scale_logits:
63
+ with tf.variable_scope(loss_name):
64
+ w = tf.get_variable(
65
+ 'scale_w',
66
+ dtype=tf.float32,
67
+ shape=(1,),
68
+ initializer=tf.ones_initializer())
69
+ b = tf.get_variable(
70
+ 'scale_b',
71
+ dtype=tf.float32,
72
+ shape=(1,),
73
+ initializer=tf.zeros_initializer())
74
+ logits = logits * tf.abs(w) + b
75
+ if temperature != 1.0:
76
+ logits /= temperature
77
+ if label_is_logits:
78
+ labels /= temperature
79
+ if transform_fn is not None:
80
+ trans_fn = load_by_path(transform_fn)
81
+ labels = trans_fn(labels)
82
+
83
+ sessions, _ = tf.unique(tf.squeeze(session_ids))
84
+ tf.summary.scalar('loss/%s_num_of_group' % loss_name, tf.size(sessions))
85
+ losses = tf.map_fn(
86
+ lambda x: _list_wise_loss(x, labels, logits, session_ids, label_is_logits
87
+ ),
88
+ sessions,
89
+ dtype=tf.float32)
90
+ if tf.is_numeric_tensor(weights):
91
+ logging.error('[%s] use unsupported sample weight' % loss_name)
92
+ return tf.reduce_mean(losses)
93
+ else:
94
+ return tf.reduce_mean(losses) * weights
95
+
96
+
97
+ def listwise_distill_loss(labels,
98
+ logits,
99
+ session_ids,
100
+ transform_fn=None,
101
+ temperature=1.0,
102
+ label_clip_max_value=512,
103
+ scale_logits=False,
104
+ weights=1.0,
105
+ name='listwise_distill_loss'):
106
+ r"""Computes listwise softmax cross entropy loss between `labels` and `logits`.
107
+
108
+ Definition:
109
+ $$
110
+ \mathcal{L}(\{y\}, \{s\}) =
111
+ \sum_i y_j \log( \frac{\exp(s_i)}{\sum_j exp(s_j)} )
112
+ $$
113
+
114
+ Args:
115
+ labels: A `Tensor` of the same shape as `logits` representing the rank position of a base model.
116
+ logits: A `Tensor` with shape [batch_size].
117
+ session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
118
+ transform_fn: an transformation function of labels.
119
+ temperature: (Optional) The temperature to use for scaling the logits.
120
+ label_clip_max_value: clip the labels to this value.
121
+ scale_logits: Whether to scale the logits.
122
+ weights: sample weights
123
+ name: the name of loss
124
+ """
125
+ loss_name = name if name else 'listwise_rank_loss'
126
+ logging.info('[{}] temperature: {}'.format(loss_name, temperature))
127
+ labels = tf.to_float(labels) # supposed to be positions of a teacher model
128
+ labels = tf.clip_by_value(labels, 1, label_clip_max_value)
129
+ if transform_fn is not None:
130
+ trans_fn = load_by_path(transform_fn)
131
+ labels = trans_fn(labels)
132
+ else:
133
+ labels = tf.log1p(label_clip_max_value) - tf.log(labels)
134
+
135
+ if scale_logits:
136
+ with tf.variable_scope(loss_name):
137
+ w = tf.get_variable(
138
+ 'scale_w',
139
+ dtype=tf.float32,
140
+ shape=(1,),
141
+ initializer=tf.ones_initializer())
142
+ b = tf.get_variable(
143
+ 'scale_b',
144
+ dtype=tf.float32,
145
+ shape=(1,),
146
+ initializer=tf.zeros_initializer())
147
+ logits = logits * tf.abs(w) + b
148
+ if temperature != 1.0:
149
+ logits /= temperature
150
+
151
+ sessions, _ = tf.unique(tf.squeeze(session_ids))
152
+ tf.summary.scalar('loss/%s_num_of_group' % loss_name, tf.size(sessions))
153
+ losses = tf.map_fn(
154
+ lambda x: _list_prob_loss(x, labels, logits, session_ids),
155
+ sessions,
156
+ dtype=tf.float32)
157
+ if tf.is_numeric_tensor(weights):
158
+ logging.error('[%s] use unsupported sample weight' % loss_name)
159
+ return tf.reduce_mean(losses)
160
+ else:
161
+ return tf.reduce_mean(losses) * weights
@@ -0,0 +1,68 @@
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import tensorflow as tf
3
+
4
+ from easy_rec.python.loss.circle_loss import get_anchor_positive_triplet_mask
5
+ from easy_rec.python.utils.shape_utils import get_shape_list
6
+
7
+ if tf.__version__ >= '2.0':
8
+ tf = tf.compat.v1
9
+
10
+
11
+ def ms_loss(embeddings,
12
+ labels,
13
+ session_ids=None,
14
+ alpha=2.0,
15
+ beta=50.0,
16
+ lamb=1.0,
17
+ eps=0.1,
18
+ ms_mining=False,
19
+ embed_normed=False):
20
+ """Refer paper: Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning.
21
+
22
+ ref: http://openaccess.thecvf.com/content_CVPR_2019/papers/
23
+ Wang_Multi-Similarity_Loss_With_General_Pair_Weighting_for_Deep_Metric_Learning_CVPR_2019_paper.pdf
24
+ """
25
+ # make sure embedding should be l2-normalized
26
+ if not embed_normed:
27
+ embeddings = tf.nn.l2_normalize(embeddings, axis=1)
28
+ labels = tf.reshape(labels, [-1, 1])
29
+
30
+ embed_shape = get_shape_list(embeddings)
31
+ batch_size = embed_shape[0]
32
+
33
+ mask_pos = get_anchor_positive_triplet_mask(labels, session_ids)
34
+ mask_neg = 1 - mask_pos - tf.eye(batch_size)
35
+
36
+ sim_mat = tf.matmul(
37
+ embeddings, embeddings, transpose_a=False, transpose_b=True)
38
+ sim_mat = tf.maximum(sim_mat, 0.0)
39
+
40
+ pos_mat = tf.multiply(sim_mat, mask_pos)
41
+ neg_mat = tf.multiply(sim_mat, mask_neg)
42
+
43
+ if ms_mining:
44
+ max_val = tf.reduce_max(neg_mat, axis=1, keepdims=True)
45
+ tmp_max_val = tf.reduce_max(pos_mat, axis=1, keepdims=True)
46
+ min_val = tf.reduce_min(
47
+ tf.multiply(sim_mat - tmp_max_val, mask_pos), axis=1,
48
+ keepdims=True) + tmp_max_val
49
+
50
+ max_val = tf.tile(max_val, [1, batch_size])
51
+ min_val = tf.tile(min_val, [1, batch_size])
52
+
53
+ mask_pos = tf.where(pos_mat < max_val + eps, mask_pos,
54
+ tf.zeros_like(mask_pos))
55
+ mask_neg = tf.where(neg_mat > min_val - eps, mask_neg,
56
+ tf.zeros_like(mask_neg))
57
+
58
+ pos_exp = tf.exp(-alpha * (pos_mat - lamb))
59
+ pos_exp = tf.where(mask_pos > 0.0, pos_exp, tf.zeros_like(pos_exp))
60
+
61
+ neg_exp = tf.exp(beta * (neg_mat - lamb))
62
+ neg_exp = tf.where(mask_neg > 0.0, neg_exp, tf.zeros_like(neg_exp))
63
+
64
+ pos_term = tf.log(1.0 + tf.reduce_sum(pos_exp, axis=1)) / alpha
65
+ neg_term = tf.log(1.0 + tf.reduce_sum(neg_exp, axis=1)) / beta
66
+
67
+ loss = tf.reduce_mean(pos_term + neg_term)
68
+ return loss