easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,228 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+ """Library of common learning rate schedules."""
17
+
18
+ import numpy as np
19
+ import tensorflow as tf
20
+
21
+ if tf.__version__ >= '2.0':
22
+ tf = tf.compat.v1
23
+
24
+
25
+ def exponential_decay_with_burnin(global_step,
26
+ learning_rate_base,
27
+ learning_rate_decay_steps,
28
+ learning_rate_decay_factor,
29
+ burnin_learning_rate=0.0,
30
+ burnin_steps=0,
31
+ min_learning_rate=0.0,
32
+ staircase=True):
33
+ """Exponential decay schedule with burn-in period.
34
+
35
+ In this schedule, learning rate is fixed at burnin_learning_rate
36
+ for a fixed period, before transitioning to a regular exponential
37
+ decay schedule.
38
+
39
+ Args:
40
+ global_step: int tensor representing global step.
41
+ learning_rate_base: base learning rate.
42
+ learning_rate_decay_steps: steps to take between decaying the learning rate.
43
+ Note that this includes the number of burn-in steps.
44
+ learning_rate_decay_factor: multiplicative factor by which to decay
45
+ learning rate.
46
+ burnin_learning_rate: initial learning rate during burn-in period. If
47
+ 0.0 (which is the default), then the burn-in learning rate is simply
48
+ set to learning_rate_base.
49
+ burnin_steps: number of steps to use burnin learning rate.
50
+ min_learning_rate: the minimum learning rate.
51
+ staircase: whether use staircase decay.
52
+
53
+ Returns:
54
+ a (scalar) float tensor representing learning rate
55
+ """
56
+ if burnin_learning_rate == 0:
57
+ burnin_rate = learning_rate_base
58
+ else:
59
+ slope = (learning_rate_base - burnin_learning_rate) / burnin_steps
60
+ burnin_rate = slope * tf.cast(global_step,
61
+ tf.float32) + burnin_learning_rate
62
+ post_burnin_learning_rate = tf.train.exponential_decay(
63
+ learning_rate_base,
64
+ global_step - burnin_steps,
65
+ learning_rate_decay_steps,
66
+ learning_rate_decay_factor,
67
+ staircase=staircase)
68
+ return tf.maximum(
69
+ tf.where(
70
+ tf.less(tf.cast(global_step, tf.int32), tf.constant(burnin_steps)),
71
+ burnin_rate, post_burnin_learning_rate),
72
+ min_learning_rate,
73
+ name='learning_rate')
74
+
75
+
76
+ def cosine_decay_with_warmup(global_step,
77
+ learning_rate_base,
78
+ total_steps,
79
+ warmup_learning_rate=0.0,
80
+ warmup_steps=0,
81
+ hold_base_rate_steps=0):
82
+ """Cosine decay schedule with warm up period.
83
+
84
+ Cosine annealing learning rate as described in:
85
+ Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts.
86
+ ICLR 2017. https://arxiv.org/abs/1608.03983
87
+ In this schedule, the learning rate grows linearly from warmup_learning_rate
88
+ to learning_rate_base for warmup_steps, then transitions to a cosine decay
89
+ schedule.
90
+
91
+ Args:
92
+ global_step: int64 (scalar) tensor representing global step.
93
+ learning_rate_base: base learning rate.
94
+ total_steps: total number of training steps.
95
+ warmup_learning_rate: initial learning rate for warm up.
96
+ warmup_steps: number of warmup steps.
97
+ hold_base_rate_steps: Optional number of steps to hold base learning rate
98
+ before decaying.
99
+
100
+ Returns:
101
+ a (scalar) float tensor representing learning rate.
102
+
103
+ Raises:
104
+ ValueError: if warmup_learning_rate is larger than learning_rate_base,
105
+ or if warmup_steps is larger than total_steps.
106
+ """
107
+ if learning_rate_base < warmup_learning_rate:
108
+ raise ValueError('learning_rate_base must be larger '
109
+ 'or equal to warmup_learning_rate.')
110
+ if total_steps < warmup_steps:
111
+ raise ValueError('total_steps must be larger or equal to ' 'warmup_steps.')
112
+ learning_rate = 0.5 * learning_rate_base * (1 + tf.cos(
113
+ np.pi *
114
+ (tf.cast(global_step, tf.float32) - warmup_steps - hold_base_rate_steps) /
115
+ float(total_steps - warmup_steps - hold_base_rate_steps)))
116
+ if hold_base_rate_steps > 0:
117
+ learning_rate = tf.where(global_step > warmup_steps + hold_base_rate_steps,
118
+ learning_rate, learning_rate_base)
119
+ if warmup_steps > 0:
120
+ slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
121
+ warmup_rate = slope * tf.cast(global_step,
122
+ tf.float32) + warmup_learning_rate
123
+ learning_rate = tf.where(global_step < warmup_steps, warmup_rate,
124
+ learning_rate)
125
+ return tf.where(
126
+ global_step > total_steps, 0.0, learning_rate, name='learning_rate')
127
+
128
+
129
+ def manual_stepping(global_step, boundaries, rates, warmup=False):
130
+ """Manually stepped learning rate schedule.
131
+
132
+ This function provides fine grained control over learning rates. One must
133
+ specify a sequence of learning rates as well as a set of integer steps
134
+ at which the current learning rate must transition to the next. For example,
135
+ if boundaries = [5, 10] and rates = [.1, .01, .001], then the learning
136
+ rate returned by this function is .1 for global_step=0,...,4, .01 for
137
+ global_step=5...9, and .001 for global_step=10 and onward.
138
+
139
+ Args:
140
+ global_step: int64 (scalar) tensor representing global step.
141
+ boundaries: a list of global steps at which to switch learning
142
+ rates. This list is assumed to consist of increasing positive integers.
143
+ rates: a list of (float) learning rates corresponding to intervals between
144
+ the boundaries. The length of this list must be exactly
145
+ len(boundaries) + 1.
146
+ warmup: Whether to linearly interpolate learning rate for steps in
147
+ [0, boundaries[0]].
148
+
149
+ Returns:
150
+ a (scalar) float tensor representing learning rate
151
+ Raises:
152
+ ValueError: if one of the following checks fails:
153
+ 1. boundaries is a strictly increasing list of positive integers
154
+ 2. len(rates) == len(boundaries) + 1
155
+ 3. boundaries[0] != 0
156
+ """
157
+ if any([b < 0 for b in boundaries]) or any(
158
+ [not isinstance(b, int) for b in boundaries]):
159
+ raise ValueError('boundaries must be a list of positive integers')
160
+ if any([bnext <= b for bnext, b in zip(boundaries[1:], boundaries[:-1])]):
161
+ raise ValueError('Entries in boundaries must be strictly increasing.')
162
+ if any([not isinstance(r, float) for r in rates]):
163
+ raise ValueError('Learning rates must be floats')
164
+ if len(rates) != len(boundaries) + 1:
165
+ raise ValueError('Number of provided learning rates must exceed '
166
+ 'number of boundary points by exactly 1.')
167
+
168
+ if boundaries and boundaries[0] == 0:
169
+ raise ValueError('First step cannot be zero.')
170
+
171
+ if warmup and boundaries:
172
+ slope = (rates[1] - rates[0]) * 1.0 / boundaries[0]
173
+ warmup_steps = list(range(boundaries[0]))
174
+ warmup_rates = [rates[0] + slope * step for step in warmup_steps]
175
+ boundaries = warmup_steps + boundaries
176
+ rates = warmup_rates + rates[1:]
177
+ else:
178
+ boundaries = [0] + boundaries
179
+ num_boundaries = len(boundaries)
180
+ rate_index = tf.reduce_max(
181
+ tf.where(
182
+ tf.greater_equal(global_step, boundaries),
183
+ list(range(num_boundaries)), [0] * num_boundaries))
184
+ return tf.reduce_sum(
185
+ rates * tf.one_hot(rate_index, depth=num_boundaries),
186
+ name='learning_rate')
187
+
188
+
189
+ def transformer_policy(global_step,
190
+ learning_rate,
191
+ d_model,
192
+ warmup_steps,
193
+ step_scaling_rate=1.0,
194
+ max_lr=None,
195
+ coefficient=1.0,
196
+ dtype=tf.float32):
197
+ """Transformer's learning rate schedule.
198
+
199
+ Transformer's learning rate policy from
200
+ https://arxiv.org/pdf/1706.03762.pdf
201
+ with a hat (max_lr) (also called "noam" learning rate decay scheme).
202
+
203
+ Args:
204
+ global_step: global step TensorFlow tensor (ignored for this policy).
205
+ learning_rate (float): initial learning rate to use.
206
+ d_model (int): model dimensionality.
207
+ warmup_steps (int): number of warm-up steps.
208
+ step_scaling_rate (float): num step scale rate
209
+ max_lr (float): maximal learning rate, i.e. hat.
210
+ coefficient (float): optimizer adjustment.
211
+ Recommended 0.002 if using "Adam" else 1.0.
212
+ dtype: dtype for this policy.
213
+
214
+ Returns:
215
+ learning rate at step ``global_step``.
216
+ """
217
+ step_num = tf.cast(global_step, dtype=dtype)
218
+ ws = tf.cast(warmup_steps, dtype=dtype)
219
+ step_num *= step_scaling_rate
220
+ ws *= step_scaling_rate
221
+
222
+ decay = coefficient * d_model**-0.5 * tf.minimum((step_num + 1) * ws**-1.5,
223
+ (step_num + 1)**-0.5)
224
+
225
+ new_lr = decay * learning_rate
226
+ if max_lr is not None:
227
+ return tf.minimum(max_lr, new_lr)
228
+ return new_lr
@@ -0,0 +1,402 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import json
4
+ import logging
5
+ import os
6
+ from collections import defaultdict
7
+
8
+ import numpy as np
9
+ import tensorflow as tf
10
+ from sklearn import metrics as sklearn_metrics
11
+ from tensorflow.python.ops import array_ops
12
+ from tensorflow.python.ops import math_ops
13
+ from tensorflow.python.ops import state_ops
14
+ from tensorflow.python.ops import variable_scope
15
+
16
+ from easy_rec.python.utils.estimator_utils import get_task_index_and_num
17
+ from easy_rec.python.utils.io_util import read_data_from_json_path
18
+ from easy_rec.python.utils.io_util import save_data_to_json_path
19
+ from easy_rec.python.utils.shape_utils import get_shape_list
20
+
21
+ if tf.__version__ >= '2.0':
22
+ tf = tf.compat.v1
23
+
24
+
25
+ def max_f1(label, predictions):
26
+ """Calculate the largest F1 metric under different thresholds.
27
+
28
+ Args:
29
+ label: Ground truth (correct) target values.
30
+ predictions: Estimated targets as returned by a model.
31
+ """
32
+ from easy_rec.python.core.easyrec_metrics import metrics_tf
33
+ num_thresholds = 200
34
+ kepsilon = 1e-7
35
+ thresholds = [
36
+ (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
37
+ ]
38
+ thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
39
+
40
+ f1_scores = []
41
+ precision_update_ops = []
42
+ recall_update_ops = []
43
+ for threshold in thresholds:
44
+ pred = (predictions > threshold)
45
+ precision, precision_update_op = metrics_tf.precision(
46
+ labels=label, predictions=pred, name='precision_%s' % threshold)
47
+ recall, recall_update_op = metrics_tf.recall(
48
+ labels=label, predictions=pred, name='recall_%s' % threshold)
49
+ f1_score = (2 * precision * recall) / (precision + recall + 1e-12)
50
+ precision_update_ops.append(precision_update_op)
51
+ recall_update_ops.append(recall_update_op)
52
+ f1_scores.append(f1_score)
53
+
54
+ f1 = tf.math.reduce_max(tf.stack(f1_scores))
55
+ f1_update_op = tf.group(precision_update_ops + recall_update_ops)
56
+ return f1, f1_update_op
57
+
58
+
59
+ def _separated_auc_impl(labels, predictions, keys, reduction='mean'):
60
+ """Computes the AUC group by the key separately.
61
+
62
+ Args:
63
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
64
+ `bool`.
65
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
66
+ are in the range `[0, 1]`.
67
+ keys: keys to be group by, A int or string `Tensor` whose shape matches `predictions`.
68
+ reduction: reduction metric for auc of different keys
69
+ * "mean": simple mean of different keys
70
+ * "mean_by_sample_num": weighted mean with sample num of different keys
71
+ * "mean_by_positive_num": weighted mean with positive sample num of different keys
72
+ """
73
+ assert reduction in ['mean', 'mean_by_sample_num', 'mean_by_positive_num'], \
74
+ 'reduction method must in mean | mean_by_sample_num | mean_by_positive_num'
75
+ separated_label = defaultdict(list)
76
+ separated_prediction = defaultdict(list)
77
+ separated_weights = defaultdict(int)
78
+
79
+ def update_pyfunc(labels, predictions, keys):
80
+ for label, prediction, key in zip(labels, predictions, keys):
81
+ separated_label[key].append(label)
82
+ separated_prediction[key].append(prediction)
83
+ if reduction == 'mean':
84
+ separated_weights[key] = 1
85
+ elif reduction == 'mean_by_sample_num':
86
+ separated_weights[key] += 1
87
+ elif reduction == 'mean_by_positive_num':
88
+ separated_weights[key] += label
89
+
90
+ def value_pyfunc():
91
+ metrics = []
92
+ weights = []
93
+ for key in separated_label.keys():
94
+ per_label = np.asarray(separated_label[key]).reshape([-1])
95
+ per_prediction = np.asarray(separated_prediction[key]).reshape([-1])
96
+ if np.all(per_label == 1) or np.all(per_label == 0):
97
+ continue
98
+ metric = sklearn_metrics.roc_auc_score(per_label, per_prediction)
99
+ metrics.append(metric)
100
+ weights.append(separated_weights[key])
101
+ if len(metrics) > 0:
102
+ return np.average(metrics, weights=weights).astype(np.float32)
103
+ else:
104
+ return np.float32(0.0)
105
+
106
+ update_op = tf.py_func(update_pyfunc, [labels, predictions, keys], [])
107
+ value_op = tf.py_func(value_pyfunc, [], tf.float32)
108
+ return value_op, update_op
109
+
110
+
111
+ def fast_auc(labels, predictions, name, num_thresholds=1e5):
112
+ num_thresholds = int(num_thresholds)
113
+
114
+ def value_pyfunc(pos_neg_arr, total_pos_neg):
115
+ partial_sum_pos = 0
116
+ auc = 0
117
+ total_neg = total_pos_neg[0]
118
+ total_pos = total_pos_neg[1]
119
+ for i in range(num_thresholds + 1):
120
+ partial_sum_pos += pos_neg_arr[1][i]
121
+ auc += (total_pos - partial_sum_pos) * pos_neg_arr[0][i] * 2
122
+ auc += pos_neg_arr[0][i] * pos_neg_arr[1][i]
123
+ auc = np.double(auc) / np.double(total_pos * total_neg * 2)
124
+ logging.info('fast_auc[%s]: total_pos=%d total_neg=%d total=%d' %
125
+ (name, total_pos, total_neg, total_pos + total_neg))
126
+ return np.float32(auc)
127
+
128
+ with variable_scope.variable_scope(name_or_scope=name), tf.name_scope(name):
129
+ neg_pos_var = variable_scope.get_variable(
130
+ name='neg_pos_cnt',
131
+ shape=[2, num_thresholds + 1],
132
+ trainable=False,
133
+ collections=[tf.GraphKeys.METRIC_VARIABLES],
134
+ initializer=tf.zeros_initializer(),
135
+ dtype=tf.int64)
136
+ total_var = variable_scope.get_variable(
137
+ name='total_cnt',
138
+ shape=[2],
139
+ trainable=False,
140
+ collections=[tf.GraphKeys.METRIC_VARIABLES],
141
+ initializer=tf.zeros_initializer(),
142
+ dtype=tf.int64)
143
+ pred_bins = math_ops.cast(predictions * num_thresholds, dtype=tf.int32)
144
+ labels = math_ops.cast(labels, dtype=tf.int32)
145
+ labels = array_ops.reshape(labels, [-1, 1])
146
+ pred_bins = array_ops.reshape(pred_bins, [-1, 1])
147
+ update_op0 = state_ops.scatter_nd_add(
148
+ neg_pos_var, tf.concat([labels, pred_bins], axis=1),
149
+ array_ops.ones(tf.shape(labels)[0], dtype=tf.int64))
150
+ total_pos = math_ops.reduce_sum(labels)
151
+ total_neg = array_ops.shape(labels)[0] - total_pos
152
+ total_add = math_ops.cast(tf.stack([total_neg, total_pos]), dtype=tf.int64)
153
+ update_op1 = state_ops.assign_add(total_var, total_add)
154
+ return tf.py_func(value_pyfunc, [neg_pos_var, total_var],
155
+ tf.float32), tf.group([update_op0, update_op1])
156
+
157
+
158
+ def _distribute_separated_auc_impl(labels,
159
+ predictions,
160
+ keys,
161
+ reduction='mean',
162
+ metric_name='sepatated_auc'):
163
+ """Computes the AUC group by the key separately.
164
+
165
+ Args:
166
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
167
+ `bool`.
168
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
169
+ are in the range `[0, 1]`.
170
+ keys: keys to be group by, A int or string `Tensor` whose shape matches `predictions`.
171
+ reduction: reduction metric for auc of different keys
172
+ metric_name: the name of compute metric
173
+ * "mean": simple mean of different keys
174
+ * "mean_by_sample_num": weighted mean with sample num of different keys
175
+ * "mean_by_positive_num": weighted mean with positive sample num of different keys
176
+ """
177
+ assert reduction in ['mean', 'mean_by_sample_num', 'mean_by_positive_num'], \
178
+ 'reduction method must in mean | mean_by_sample_num | mean_by_positive_num'
179
+ separated_label = defaultdict(list)
180
+ separated_prediction = defaultdict(list)
181
+ separated_weights = defaultdict(int)
182
+ tf_config = json.loads(os.environ['TF_CONFIG'])
183
+ cur_job_name = tf_config['task']['type']
184
+ cur_task_index, task_num = get_task_index_and_num()
185
+ cur_work_device = 'job_' + cur_job_name + '__' + 'task_' + str(cur_task_index)
186
+ eval_tmp_results_dir = os.environ['eval_tmp_results_dir']
187
+ assert tf.gfile.IsDirectory(
188
+ eval_tmp_results_dir), 'eval_tmp_results_dir not exists'
189
+
190
+ def update_pyfunc(labels, predictions, keys):
191
+ for label, prediction, key in zip(labels, predictions, keys):
192
+ key = str(key)
193
+ separated_label[key].append(label.item())
194
+ separated_prediction[key].append(prediction.item())
195
+ if reduction == 'mean':
196
+ separated_weights[key] = 1
197
+ elif reduction == 'mean_by_sample_num':
198
+ separated_weights[key] += 1
199
+ elif reduction == 'mean_by_positive_num':
200
+ separated_weights[key] += label.item()
201
+ for name, data in zip(
202
+ ['separated_label', 'separated_prediction', 'separated_weights'],
203
+ [separated_label, separated_prediction, separated_weights]):
204
+ cur_json_name = metric_name + '__' + cur_work_device + '__' + name + '.json'
205
+ cur_json_path = os.path.join(eval_tmp_results_dir, cur_json_name)
206
+ save_data_to_json_path(cur_json_path, data)
207
+
208
+ def value_pyfunc():
209
+ for task_i in range(1, task_num):
210
+ work_device_i = 'job_worker__task_' + str(task_i)
211
+ for name in [
212
+ 'separated_label', 'separated_prediction', 'separated_weights'
213
+ ]:
214
+ json_name_i = metric_name + '__' + work_device_i + '__' + name + '.json'
215
+ json_path_i = os.path.join(eval_tmp_results_dir, json_name_i)
216
+ data_i = read_data_from_json_path(json_path_i)
217
+ if (name == 'separated_label'):
218
+ separated_label.update({
219
+ key: separated_label.get(key, []) + data_i.get(key, [])
220
+ for key in set(
221
+ list(separated_label.keys()) + list(data_i.keys()))
222
+ })
223
+ elif (name == 'separated_prediction'):
224
+ separated_prediction.update({
225
+ key: separated_prediction.get(key, []) + data_i.get(key, [])
226
+ for key in set(
227
+ list(separated_prediction.keys()) + list(data_i.keys()))
228
+ })
229
+ elif (name == 'separated_weights'):
230
+ if reduction == 'mean':
231
+ separated_weights.update(data_i)
232
+ else:
233
+ separated_weights.update({
234
+ key: separated_weights.get(key, 0) + data_i.get(key, 0)
235
+ for key in set(
236
+ list(separated_weights.keys()) + list(data_i.keys()))
237
+ })
238
+ else:
239
+ assert False, 'Not supported name {}'.format(name)
240
+ metrics = []
241
+ weights = []
242
+ for key in separated_label.keys():
243
+ per_label = np.asarray(separated_label[key]).reshape([-1])
244
+ per_prediction = np.asarray(separated_prediction[key]).reshape([-1])
245
+ if np.all(per_label == 1) or np.all(per_label == 0):
246
+ continue
247
+ metric = sklearn_metrics.roc_auc_score(per_label, per_prediction)
248
+ metrics.append(metric)
249
+ weights.append(separated_weights[key])
250
+ if len(metrics) > 0:
251
+ return np.average(metrics, weights=weights).astype(np.float32)
252
+ else:
253
+ return np.float32(0.0)
254
+
255
+ update_op = tf.py_func(update_pyfunc, [labels, predictions, keys], [])
256
+ value_op = tf.py_func(value_pyfunc, [], tf.float32)
257
+ return value_op, update_op
258
+
259
+
260
+ def gauc(labels, predictions, uids, reduction='mean'):
261
+ """Computes the AUC group by user separately.
262
+
263
+ Args:
264
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
265
+ `bool`.
266
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
267
+ are in the range `[0, 1]`.
268
+ uids: user ids, A int or string `Tensor` whose shape matches `predictions`.
269
+ reduction: reduction method for auc of different users
270
+ * "mean": simple mean of different users
271
+ * "mean_by_sample_num": weighted mean with sample num of different users
272
+ * "mean_by_positive_num": weighted mean with positive sample num of different users
273
+ """
274
+ if os.environ.get('distribute_eval') == 'True':
275
+ return _distribute_separated_auc_impl(
276
+ labels, predictions, uids, reduction, metric_name='gauc')
277
+ return _separated_auc_impl(labels, predictions, uids, reduction)
278
+
279
+
280
+ def session_auc(labels, predictions, session_ids, reduction='mean'):
281
+ """Computes the AUC group by session separately.
282
+
283
+ Args:
284
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
285
+ `bool`.
286
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
287
+ are in the range `[0, 1]`.
288
+ session_ids: session ids, A int or string `Tensor` whose shape matches `predictions`.
289
+ reduction: reduction method for auc of different sessions
290
+ * "mean": simple mean of different sessions
291
+ * "mean_by_sample_num": weighted mean with sample num of different sessions
292
+ * "mean_by_positive_num": weighted mean with positive sample num of different sessions
293
+ """
294
+ if os.environ.get('distribute_eval') == 'True':
295
+ return _distribute_separated_auc_impl(
296
+ labels, predictions, session_ids, reduction, metric_name='session_auc')
297
+ return _separated_auc_impl(labels, predictions, session_ids, reduction)
298
+
299
+
300
+ def metric_learning_recall_at_k(k,
301
+ embeddings,
302
+ labels,
303
+ session_ids=None,
304
+ embed_normed=False):
305
+ """Computes the recall_at_k metric for metric learning.
306
+
307
+ Args:
308
+ k: a scalar of int, or a tuple of ints
309
+ embeddings: the output of last hidden layer, a tf.float32 `Tensor` with shape [batch_size, embedding_size]
310
+ labels: a `Tensor` with shape [batch_size]
311
+ session_ids: session ids, a `Tensor` with shape [batch_size]
312
+ embed_normed: indicator of whether the input embeddings are l2_normalized
313
+ """
314
+ from easy_rec.python.core.easyrec_metrics import metrics_tf
315
+ # make sure embedding should be l2-normalized
316
+ if not embed_normed:
317
+ embeddings = tf.nn.l2_normalize(embeddings, axis=1)
318
+ embed_shape = get_shape_list(embeddings)
319
+ batch_size = embed_shape[0]
320
+ sim_mat = tf.matmul(embeddings, embeddings, transpose_b=True)
321
+ sim_mat = sim_mat - tf.eye(batch_size) * 2.0
322
+ indices_not_equal = tf.logical_not(tf.eye(batch_size, dtype=tf.bool))
323
+ # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
324
+ labels_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
325
+ if session_ids is not None and session_ids is not labels:
326
+ sessions_equal = tf.equal(
327
+ tf.expand_dims(session_ids, 0), tf.expand_dims(session_ids, 1))
328
+ labels_equal = tf.logical_and(sessions_equal, labels_equal)
329
+ mask = tf.logical_and(indices_not_equal, labels_equal)
330
+ mask_pos = tf.where(
331
+ mask, sim_mat,
332
+ -array_ops.ones_like(sim_mat)) # shape: (batch_size, batch_size)
333
+ if isinstance(k, int):
334
+ _, pos_top_k_idx = tf.nn.top_k(mask_pos, k) # shape: (batch_size, k)
335
+ return metrics_tf.recall_at_k(
336
+ labels=tf.to_int64(pos_top_k_idx), predictions=sim_mat, k=k)
337
+ if any((isinstance(k, list), isinstance(k, tuple), isinstance(k, set))):
338
+ metrics = {}
339
+ for kk in k:
340
+ if kk < 1:
341
+ continue
342
+ _, pos_top_k_idx = tf.nn.top_k(mask_pos, kk)
343
+ metrics['recall@' + str(kk)] = metrics_tf.recall_at_k(
344
+ labels=tf.to_int64(pos_top_k_idx), predictions=sim_mat, k=kk)
345
+ return metrics
346
+ else:
347
+ raise ValueError('k should be a `int` or a list/tuple/set of int.')
348
+
349
+
350
+ def metric_learning_average_precision_at_k(k,
351
+ embeddings,
352
+ labels,
353
+ session_ids=None,
354
+ embed_normed=False):
355
+ from easy_rec.python.core.easyrec_metrics import metrics_tf
356
+ # make sure embedding should be l2-normalized
357
+ if not embed_normed:
358
+ embeddings = tf.nn.l2_normalize(embeddings, axis=1)
359
+ embed_shape = get_shape_list(embeddings)
360
+ batch_size = embed_shape[0]
361
+ sim_mat = tf.matmul(embeddings, embeddings, transpose_b=True)
362
+ sim_mat = sim_mat - tf.eye(batch_size) * 2.0
363
+ mask = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
364
+ if session_ids is not None and session_ids is not labels:
365
+ sessions_equal = tf.equal(
366
+ tf.expand_dims(session_ids, 0), tf.expand_dims(session_ids, 1))
367
+ mask = tf.logical_and(sessions_equal, mask)
368
+ label_indices = _get_matrix_mask_indices(mask)
369
+ if isinstance(k, int):
370
+ return metrics_tf.average_precision_at_k(label_indices, sim_mat, k)
371
+ if any((isinstance(k, list), isinstance(k, tuple), isinstance(k, set))):
372
+ metrics = {}
373
+ for kk in k:
374
+ if kk < 1:
375
+ continue
376
+ metrics['MAP@' + str(kk)] = metrics_tf.average_precision_at_k(
377
+ label_indices, sim_mat, kk)
378
+ return metrics
379
+ else:
380
+ raise ValueError('k should be a `int` or a list/tuple/set of int.')
381
+
382
+
383
+ def _get_matrix_mask_indices(matrix, num_rows=None):
384
+ if num_rows is None:
385
+ num_rows = get_shape_list(matrix)[0]
386
+ indices = tf.where(matrix)
387
+ num_indices = tf.shape(indices)[0]
388
+ elem_per_row = tf.bincount(
389
+ tf.cast(indices[:, 0], tf.int32), minlength=num_rows)
390
+ max_elem_per_row = tf.reduce_max(elem_per_row)
391
+ row_start = tf.concat([[0], tf.cumsum(elem_per_row[:-1])], axis=0)
392
+ r = tf.range(max_elem_per_row)
393
+ idx = tf.expand_dims(row_start, 1) + r
394
+ idx = tf.minimum(idx, num_indices - 1)
395
+ result = tf.gather(indices[:, 1], idx)
396
+ # replace invalid elements with -1
397
+ result = tf.where(
398
+ tf.expand_dims(elem_per_row, 1) > r, result, -array_ops.ones_like(result))
399
+ max_index_per_row = tf.reduce_max(result, axis=1, keepdims=True)
400
+ max_index_per_row = tf.tile(max_index_per_row, [1, max_elem_per_row])
401
+ result = tf.where(result >= 0, result, max_index_per_row)
402
+ return result